Extract features from second-to-last layer after cutting last layer


I’m using Resnet18 to do image classification based on three different classes. Instead of the actual predictions of the classes I’m interested in the resulting features (i.e. I want the input for the last linear layer instead of its output).

I found another topic on how to remove the last linear layer from the head of the model and it seems to work as the new model is now missing the last layer (and the new last layer is Dropout). However, when I now try to get the “predictions” (which to my understanding should be the features instead of the actual predictions?), then I get an IndexError (“IndexError: list index out of range”).

These are the lines of code I’m using to get the predictions (loading the image and showing it works just fine):

features_out = np.empty((len(train_dataset["id"]), 512))
features_out_img = train_dataset["file_name"]
imagery_path = "/Users/me/Documents/image_folder"

i = 6691
path_i = "282549.jpg"

temp_img = load_image(os.path.join(imagery_path, path_i))


The full output is here:

IndexError                                Traceback (most recent call last)
Input In [56], in <cell line: 10>()
      6 path_i = "282549.jpg"
      8 temp_img = load_image(os.path.join(imagery_path, path_i))
---> 10 model.predict(temp_img)

File ~/opt/anaconda3/envs/fastai2-pytorch2/lib/python3.9/site-packages/fastai/learner.py:324, in Learner.predict(self, item, rm_type_tfms, with_input)
    322 i = getattr(self.dls, 'n_inp', -1)
    323 inp = (inp,) if i==1 else tuplify(inp)
--> 324 dec = self.dls.decode_batch(inp + tuplify(dec_preds))[0]
    325 dec_inp,dec_targ = map(detuplify, [dec[:i],dec[i:]])
    326 res = dec_targ,dec_preds[0],preds[0]

File ~/opt/anaconda3/envs/fastai2-pytorch2/lib/python3.9/site-packages/fastai/data/core.py:121, in TfmdDL.decode_batch(self, b, max_n, full)
    116 def decode_batch(self, 
    117     b, # Batch to decode
    118     max_n:int=9, # Maximum number of items to decode
    119     full:bool=True # Whether to decode all transforms. If `False`, decode up to the point the item knows how to show itself
    120 ): 
--> 121     return self._decode_batch(self.decode(b), max_n, full)

File ~/opt/anaconda3/envs/fastai2-pytorch2/lib/python3.9/site-packages/fastai/data/core.py:127, in TfmdDL._decode_batch(self, b, max_n, full)
    125 f1 = self.before_batch.decode
    126 f = compose(f1, f, partial(getcallable(self.dataset,'decode'), full = full))
--> 127 return L(batch_to_samples(b, max_n=max_n)).map(f)

File ~/opt/anaconda3/envs/fastai2-pytorch2/lib/python3.9/site-packages/fastcore/foundation.py:156, in L.map(self, f, *args, **kwargs)
--> 156 def map(self, f, *args, **kwargs): return self._new(map_ex(self, f, *args, gen=False, **kwargs))

File ~/opt/anaconda3/envs/fastai2-pytorch2/lib/python3.9/site-packages/fastcore/basics.py:840, in map_ex(iterable, f, gen, *args, **kwargs)
    838 res = map(g, iterable)
    839 if gen: return res
--> 840 return list(res)

File ~/opt/anaconda3/envs/fastai2-pytorch2/lib/python3.9/site-packages/fastcore/basics.py:825, in bind.__call__(self, *args, **kwargs)
    823     if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    824 fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 825 return self.func(*fargs, **kwargs)

File ~/opt/anaconda3/envs/fastai2-pytorch2/lib/python3.9/site-packages/fastcore/basics.py:850, in compose.<locals>._inner(x, *args, **kwargs)
    849 def _inner(x, *args, **kwargs):
--> 850     for f in funcs: x = f(x, *args, **kwargs)
    851     return x

File ~/opt/anaconda3/envs/fastai2-pytorch2/lib/python3.9/site-packages/fastai/data/core.py:466, in Datasets.decode(self, o, full)
--> 466 def decode(self, o, full=True): return tuple(tl.decode(o_, full=full) for o_,tl in zip(o,tuplify(self.tls, match=o)))

File ~/opt/anaconda3/envs/fastai2-pytorch2/lib/python3.9/site-packages/fastai/data/core.py:466, in <genexpr>(.0)
--> 466 def decode(self, o, full=True): return tuple(tl.decode(o_, full=full) for o_,tl in zip(o,tuplify(self.tls, match=o)))

File ~/opt/anaconda3/envs/fastai2-pytorch2/lib/python3.9/site-packages/fastai/data/core.py:381, in TfmdLists.decode(self, o, **kwargs)
--> 381 def decode(self, o, **kwargs): return self.tfms.decode(o, **kwargs)

File ~/opt/anaconda3/envs/fastai2-pytorch2/lib/python3.9/site-packages/fastcore/transform.py:216, in Pipeline.decode(self, o, full)
    215 def decode  (self, o, full=True):
--> 216     if full: return compose_tfms(o, tfms=self.fs, is_enc=False, reverse=True, split_idx=self.split_idx)
    217     #Not full means we decode up to the point the item knows how to show itself.
    218     for f in reversed(self.fs):

