L1 cost penalty for specific layer

I am a bit of a newbie with fastai and pytorch, so my apologies if this is a silly question.

I am trying to replicate some results from a paper and in this paper, they are adding an L1 penalty for a specific layer in the model. So say I have a model that has a couple of linear layers and then a couple of convolutional layers and I would like for my loss function to put an extra L1 penalty on the output of a hidden layer in addition to MSELoss on the output, so say my network is something like this (this is a dummy example, to show the general idea):

class MyNetwork(nn.Module):
    def __init__(self, samples_in, matrix_out):
        super(MyNetwork, self).__init__()
        
        self.samples_in = samples_in
        self.matrix_out = matrix_out
        self.samples_out = np.prod(self.matrix_out)
        
        self.fc1 = nn.Sequential(nn.Linear(self.samples_in*2, self.samples_out), nn.Tanh())
        self.cnv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=64, kernel_size=5, stride=1, padding=2), nn.ReLU())
        self.dcnv = nn.Sequential(nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=7, stride=1, padding=3))

    def forward(self, x):
        batch_size = x.shape[0]
        
        x = self.fc1(x)
        x = x.reshape((batch_size,1) + self.matrix_out)
        x = self.cnv1(x)

        # Calculating L1 norm of the output of convolutional layer:
        l1_term = torch.mean(torch.abs(x))
  
        x = self.dcnv(x)
        x = x.reshape((batch_size, self.samples_out))

        return x

I would like the cost function to be something like:

def mycost(pred, target):
     cost = ((pred-target)**2).mean() + 0.001*l1_term
     return cost

So in order to do that I could change the forward function to return the l1_term as well, so something like:

def forward(self, x):
   #... other stuff in forward

   return x, l1_term

and then have cost function:

def mycost(pred, target):
     cost = ((pred[0]-target)**2).mean() + 0.001*pred[1]
     return cost

Which actually means to work in the sense that the training runs, but when I go to do a prediction with the model afterwards, I get an error. It is probably predictable that there would be some issues, but it is pretty opaque to me as a newbie what is happening, here is the error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
~/.conda/envs/ai/lib/python3.8/site-packages/fastai/torch_core.py in to_concat(xs, dim)
    236     #   in this case we return a big list
