scikit-learnで分類子をディスクに保存する


191

トレーニング済みの単純ベイズ分類器ディスクに保存し、それを使用してデータを予測するにはどうすればよいですか?

scikit-learnのWebサイトにある次のサンプルプログラムがあります。

from sklearn import datasets
iris = datasets.load_iris()
from sklearn.naive_bayes import GaussianNB
gnb = GaussianNB()
y_pred = gnb.fit(iris.data, iris.target).predict(iris.data)
print "Number of mislabeled points : %d" % (iris.target != y_pred).sum()

回答:


201

分類子は、他のようにピクルスにしてダンプできるオブジェクトです。例を続けるには:

import cPickle
# save the classifier
with open('my_dumped_classifier.pkl', 'wb') as fid:
    cPickle.dump(gnb, fid)    

# load it again
with open('my_dumped_classifier.pkl', 'rb') as fid:
    gnb_loaded = cPickle.load(fid)

1
魅力的な作品!私はnp.savezを使用してそれをずっとロードしようとしていましたが、それは役に立ちませんでした。どうもありがとう。
Kartos 2014年

7
python3では、pickleモジュールを使用します。これは、このように機能します。
MCSH 2018年

212

デフォルトのpython picklerよりも数値配列の処理ではるかに効率的なjoblib.dumpjoblib.loadを使用することもできます。

Joblibはscikit-learnに含まれています。

>>> import joblib
>>> from sklearn.datasets import load_digits
>>> from sklearn.linear_model import SGDClassifier

>>> digits = load_digits()
>>> clf = SGDClassifier().fit(digits.data, digits.target)
>>> clf.score(digits.data, digits.target)  # evaluate training error
0.9526989426822482

>>> filename = '/tmp/digits_classifier.joblib.pkl'
>>> _ = joblib.dump(clf, filename, compress=9)

>>> clf2 = joblib.load(filename)
>>> clf2
SGDClassifier(alpha=0.0001, class_weight=None, epsilon=0.1, eta0=0.0,
       fit_intercept=True, learning_rate='optimal', loss='hinge', n_iter=5,
       n_jobs=1, penalty='l2', power_t=0.5, rho=0.85, seed=0,
       shuffle=False, verbose=0, warm_start=False)
>>> clf2.score(digits.data, digits.target)
0.9526989426822482

編集:Python 3.8以降では、pickleプロトコル5(デフォルトではない)を使用する場合、属性として大きな数値配列を持つオブジェクトを効率的に酸洗いするためにpickleを使用できるようになりました。


1
しかし、私の理解では、パイプラインは、単一のワークフローの一部である場合に機能します。モデルをビルドしたい場合は、ディスクに保存して、そこで実行を停止します。それから私は一週間後に戻ってくると、それは私にエラーをスローし、ディスクからモデルをロードしよう:
venuktan

2
これが目的のfit場合、メソッドの実行を停止して再開する方法はありません。そうは言っても、scikit-learnライブラリの同じバージョンを使用してPythonから呼び出しjoblib.loadjoblib.dump場合、成功後に例外を発生させるべきではありません。
ogrisel 2014年

10
IPythonを使用している場合は、--pylabコマンドラインフラグや%pylabマジックを使用しないでください。暗黙的な名前空間のオーバーロードにより、酸洗いプロセスが中断されることがわかっています。%matplotlib inline代わりに、明示的なインポートとマジックを使用してください。
ogrisel 2014年

2
参照用scikit-学ぶのドキュメントを参照してください。scikit-learn.org/stable/tutorial/basic/...
user1448319

1
以前に保存したモデルを再トレーニングすることは可能ですか?特にSVCモデル?
Uday Sawant 2017

108

あなたが探しているものはsklearnの言葉でモデルの永続性と呼ばれ、それは導入部モデルの永続性のセクションに記載されています。

したがって、分類子を初期化し、長い間それをトレーニングしました

clf = some.classifier()
clf.fit(X, y)

この後、2つのオプションがあります。

1)ピクルスの使用

import pickle
# now you can save it to a file
with open('filename.pkl', 'wb') as f:
    pickle.dump(clf, f)

# and later you can load it
with open('filename.pkl', 'rb') as f:
    clf = pickle.load(f)

2)Joblibの使用

from sklearn.externals import joblib
# now you can save it to a file
joblib.dump(clf, 'filename.pkl') 
# and later you can load it
clf = joblib.load('filename.pkl')

もう一度、上記のリンクを読むと役に立ちます


30

多くの場合、特にテキスト分類の場合、分類子を保存するだけでは十分ではありませんが、今後入力をベクトル化できるように、ベクトライザーも保存する必要があります。

import pickle
with open('model.pkl', 'wb') as fout:
  pickle.dump((vectorizer, clf), fout)

将来のユースケース:

