期待値最大化手法の直感的な説明とは何ですか?[閉まっている]


109

期待値最大化(EM)は、データを分類するための一種の確率的手法です。分類器でない場合は、誤りがあれば訂正してください。

このEM技術の直感的な説明は何ですか?expectationここには何があり、何があるのmaximizedでしょうか。


12
期待値最大化アルゴリズムとは何ですか?ネイチャーバイオテクノロジー 26、897から899(2008)のアルゴリズムがどのように機能するかを示して素敵な絵を持っています。
12

パートで@chl B素敵な絵、どのように彼らはZ上の確率分布(すなわち、0.45xA、0.55xBなど)の値を取得するのですか?
Noobサイボット2013年


3
@chlが言及した画像へのリンク更新しました
n1k31t4

回答:


119

注:この回答の背後にあるコードはこちらにあります


赤と青の2つの異なるグループからサンプリングされたデータがあるとします。

ここに画像の説明を入力してください

ここでは、どのデータポイントが赤または青のグループに属しているかを確認できます。これにより、各グループを特徴付けるパラメーターを簡単に見つけることができます。たとえば、赤のグループの平均は約3、青のグループの平均は約7です(必要に応じて正確な平均を見つけることができます)。

これは、一般的に言えば、最尤推定として知られています。いくつかのデータが与えられると、そのデータを最もよく説明する1つまたは複数のパラメーターの値を計算します。

ここで、どのグループからどの値がサンプリングされたかを確認できないと想像してください。すべてが紫に見えます。

ここに画像の説明を入力してください

ここでは、値のグループが2つあることはわかっていますが、特定の値がどのグループに属しているかはわかりません。

このデータに最も適合する赤のグループと青のグループの平均を推定できますか?

はい、できます!期待値の最大化は、それを行う方法を提供します。アルゴリズムの背後にある非常に一般的な考え方は次のとおりです。

  1. 各パラメータが何であるかについての初期推定から始めます。
  2. 各パラメーターがデータポイントを生成する可能性を計算します。
  3. パラメータによって生成される可能性に基づいて、各データポイントの重みを計算して、赤か青かを示します。重みとデータを組み合わせます(期待値)。
  4. 重み調整されたデータを使用して、パラメーターのより良い推定を計算します(maximization)。
  5. パラメータ推定値が収束するまで、手順2〜4を繰り返します(プロセスは別の推定値の生成を停止します)。

これらの手順にはさらに説明が必要なため、上記の問題について説明します。

例:平均と標準偏差の推定

この例ではPythonを使用しますが、この言語に慣れていない場合、コードはかなり理解しやすいはずです。

赤と青の2つのグループがあり、値が上の図のように分布しているとします。具体的には、各グループには、次のパラメーターを持つ正規分布から抽出された値が含まれています。

import numpy as np
from scipy import stats

np.random.seed(110) # for reproducible results

# set parameters
red_mean = 3
red_std = 0.8

blue_mean = 7
blue_std = 2

# draw 20 samples from normal distributions with red/blue parameters
red = np.random.normal(red_mean, red_std, size=20)
blue = np.random.normal(blue_mean, blue_std, size=20)

both_colours = np.sort(np.concatenate((red, blue))) # for later use...

以下は、これらの赤と青のグループの画像です(上にスクロールする必要をなくすため)。

ここに画像の説明を入力してください

各ポイントの色(つまり、どのグループに属しているか)がわかると、各グループの平均と標準偏差を推定するのは非常に簡単です。赤と青の値をNumPyの組み込み関数に渡すだけです。例えば:

>>> np.mean(red)
2.802
>>> np.std(red)
0.871
>>> np.mean(blue)
6.932
>>> np.std(blue)
2.195

しかし、ポイントの色が見えない場合はどうでしょうか。つまり、赤や青ではなく、すべての点が紫に着色されています。

赤と青のグループの平均と標準偏差のパラメーターを回復するために、期待値最大化を使用できます。

最初のステップ(上記のステップ1)は、各グループの平均と標準偏差のパラメーター値を推測することです。インテリジェントに推測する必要はありません。私たちは好きな数字を選ぶことができます:

# estimates for the mean
red_mean_guess = 1.1
blue_mean_guess = 9

