以下は、平均と標準偏差の推定に使用される期待値最大化(EM)の例です。コードはPythonで記述されていますが、言語に精通していなくても簡単に理解できるはずです。
EMの動機
以下に示す赤と青の点は、それぞれ特定の平均と標準偏差を持つ2つの異なる正規分布から描画されます。
赤分布の「真の」平均および標準偏差パラメーターの合理的な近似値を計算するには、赤点を非常に簡単に調べて各点の位置を記録し、おなじみの式を使用します(青のグループについても同様)。 。
次に、ポイントのグループが2つあることがわかっているが、どのポイントがどのグループに属しているかがわからない場合を考えます。つまり、色は隠されています:
ポイントを2つのグループに分割する方法はまったく明らかではありません。現在、位置を見て、赤分布または青分布のパラメーターの推定値を計算することはできません。
これは、EMを使用して問題を解決できる場所です。
EMを使用してパラメーターを推定する
上記のポイントを生成するために使用されるコードは次のとおりです。ポイントが引き出された正規分布の実際の平均と標準偏差を確認できます。変数red
とblue
は、それぞれ赤と青のグループの各ポイントの位置を保持します。
import numpy as np
from scipy import stats
np.random.seed(110) # for reproducible random results
# set parameters
red_mean = 3
red_std = 0.8
blue_mean = 7
blue_std = 2
# draw 20 samples from normal distributions with red/blue parameters
red = np.random.normal(red_mean, red_std, size=20)
blue = np.random.normal(blue_mean, blue_std, size=20)
both_colours = np.sort(np.concatenate((red, blue)))
各ポイントの色を確認できる場合、ライブラリ関数を使用して平均と標準偏差を回復しようとします。
>>> np.mean(red)
2.802
>>> np.std(red)
0.871
>>> np.mean(blue)
6.932
>>> np.std(blue)
2.195
しかし、色は私たちから隠されているので、EMプロセスを開始します...
まず、各グループのパラメーターの値を推測するだけです(ステップ1)。これらの推測は必ずしも良いものである必要はありません。
# estimates for the mean
red_mean_guess = 1.1
blue_mean_guess = 9
# estimates for the standard deviation
red_std_guess = 2
blue_std_guess = 1.7
かなり悪い推測-平均点は、ポイントグループの「中間」から遠く離れているように見えます。
EMを続行し、これらの推測を改善するために、平均と標準偏差の推測の下に表示される各データポイント(その秘密の色に関係なく)の尤度を計算します(ステップ2)。
変数both_colours
は各データポイントを保持します。この関数stats.norm
は、指定されたパラメーターを使用して、正規分布の下でポイントの確率を計算します。
likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)
likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)
これは、たとえば、現在の推測では、1.761のデータポイントは青(0.00003)よりも赤(0.189)である可能性がはるかに高いことを示しています。
これらの2つの尤度値を重みに変換して(ステップ3)、次のように合計が1になるようにすることができます。
likelihood_total = likelihood_of_red + likelihood_of_blue
red_weight = likelihood_of_red / likelihood_total
blue_weight = likelihood_of_blue / likelihood_total
現在の推定値と新しく計算された重みを使用して、パラメータの新しい、おそらくより良い推定値を計算できます(ステップ4)。平均値の関数と標準偏差の関数が必要です。
def estimate_mean(data, weight):
return np.sum(data * weight) / np.sum(weight)
def estimate_std(data, weight, mean):
variance = np.sum(weight * (data - mean)**2) / np.sum(weight)
return np.sqrt(variance)
これらは、データの平均と標準偏差に通常の関数に非常に似ています。違いは、weight
各データポイントに重みを割り当てるパラメーターの使用です。
この重み付けは、EMの鍵です。データポイントの色の重みが大きいほど、データポイントはその色のパラメータの次の推定値に大きく影響します。最終的に、これは各パラメーターを正しい方向に引く効果があります。
新しい推測は、次の関数を使用して計算されます。
# new estimates for standard deviation
blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)
red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)
# new estimates for mean
red_mean_guess = estimate_mean(both_colours, red_weight)
blue_mean_guess = estimate_mean(both_colours, blue_weight)
EMプロセスは、ステップ2以降のこれらの新しい推測で繰り返されます。特定の反復回数(20回など)の間、またはパラメーターが収束するまでステップを繰り返すことができます。
5回の反復の後、最初の悪い推測が改善し始めるのがわかります。
20回の反復後、EMプロセスはほぼ収束しました。
比較のために、色情報が隠されていない場合に計算された値と比較したEMプロセスの結果を以下に示します。
| EM guess | Actual
----------+----------+--------
Red mean | 2.910 | 2.802
Red std | 0.854 | 0.871
Blue mean | 6.838 | 6.932
Blue std | 2.227 | 2.195
注:この回答は、スタックオーバーフローに関する私の回答をここに適用したものです。