GELUアクティベーションとは何ですか?


17

GELU(x)= xP(X≤x)=xΦ(x)として方程式を述べる GELU(Gaussian Linear Linear Unit)を使用するBERTペーパー調べました。0.5x (1 + tanh [\ sqrt {2 /π}(x + 0.044715x ^ 3)])に 相当します。方程式を単純化し、どのように承認されたかを説明してください。

GELU(x)=xP(Xx)=xΦ(x).
0.5x(1+tanh[2/π(x+0.044715x3)])

回答:


16

GELU関数

次のようにN(0,1)累積分布、つまりΦ(x)を展開できます。

GELU(x):=xP(Xx)=xΦ(x)=0.5x(1+erf(x2))

これは定義であり、方程式(または関係)ではないことに注意してください。著者は、この提案に対していくつかの正当化、例えば確率論的な類似性を提供しているが、数学的にはこれは単なる定義に過ぎない。

GELUのプロットは次のとおりです。

タン近似

これらのタイプの数値近似の重要な考え方は、(主に経験に基づいて)同様の関数を見つけ、それをパラメーター化し、元の関数からのポイントのセットに適合させることです。

がに非常に近いことを知っているerf(x)tanh(x)

そして、の一次導関数のそれと一致において、これは、 (またはより多くの用語で)ポイントのセット。erf(x2)tanh(2πx)x=02π

tanh(2π(x+ax2+bx3+cx4+dx5))
(xi,erf(xi2))

この関数を間の20個のサンプルに適合させました(このサイトを使用)。係数は次のとおりです。(1.5,1.5)

と設定ことにより、はと推定され。より広い範囲からのサンプルが多いほど(そのサイトでは20しか許可されていません)、係数は紙の近くなります。最後にa=c=d=0b0.04495641b0.044715

GELU(x)=xΦ(x)=0.5x(1+erf(x2))0.5x(1+tanh(2π(x+0.044715x3)))

平均二乗誤差でのために。108x[10,10]

一次導関数間の関係を利用しなかった場合、用語がパラメーターに次のように含まれる ことに注意してください これはあまり美しくありません(分析的ではなく、数値的です)!2π

0.5x(1+tanh(0.797885x+0.035677x3))

パリティを活用する

@BookYourLuckが示唆するように、関数のパリティを利用して、検索する多項式の空間を制限できます。つまり、は奇数関数、つまりであり、も奇数関数であるため、多項式関数はは、を持つためには 奇数(奇数のべき乗のみ)でなければなりませんerff(x)=f(x)tanhpol(x)tanhx

erf(x)tanh(pol(x))=tanh(pol(x))=tanh(pol(x))erf(x)

以前は、私たちも、権力のために(ほぼ)ゼロ係数で終わることが幸運だったと、しかし一般的に、これは例えば、などの用語持っている、低品質の近似につながるかもしれないその単に選択する代わりに、余分な条件(偶数または奇数)によってキャンセルされます。x2x40.23x20x2

シグモイド近似

と(シグモイド)の間にも同様の関係があり平均二乗誤差のための。erf(x)2(σ(x)12)104x[10,10]

データポイントを生成し、関数を近似し、平均二乗誤差を計算するPythonコードを次に示します。

import math
import numpy as np
import scipy.optimize as optimize


def tahn(xs, a):
    return [math.tanh(math.sqrt(2 / math.pi) * (x + a * x**3)) for x in xs]


def sigmoid(xs, a):
    return [2 * (1 / (1 + math.exp(-a * x)) - 0.5) for x in xs]


print_points = 0
np.random.seed(123)
# xs = [-2, -1, -.9, -.7, 0.6, -.5, -.4, -.3, -0.2, -.1, 0,
#       .1, 0.2, .3, .4, .5, 0.6, .7, .9, 2]
# xs = np.concatenate((np.arange(-1, 1, 0.2), np.arange(-4, 4, 0.8)))
# xs = np.concatenate((np.arange(-2, 2, 0.5), np.arange(-8, 8, 1.6)))
xs = np.arange(-10, 10, 0.001)
erfs = np.array([math.erf(x/math.sqrt(2)) for x in xs])
ys = np.array([0.5 * x * (1 + math.erf(x/math.sqrt(2))) for x in xs])

# Fit tanh and sigmoid curves to erf points
tanh_popt, _ = optimize.curve_fit(tahn, xs, erfs)
print('Tanh fit: a=%5.5f' % tuple(tanh_popt))

sig_popt, _ = optimize.curve_fit(sigmoid, xs, erfs)
print('Sigmoid fit: a=%5.5f' % tuple(sig_popt))

# curves used in https://mycurvefit.com:
# 1. sinh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))/cosh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))
# 2. sinh(sqrt(2/3.141593)*(x+b*x^3))/cosh(sqrt(2/3.141593)*(x+b*x^3))
y_paper_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + 0.044715 * x**3))) for x in xs])
tanh_error_paper = (np.square(ys - y_paper_tanh)).mean()
y_alt_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + tanh_popt[0] * x**3))) for x in xs])
tanh_error_alt = (np.square(ys - y_alt_tanh)).mean()

# curve used in https://mycurvefit.com:
# 1. 2*(1/(1+2.718281828459^(-(a*x))) - 0.5)
y_paper_sigmoid = np.array([x * (1 / (1 + math.exp(-1.702 * x))) for x in xs])
sigmoid_error_paper = (np.square(ys - y_paper_sigmoid)).mean()
y_alt_sigmoid = np.array([x * (1 / (1 + math.exp(-sig_popt[0] * x))) for x in xs])
sigmoid_error_alt = (np.square(ys - y_alt_sigmoid)).mean()

print('Paper tanh error:', tanh_error_paper)
print('Alternative tanh error:', tanh_error_alt)
print('Paper sigmoid error:', sigmoid_error_paper)
print('Alternative sigmoid error:', sigmoid_error_alt)

if print_points == 1:
    print(len(xs))
    for x, erf in zip(xs, erfs):
        print(x, erf)

出力:

Tanh fit: a=0.04485
Sigmoid fit: a=1.70099
Paper tanh error: 2.4329173471294176e-08
Alternative tanh error: 2.698034519269613e-08
Paper sigmoid error: 5.6479106346814546e-05
Alternative sigmoid error: 5.704246564663601e-05

2
なぜ近似が必要ですか?単にerf関数を使用できないのでしょうか?
SebiSebi

8

まず、パリティによります。私達は示す必要があること用。

Φ(x)=12erfc(x2)=12(1+erf(x2))
erf
erf(x2)tanh(2π(x+ax3))
a0.044715

値が大きい場合、両方の関数は区切られます。小さい場合、それぞれのテイラー級数はおよび 代わりに、 および 係数を すると、論文の近い見つかります。 x[1,1]x

tanh(x)=xx33+o(x3)
erf(x)=2π(xx33)+o(x3).
tanh(2π(x+ax3))=2π(x+(a23π)x3)+o(x3)
erf(x2)=2π(xx36)+o(x3).
x3
a0.04553992412
0.044715

弊社のサイトを使用することにより、あなたは弊社のクッキーポリシーおよびプライバシーポリシーを読み、理解したものとみなされます。
Licensed under cc by-sa 3.0 with attribution required.