# estimates for the standard deviation
red_std_guess = 2
blue_std_guess = 1.7

これらのパラメーター推定により、次のようなベル曲線が生成されます。

ここに画像の説明を入力してください

これらは悪い見積もりです。両方の意味(縦の点線)は、たとえば、意味のあるポイントのグループについて、あらゆる「中間」から遠くを見ています。これらの見積もりを改善したいと考えています。

次のステップ(ステップ2)では、現在のパラメーターの推測の下に表示される各データポイントの尤度を計算します。

likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)
likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)

ここでは、赤と青の平均と標準偏差での現在の推測を使用して、各データポイントを正規分布の確率密度関数に単純に入れました。これは、たとえば、現在の推測では、1.761のデータポイントは青(0.00003)よりも赤(0.189)である可能性がはるかに高いことを示しています。

各データポイントについて、これらの2つの尤度値を重みに変換し(ステップ3)、次のように合計して1になるようにします。

likelihood_total = likelihood_of_red + likelihood_of_blue

red_weight = likelihood_of_red / likelihood_total
blue_weight = likelihood_of_blue / likelihood_total

現在の推定値と新しく計算された重みを使用して、赤と青のグループの平均と標準偏差の新しい推定値を計算できます(ステップ4)。

すべてのデータポイントを使用して平均と標準偏差を2回計算しますが、重み付けは異なります。1回目は赤の重み付け、もう1回は青の重み付けです。

直感の重要な点は、データポイントの色の重みが大きいほど、データポイントがその色のパラメーターの次の推定に影響を与えることです。これには、パラメーターを正しい方向に「引き寄せる」という効果があります。

def estimate_mean(data, weight):
    """
    For each data point, multiply the point by the probability it
    was drawn from the colour's distribution (its "weight").

    Divide by the total weight: essentially, we're finding where 
    the weight is centred among our data points.
    """
    return np.sum(data * weight) / np.sum(weight)

def estimate_std(data, weight, mean):
    """
    For each data point, multiply the point's squared difference
    from a mean value by the probability it was drawn from
    that distribution (its "weight").

    Divide by the total weight: essentially, we're finding where 
    the weight is centred among the values for the difference of
    each data point from the mean.

    This is the estimate of the variance, take the positive square
    root to find the standard deviation.
    """
    variance = np.sum(weight * (data - mean)**2) / np.sum(weight)
    return np.sqrt(variance)

# new estimates for standard deviation
blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)
red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)

# new estimates for mean
red_mean_guess = estimate_mean(both_colours, red_weight)
blue_mean_guess = estimate_mean(both_colours, blue_weight)

パラメータの新しい推定値があります。それらを再び改善するために、ステップ2に戻ってプロセスを繰り返すことができます。これは、推定値が収束するまで、またはいくつかの反復が実行された後(ステップ5)に行われます。

私たちのデータでは、このプロセスの最初の5回の反復は次のようになります(最近の反復の方が見た目が強くなっています)。

ここに画像の説明を入力してください

平均値は既にいくつかの値に収束しており、曲線の形状(標準偏差によって制御される)もより安定していることがわかります。

20回反復すると、次のようになります。

ここに画像の説明を入力してください

EMプロセスは次の値に収束しました。これは実際の値に非常に近いことがわかります(色を見ることができます-非表示の変数はありません)。

          | EM guess | Actual |  Delta
----------+----------+--------+-------
Red mean  |    2.910 |  2.802 |  0.108
Red std   |    0.854 |  0.871 | -0.017
Blue mean |    6.838 |  6.932 | -0.094
Blue std  |    2.227 |  2.195 |  0.032

上記のコードでは、標準偏差の新しい推定値が、前の反復の平均値の推定値を使用して計算されていることに気付いたかもしれません。最終的には、いくつかの中心点の周りの値の(重み付けされた)分散を見つけるだけなので、最初に平均の新しい値を計算するかどうかは重要ではありません。パラメータの推定値は収束していることがわかります。


これが由来する正規分布の数がわからない場合はどうなりますか?ここでは、k = 2分布の例を取り上げましたが、kとkパラメータセットも推定できますか?
stackit 2017