with open('model.pkl', 'rb') as fin:
  vectorizer, clf = pickle.load(fin)

X_new = vectorizer.transform(new_samples)
X_new_preds = clf.predict(X_new)

ベクトライザーをダンプする前に、ベクトライザーのstop_words_プロパティを次の方法で削除できます。

vectorizer.stop_words_ = None

ダンプをより効率的にするため。また、(ほとんどのテキスト分類の例のように)分類子パラメーターがスパースである場合は、パラメーターをデンスからスパースに変換できます。これにより、メモリの消費、ロード、およびダンプの点で大きな違いが生じます。の方法でモデルをスパース化します。

clf.sparsify()

これは自動的にSGDClassifierで機能しますが、モデルがスパース(clf.coef_のゼロのロット)であることがわかっている場合は、clf.coef_をcsr scipyスパース行列に手動で変換できます。

clf.coef_ = scipy.sparse.csr_matrix(clf.coef_)

その後、より効率的に保管できます。


洞察に満ちた答え!SVCの場合に追加したいだけで、スパースモデルパラメーターを返します。
Shayan Amani

4

sklearn推定器は、推定器の関連するトレーニング済みプロパティを簡単に保存できるようにするメソッドを実装します。一部のエスティメータは__getstate__メソッド自体を実装しますが、オブジェクトの内部ディクショナリを保存するだけGMM基本実装を使用するようなものもあります。

def __getstate__(self):
    try:
        state = super(BaseEstimator, self).__getstate__()
    except AttributeError:
        state = self.__dict__.copy()

    if type(self).__module__.startswith('sklearn.'):
        return dict(state.items(), _sklearn_version=__version__)
    else:
        return state

モデルをディスクに保存するための推奨される方法は、pickleモジュールを使用することです。

from sklearn import datasets
from sklearn.svm import SVC
iris = datasets.load_iris()
X = iris.data[:100, :2]
y = iris.target[:100]
model = SVC()
model.fit(X,y)
import pickle
with open('mymodel','wb') as f:
    pickle.dump(model,f)

ただし、追加のデータを保存して、将来モデルを再トレーニングしたり、悲惨な結果(古いバージョンのsklearnにロックされているなど)に対応したりできるようにする必要があります。

ドキュメントから:

scikit-learnの将来のバージョンで同様のモデルを再構築するには、ピクルドモデルに沿って追加のメタデータを保存する必要があります。

トレーニングデータ、たとえば不変のスナップショットへの参照

モデルの生成に使用されたpythonソースコード

scikit-learnのバージョンとその依存関係

トレーニングデータで取得した相互検証スコア

これはtree.pyx、Cythonで記述されたモジュール(などIsolationForest)に依存するEnsemble推定器に特に当てはまります。これは、sklearnのバージョン間での安定性が保証されない実装への結合を作成するためです。過去に互換性のない変更がありました。

モデルが非常に大きくなり、ロードが煩わしい場合は、より効率的なを使用することもできますjoblib。ドキュメントから:

scikitの特定のケースでは、joblibのpicklejoblib.dumpjoblib.load)の置換を使用する方が興味深い場合があります。これは、フィッティングされたscikit-learn推定器の場合によくあるように、内部で大きなnumpy配列を運ぶオブジェクトでより効率的ですが、ピクルのみが可能です。文字列ではなくディスクに:


1
but can only pickle to the disk and not to a stringしかし、これをjoblibからStringIOにピクルすることができます。これは私がいつもしていることです。
マシュー

1

sklearn.externals.joblib以降廃止され0.21で削除されv0.23ます:

/usr/local/lib/python3.7/site-packages/sklearn/externals/joblib/ init .py:15:FutureWarning:sklearn.externals.joblibは0.21で廃止され、0.23で削除されます。この機能は、pib install joblibを使用してインストールできるjoblibから直接インポートしてください。ピクルドモデルのロード時にこの警告が表示された場合、scikit-learn 0.21+を使用してそれらのモデルを再シリアル化する必要がある場合があります。
warnings.warn(msg、category = FutureWarning)


したがって、インストールする必要がありますjoblib

pip install joblib

最後にモデルをディスクに書き込みます。

import joblib
from sklearn.datasets import load_digits
from sklearn.linear_model import SGDClassifier


digits = load_digits()
clf = SGDClassifier().fit(digits.data, digits.target)

with open('myClassifier.joblib.pkl', 'wb') as f:
    joblib.dump(clf, f, compress=9)

ダンプファイルを読み取るために実行する必要があるのは、次のとおりです。

with open('myClassifier.joblib.pkl', 'rb') as f:
    my_clf = joblib.load(f)
弊社のサイトを使用することにより、あなたは弊社のクッキーポリシーおよびプライバシーポリシーを読み、理解したものとみなされます。
Licensed under cc by-sa 3.0 with attribution required.