ピクルス Pythonライブラリを実装シリアライズとPythonオブジェクトをデシリアライズするためのバイナリプロトコル。
あなたimport torch
(またはPyTorchを使用するとき)はそれをimport pickle
実行し、オブジェクトを保存およびロードするメソッドであるpickle.dump()
andをpickle.load()
直接呼び出す必要はありません。
実際に、torch.save()
そしてtorch.load()
ラップされますpickle.dump()
と、pickle.load()
あなたのために。
state_dict
他の答えはわずか数より多くのノートに値する言及しました。
state_dict
PyTorchの内部には何がありますか?実際には2つstate_dict
のがあります。
PyTorchモデルには、学習可能なパラメーター(wおよびb)を取得するための呼び出しがtorch.nn.Module
ありmodel.parameters()
ます。これらの学習可能なパラメータは、ランダムに設定されると、時間の経過とともに更新されます。学習可能なパラメータが最初state_dict
です。
2つ目state_dict
は、オプティマイザの状態辞書です。オプティマイザーは学習可能なパラメーターを改善するために使用されることを思い出してください。しかし、オプティマイザstate_dict
は修正されています。そこで学ぶことは何もありません。
state_dict
オブジェクトはPython辞書であるため、簡単に保存、更新、変更、復元でき、PyTorchモデルとオプティマイザに大幅なモジュール性を追加します。
これを説明するための超シンプルなモデルを作成しましょう:
import torch
import torch.optim as optim
model = torch.nn.Linear(5, 2)
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
print("Model weight:")
print(model.weight)
print("Model bias:")
print(model.bias)
print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
このコードは以下を出力します:
Model's state_dict:
weight torch.Size([2, 5])
bias torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328, 0.1360, 0.1553, -0.1838, -0.0316],
[ 0.0479, 0.1760, 0.1712, 0.2244, 0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]
これは最小限のモデルであることに注意してください。シーケンシャルのスタックを追加してみてください
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.Conv2d(A, B, C)
torch.nn.Linear(H, D_out),
)
モデルのにエントリがあるのは、学習可能なパラメーターを持つレイヤー(たたみ込みレイヤー、線形レイヤーなど)と登録されたバッファー(batchnormレイヤー)のみであることに注意してくださいstate_dict
。
学習不能なものは、オプティマイザーオブジェクトに属しますstate_dict
。オプティマイザーオブジェクトには、オプティマイザーの状態に関する情報と、使用されるハイパーパラメーターが含まれています。
物語の残りは同じです。予測のための推論フェーズ(これはトレーニング後にモデルを使用するフェーズです)。学習したパラメータに基づいて予測します。したがって、推論のために、パラメーターを保存する必要があるだけですmodel.state_dict()
。
torch.save(model.state_dict(), filepath)
そして、後でmodel.load_state_dict(torch.load(filepath))model.eval()を使用する
注:model.eval()
モデルをロードした後、これは重要な最後の行を忘れないでください。
また、保存しようとしないでくださいtorch.save(model.parameters(), filepath)
。これmodel.parameters()
は単なるジェネレータオブジェクトです。
一方、torch.save(model, filepath)
モデルオブジェクト自体は保存されますが、モデルにはオプティマイザのがないことに注意してstate_dict
ください。@Jadiel de Armasによる他の優れた回答をチェックして、オプティマイザの状態の辞書を保存してください。