1
@stackit:この場合、EMプロセスの一部としてkの最も可能性の高い値を計算する簡単な一般的な方法があるかどうかはわかりません。主な問題は、検索する各パラメーターの推定値を使用してEMを開始する必要があることと、開始​​する前にkを知っている/推定する必要があることです。ただし、ここでEMを介してグループに属するポイントの割合を推定することは可能です。多分私たちがkを過大評価すると、2つのグループを除くすべてのグループの比率がゼロ近くに低下します。私はこれを実験したことがないので、実際にどれほどうまくいくかわかりません。
Alex Riley

1
@AlexRiley新しい平均と標準偏差の推定値を計算するための式についてもう少し詳しく説明していただけますか?
レモン

2
@AlexRiley説明ありがとうございます。新しい標準偏差の見積もりが、平均の古い推測を使用して計算されるのはなぜですか?平均の新しい推定値が最初に見つかった場合はどうなりますか?
GoodDeeds 2018年

1
@Lemon GoodDeeds Kaushal-質問への返信が遅くなりましたことをお詫び申し上げます。私はあなたが提起したポイントに対処するために答えを編集しようとしました。また、この回答で使用されているすべてのコードを、ノートブックアクセスできるようにしました (これには、触れたいくつかのポイントの詳細な説明も含まれています)。
Alex Riley

36

EMは、モデル内の一部の変数が観測されない場合(つまり、潜在変数がある場合)に尤度関数を最大化するアルゴリズムです。

関数を最大化するだけの場合は、既存の機構を使用して関数を最大化するのはどうでしょうか。さて、導関数を取得してゼロに設定することでこれを最大化しようとすると、多くの場合、1次条件には解がありません。モデルパラメータを解くには、観察されていないデータの分布を知る必要があるという、鶏と卵の問題があります。ただし、観測されていないデータの分布は、モデルパラメータの関数です。

EMは、観測されていないデータの分布を繰り返し推測し、実際の尤度関数の下限となるものを最大化してモデルパラメーターを推定し、収束するまで繰り返すことで、これを回避しようとします。

EMアルゴリズム

モデルパラメータの値の推測から始めます

Eステップ:欠損値のある各データポイントについて、モデル方程式を使用して、モデルパラメーターの現在の推定値と観測データを指定して、欠損データの分布を解きます(欠損値ごとの分布を解くことに注意してください)値であり、期待値ではありません)。各欠損値の分布が得られたので、観測されていない変数に関する尤度関数の期待値を計算できます。モデルパラメーターの推測が正しかった場合、この予想される可能性は、観測されたデータの実際の可能性になります。パラメータが正しくなかった場合、それは単に下限になります。

Mステップ:観測されていない変数を含まない予測尤度関数が得られたので、完全に観測された場合と同じように関数を最大化して、モデルパラメーターの新しい推定値を取得します。

収束するまで繰り返します。


5
Eステップがわかりません。問題の一部は、私がこのことを学んでいるので、同じ用語を使用している人を見つけることができないということです。では、モデル方程式とはどういう意味ですか?確率分布を解くことで何を意味するのかわかりませんか?
user678392 2013年

27

期待値最大化アルゴリズムを理解するための簡単なレシピを次に示します。

1- DoとBatzoglouによるこのEMチュートリアルペーパーを読んでください。

2-あなたはあなたの頭に疑問符を持っているかもしれません、この数学スタック交換ページの説明を見てください

3-アイテム1のEMチュートリアルペーパーの例を説明する、Pythonで記述したこのコードを見てください。

警告:私はPython開発者ではないため、コードは乱雑で最適ではない可能性があります。しかし、それは仕事をします。

import numpy as np
import math

#### E-M Coin Toss Example as given in the EM tutorial paper by Do and Batzoglou* #### 

