PyTorchのbackwardに引数を入れる

Nov. 4, 2021, 12:51 p.m. edited Nov. 6, 2021, 6:05 a.m.

#PyTorch  #Python 

PyTorch の backward についてメモ。

x, y がスカラー、引数なし

まずは引数なしの場合のシンプルな例:

import torch

x = torch.tensor(3., requires_grad=True)
y = 2 * x ** 2
y.backward()

print(f'x: {x}, y: {y}')
print(f'x.grad: {x.grad}')

とすると、出力は

x: 3.0, y: 18.0
x.grad: 12.0

となる。なぜなら、 \(y=2x^2\) に \(x=3\) を代入すると \(y=18\) となり、また、このとき \(\frac{dy}{dx}=4x=12\) となるため。

x, y がベクトル、引数にベクトル v

次に本題の引数ありの場合(式は同じ \(y=2x^2\) だが x にベクトルを用いている)の例:

import torch

x = torch.tensor([2., 3., 3.], requires_grad=True)
y = 2 * x ** 2
v = torch.tensor([1., 1., 2.])
y.backward(v)

print(f'x: {x}, y: {y},  v: {v}')
print(f'x.grad: {x.grad}')

出力は

x: tensor([2., 3., 3.], requires_grad=True), y: tensor([ 8., 18., 18.], grad_fn=<MulBackward0>),  v: tensor([1., 1., 2.])
x.grad: tensor([ 8., 12., 24.])

となる。それぞれの x の値に対して y および x.grad が結びついているが、同じ x=3. に対して 2 通りの x.grad (12.24.) がある。これは、 y.backward() に引数として v を渡すことでそれぞれのもともとの x.grad の値に要素積されたためである。つまり、 [8. * 1., 12. * 1., 12. * 2.] = [8., 12., 24.] ということ。

最後に、以下のような x の各要素の値が y の 1 つの要素に影響を与える場合を考える:

x = torch.tensor([2., 3., 3.], requires_grad=True)
y = 2 * x ** 2
y[0] = (x ** 2).sum()
v = torch.tensor([1., 1., 2.])
y.backward(v)
print(f'x: {x}, y: {y},  v: {v}')
print(f'x.grad: {x.grad}')

出力は

x: tensor([2., 3., 3.], requires_grad=True), y: tensor([22., 18., 18.], grad_fn=<CopySlices>),  v: tensor([1., 1., 2.])
x.grad: tensor([ 4., 18., 30.])

これは y の 1, 2 番目の要素は先程までと変わらず \(y=2x^2\) であるが、 0 番目は x の二乗和となっている。ゆえに、 y[0] = 2 ** 2 + 3 ** 2 + 3 ** 2 = 22

では、肝心の x.grad を見てみよう。まず x.grad[0] は \(y_0=x_0^2+x_1^2+x_2^2\) を \(x_0\) で偏微分したものなので、 \(\frac{\partial y_0}{\partial x_0}=2x_0\) に \(x_0=2\) を代入して x.grad[0] = 4 となる。これは x[0] しか使っていないので簡単。問題は次である。 x.grad[1]y[0]y[1] 両方に関係があり、その計算式は \(\frac{\partial y_0}{\partial x_1}+\frac{dy_1}{dx_1}=2x_1+4x_1=6x_1\)、ここに \(x_1=3\) を代入して x.grad[1] = 18 となる。そして最後に \(v_2=2\neq 1\) である x.grad[2] の計算式は \(\frac{\partial y_0}{\partial x_2}+\frac{dy_2}{dx_2}\times 2=2x_2+4x_2\times 2=10x_2\)、ここに \(x_2=3\) を代入して x.grad[2] = 30 となる。

つまり、 \(x\), \(y\), \(v\) が \(n\) 次元ベクトルのときの \(x.grad\) を得る計算式の一般形は

$$x.grad_i=\sum_{j=0}^{n-1}\frac{\partial y_j}{\partial x_i}\times v_j$$

ということになる。

はたしてこんなの使うケースがあるのかと疑問に思いがちだが、連鎖律を自分で使ったりするときなど思わぬところで有用なことがある。私も今日そんなケースにあったのでこの記事を書いた。