File ~/opt/anaconda3/envs/fastai2-pytorch2/lib/python3.9/site-packages/fastcore/transform.py:158, in compose_tfms(x, tfms, is_enc, reverse, **kwargs)
    156 for f in tfms:
    157     if not is_enc: f = f.decode
--> 158     x = f(x, **kwargs)
    159 return x

File ~/opt/anaconda3/envs/fastai2-pytorch2/lib/python3.9/site-packages/fastcore/transform.py:82, in Transform.decode(self, x, **kwargs)
---> 82 def decode  (self, x, **kwargs): return self._call('decodes', x, **kwargs)

File ~/opt/anaconda3/envs/fastai2-pytorch2/lib/python3.9/site-packages/fastcore/transform.py:91, in Transform._call(self, fn, x, split_idx, **kwargs)
     89 def _call(self, fn, x, split_idx=None, **kwargs):
     90     if split_idx!=self.split_idx and self.split_idx is not None: return x
---> 91     return self._do_call(getattr(self, fn), x, **kwargs)

File ~/opt/anaconda3/envs/fastai2-pytorch2/lib/python3.9/site-packages/fastcore/transform.py:97, in Transform._do_call(self, f, x, **kwargs)
     95     if f is None: return x
     96     ret = f.returns(x) if hasattr(f,'returns') else None
---> 97     return retain_type(f(x, **kwargs), x, ret)
     98 res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)
     99 return retain_type(res, x)

File ~/opt/anaconda3/envs/fastai2-pytorch2/lib/python3.9/site-packages/fastcore/dispatch.py:120, in TypeDispatch.__call__(self, *args, **kwargs)
    118 elif self.inst is not None: f = MethodType(f, self.inst)
    119 elif self.owner is not None: f = MethodType(f, self.owner)
--> 120 return f(*args, **kwargs)

File ~/opt/anaconda3/envs/fastai2-pytorch2/lib/python3.9/site-packages/fastai/data/transforms.py:264, in Categorize.decodes(self, o)
--> 264 def decodes(self, o): return Category  (self.vocab    [o])

File ~/opt/anaconda3/envs/fastai2-pytorch2/lib/python3.9/site-packages/fastcore/foundation.py:88, in CollBase.__getitem__(self, k)
---> 88 def __getitem__(self, k): return self.items[list(k) if isinstance(k,CollBase) else k]

File ~/opt/anaconda3/envs/fastai2-pytorch2/lib/python3.9/site-packages/fastcore/foundation.py:112, in L.__getitem__(self, idx)
--> 112 def __getitem__(self, idx): return self._get(idx) if is_indexer(idx) else L(self._get(idx), use_list=None)

File ~/opt/anaconda3/envs/fastai2-pytorch2/lib/python3.9/site-packages/fastcore/foundation.py:116, in L._get(self, i)
    115 def _get(self, i):
--> 116     if is_indexer(i) or isinstance(i,slice): return getattr(self.items,'iloc',self.items)[i]
    117     i = mask2idxs(i)
    118     return (self.items.iloc[list(i)] if hasattr(self.items,'iloc')
    119             else self.items.__array__()[(i,)] if hasattr(self.items,'__array__')
    120             else [self.items[i_] for i_ in i])

IndexError: list index out of range


The following should give you the features:

learn = vision_learner(dls, resnet18)  # dls is your task-specific dataloader
learn.model[-1] = cut_layer(learn.model[-1], -1)  # to remove the last linear layer; learn is Learner object and has model attribute
temp_img.shape  # torch.Size([3, 224, 224]) for example
learn.model.eval() # to put BatchNorm in evaluation mode
features = learn.model(temp_img.unsqueeze(0))
features.shape # torch.Size([1, 512])

learn.predict() returns the decoded predictions; the model does not know how to take the features and decode them to obtain predictions.

1 Like

Thanks for the reply! I had to adjust a few things, but it seems like now it’s working.

temp_img = load_image(os.path.join(imagery_path, path_i))

temp_img.shape  # torch.Size([3, 224, 224]) for example

temp_img = image2tensor(temp_img).to("mps")
temp_img = torch.as_tensor(temp_img, dtype = torch.float32, device = "mps") # dtype has to be specified, otherwise I get an error
model.model.eval() # to put BatchNorm in evaluation mode
features = model.model(temp_img.unsqueeze(0))
features.shape # torch.Size([1, 512])


I have been through your query, you should use the model as a feature extractor. Here’s how you can do it:

Python Code-

import torch
import torchvision.transforms as transforms
from torchvision import models

# Load the pre-trained ResNet18 model
model = models.resnet18(pretrained=True)

# Remove the last classification layer
model = torch.nn.Sequential(*list(model.children())[:-1])

# Set the model to evaluation mode

# Define a transformation for your input image
preprocess = transforms.Compose([
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),

# Load and preprocess your image
input_image = preprocess(temp_img).unsqueeze(0)  # Add a batch dimension

# Extract features
with torch.no_grad():
    features = model(input_image)

# Now, 'features' contains the features extracted from your image

This code loads the ResNet18 model, removes the last classification layer, preprocesses your image, and then extracts the features from the image using the modified model. The variable ‘features’ contains the extracted features that you can use for your purposes.