def get_mn_log_likelihood(obs,probs):
    """ Return the (log)likelihood of obs, given the probs"""
    # Multinomial Distribution Log PMF
    # ln (pdf)      =             multinomial coeff            *   product of probabilities
    # ln[f(x|n, p)] = [ln(n!) - (ln(x1!)+ln(x2!)+...+ln(xk!))] + [x1*ln(p1)+x2*ln(p2)+...+xk*ln(pk)]     

    multinomial_coeff_denom= 0
    prod_probs = 0
    for x in range(0,len(obs)): # loop through state counts in each observation
        multinomial_coeff_denom = multinomial_coeff_denom + math.log(math.factorial(obs[x]))
        prod_probs = prod_probs + obs[x]*math.log(probs[x])

    multinomial_coeff = math.log(math.factorial(sum(obs))) -  multinomial_coeff_denom
    likelihood = multinomial_coeff + prod_probs
    return likelihood

# 1st:  Coin B, {HTTTHHTHTH}, 5H,5T
# 2nd:  Coin A, {HHHHTHHHHH}, 9H,1T
# 3rd:  Coin A, {HTHHHHHTHH}, 8H,2T
# 4th:  Coin B, {HTHTTTHHTT}, 4H,6T
# 5th:  Coin A, {THHHTHHHTH}, 7H,3T
# so, from MLE: pA(heads) = 0.80 and pB(heads)=0.45

# represent the experiments
head_counts = np.array([5,9,8,4,7])
tail_counts = 10-head_counts
experiments = zip(head_counts,tail_counts)

# initialise the pA(heads) and pB(heads)
pA_heads = np.zeros(100); pA_heads[0] = 0.60
pB_heads = np.zeros(100); pB_heads[0] = 0.50

# E-M begins!
delta = 0.001  
j = 0 # iteration counter
improvement = float('inf')
while (improvement>delta):
    expectation_A = np.zeros((5,2), dtype=float) 
    expectation_B = np.zeros((5,2), dtype=float)
    for i in range(0,len(experiments)):
        e = experiments[i] # i'th experiment
        ll_A = get_mn_log_likelihood(e,np.array([pA_heads[j],1-pA_heads[j]])) # loglikelihood of e given coin A
        ll_B = get_mn_log_likelihood(e,np.array([pB_heads[j],1-pB_heads[j]])) # loglikelihood of e given coin B

        weightA = math.exp(ll_A) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of A proportional to likelihood of A 
        weightB = math.exp(ll_B) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of B proportional to likelihood of B                            

        expectation_A[i] = np.dot(weightA, e) 
        expectation_B[i] = np.dot(weightB, e)

    pA_heads[j+1] = sum(expectation_A)[0] / sum(sum(expectation_A)); 
    pB_heads[j+1] = sum(expectation_B)[0] / sum(sum(expectation_B)); 

    improvement = max( abs(np.array([pA_heads[j+1],pB_heads[j+1]]) - np.array([pA_heads[j],pB_heads[j]]) ))
    j = j+1

あなたのプログラムはAとBの両方で0.66になると思います。私はscalaを使用してそれを実装します。また、結果が0.66であることもわかります。それを確認できますか?
zjffdu 2013年

スプレッドシートを使用して、最初の推測が等しい場合にのみ、0.66の結果を見つけます。それ以外の場合は、チュートリアルの出力を再現できます。
soakley

@ zjffdu、EMは0.66を返す前に何回反復を実行しますか?等しい値で初期化すると、ローカルの最大値でスタックする可能性があり、反復回数が非常に少ないことがわかります(改善がないため)。
Zhubarb 2013年


16

技術的には「EM」という用語は少々指定不足ですが、一般的なEM原理のであるガウス混合モデリングクラスター分析手法を参照していると思います。

実際、EMクラスター分析は分類子ではありません。クラスタリングを「教師なし分類」と考える人もいることは知っていますが、実際にはクラスター分析はまったく異なるものです。

主な違い、および人々がクラスター分析で常に持つ大きな誤解の分類は、次のとおりです。クラスター分析では、「正しい解決策」はありません。それは知識発見方法であり、実際には新しい何かを見つけることを意味します!これにより、評価が非常に難しくなります。多くの場合、既知の分類を参照として使用して評価されますが、常に適切であるとは限りません。現在の分類がデータの内容を反映している場合とそうでない場合があります。