--> 237     try:    return retain_type(torch.cat(xs, dim=dim), xs[0])
    238     except: return sum([L(retain_type(o_.index_select(dim, tensor(i)).squeeze(dim), xs[0])

RuntimeError: zero-dimensional tensor (at position 0) cannot be concatenated

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
<ipython-input-21-3f52613b143e> in <module>
      3 img = read_crop_image(testpath,size=image_size).reshape((image_size,image_size))
      4 raw = read_transform_image(testpath,size=image_size,sampling_pattern=sp)
----> 5 pred = learn.predict(raw)

~/.conda/envs/ai/lib/python3.8/site-packages/fastai/learner.py in predict(self, item, rm_type_tfms, with_input)
    246     def predict(self, item, rm_type_tfms=None, with_input=False):
    247         dl = self.dls.test_dl([item], rm_type_tfms=rm_type_tfms, num_workers=0)
--> 248         inp,preds,_,dec_preds = self.get_preds(dl=dl, with_input=True, with_decoded=True)
    249         i = getattr(self.dls, 'n_inp', -1)
    250         inp = (inp,) if i==1 else tuplify(inp)

~/.conda/envs/ai/lib/python3.8/site-packages/fastai/learner.py in get_preds(self, ds_idx, dl, with_input, with_decoded, with_loss, act, inner, reorder, cbs, n_workers, **kwargs)
    233         if with_loss: ctx_mgrs.append(self.loss_not_reduced())
    234         with ContextManagers(ctx_mgrs):
--> 235             self._do_epoch_validate(dl=dl)
    236             if act is None: act = getattr(self.loss_func, 'activation', noop)
    237             res = cb.all_tensors()

~/.conda/envs/ai/lib/python3.8/site-packages/fastai/learner.py in _do_epoch_validate(self, ds_idx, dl)
    186         if dl is None: dl = self.dls[ds_idx]
    187         self.dl = dl
--> 188         with torch.no_grad(): self._with_events(self.all_batches, 'validate', CancelValidException)
    189 
    190     def _do_epoch(self):

~/.conda/envs/ai/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    155         try:       self(f'before_{event_type}')       ;f()
    156         except ex: self(f'after_cancel_{event_type}')
--> 157         finally:   self(f'after_{event_type}')        ;final()
    158 
    159     def all_batches(self):

~/.conda/envs/ai/lib/python3.8/site-packages/fastai/learner.py in __call__(self, event_name)
    131     def ordered_cbs(self, event): return [cb for cb in sort_by_run(self.cbs) if hasattr(cb, event)]
    132 
--> 133     def __call__(self, event_name): L(event_name).map(self._call_one)
    134 
    135     def _call_one(self, event_name):

~/.conda/envs/ai/lib/python3.8/site-packages/fastcore/foundation.py in map(self, f, *args, **kwargs)
    394              else f.format if isinstance(f,str)
    395              else f.__getitem__)
--> 396         return self._new(map(g, self))
    397 
    398     def filter(self, f, negate=False, **kwargs):

~/.conda/envs/ai/lib/python3.8/site-packages/fastcore/foundation.py in _new(self, items, *args, **kwargs)
    340     @property
    341     def _xtra(self): return None
--> 342     def _new(self, items, *args, **kwargs): return type(self)(items, *args, use_list=None, **kwargs)
    343     def __getitem__(self, idx): return self._get(idx) if is_indexer(idx) else L(self._get(idx), use_list=None)
    344     def copy(self): return self._new(self.items.copy())

~/.conda/envs/ai/lib/python3.8/site-packages/fastcore/foundation.py in __call__(cls, x, *args, **kwargs)
     49             return x
     50 
---> 51         res = super().__call__(*((x,) + args), **kwargs)
     52         res._newchk = 0
     53         return res

~/.conda/envs/ai/lib/python3.8/site-packages/fastcore/foundation.py in __init__(self, items, use_list, match, *rest)
    331         if items is None: items = []
    332         if (use_list is not None) or not _is_array(items):
--> 333             items = list(items) if use_list else _listify(items)
    334         if match is not None:
    335             if is_coll(match): match = len(match)

~/.conda/envs/ai/lib/python3.8/site-packages/fastcore/foundation.py in _listify(o)
    244     if isinstance(o, list): return o
    245     if isinstance(o, str) or _is_array(o): return [o]
--> 246     if is_iter(o): return list(o)
    247     return [o]
    248 

~/.conda/envs/ai/lib/python3.8/site-packages/fastcore/foundation.py in __call__(self, *args, **kwargs)
    307             if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    308         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 309         return self.fn(*fargs, **kwargs)
    310 
    311 # Cell

~/.conda/envs/ai/lib/python3.8/site-packages/fastai/learner.py in _call_one(self, event_name)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name), event_name
--> 137         [cb(event_name) for cb in sort_by_run(self.cbs)]
    138 
    139     def _bn_bias_state(self, with_bias): return norm_bias_params(self.model, with_bias).map(self.opt.state)

~/.conda/envs/ai/lib/python3.8/site-packages/fastai/learner.py in <listcomp>(.0)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name), event_name
--> 137         [cb(event_name) for cb in sort_by_run(self.cbs)]
    138 
    139     def _bn_bias_state(self, with_bias): return norm_bias_params(self.model, with_bias).map(self.opt.state)

~/.conda/envs/ai/lib/python3.8/site-packages/fastai/callback/core.py in __call__(self, event_name)
     42                (self.run_valid and not getattr(self, 'training', False)))
     43         res = None
---> 44         if self.run and _run: res = getattr(self, event_name, noop)()
     45         if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
     46         return res

~/.conda/envs/ai/lib/python3.8/site-packages/fastai/callback/core.py in after_validate(self)
    118         if not hasattr(self, 'preds'): return
    119         if self.with_input:     self.inputs  = detuplify(to_concat(self.inputs, dim=self.concat_dim))
--> 120         if not self.save_preds: self.preds   = detuplify(to_concat(self.preds, dim=self.concat_dim))
    121         if not self.save_targs: self.targets = detuplify(to_concat(self.targets, dim=self.concat_dim))
    122         if self.with_loss:      self.losses  = to_concat(self.losses)

~/.conda/envs/ai/lib/python3.8/site-packages/fastai/torch_core.py in to_concat(xs, dim)
    231     "Concat the element in `xs` (recursively if they are tuples/lists of tensors)"
    232     if not xs: return xs
--> 233     if is_listy(xs[0]): return type(xs[0])([to_concat([x[i] for x in xs], dim=dim) for i in range_of(xs[0])])
    234     if isinstance(xs[0],dict):  return {k: to_concat([x[k] for x in xs], dim=dim) for k in xs[0].keys()}
    235     #We may receive xs that are not concatenable (inputs of a text classifier for instance),

~/.conda/envs/ai/lib/python3.8/site-packages/fastai/torch_core.py in <listcomp>(.0)
    231     "Concat the element in `xs` (recursively if they are tuples/lists of tensors)"
    232     if not xs: return xs
