2番目の方法、およびy = s i n (α )の予測はまったく問題ありません。x = c o s (α )y= s i n (α )
はい、予測ベクトルのノルムは1に近いことが保証されていません。しかし、特にシグモイド活性化関数(それらは自然に制限されている)を使用する場合、および/またはモデルを適切に正則化する場合、爆発する可能性はありません。なぜ、あなたのモデルは、すべての訓練サンプルがでた場合は、大きな値を予測する必要があり、[ - 1 、1 ]?(x 、y)1[ - 1 、1 ]
別の側面は、ベクターである近すぎると(0 、0 )。これは時々発生する可能性があり、実際に間違った角度を予測する可能性があります。しかし、それはあなたのモデルの利点とみなされるかもしれません-あなたのモデルの信頼性の尺度としてのノルム(x 、y )を考慮することができます。実際、0に近いノルムは、モデルが正しい方向がどこにあるのかわからないことを意味します。(x 、y)(0 、0 )(x 、y)
これは、sinとcosを予測する方が角度を直接予測する方が良いことを示すPythonの小さな例です。
# predicting the angle (in radians)
import numpy as np
from sklearn.neural_network import MLPRegressor
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import r2_score
# generate toy data
np.random.seed(1)
X = np.random.normal(size=(100, 2))
y = np.arctan2(np.dot(X, [1,2]), np.dot(X, [3,0.4]))
# simple prediction
model = MLPRegressor(random_state=42, activation='tanh', max_iter=10000)
y_simple_pred = cross_val_predict(model, X, y)
# transformed prediction
joint = cross_val_predict(model, X, np.column_stack([np.sin(y), np.cos(y)]))
y_trig_pred = np.arctan2(joint[:,0], joint[:,1])
# compare
def align(y_true, y_pred):
""" Add or remove 2*pi to predicted angle to minimize difference from GT"""
y_pred = y_pred.copy()
y_pred[y_true-y_pred > np.pi] += np.pi*2
y_pred[y_true-y_pred < -np.pi] -= np.pi*2
return y_pred
print(r2_score(y, align(y, y_simple_pred))) # R^2 about 0.57
print(r2_score(y, align(y, y_trig_pred))) # R^2 about 0.99
次に、予測をプロットして、サイン-コサインモデルの予測がほぼ正しいことを確認しますが、さらにキャリブレーションが必要になる場合があります。
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 3))
plt.subplot(1,4,1)
plt.scatter(X[:,0], X[:,1], c=y)
plt.title('Data (y=color)'); plt.xlabel('x1'); plt.ylabel('x2')
plt.subplot(1,4,2)
plt.scatter(y_simple_pred, y)
plt.title('Direct model'); plt.xlabel('prediction'); plt.ylabel('actual')
plt.subplot(1,4,3)
plt.scatter(y_trig_pred, y)
plt.title('Sine-cosine model'); plt.xlabel('prediction'); plt.ylabel('actual')
plt.subplot(1,4,4)
plt.scatter(joint[:,0], joint[:,1], s=5)
plt.title('Predicted sin and cos'); plt.xlabel('cos'); plt.ylabel('sin')
plt.tight_layout();
πN2αcos(α )罪(α )z= 罪(α + π4)w = cos(α + π4)
(x 、y)(z、w )(x 、y)arctan2