例を挙げましょう。性別データを含む、顧客の大規模なデータセットがあります。このデータセットを「男性」と「女性」に分割する方法は、既存のクラスと比較する場合に最適です。「予測」の考え方では、これは良いことです。新しいユーザーは、性別を予測できるようになります。「知識発見」の考え方では、データの新しい構造を発見したかったため、これは実際には悪いことです。ただし、たとえばデータを高齢者と子供に分割する方法では、男性/女性のクラスに比べてスコアが低くなります。ただし、これは優れたクラスタリング結果です(年齢が指定されていない場合)。

EMに戻ります。基本的に、データが複数の多変量正規分布で構成されていることを前提としています(これは、特にクラスター数を修正する場合に非常に強力な前提になること注意してください)。次に、モデルとモデルへのオブジェクトの割り当てを交互に改善することにより、このためのローカル最適モデルを見つけようとします。

分類コンテキストで最良の結果を得るには、クラスの数より大きいクラスターの数を選択するか、クラスター化を単一のクラスのみに適用します(クラス内に構造があるかどうかを確認するため!)。

「車」、「自転車」、「トラック」を区別するために分類子を訓練したいとします。データが正確に3つの正規分布で構成されると仮定してもほとんど役に立ちません。ただし、複数のタイプの車(およびトラックとバイク)あり。したがって、これらの3つのクラスの分類子をトレーニングする代わりに、車、トラック、バイクをそれぞれ10個のクラスター(または、おそらく10台の車、3台のトラック、3つのバイクなど)にクラスター化し、これらの30個のクラスを区別するように分類子をトレーニングします。クラスの結果を元のクラスにマージします。また、たとえばTrikesなど、分類が特に困難なクラスターが1つあることもわかります。彼らはやや車であり、ややバイクです。または、トラックというよりも特大車のような配達用トラック。


EMはどのように過小評価されていますか?
sam boosalis 2013

複数のバージョンがあります。技術的には、ロイドスタイルのk-meansを「EM」と呼ぶこともできます。使用するモデルを指定する必要があります。
QUITあり-Anony-Mousse 2013年

2

他の答えは良いです、私は別の視点を提供し、質問の直感的な部分に取り組みます。

EM(期待値最大化)アルゴリズムは、双対性を使用する反復アルゴリズムのクラスのバリアントです。

抜粋(鉱山を強調):

数学では、一般的に言えば、双対性は概念、定理、または数学的構造を他の概念、定理、または構造に1対1の方法で、しばしば(常にではない)畳み込み演算によって変換します。 AはBであり、次にBの双対はAです。このようなインボリューションには固定点がある場合があります AのデュアルA自身であるように、

通常、オブジェクト Aのデュアル Bは、何らかの対称性または互換性を維持する何らかの方法でAに関連付けられています。たとえばAB = const

(以前の意味での)双対性を使用する反復アルゴリズムの例は次のとおりです。

  1. Greatest Common Divisorのユークリッドアルゴリズムとそのバリアント
  2. グラム・シュミットのベクトル基底アルゴリズムとバリアント
  3. 算術平均-幾何平均不等式、およびその変形
  4. 期待値最大化アルゴリズムとその変形情報幾何学的ビューについては、こちらご覧ください)
  5. (..他の同様のアルゴリズム..)

同様に、EMアルゴリズムは、2つの最大化ステップとしても見ることができます

.. [EM]は、パラメータと観測されていない変数の分布の結合関数を最大化すると見なされます。Eステップは、観測されていない変数の分布に関してこの関数を最大化します。パラメータに関するMステップ

双対性を使用する反復アルゴリズムでは、平衡(または固定)収束点の明示的(または暗黙的)仮定があります(EMの場合、これはJensenの不等式を使用して証明されます)

したがって、そのようなアルゴリズムの概要は次のとおりです。

  1. Eライクなステップ:与えられたyが一定に保たれるという点で、最良の解xを見つけます。
  2. M-like step(dual):一定に保たれているx(前のステップで計算された)に関して、最良の解yを見つけます。
  3. 終了/収束ステップの基準:収束(または指定された反復回数に達する)まで、xyの更新された値を使用してステップ1、2を繰り返します。

メモこのようなアルゴリズムの収束(グローバル)最適に、それが構成発見したことの両方の意味で最良の(両方で、すなわちXのドメイン/パラメータおよびYドメイン/パラメータ)。ただし、アルゴリズムはローカル最適を見つけるだけで、グローバル最適を見つけることはできません。

