A callback that transfer the learning during training

I have a custom CIFAR10 dataset where the images are occluded by MNIST images.

The MNIST dataset is noise for the CIFAR task.

I want to train a resnet18 (pretrained=False) to learn the CIFAR task and at every epoch, use the same network trained in CIFAR to fine-tune to MNIST.

My hypothesis is that during training the CIFAR model gets better at learning a representation that ignores the MNIST dataset, therefore, the MNIST fine-tuning should get worse results in time.

The idea is to use a callback that at every epoch tries to train the head for the MNIST dataset while keeping the CIFAR backbone frozen.

path = Path('/home/fredguth/.fastai/data/cifar10_mnist')

def label_func(f):
    return (str(f).split("_y")[1:][0][0])

def noise_func(f):
    return (str(f).split("_n")[1:][0][0])

def get_dls(task="CIFAR"):
    dblock = DataBlock(blocks    = (ImageBlock, CategoryBlock),
                   splitter  = IndexSplitter(list(range(10000))),
                   get_items = get_image_files,
                   get_y = label_func if (task == 'CIFAR') else noise_func,
                   batch_tfms= [Normalize],
                   n_inp     = 1 )
    return dblock.dataloaders(path, num_workers=4)

cifar_learner= cnn_learner(dls=get_dls(task="CIFAR"),

class RunMNIST(Callback):
    def __init__(self, learner=None):
        self.l = learner
    def after_epoch(self):    
        print('self.l=', self.l)
        l = self.l
        l.dls = get_dls(task="MNIST")
        l.finetune(epochs = 1, base_lr=0.002, cbs=[CSVLogger(fname=f"mnist_after_cifar_e{self.epoch}.csv")])


The cifar_learner.fit throws an error: ModuleAttributeError: 'Sequential' object has no attribute 'finetune'

Despite print('self.l=', self.l) printing self.l= <fastai.learner.Learner object at 0x7f13216cf250>

What is going on here? How self.l learner somehow is transformed to a Sequential? What am I missing?

It seems to be something that happens when self.run is True.
Posting the complete error here:

ModuleAttributeError                      Traceback (most recent call last)
<ipython-input-8-52af6e76caee> in <module>
----> 1 cifar_learner.fit(1,0.002)

~/.miniconda/envs/infod/lib/python3.9/site-packages/fastai/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
    210             self.opt.set_hypers(lr=self.lr if lr is None else lr)
    211             self.n_epoch = n_epoch
--> 212             self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
    214     def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None

~/.miniconda/envs/infod/lib/python3.9/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    159     def _with_events(self, f, event_type, ex, final=noop):
--> 160         try: self(f'before_{event_type}');  f()
    161         except ex: self(f'after_cancel_{event_type}')
    162         self(f'after_{event_type}');  final()

~/.miniconda/envs/infod/lib/python3.9/site-packages/fastai/learner.py in _do_fit(self)
    201         for epoch in range(self.n_epoch):
    202             self.epoch=epoch
--> 203             self._with_events(self._do_epoch, 'epoch', CancelEpochException)
    205     def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False):

~/.miniconda/envs/infod/lib/python3.9/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    160         try: self(f'before_{event_type}');  f()
    161         except ex: self(f'after_cancel_{event_type}')
--> 162         self(f'after_{event_type}');  final()
    164     def all_batches(self):

~/.miniconda/envs/infod/lib/python3.9/site-packages/fastai/learner.py in __call__(self, event_name)
    140     def ordered_cbs(self, event): return [cb for cb in self.cbs.sorted('order') if hasattr(cb, event)]
--> 141     def __call__(self, event_name): L(event_name).map(self._call_one)
    143     def _call_one(self, event_name):

~/.miniconda/envs/infod/lib/python3.9/site-packages/fastcore/foundation.py in map(self, f, gen, *args, **kwargs)
    152     def range(cls, a, b=None, step=None): return cls(range_of(a, b=b, step=step))
--> 154     def map(self, f, *args, gen=False, **kwargs): return self._new(map_ex(self, f, *args, gen=gen, **kwargs))
    155     def argwhere(self, f, negate=False, **kwargs): return self._new(argwhere(self, f, negate, **kwargs))
    156     def filter(self, f=noop, negate=False, gen=False, **kwargs):

~/.miniconda/envs/infod/lib/python3.9/site-packages/fastcore/basics.py in map_ex(iterable, f, gen, *args, **kwargs)
    664     res = map(g, iterable)
    665     if gen: return res
--> 666     return list(res)
    668 # Cell

~/.miniconda/envs/infod/lib/python3.9/site-packages/fastcore/basics.py in __call__(self, *args, **kwargs)
    649             if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    650         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 651         return self.func(*fargs, **kwargs)
    653 # Cell

~/.miniconda/envs/infod/lib/python3.9/site-packages/fastai/learner.py in _call_one(self, event_name)
    143     def _call_one(self, event_name):
    144         if not hasattr(event, event_name): raise Exception(f'missing {event_name}')
--> 145         for cb in self.cbs.sorted('order'): cb(event_name)
    147     def _bn_bias_state(self, with_bias): return norm_bias_params(self.model, with_bias).map(self.opt.state)

~/.miniconda/envs/infod/lib/python3.9/site-packages/fastai/callback/core.py in __call__(self, event_name)
     42                (self.run_valid and not getattr(self, 'training', False)))
     43         res = None
---> 44         if self.run and _run: res = getattr(self, event_name, noop)()
     45         if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
     46         return res

<ipython-input-5-5e6267e99289> in after_epoch(self)
      6         l = self.l
      7         l.dls = get_dls(task="MNIST")
----> 8         l.finetune(epochs = 1, base_lr=0.002, cbs=[CSVLogger(fname=f"mnist_after_cifar_e{self.epoch}.csv")])

~/.miniconda/envs/infod/lib/python3.9/site-packages/fastcore/basics.py in __getattr__(self, k)
    386         if self._component_attr_filter(k):
    387             attr = getattr(self,self._default,None)
--> 388             if attr is not None: return getattr(attr,k)
    389         raise AttributeError(k)
    390     def __dir__(self): return custom_dir(self,self._dir())

~/.miniconda/envs/infod/lib/python3.9/site-packages/torch/nn/modules/module.py in __getattr__(self, name)
    776             if name in modules:
    777                 return modules[name]
--> 778         raise ModuleAttributeError("'{}' object has no attribute '{}'".format(
    779             type(self).__name__, name))

ModuleAttributeError: 'Sequential' object has no attribute 'finetune'

It is learn.fine_tune, I don’t know why I thought it was learn.finetune. :rofl: