L1 cost penalty for specific layer

I am a bit of a newbie with fastai and pytorch, so my apologies if this is a silly question.

I am trying to replicate some results from a paper and in this paper, they are adding an L1 penalty for a specific layer in the model. So say I have a model that has a couple of linear layers and then a couple of convolutional layers and I would like for my loss function to put an extra L1 penalty on the output of a hidden layer in addition to MSELoss on the output, so say my network is something like this (this is a dummy example, to show the general idea):

class MyNetwork(nn.Module):
    def __init__(self, samples_in, matrix_out):
        super(MyNetwork, self).__init__()
        
        self.samples_in = samples_in
        self.matrix_out = matrix_out
        self.samples_out = np.prod(self.matrix_out)
        
        self.fc1 = nn.Sequential(nn.Linear(self.samples_in*2, self.samples_out), nn.Tanh())
        self.cnv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=64, kernel_size=5, stride=1, padding=2), nn.ReLU())
        self.dcnv = nn.Sequential(nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=7, stride=1, padding=3))

    def forward(self, x):
        batch_size = x.shape[0]
        
        x = self.fc1(x)
        x = x.reshape((batch_size,1) + self.matrix_out)
        x = self.cnv1(x)

        # Calculating L1 norm of the output of convolutional layer:
        l1_term = torch.mean(torch.abs(x))
  
        x = self.dcnv(x)
        x = x.reshape((batch_size, self.samples_out))

        return x

I would like the cost function to be something like:

def mycost(pred, target):
     cost = ((pred-target)**2).mean() + 0.001*l1_term
     return cost

So in order to do that I could change the forward function to return the l1_term as well, so something like:

def forward(self, x):
   #... other stuff in forward

   return x, l1_term

and then have cost function:

def mycost(pred, target):
     cost = ((pred[0]-target)**2).mean() + 0.001*pred[1]
     return cost

Which actually means to work in the sense that the training runs, but when I go to do a prediction with the model afterwards, I get an error. It is probably predictable that there would be some issues, but it is pretty opaque to me as a newbie what is happening, here is the error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
~/.conda/envs/ai/lib/python3.8/site-packages/fastai/torch_core.py in to_concat(xs, dim)
    236     #   in this case we return a big list