これはアルゴリズムの概要の直感的な説明だと思います

統計的な議論とアプリケーションについては、他の回答が良い説明を与えています(この回答の参照も確認してください)


2

受け入れられた回答は、EMを説明する適切な仕事をするChuong EM Paperを参照しています。紙をより詳細に説明するyoutubeビデオもあります。

要約すると、ここにシナリオがあります:

1st:  {H,T,T,T,H,H,T,H,T,H} 5 Heads, 5 Tails; Did coin A or B generate me?
2nd:  {H,H,H,H,T,H,H,H,H,H} 9 Heads, 1 Tails
3rd:  {H,T,H,H,H,H,H,T,H,H} 8 Heads, 2 Tails
4th:  {H,T,H,T,T,T,H,H,T,T} 4 Heads, 6 Tails
5th:  {T,H,H,H,T,H,H,H,T,H} 7 Heads, 3 Tails

Two possible coins, A & B are used to generate these distributions.
A & B have an unknown parameter: their bias towards heads.

We don't know the biases, but we can simply start with a guess: A=60% heads, B=50% heads.

最初のトライアルの質問の場合、ヘッドの比率がBのバイアスに非常によく一致しているので、直感的にはBがそれを生成したと思います...しかし、その値は単なる推測でしたので、確信が持てません。

それを念頭に置いて、私は次のようなEMソリューションについて考えたいと思います。

  • フリップの各トライアルは、それが最も好きなコインに「投票」します
    • これは、各コインがその分布にどの程度適合しているかに基づいています
    • または、コインの観点から、他のコインと比較して(対数尤度に基づいて)このトライアルを見ることへの高い期待があります
  • 各トライアルが各コインをどの程度気に入っているかに応じて、そのコインのパラメーター(バイアス)の推測を更新できます。
    • トライアルがコインを好きであればあるほど、コインのバイアスを更新して、それを反映するようになります!
    • 基本的に、コインのバイアスは、すべての試行にわたってこれらの重み付けされた更新を組み合わせることによって更新されます。プロセスは(maximazation)と呼ばれ、一連の試行で各コインのバイアスについて最良の推測を試みます。

これは単純化しすぎるかもしれません(または、一部のレベルでは根本的に間違っているかもしれません)が、これが直感的なレベルで役立つことを願っています!


1

EMは、潜在変数Zを持つモデルQの可能性を最大化するために使用されます。

これは反復最適化です。

theta <- initial guess for hidden parameters
while not converged:
    #e-step
    Q(theta'|theta) = E[log L(theta|Z)]
    #m-step
    theta <- argmax_theta' Q(theta'|theta)

e-step:Zの現在の推定が与えられると、期待される対数尤度関数を計算します

m-step:このQを最大にするシータを見つける

GMMの例:

e-step:現在のgmmパラメーター推定値を指定して、各データポイントのラベル割り当てを推定します

m-step:新しいラベル割り当てを指定して新しいシータを最大化します

K-meansもEMアルゴリズムであり、K-meansに関するアニメーションの説明はたくさんあります。


1

Zhubarbの答えで引用されたDoとBatzoglouによる同じ記事を使用して、Javaでその問題にEMを実装しました。彼の回答へのコメントは、アルゴリズムがローカルの最適値で動かなくなることを示しています。これは、パラメーターthetaAとthetaBが同じ場合、私の実装でも発生します。

以下は、私のコードの標準出力で、パラメーターの収束を示しています。

thetaA = 0.71301, thetaB = 0.58134
thetaA = 0.74529, thetaB = 0.56926
thetaA = 0.76810, thetaB = 0.54954
thetaA = 0.78316, thetaB = 0.53462
thetaA = 0.79106, thetaB = 0.52628
thetaA = 0.79453, thetaB = 0.52239
thetaA = 0.79593, thetaB = 0.52073
thetaA = 0.79647, thetaB = 0.52005
thetaA = 0.79667, thetaB = 0.51977
thetaA = 0.79674, thetaB = 0.51966
thetaA = 0.79677, thetaB = 0.51961
thetaA = 0.79678, thetaB = 0.51960
thetaA = 0.79679, thetaB = 0.51959
Final result:
thetaA = 0.79678, thetaB = 0.51960

