hook 的引入

在pytorch中的自动求梯度机制(Autograd mechanics)中,如果将tensor的requires_grad设为True, 那么涉及到它的一系列运算将在反向传播中自动求梯度。但是自动求导的机制有个我们需要注意的地方:在自动求导机制中只保存叶子节点,也就是中间变量在计算完成梯度后会自动释放以节省空间.

We’ve inspected the weights and the gradients. But how about inspecting / modifying the output and grad_output of a layer ? We introduce hooks for this purpose.

hook的引入是为了让我们可以检测或者修改一个layer的output或者grad_output.

hook 的分类

  • TENSOR.register_hook(FUNCTION)

  • MODULE.register_forward_hook(FUNCTION)

  • MODULE.register_backward_hook(FUNCTION)

PyTorch hooks are registered for each Tensor or nn.Module object and are triggered by either the forward or backward pass of the object. They have the following function signatures:

1
2
3
4
5
6
from torch import nn, Tensor
def module_hook(module: nn.Module, input: Tensor, output: Tensor):
    # For nn.Module objects only.
def tensor_hook(grad: Tensor):
    # For Tensor objects only.
    # Only executed during the *backward* pass!

hook 的使用场景

  1. debug

Each hook can modify the input, output, or internal Module parameters. Most commonly, they are used for debugging purposes. But we will see that they have many other uses.

  1. verbose model execution

可以使用print,但是不够professional。

Never again! Let’s use hooks instead to debug models without modifying their implementation in any way. For example, suppose you want to know the shape of each layer’s output. We can create a simple wrapper that prints the output shapes using hooks

给出一个case:输出模型每一层的 output shape using hooks

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
from torchvision.models import resnet50
from torch import nn, Tensor

class VerboseExecution(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model

        # Register a hook for each layer
        for name, layer in self.model.named_children():
            layer.__name__ = name
            layer.register_forward_hook(
                lambda layer, _, output: print(f"{layer.__name__}: {output.shape}")
            )

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)

verbose_resnet = VerboseExecution(resnet50())
dummy_input = torch.ones(10, 3, 224, 224)
_ = verbose_resnet(dummy_input)
# conv1: torch.Size([10, 64, 112, 112])
# bn1: torch.Size([10, 64, 112, 112])
# relu: torch.Size([10, 64, 112, 112])
# maxpool: torch.Size([10, 64, 56, 56])
# layer1: torch.Size([10, 256, 56, 56])
# layer2: torch.Size([10, 512, 28, 28])
# layer3: torch.Size([10, 1024, 14, 14])
# layer4: torch.Size([10, 2048, 7, 7])
# avgpool: torch.Size([10, 2048, 1, 1])
# fc: torch.Size([10, 1000])
  1. feature extraction

Commonly, we want to generate features from a pre-trained network, and use them for another task (e.g. classification, similarity search, etc.). Using hooks, we can extract features without needing to re-create the existing model or modify it in any way.

使用任意 pretrained model 去 extract feature embedding

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from typing import Dict, Iterable, Callable

class FeatureExtractor(nn.Module):
    def __init__(self, model: nn.Module, layers: Iterable[str]):
        super().__init__()
        self.model = model
        self.layers = layers
        self._features = {layer: torch.empty(0) for layer in layers}

        for layer_id in layers:
            layer = dict([*self.model.named_modules()])[layer_id]
            layer.register_forward_hook(self.save_outputs_hook(layer_id))

    def save_outputs_hook(self, layer_id: str) -> Callable:
        def fn(_, __, output):
            self._features[layer_id] = output
        return fn

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        _ = self.model(x)
        return self._features
    
resnet_features = FeatureExtractor(resnet50(), layers=["layer4", "avgpool"])
features = resnet_features(dummy_input)
print({name: output.shape for name, output in features.items()})
# {'layer4': torch.Size([10, 2048, 7, 7]), 'avgpool': torch.Size([10, 2048, 1, 1])}

提取 layer4avgpool 的信息,后者的维度是 2048

  1. gradient clipping

Gradient clipping is a well-known method for dealing with exploding gradients. PyTorch already provides utility methods for performing gradient clipping, but we can also easily do it with hooks. Any other method for gradient clipping/normalization/modification can be done the same way.

可以使用pytorch 自带的 method,也可以使用 hook实现 gradient clipping (用来处理梯度爆炸)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
def gradient_clipper(model: nn.Module, val: float) -> nn.Module:
    for parameter in model.parameters():
        parameter.register_hook(lambda grad: grad.clamp_(-val, val))
    return model

clipped_resnet = gradient_clipper(resnet50(), 0.01)
pred = clipped_resnet(dummy_input)
loss = pred.log().mean()
loss.backward()
print(clipped_resnet.fc.bias.grad[:25])
# tensor([-0.0010, -0.0047, -0.0010, -0.0009, -0.0015,  0.0027,  0.0017, -0.0023,
#          0.0051, -0.0007, -0.0057, -0.0010, -0.0039, -0.0100, -0.0018,  0.0062,
#          0.0034, -0.0010,  0.0052,  0.0021,  0.0010,  0.0017, -0.0100,  0.0021,
#          0.0020])

参考文献

How to Use PyTorch Hooks

PyTorch Hook