PyTorch - Best way to get at intermediate layers in VGG and ResNet?

Just getting started with transfer learning in PyTorch and was wondering …

What is the recommended way(s) to grab output at intermediate layers (not just the last layer)?

In particular, how should one pre-compute the convolutional output for VGG16 … or get the output of ResNet50 BEFORE the global average pooling layer?

2 Likes

Here is at least one way to do it … thoughts?

res50_model = models.resnet50(pretrained=True)
res50_conv = nn.Sequential(*list(res50_model.children())[:-2])

This grabs a pretrained resnet50 model courtesy of the torchvision package and then builds a sequential model based on it that excludes the final two modules (e.g., the one that does average pooling and the fully connected one)

for param in res50_conv.parameters():
    param.requires_grad = False

No need to backprop through the model since I’m using it purely for feature extraction.

inputs, labels = next(iter(dataloders['train']))
inputs, labels = Variable(inputs), Variable(labels)
outputs = res50_conv(inputs)

To test, I grab 4 examples and run them through my modified model.

outputs.data.shape # => torch.Size([4, 2048, 7, 7])

And voila, I get the 2048x7x7 output I expected!

It’s feels both weird and cool to be able to pass in images of any size into the network and that it just works. I burnt a few minutes here and there trying to get the model to tell me the output size of this layer or that layer until I realized it only works for the fully connected layers because, I believe, those are the only ones that do have a definitive input and output shape. The convolutional layers inputs/outputs shape will b dynamic based on the shape of your examples … which like I said, feels both weird coming from using Theano/TF but also very cool.

10 Likes

Here is another way inspired from this forum post:

class ResNet50Bottom(nn.Module):
    def __init__(self, original_model):
        super(ResNet50Bottom, self).__init__()
        self.features = nn.Sequential(*list(original_model.children())[:-2])
        
    def forward(self, x):
        x = self.features(x)
        return x

res50_model = models.resnet50(pretrained=True)
res50_conv2 = ResNet50Bottom(res50_model)

outputs = res50_conv2(inputs)
outputs.data.shape  # => torch.Size([4, 2048, 7, 7])

Also, forgot to mention that I’m tweaking things from the transfer learning tutorial on the pytorch website. Check it out to understand the dataset and other particulars.

5 Likes

To get output of any layer while doing a single forward pass, you can use register_forward_hook.

outputs= []
def hook(module, input, output):
    outputs.append(output)

res50_model = models.resnet50(pretrained=True)
res50_model.layer4[0].conv2.register_forward_hook(hook)
out = res50_model(res)
out = res50_model(res1)
print(outputs)

Here in the outputs you get a list with two tensors in it. Those tensors are outputs of that particular layer for each forward pass.

4 Likes

I tried to change the first maxpool layer to avg_pool layer in resnet34, I did this but only got a maxpool missed model, could you tell me why doesn’t it work, how to do that?

class resnet30_avg(nn.Module):
def init(self, original_model):
super(resnet30_avg, self).init()
self.S1 = nn.Sequential(*list(original_model.children())[:3])
self.S2 = nn.Sequential(*list(original_model.children())[4:])

def forward(self, x):
    x = F.avg_pool2d(self.S1(x), 3, stride=2, padding=1)
    x = self.S2(x)
    return x

arch=resnet30_avg(models.resnet34(pretrained=True))

arch:
resnet30_avg(
(S1): Sequential(
(0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace)
)
(S2): Sequential(
(0): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

hi @krishnavishalv, I use this method to get outputs but how do I de-register the hook? Everytime I run my script, the outputs I have appended keep growing because I think the hook is still registered? Any ideas?

This is the code I’m using:

class VectorGetter(object):
    def __init__(self):
        self.model = MODEL
        self.model.eval()
        self.scaler = transforms.Scale((224, 224))
        self.normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
        self.to_tensor = transforms.ToTensor()
        self.layer = self.model._modules.get("avgpool")

    def get_vector(self, path):

        im_object =Image.open(path)

        t_img = Variable(
            self.normalize(self.to_tensor(self.scaler(im_object))).unsqueeze(0)
        )
        my_embedding = torch.zeros(512)

        def copy_data(m, i, o):
            my_embedding.copy_(o.data.squeeze())

        h = self.layer.register_forward_hook(copy_data)
        self.model(t_img)
        h.remove()

        return path, my_embedding.data.numpy().astype(DTYPE)

Does anyone know how to do this with multiple inputs, taking full advantage of the GPU?

1 Like

May be create another instance of the model and run? I don’t know how to unregister the hook

what are ‘res’ and ‘res1’? input images?

with fast ai, there are many ways to get the intermediate layers without using register the hook-like plain python. For Resent, I have used more than one method, one of the methods flattens the layer and extract the output by index.

In addition, for VGG, it is, even more, simpler if you want to extract the batch norm layers. The example is in this notebook.

1 Like

You can add the intermediate result as part of module attribute. So that you can access them as needed, but don’t need to keep a list of all results.

See this link for details: https://github.com/utkuozbulak/pytorch-cnn-visualizations/blob/master/src/cnn_layer_visualization.py

he just passed two separate input tensors called res and res1 to show that the method works

This page shows how to remove a hook

model = …
handle = model.register_forward_hook(…)
handle.remove()
hook will no longer trigger

There is a pytorch utility to easily get any intermediate result torch-intermediate-layer-getter · PyPI

hi, how can save these features to the txt files or csv files?, thank you.

1 Like