【PyTorch】register_hook的使用

首先看一下正常的梯度计算例子:

#正常求导情况v = torch.randn((1, 3), dtype=torch.float32, requires_grad=True)z = v.sum()z.backward()print(v.grad)

输出:

tensor([[1., 1., 1.]])

上面的代码中,当执行到z.backward()这一句代码的时候,就是计算变量z的偏导数,因为v=torch.randn(1,3)也就是1行3列,所以可以假设v=(v1, v2, v3),那么z=v1+v2+v3。所以z对v的偏微分就是:

图片[1] - 【PyTorch】register_hook的使用 - MaxSSL

其中:

图片[1] - 【PyTorch】register_hook的使用 - MaxSSL

所以可以得出上面的偏微分结果为 :tensor([[1., 1., 1.]])。

如果我们需要对导数进行2倍的操作:

v = torch.randn((1, 3), dtype=torch.float32, requires_grad=True)z = v.sum()# lambda grad: grad*2是一个函数,即:# def lambda(grad):#return grad*2v.register_hook(lambda grad: grad*2)z.backward()print(v.grad)

输出为:

tensor([[2., 2., 2.]])

可以看出v.register_hook()的作用是将反向传播过程中关于v的梯度给取出来,同时进行一些操作,上面代码所进行的操作是对关于v的梯度乘以2,当然,这里的梯度只是暂时取出来了,如果需要“长久的”保存梯度信息方便后续的计算的话,则可以如下代码所示:

grad_store = []def function(grad):grad_store.append(grad)v = torch.randn((1, 3), dtype=torch.float32, requires_grad=True)z = v.sum()v.register_hook(function)z.backward()

上面的代码即可将梯度保存到变量grad_store中,方便后面计算Grad-CAM等等。

© 版权声明
THE END
喜欢就支持一下吧
点赞0分享