以下は(Do and Batzoglou、2008)の問題を解決するためのEMのJava実装です。実装の中核部分は、パラメーターが収束するまでEMを実行するループです。

private Parameters _parameters;

public Parameters run()
{
    while (true)
    {
        expectation();

        Parameters estimatedParameters = maximization();

        if (_parameters.converged(estimatedParameters)) {
            break;
        }

        _parameters = estimatedParameters;
    }

    return _parameters;
}

以下はコード全体です。

import java.util.*;

/*****************************************************************************
This class encapsulates the parameters of the problem. For this problem posed
in the article by (Do and Batzoglou, 2008), the parameters are thetaA and
thetaB, the probability of a coin coming up heads for the two coins A and B,
respectively.
*****************************************************************************/
class Parameters
{
    double _thetaA = 0.0; // Probability of heads for coin A.
    double _thetaB = 0.0; // Probability of heads for coin B.

    double _delta = 0.00001;

    public Parameters(double thetaA, double thetaB)
    {
        _thetaA = thetaA;
        _thetaB = thetaB;
    }

    /*************************************************************************
    Returns true if this parameter is close enough to another parameter
    (typically the estimated parameter coming from the maximization step).
    *************************************************************************/
    public boolean converged(Parameters other)
    {
        if (Math.abs(_thetaA - other._thetaA) < _delta &&
            Math.abs(_thetaB - other._thetaB) < _delta)
        {
            return true;
        }

        return false;
    }

    public double getThetaA()
    {
        return _thetaA;
    }

    public double getThetaB()
    {
        return _thetaB;
    }

    public String toString()
    {
        return String.format("thetaA = %.5f, thetaB = %.5f", _thetaA, _thetaB);
    }

}


/*****************************************************************************
This class encapsulates an observation, that is the number of heads
and tails in a trial. The observation can be either (1) one of the
experimental observations, or (2) an estimated observation resulting from
the expectation step.
*****************************************************************************/
class Observation
{
    double _numHeads = 0;
    double _numTails = 0;

    public Observation(String s)
    {
        for (int i = 0; i < s.length(); i++)
        {
            char c = s.charAt(i);

            if (c == 'H')
            {
                _numHeads++;
            }
            else if (c == 'T')
            {
                _numTails++;
            }
            else
            {
                throw new RuntimeException("Unknown character: " + c);
            }
        }
    }

    public Observation(double numHeads, double numTails)
    {
        _numHeads = numHeads;
        _numTails = numTails;
    }

    public double getNumHeads()
    {
        return _numHeads;
    }

    public double getNumTails()
    {
        return _numTails;
    }

    public String toString()
    {
        return String.format("heads: %.1f, tails: %.1f", _numHeads, _numTails);
    }

}

/*****************************************************************************
This class runs expectation-maximization for the problem posed by the article
from (Do and Batzoglou, 2008).
*****************************************************************************/
public class EM
{
    // Current estimated parameters.
    private Parameters _parameters;

    // Observations from the trials. These observations are set once.
    private final List<Observation> _observations;

    // Estimated observations per coin. These observations are the output
    // of the expectation step.
    private List<Observation> _expectedObservationsForCoinA;
    private List<Observation> _expectedObservationsForCoinB;

    private static java.io.PrintStream o = System.out;

    /*************************************************************************
    Principal constructor.
    @param observations The observations from the trial.
    @param parameters The initial guessed parameters.
    *************************************************************************/
    public EM(List<Observation> observations, Parameters parameters)
    {
        _observations = observations;
        _parameters = parameters;
    }

    /*************************************************************************
    Run EM until parameters converge.
    *************************************************************************/
    public Parameters run()
    {

        while (true)
        {
            expectation();

            Parameters estimatedParameters = maximization();

            o.printf("%s\n", estimatedParameters);

            if (_parameters.converged(estimatedParameters)) {
                break;
            }

            _parameters = estimatedParameters;
        }

        return _parameters;

    }