--> 237     try:    return retain_type(torch.cat(xs, dim=dim), xs[0])
    238     except: return sum([L(retain_type(o_.index_select(dim, tensor(i)).squeeze(dim), xs[0])

RuntimeError: zero-dimensional tensor (at position 0) cannot be concatenated

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
<ipython-input-21-3f52613b143e> in <module>
      3 img = read_crop_image(testpath,size=image_size).reshape((image_size,image_size))
      4 raw = read_transform_image(testpath,size=image_size,sampling_pattern=sp)
----> 5 pred = learn.predict(raw)

~/.conda/envs/ai/lib/python3.8/site-packages/fastai/learner.py in predict(self, item, rm_type_tfms, with_input)
    246     def predict(self, item, rm_type_tfms=None, with_input=False):
    247         dl = self.dls.test_dl([item], rm_type_tfms=rm_type_tfms, num_workers=0)
--> 248         inp,preds,_,dec_preds = self.get_preds(dl=dl, with_input=True, with_decoded=True)
    249         i = getattr(self.dls, 'n_inp', -1)
    250         inp = (inp,) if i==1 else tuplify(inp)

~/.conda/envs/ai/lib/python3.8/site-packages/fastai/learner.py in get_preds(self, ds_idx, dl, with_input, with_decoded, with_loss, act, inner, reorder, cbs, n_workers, **kwargs)
    233         if with_loss: ctx_mgrs.append(self.loss_not_reduced())
    234         with ContextManagers(ctx_mgrs):
--> 235             self._do_epoch_validate(dl=dl)
    236             if act is None: act = getattr(self.loss_func, 'activation', noop)
    237             res = cb.all_tensors()

~/.conda/envs/ai/lib/python3.8/site-packages/fastai/learner.py in _do_epoch_validate(self, ds_idx, dl)
    186         if dl is None: dl = self.dls[ds_idx]
    187         self.dl = dl
--> 188         with torch.no_grad(): self._with_events(self.all_batches, 'validate', CancelValidException)
    189 
    190     def _do_epoch(self):

~/.conda/envs/ai/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    155         try:       self(f'before_{event_type}')       ;f()
    156         except ex: self(f'after_cancel_{event_type}')
--> 157         finally:   self(f'after_{event_type}')        ;final()
    158 
    159     def all_batches(self):

~/.conda/envs/ai/lib/python3.8/site-packages/fastai/learner.py in __call__(self, event_name)
    131     def ordered_cbs(self, event): return [cb for cb in sort_by_run(self.cbs) if hasattr(cb, event)]
    132 
--> 133     def __call__(self, event_name): L(event_name).map(self._call_one)
    134 
    135     def _call_one(self, event_name):

~/.conda/envs/ai/lib/python3.8/site-packages/fastcore/foundation.py in map(self, f, *args, **kwargs)
    394              else f.format if isinstance(f,str)
    395              else f.__getitem__)
--> 396         return self._new(map(g, self))
    397 
    398     def filter(self, f, negate=False, **kwargs):

~/.conda/envs/ai/lib/python3.8/site-packages/fastcore/foundation.py in _new(self, items, *args, **kwargs)
    340     @property
    341     def _xtra(self): return None
--> 342     def _new(self, items, *args, **kwargs): return type(self)(items, *args, use_list=None, **kwargs)
    343     def __getitem__(self, idx): return self._get(idx) if is_indexer(idx) else L(self._get(idx), use_list=None)
    344     def copy(self): return self._new(self.items.copy())

~/.conda/envs/ai/lib/python3.8/site-packages/fastcore/foundation.py in __call__(cls, x, *args, **kwargs)
     49             return x
     50 
---> 51         res = super().__call__(*((x,) + args), **kwargs)
     52         res._newchk = 0
     53         return res

~/.conda/envs/ai/lib/python3.8/site-packages/fastcore/foundation.py in __init__(self, items, use_list, match, *rest)
    331         if items is None: items = []
    332         if (use_list is not None) or not _is_array(items):
--> 333             items = list(items) if use_list else _listify(items)
    334         if match is not None:
    335             if is_coll(match): match = len(match)

~/.conda/envs/ai/lib/python3.8/site-packages/fastcore/foundation.py in _listify(o)
    244     if isinstance(o, list): return o
    245     if isinstance(o, str) or _is_array(o): return [o]
--> 246     if is_iter(o): return list(o)
    247     return [o]
    248 

~/.conda/envs/ai/lib/python3.8/site-packages/fastcore/foundation.py in __call__(self, *args, **kwargs)
    307             if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    308         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 309         return self.fn(*fargs, **kwargs)
    310 
    311 # Cell

~/.conda/envs/ai/lib/python3.8/site-packages/fastai/learner.py in _call_one(self, event_name)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name), event_name
--> 137         [cb(event_name) for cb in sort_by_run(self.cbs)]
    138 
    139     def _bn_bias_state(self, with_bias): return norm_bias_params(self.model, with_bias).map(self.opt.state)

~/.conda/envs/ai/lib/python3.8/site-packages/fastai/learner.py in <listcomp>(.0)
    135     def _call_one(self, event_name):
    136         assert hasattr(event, event_name), event_name
--> 137         [cb(event_name) for cb in sort_by_run(self.cbs)]
    138 
    139     def _bn_bias_state(self, with_bias): return norm_bias_params(self.model, with_bias).map(self.opt.state)

~/.conda/envs/ai/lib/python3.8/site-packages/fastai/callback/core.py in __call__(self, event_name)
     42                (self.run_valid and not getattr(self, 'training', False)))
     43         res = None
---> 44         if self.run and _run: res = getattr(self, event_name, noop)()
     45         if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
     46         return res

~/.conda/envs/ai/lib/python3.8/site-packages/fastai/callback/core.py in after_validate(self)
    118         if not hasattr(self, 'preds'): return
    119         if self.with_input:     self.inputs  = detuplify(to_concat(self.inputs, dim=self.concat_dim))
--> 120         if not self.save_preds: self.preds   = detuplify(to_concat(self.preds, dim=self.concat_dim))
    121         if not self.save_targs: self.targets = detuplify(to_concat(self.targets, dim=self.concat_dim))
    122         if self.with_loss:      self.losses  = to_concat(self.losses)

~/.conda/envs/ai/lib/python3.8/site-packages/fastai/torch_core.py in to_concat(xs, dim)
    231     "Concat the element in `xs` (recursively if they are tuples/lists of tensors)"
    232     if not xs: return xs
--> 233     if is_listy(xs[0]): return type(xs[0])([to_concat([x[i] for x in xs], dim=dim) for i in range_of(xs[0])])
    234     if isinstance(xs[0],dict):  return {k: to_concat([x[k] for x in xs], dim=dim) for k in xs[0].keys()}
    235     #We may receive xs that are not concatenable (inputs of a text classifier for instance),

~/.conda/envs/ai/lib/python3.8/site-packages/fastai/torch_core.py in <listcomp>(.0)
    231     "Concat the element in `xs` (recursively if they are tuples/lists of tensors)"
    232     if not xs: return xs
