Freezing Weights mid train

I am trying to freeze certain weights of my network (using pytroch) dynamically during the middle of training. Specifically the weights I am trying to freeze relate entirely to certain inputs.

Ie if I have a network with inputs a,b,c,d and output y, then I am trying to freeze the first layer weight/s associated/connected only to input b. I also only want to freeze these weights/weight for a single backward pass. On the next batch it may be that the weights associated with input c need to be frozen.

Does anyone have an Idea how I might do this?

I know I need to overwrite the backkward function for the class I am using for my first layer but I am not sure of the exact structure of the gradients tensor. If I have a mask array with 1s for my frozen inputs and 0s for the rest.
currently I am simply doing the following and am unsure if it is correct:

def backward(self, grad_output):
grad_output = grad_outputut@torch.diag(self.mask)
return super().backward(grad_output)

Where self.mask is simply a 1d array for which inputs need to masked vs unmasked.

The implementation uses the optimizer to set requires_grad_ on the model parameters.

from what I can see, this seems to operate on a per layer basis rather than a per weight basis.

It can be set per tensor. See example below. Can you break out your first layer into separate modules for each input? I think that would be the best way to do it if possible. If not you could mask the desired parameter updates during your update parameter step, but I guess it depends on what your input and first layer look like as to what that would look like. Depending on how this is composed it may or may not be possible.

This seems to work if you want to freeze a parameter for a whole training run. But I am looking at freezing and unfreezing parameters during training. I may be incorrect here so please let me know if so?

Why do you think that? You should be able to set requires grad or mask the parameter updates for each batch or optimizer update step. If you’re using the learner you would use callbacks. If you’re using a pytorch training loop you can manually add it to the loop yourself.

Some callback examples: Understanding callbacks in fastai • Pierre Ouannes

Callbacks in the course: Practical Deep Learning for Coders - 16: The Learner framework

Oh Ok, thanks for that info I was unaware. Was currently trying to solve the probelem by modifying the backward pass for my first layer, but this looks more promising. Cheers