--> 233     if is_listy(xs[0]): return type(xs[0])([to_concat([x[i] for x in xs], dim=dim) for i in range_of(xs[0])])
    234     if isinstance(xs[0],dict):  return {k: to_concat([x[k] for x in xs], dim=dim) for k in xs[0].keys()}
    235     #We may receive xs that are not concatenable (inputs of a text classifier for instance),

~/.conda/envs/ai/lib/python3.8/site-packages/fastai/torch_core.py in to_concat(xs, dim)
    236     #   in this case we return a big list
    237     try:    return retain_type(torch.cat(xs, dim=dim), xs[0])
--> 238     except: return sum([L(retain_type(o_.index_select(dim, tensor(i)).squeeze(dim), xs[0])
    239                           for i in range_of(o_)) for o_ in xs], L())
    240 

~/.conda/envs/ai/lib/python3.8/site-packages/fastai/torch_core.py in <listcomp>(.0)
    237     try:    return retain_type(torch.cat(xs, dim=dim), xs[0])
    238     except: return sum([L(retain_type(o_.index_select(dim, tensor(i)).squeeze(dim), xs[0])
--> 239                           for i in range_of(o_)) for o_ in xs], L())
    240 
    241 # Cell

~/.conda/envs/ai/lib/python3.8/site-packages/fastcore/utils.py in range_of(x)
    198 def range_of(x):
    199     "All indices of collection `x` (i.e. `list(range(len(x)))`)"
--> 200     return list(range(len(x)))
    201 
    202 # Cell

~/.conda/envs/ai/lib/python3.8/site-packages/torch/tensor.py in __len__(self)
    443     def __len__(self):
    444         if self.dim() == 0:
--> 445             raise TypeError("len() of a 0-d tensor")
    446         return self.shape[0]
    447 

TypeError: len() of a 0-d tensor

So I am thinking that I am probably trying to shoehorn this in the wrong way that I am wondering if there is a typical pattern one can/should use with fastai. Any help/guidance/commentary would be much appreciated.

Thanks,
Michael

Maybe I can have a go at answering this myself :wink: One possible pattern is to register a forward hook on that convolution layer and in that hook, grab the output of the layer, calculate the L1 norm and store it in a variable, which we can then add in when calculating the cost. So something like this:

def myfwdhook(module,input_,output):
     global l1_term
     l1_term = output.abs().mean()

def mycustomloss(pred,target):
     return ((pred-target)**2).mean() + 0.0001*l1_term

class MyNetwork(nn.Module):
    def __init__(self, samples_in, matrix_out):
        super(MyNetwork, self).__init__()
        
        self.samples_in = samples_in
        self.matrix_out = matrix_out
        self.samples_out = np.prod(self.matrix_out)
        
        self.fc1 = nn.Sequential(nn.Linear(self.samples_in*2, self.samples_out), nn.Tanh())
        self.cnv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=64, kernel_size=5, stride=1, padding=2), nn.ReLU())

        self.cnv1.register_forward_hook(myfwdhook)        

        self.dcnv = nn.Sequential(nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=7, stride=1, padding=3))

    def forward(self, x):
        batch_size = x.shape[0]
        
        x = self.fc1(x)
        x = x.reshape((batch_size,1) + self.matrix_out)
        x = self.cnv1(x)
        x = self.dcnv(x)
        x = x.reshape((batch_size, self.samples_out))

        return x

Or something like that.

In my continued conversation with myself, I can maybe try another approach, which I think is the fastai way to do this with a Learner callback. So I could do something like:

class L1RegCallback(Callback):
    def __init__(self, reglambda = 0.0001):
        self.reglambda = reglambda
        
    def before_backward(self):
        regularization_loss = 0.0
        for param in self.learn.model.cnv2.parameters():
            regularization_loss += torch.mean(torch.abs(param))
        
        self.learn.loss += self.reglambda*regularization_loss

And then something like:

learn = Learner(dls, mynetwork, opt_func=RMSProp, loss_func = nn.MSELoss(reduction='mean'), metrics=nn.MSELoss(reduction='mean'), cbs=[L1RegCallback()])

And it would add the L1 loss to the loss (before backprojection).

@sachinruk and @sgugger you have had some other threads on this:

But I think some of that was based on previous version of fastai. So I was wondering what your thoughts are, is the approach above correct for adding an L1 loss (for a specific layer).

Hey, just to quickly answer your question, I found better (faster) convergence when you updated the weights directly instead of adding it to the loss as you seem to be doing here.

The only thing Iā€™d change in my answer is that it ought to be param.data = param.data - learning_rate * self.beta * sign.

1 Like