新しい形状が元のテンソルの形状と互換性がある限り、またはにtorch.Tensor.view()
触発された簡単に言えば、テンソルの新しいビューが作成されます。numpy.ndarray.reshape()
numpy.reshape()
具体例を用いてこれを詳しく理解しましょう。
In [43]: t = torch.arange(18)
In [44]: t
Out[44]:
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17])
このt
形状のテンソルを使用すると(18,)
、新しいビューは次の形状に対してのみ作成できます。
(1, 18)
または同等 (1, -1)
又は 又は同等 又は 又は同等 又は 又は同等 又は 又は同等 またはまたは同等 又は(-1, 18)
(2, 9)
(2, -1)
(-1, 9)
(3, 6)
(3, -1)
(-1, 6)
(6, 3)
(6, -1)
(-1, 3)
(9, 2)
(9, -1)
(-1, 2)
(18, 1)
(18, -1)
(-1, 1)
我々は既に上記形状タプルから観察できるように、形状タプル(例えば、の要素の乗算2*9
、3*6
等)が常になければならない(元のテンソルの要素の総数と等しくなります18
私たちの例では)。
もう1つ注意すべき点は-1
、各形状タプルの1つの場所でを使用したことです。を使用する-1
ことで、計算を自分で行うのが面倒になり、タスクをPyTorchに委任して、新しいビューを作成するときに形状のその値の計算を実行します。注意すべき重要な点の1つは、形状タプルでは1つしか使用できないこと-1
です。残りの値は、明示的に提供する必要があります。エルスPyTorchは投げることで文句を言うだろうRuntimeError
:
RuntimeError:推論できる次元は1つだけです
したがって、上記のすべての形状で、PyTorchは常に元のテンソルの新しいビューを返しt
ます。これは基本的に、要求された新しいビューごとにテンソルのストライド情報を変更するだけであることを意味します。
以下は、テンソルのストライドが新しいビューごとにどのように変更されるかを示すいくつかの例です。
# stride of our original tensor `t`
In [53]: t.stride()
Out[53]: (1,)
ここで、新しいビューのストライドを確認します。
# shape (1, 18)
In [54]: t1 = t.view(1, -1)
# stride tensor `t1` with shape (1, 18)
In [55]: t1.stride()
Out[55]: (18, 1)
# shape (2, 9)
In [56]: t2 = t.view(2, -1)
# stride of tensor `t2` with shape (2, 9)
In [57]: t2.stride()
Out[57]: (9, 1)
# shape (3, 6)
In [59]: t3 = t.view(3, -1)
# stride of tensor `t3` with shape (3, 6)
In [60]: t3.stride()
Out[60]: (6, 1)
# shape (6, 3)
In [62]: t4 = t.view(6,-1)
# stride of tensor `t4` with shape (6, 3)
In [63]: t4.stride()
Out[63]: (3, 1)
# shape (9, 2)
In [65]: t5 = t.view(9, -1)
# stride of tensor `t5` with shape (9, 2)
In [66]: t5.stride()
Out[66]: (2, 1)
# shape (18, 1)
In [68]: t6 = t.view(18, -1)
# stride of tensor `t6` with shape (18, 1)
In [69]: t6.stride()
Out[69]: (1, 1)
これがview()
関数の魔法です。それはちょうど新しいのそれぞれについて(オリジナル)テンソルのストライド変化するビューを限り新規の形状として、ビューは元の形状に対応しています。
stridesタプルから観察できるもう1つの興味深いことは、0 番目の位置にある要素の値が、シェイプタプルの1番目の位置にある要素の値と等しいことです。
In [74]: t3.shape
Out[74]: torch.Size([3, 6])
|
In [75]: t3.stride() |
Out[75]: (6, 1) |
|_____________|
それの訳は:
In [76]: t3
Out[76]:
tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17]])
ストライド(6, 1)
は、0 番目の次元に沿って1つの要素から次の要素に移動するには、ジャンプするか6つのステップを実行する必要があることを示しています。(つまり、からに移動する0
には6
、6つのステップを実行する必要があります。)1次元で1つの要素から次の要素に移動するには、ステップが1つだけ必要です(たとえば、からに移動2
します3
)。
したがって、ストライド情報は、計算を実行するためにメモリから要素にアクセスする方法の中心にあります。
この関数はビューを返しtorch.Tensor.view()
、新しい形状が元のテンソルの形状と互換性がある限り、を使用する場合とまったく同じです。それ以外の場合は、コピーを返します。
ただし、のメモは次のtorch.reshape()
ように警告しています。
隣接する入力と互換性のあるストライドを持つ入力は、コピーせずに再形成できますが、コピーと表示の動作に依存するべきではありません。
reshape
PyTorchでそれを呼び出さなかったのですか?