Torchvision Segmentation Models with FastAI 2

I have make my first approach with the next code:

import torchvision
learn = Learner(dls=dls, model=model, metrics=[Dice(),JaccardCoeff()],wd=1e-2).to_fp16()

learn.lr_find() # find learning rate
learn.recorder # plot learning rate graph

However, it throws the next error:

AttributeError                            Traceback (most recent call last)
<ipython-input-23-0a497bfa9bea> in <module>
----> 1 learn.lr_find() # find learning rate
      2 learn.recorder # plot learning rate graph

~/anaconda3/envs/seg/lib/python3.7/site-packages/fastai2/callback/ in lr_find(self, start_lr, end_lr, num_it, stop_div, show_plot, suggestions)
    226     n_epoch = num_it//len(self.dls.train) + 1
    227     cb=LRFinder(start_lr=start_lr, end_lr=end_lr, num_it=num_it, stop_div=stop_div)
--> 228     with self.no_logging():, cbs=cb)
    229     if show_plot: self.recorder.plot_lr_find()
    230     if suggestions:

~/anaconda3/envs/seg/lib/python3.7/site-packages/fastcore/ in _f(*args, **kwargs)
    428         init_args.update(log)
    429         setattr(inst, 'init_args', init_args)
--> 430         return inst if to_return else f(*args, **kwargs)
    431     return _f

~/anaconda3/envs/seg/lib/python3.7/site-packages/fastai2/ in fit(self, n_epoch, lr, wd, cbs, reset_opt)
    198                     try:
    199                         self.epoch=epoch;          self('begin_epoch')
--> 200                         self._do_epoch_train()
    201                         self._do_epoch_validate()
    202                     except CancelEpochException:   self('after_cancel_epoch')

~/anaconda3/envs/seg/lib/python3.7/site-packages/fastai2/ in _do_epoch_train(self)
    173         try:
    174             self.dl = self.dls.train;                        self('begin_train')
--> 175             self.all_batches()
    176         except CancelTrainException:                         self('after_cancel_train')
    177         finally:                                             self('after_train')

~/anaconda3/envs/seg/lib/python3.7/site-packages/fastai2/ in all_batches(self)
    151     def all_batches(self):
    152         self.n_iter = len(self.dl)
--> 153         for o in enumerate(self.dl): self.one_batch(*o)
    155     def one_batch(self, i, b):

~/anaconda3/envs/seg/lib/python3.7/site-packages/fastai2/ in one_batch(self, i, b)
    159             self.pred = self.model(*self.xb);                self('after_pred')
    160             if len(self.yb) == 0: return
--> 161             self.loss = self.loss_func(self.pred, *self.yb); self('after_loss')
    162             if not return
    163             self.loss.backward();                            self('after_backward')

~/anaconda3/envs/seg/lib/python3.7/site-packages/fastai2/ in __call__(self, inp, targ, **kwargs)
    290     def __call__(self, inp, targ, **kwargs):
--> 291         inp  = inp .transpose(self.axis,-1).contiguous()
    292         targ = targ.transpose(self.axis,-1).contiguous()
    293         if self.floatify and targ.dtype!=torch.float16: targ = targ.float()

AttributeError: 'dict' object has no attribute 'transpose'

I would like to make a print of self.pred for looking into the shape and type that returns that model but I don’t know how to achieve that.

I have tried to pass an item to the model as follows:

for f in dls.train.after_item:
  name =
  x = f(x)
for f in dls.train.after_batch:
  name =
  x = f(x)  

It returns:

torch.Size([2, 1, 501, 501])

Is there a way of making the training loop of a learner get the result from an Ordered Dict??

I’d try a custom callback. I would customize after_pred. You may need to set an specific order (see any callback implementation) in order to run before a certain callback.

From the docs: “”"after_pred : called after computing the output of the model on the batch. It can be used to change that output before it’s fed to the loss"""

1 Like

I managed to get it working as follows:

import torchvision

class GetResult(Callback):
    def after_pred(self):
        self.learn.pred = self.pred["out"]

learn = Learner(dls=dls, model=model, metrics=[Dice(),JaccardCoeff()],wd=1e-2, cbs=GetResult()).to_fp16()