Hook in pytorch
文章目录
hook 的引入
在pytorch中的自动求梯度机制(Autograd mechanics)中,如果将tensor的requires_grad设为True, 那么涉及到它的一系列运算将在反向传播中自动求梯度。但是自动求导的机制有个我们需要注意的地方:在自动求导机制中只保存叶子节点,也就是中间变量在计算完成梯度后会自动释放以节省空间.
We’ve inspected the weights and the gradients. But how about inspecting / modifying the output and grad_output of a layer ? We introduce hooks for this purpose.
hook的引入是为了让我们可以检测或者修改一个layer的output或者grad_output.
hook 的分类
-
TENSOR.register_hook(FUNCTION)
-
MODULE.register_forward_hook(FUNCTION)
-
MODULE.register_backward_hook(FUNCTION)
PyTorch hooks are registered for each
Tensor
ornn.Module
object and are triggered by either the forward or backward pass of the object. They have the following function signatures:
|
|
hook 的使用场景
- debug
Each hook can modify the input, output, or internal Module parameters. Most commonly, they are used for debugging purposes. But we will see that they have many other uses.
- verbose model execution
可以使用print,但是不够professional。
Never again! Let’s use hooks instead to debug models without modifying their implementation in any way. For example, suppose you want to know the shape of each layer’s output. We can create a simple wrapper that prints the output shapes using hooks
给出一个case:输出模型每一层的 output shape using hooks
|
|
- feature extraction
Commonly, we want to generate features from a pre-trained network, and use them for another task (e.g. classification, similarity search, etc.). Using hooks, we can extract features without needing to re-create the existing model or modify it in any way.
使用任意 pretrained model 去 extract feature embedding
|
|
提取 layer4
和avgpool
的信息,后者的维度是 2048
- gradient clipping
Gradient clipping is a well-known method for dealing with exploding gradients. PyTorch already provides utility methods for performing gradient clipping, but we can also easily do it with hooks. Any other method for gradient clipping/normalization/modification can be done the same way.
可以使用pytorch 自带的 method,也可以使用 hook实现 gradient clipping (用来处理梯度爆炸)
|
|
参考文献
文章作者 jijeng
上次更新 2020-11-29