    /*************************************************************************
    Given the observations and current estimated parameters, compute new
    estimated completions (distribution over the classes) and observations.
    *************************************************************************/
    private void expectation()
    {

        _expectedObservationsForCoinA = new ArrayList<Observation>();
        _expectedObservationsForCoinB = new ArrayList<Observation>();

        for (Observation observation : _observations)
        {
            int numHeads = (int)observation.getNumHeads();
            int numTails = (int)observation.getNumTails();

            double probabilityOfObservationForCoinA=
                binomialProbability(10, numHeads, _parameters.getThetaA());

            double probabilityOfObservationForCoinB=
                binomialProbability(10, numHeads, _parameters.getThetaB());

            double normalizer = probabilityOfObservationForCoinA +
                                probabilityOfObservationForCoinB;

            // Compute the completions for coin A and B (i.e. the probability
            // distribution of the two classes, summed to 1.0).

            double completionCoinA = probabilityOfObservationForCoinA /
                                     normalizer;
            double completionCoinB = probabilityOfObservationForCoinB /
                                     normalizer;

            // Compute new expected observations for the two coins.

            Observation expectedObservationForCoinA =
                new Observation(numHeads * completionCoinA,
                                numTails * completionCoinA);

            Observation expectedObservationForCoinB =
                new Observation(numHeads * completionCoinB,
                                numTails * completionCoinB);

            _expectedObservationsForCoinA.add(expectedObservationForCoinA);
            _expectedObservationsForCoinB.add(expectedObservationForCoinB);
        }
    }

    /*************************************************************************
    Given new estimated observations, compute new estimated parameters.
    *************************************************************************/
    private Parameters maximization()
    {

        double sumCoinAHeads = 0.0;
        double sumCoinATails = 0.0;
        double sumCoinBHeads = 0.0;
        double sumCoinBTails = 0.0;

        for (Observation observation : _expectedObservationsForCoinA)
        {
            sumCoinAHeads += observation.getNumHeads();
            sumCoinATails += observation.getNumTails();
        }

        for (Observation observation : _expectedObservationsForCoinB)
        {
            sumCoinBHeads += observation.getNumHeads();
            sumCoinBTails += observation.getNumTails();
        }

        return new Parameters(sumCoinAHeads / (sumCoinAHeads + sumCoinATails),
                              sumCoinBHeads / (sumCoinBHeads + sumCoinBTails));

        //o.printf("parameters: %s\n", _parameters);

    }

    /*************************************************************************
    Since the coin-toss experiment posed in this article is a Bernoulli trial,
    use a binomial probability Pr(X=k; n,p) = (n choose k) * p^k * (1-p)^(n-k).
    *************************************************************************/
    private static double binomialProbability(int n, int k, double p)
    {
        double q = 1.0 - p;
        return nChooseK(n, k) * Math.pow(p, k) * Math.pow(q, n-k);
    }

    private static long nChooseK(int n, int k)
    {
        long numerator = 1;

        for (int i = 0; i < k; i++)
        {
            numerator = numerator * n;
            n--;
        }

        long denominator = factorial(k);

        return (long)(numerator / denominator);
    }

    private static long factorial(int n)
    {
        long result = 1;
        for (; n >0; n--)
        {
            result = result * n;
        }

        return result;
    }

    /*************************************************************************
    Entry point into the program.
    *************************************************************************/
    public static void main(String argv[])
    {
        // Create the observations and initial parameter guess
        // from the (Do and Batzoglou, 2008) article.

        List<Observation> observations = new ArrayList<Observation>();
        observations.add(new Observation("HTTTHHTHTH"));
        observations.add(new Observation("HHHHTHHHHH"));
        observations.add(new Observation("HTHHHHHTHH"));
        observations.add(new Observation("HTHTTTHHTT"));
        observations.add(new Observation("THHHTHHHTH"));

        Parameters initialParameters = new Parameters(0.6, 0.5);

        EM em = new EM(observations, initialParameters);

        Parameters finalParameters = em.run();

        o.printf("Final result:\n%s\n", finalParameters);
    }
}
弊社のサイトを使用することにより、あなたは弊社のクッキーポリシーおよびプライバシーポリシーを読み、理解したものとみなされます。
Licensed under cc by-sa 3.0 with attribution required.