短い答え:
- 多くのビッグデータ設定(たとえば数百万のデータポイント)では、すべてのデータポイントを合計する必要があるため、コストまたは勾配の計算に非常に長い時間がかかります。
- 特定の反復でコストを削減するために、正確な勾配を持つ必要はありません。勾配の近似は問題なく機能します。
- 確率的勾配降下(SGD)は、1つのデータポイントのみを使用して勾配を近似します。そのため、勾配を評価すると、すべてのデータを合計するのに比べて時間を大幅に節約できます。
- 「合理的な」反復回数(この回数は数千で、データポイントの数(数百万)よりもはるかに少ない可能性があります)では、確率的勾配まともな妥当な解決策が得られます。
長い答え:
私の表記法はAndrew NGの機械学習Courseraコースに従います。よく知らない場合は、こちらで講義シリーズを確認できます。
二乗損失の回帰を想定してみましょう、コスト関数は
J(θ )= 12 メートル∑i = 1m(hθ(x(i ))− y(i ))2
勾配は
dJ(θ )dθ= 1m∑i = 1m(hθ(x(i ))− y(i ))x(i )
勾配まとも(GD)の場合、次のようにパラメータを更新します
θN E W= θO l d- α 1m∑i = 1m(hθ(x(i ))− y(i ))x(私)
1 / mバツ(i )、y(i )
θN E W= θO l d- α ⋅ (Hθ(x(i ))− y(i ))x(i )
これが時間を節約する理由です。
10億のデータポイントがあるとします。
GDでは、パラメーターを1回更新するために、(正確な)勾配が必要です。これには、1つの更新を実行するためにこれらの10億のデータポイントを合計する必要があります。
SGD では、厳密な勾配の代わりに近似勾配を取得しようとしていると考えることができます。近似は、1つのデータポイント(またはミニバッチと呼ばれる複数のデータポイント)からのものです。したがって、SGDでは、パラメーターを非常に迅速に更新できます。さらに、すべてのデータ(1エポックと呼ばれる)を「ループ」すると、実際には10億の更新があります。
秘trickは、SGDでは10億回の反復/更新は必要ないが、反復/更新ははるかに少ない、たとえば100万であり、使用するのに十分なモデルがあるということです。
アイデアをデモするためのコードを書いています。最初に正規方程式で線形システムを解き、次にSGDで解きます。次に、パラメータ値と最終目的関数値に関して結果を比較します。後で視覚化するために、調整する2つのパラメーターがあります。
set.seed(0);n_data=1e3;n_feature=2;
A=matrix(runif(n_data*n_feature),ncol=n_feature)
b=runif(n_data)
res1=solve(t(A) %*% A, t(A) %*% b)
sq_loss<-function(A,b,x){
e=A %*% x -b
v=crossprod(e)
return(v[1])
}
sq_loss_gr_approx<-function(A,b,x){
# note, in GD, we need to sum over all data
# here i is just one random index sample
i=sample(1:n_data, 1)
gr=2*(crossprod(A[i,],x)-b[i])*A[i,]
return(gr)
}
x=runif(n_feature)
alpha=0.01
N_iter=300
loss=rep(0,N_iter)
for (i in 1:N_iter){
x=x-alpha*sq_loss_gr_approx(A,b,x)
loss[i]=sq_loss(A,b,x)
}
結果:
as.vector(res1)
[1] 0.4368427 0.3991028
x
[1] 0.3580121 0.4782659
パラメーターは近すぎませんが、損失値は 124.1343 そして 123.0355 とても近いです。
反復にわたるコスト関数の値は次のとおりです。損失を効果的に減らすことができることがわかります。これは、データのサブセットを使用して勾配を近似し、「十分な」結果を得ることができるという考えを示しています。
次に、2つのアプローチ間の計算の労力を確認しましょう。実験では、1000データポイントは、SDを使用して、データを合計する必要がある場合に勾配を評価します。sq_loss_gr_approx
ただし、SGDでは、関数は1データポイントのみを合計します。全体的に見て、アルゴリズムは以下よりも収束します300 繰り返し(注意、ではありません 1000 反復。)これは計算上の節約です。