--> 233     if is_listy(xs[0]): return type(xs[0])([to_concat([x[i] for x in xs], dim=dim) for i in range_of(xs[0])])
    234     if isinstance(xs[0],dict):  return {k: to_concat([x[k] for x in xs], dim=dim) for k in xs[0].keys()}
    235     #We may receive xs that are not concatenable (inputs of a text classifier for instance),

~/.conda/envs/ai/lib/python3.8/site-packages/fastai/torch_core.py in to_concat(xs, dim)
    236     #   in this case we return a big list
    237     try:    return retain_type(torch.cat(xs, dim=dim), xs[0])
--> 238     except: return sum([L(retain_type(o_.index_select(dim, tensor(i)).squeeze(dim), xs[0])
    239                           for i in range_of(o_)) for o_ in xs], L())
    240 

~/.conda/envs/ai/lib/python3.8/site-packages/fastai/torch_core.py in <listcomp>(.0)
    237     try:    return retain_type(torch.cat(xs, dim=dim), xs[0])
    238     except: return sum([L(retain_type(o_.index_select(dim, tensor(i)).squeeze(dim), xs[0])
--> 239                           for i in range_of(o_)) for o_ in xs], L())
    240 
    241 # Cell

~/.conda/envs/ai/lib/python3.8/site-packages/fastcore/utils.py in range_of(x)
    198 def range_of(x):
    199     "All indices of collection `x` (i.e. `list(range(len(x)))`)"
--> 200     return list(range(len(x)))
    201 
    202 # Cell

~/.conda/envs/ai/lib/python3.8/site-packages/torch/tensor.py in __len__(self)
    443     def __len__(self):
    444         if self.dim() == 0:
--> 445             raise TypeError("len() of a 0-d tensor")
    446         return self.shape[0]
    447 

TypeError: len() of a 0-d tensor

So I am thinking that I am probably trying to shoehorn this in the wrong way that I am wondering if there is a typical pattern one can/should use with fastai. Any help/guidance/commentary would be much appreciated.

Thanks,
Michael

Maybe I can have a go at answering this myself :wink: One possible pattern is to register a forward hook on that convolution layer and in that hook, grab the output of the layer, calculate the L1 norm and store it in a variable, which we can then add in when calculating the cost. So something like this:

def myfwdhook(module,input_,output):
     global l1_term
     l1_term = output.abs().mean()

def mycustomloss(pred,target):
     return ((pred-target)**2).mean() + 0.0001*l1_term

class MyNetwork(nn.Module):
    def __init__(self, samples_in, matrix_out):
        super(MyNetwork, self).__init__()
        
        self.samples_in = samples_in
        self.matrix_out = matrix_out
        self.samples_out = np.prod(self.matrix_out)
        
        self.fc1 = nn.Sequential(nn.Linear(self.samples_in*2, self.samples_out), nn.Tanh())
        self.cnv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=64, kernel_size=5, stride=1, padding=2), nn.ReLU())

        self.cnv1.register_forward_hook(myfwdhook)        

        self.dcnv = nn.Sequential(nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=7, stride=1, padding=3))

    def forward(self, x):
        batch_size = x.shape[0]
        
        x = self.fc1(x)
        x = x.reshape((batch_size,1) + self.matrix_out)
        x = self.cnv1(x)
        x = self.dcnv(x)
        x = x.reshape((batch_size, self.samples_out))

        return x

Or something like that.

In my continued conversation with myself, I can maybe try another approach, which I think is the fastai way to do this with a Learner callback. So I could do something like:

class L1RegCallback(Callback):
    def __init__(self, reglambda = 0.0001):
        self.reglambda = reglambda
        
    def before_backward(self):
        regularization_loss = 0.0
        for param in self.learn.model.cnv2.parameters():
            regularization_loss += torch.mean(torch.abs(param))
        
        self.learn.loss += self.reglambda*regularization_loss

And then something like:

learn = Learner(dls, mynetwork, opt_func=RMSProp, loss_func = nn.MSELoss(reduction='mean'), metrics=nn.MSELoss(reduction='mean'), cbs=[L1RegCallback()])

And it would add the L1 loss to the loss (before backprojection).

@sachinruk and @sgugger you have had some other threads on this:

But I think some of that was based on previous version of fastai. So I was wondering what your thoughts are, is the approach above correct for adding an L1 loss (for a specific layer).

Hey, just to quickly answer your question, I found better (faster) convergence when you updated the weights directly instead of adding it to the loss as you seem to be doing here.

The only thing I’d change in my answer is that it ought to be param.data = param.data - learning_rate * self.beta * sign.

1 Like