Pythonでセグメント化線形回帰を実行するライブラリはありますか?


7

Rでセグメント化されたという名前のパッケージがあります。Pythonにも同様のパッケージはありますか?


あなたが必要とする特定のものはありますか?既知のブレークポイントの場合、これはインジケーター関数(ブレーク前は0、ブレーク後は1)または線形スプラインとの相互作用によってモデル化できます。最初のアプローチにはジャンプがあり、2番目のアプローチでは接続された区分的回帰直線になります。
ジョセフ

回答:


7

いいえ、現在Pythonには、Rほど完全に線形回帰をセグメント化するパッケージはありません(例:このブログ投稿に記載されているRパッケージ)。または、Pythonでベイジアンマルコフ連鎖モンテカルロアルゴリズムを使用して、セグメント化されたモデルを作成できます。

上記のリンクのすべてのRパッケージで実装されているセグメント化線形回帰は、追加のパラメーター制約(つまり、事前)を許可していません。これらのパッケージは頻繁なアプローチを取るため、結果のモデルはモデルの確率分布を提供しませんパラメータ(ブレークポイント、スロープなど)。頻繁に使用されるstatsmodelsでセグメント化されたモデルを定義することは、モデルが固定されたx座標のブレークポイントを必要とするため、さらに制限されます。

ベイジアンマルコフ連鎖モンテカルロアルゴリズムemceeを使用して、Pythonでセグメントモデルを設計できます。Jake Vanderplasは、PyMCとPyStanとの比較によるemceeの実装方法について、有用なブログ投稿ペーパーを書きました。

例:

  • データを含むセグメント化されたモデル:

セグメント化された回帰

  • 適合パラメーターの確率分布:

ここに画像の説明を入力してください


0

ここに画像の説明を入力してください

これは私自身の実装です。

import numpy as np
import matplotlib.pylab as plt
from sklearn.tree import DecisionTreeRegressor
from sklearn.linear_model import LinearRegression

# parameters for setup
n_data = 20

# segmented linear regression parameters
n_seg = 3

np.random.seed(0)
fig, (ax0, ax1) = plt.subplots(1, 2)

# example 1
#xs = np.sort(np.random.rand(n_data))
#ys = np.random.rand(n_data) * .3 + np.tanh(5* (xs -.5))

# example 2
xs = np.linspace(-1, 1, 20)
ys = np.random.rand(n_data) * .3 + np.tanh(3*xs)

dys = np.gradient(ys, xs)

rgr = DecisionTreeRegressor(max_leaf_nodes=n_seg)
rgr.fit(xs.reshape(-1, 1), dys.reshape(-1, 1))
dys_dt = rgr.predict(xs.reshape(-1, 1)).flatten()

ys_sl = np.ones(len(xs)) * np.nan
for y in np.unique(dys_dt):
    msk = dys_dt == y
    lin_reg = LinearRegression()
    lin_reg.fit(xs[msk].reshape(-1, 1), ys[msk].reshape(-1, 1))
    ys_sl[msk] = lin_reg.predict(xs[msk].reshape(-1, 1)).flatten()
    ax0.plot([xs[msk][0], xs[msk][-1]],
             [ys_sl[msk][0], ys_sl[msk][-1]],
             color='r', zorder=1)

ax0.set_title('values')
ax0.scatter(xs, ys, label='data')
ax0.scatter(xs, ys_sl, s=3**2, label='seg lin reg', color='g', zorder=5)
ax0.legend()

ax1.set_title('slope')
ax1.scatter(xs, dys, label='data')
ax1.scatter(xs, dys_dt, label='DecisionTree', s=2**2)
ax1.legend()

plt.show()
弊社のサイトを使用することにより、あなたは弊社のクッキーポリシーおよびプライバシーポリシーを読み、理解したものとみなされます。
Licensed under cc by-sa 3.0 with attribution required.