pytorchでtorch.no_gradの使用は何ですか?


19

私はpytorchを初めて使い、この githubコードから始めました。コードの60-61行目のコメントがわかりません"because weights have requires_grad=True, but we don't need to track this in autograd"requires_grad=Trueautogradを使用するための勾配を計算する必要がある変数について言及していることを理解しましたが、それはどういう意味"tracked by autograd"ですか?

回答:


21

ラッパー "with torch.no_grad()"は、すべてのrequire_gradフラグを一時的にfalseに設定します。公式PyTorchチュートリアル(https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#gradients)の例:

x = torch.randn(3, requires_grad=True)
print(x.requires_grad)
print((x ** 2).requires_grad)

with torch.no_grad():
    print((x ** 2).requires_grad)

でる:

True
True
False

上記のWebサイトからすべてのチュートリアルを読むことをお勧めします。

あなたの例では、著者は、PyTorchが新しく定義された変数w1とw2の勾配を計算することを望んでいないと思います。なぜなら、彼は値を更新したいだけだからです。


5
with torch.no_grad()

ブロック内のすべての操作に勾配がありません。

pytorchでは、w1とw2のインプレースメント変更はできません。w1とw2は、2つの変数ですrequire_grad = True。w1とw2のインプレースメントの変更を避けるのは、逆伝播計算でエラーが発生するためだと思います。インプレースメントの変更はw1とw2を完全に変更するためです。

ただし、これを使用no_grad()すると、操作によって生成されるため、新しいw1と新しいw2には勾配がありません。つまり、勾配部分ではなくw1とw2の値のみを変更できます。そして、逆伝播を継続できます。

弊社のサイトを使用することにより、あなたは弊社のクッキーポリシーおよびプライバシーポリシーを読み、理解したものとみなされます。
Licensed under cc by-sa 3.0 with attribution required.