Learn.lr_find pickle error

Version: fastai-1.0.46

I just updated fastai from 1.0.42-1 to 1.0.46-1 and I’m having lr_find issues.

I’m using a custom head which requires flattening, so I made use of the lambda layer suggested here: Lambda Layer

Before updating (using version 1.0.42) everything worked fine (I could use lr_find, train, evaluate, etc.), but I was running into weird hiccups here and there so I updated. Now that I’ve updated I’m running into this serialization issue with the torch library which originates from a call to lr_find (I can supply the full stacktrace if needed):

~/.conda/envs/fellowship_env/lib/python3.7/site-packages/torch/serialization.py in _save(obj, f, pickle_module, pickle_protocol)
290     pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
291     pickler.persistent_id = persistent_id
--> 292     pickler.dump(obj)
293 
294     serialized_storage_keys = sorted(serialized_storages.keys())

AttributeError: Can't pickle local object 'Flatten.<locals>.<lambda>'

I’m aware that lambda functions can’t be serialized so I went looking and found this workaround involving dill: import dill as pickle

But I assume the torch library imports pickle in the source code, so my question is as follows:

Is there something that can be changed in lr_find so that it works again as it did with fastai version 1.42.0? If not, is there a way to override the torch library pickle import with dill without going into the source code?

1 Like

Full stacktrace:

/home/constantin.baumgartner/.conda/envs/fellowship_env/lib/python3.7/site-packages/torch/serialization.py:251: UserWarning: Couldn't retrieve source code for container of type Lambda. It won't be checked for correctness upon loading.
  "type " + obj.__name__ + ". It won't be checked "
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-18-d81c6bd29d71> in <module>
----> 1 learn.lr_find()

~/.conda/envs/fellowship_env/lib/python3.7/site-packages/fastai/train.py in lr_find(learn, start_lr, end_lr, num_it, stop_div, wd)
     30     cb = LRFinder(learn, start_lr, end_lr, num_it, stop_div)
     31     epochs = int(np.ceil(num_it/len(learn.data.train_dl)))
---> 32     learn.fit(epochs, start_lr, callbacks=[cb], wd=wd)
     33 
     34 def to_fp16(learn:Learner, loss_scale:float=None, max_noskip:int=1000, dynamic:bool=False, clip:float=None,

~/.conda/envs/fellowship_env/lib/python3.7/site-packages/fastai/basic_train.py in fit(self, epochs, lr, wd, callbacks)
    180         if defaults.extra_callbacks is not None: callbacks += defaults.extra_callbacks
    181         fit(epochs, self.model, self.loss_func, opt=self.opt, data=self.data, metrics=self.metrics,
--> 182             callbacks=self.callbacks+callbacks)
    183 
    184     def create_opt(self, lr:Floats, wd:Floats=0.)->None:

~/.conda/envs/fellowship_env/lib/python3.7/site-packages/fastai/utils/mem.py in wrapper(*args, **kwargs)
     87 
     88         try:
---> 89             return func(*args, **kwargs)
     90         except Exception as e:
     91             if ("CUDA out of memory" in str(e) or

~/.conda/envs/fellowship_env/lib/python3.7/site-packages/fastai/basic_train.py in fit(epochs, model, loss_func, opt, data, callbacks, metrics)
    101         exception = e
    102         raise
--> 103     finally: cb_handler.on_train_end(exception)
    104 
    105 loss_func_name2activ = {'cross_entropy_loss': F.softmax, 'nll_loss': torch.exp, 'poisson_nll_loss': torch.exp,

~/.conda/envs/fellowship_env/lib/python3.7/site-packages/fastai/callback.py in on_train_end(self, exception)
    289     def on_train_end(self, exception:Union[bool,Exception])->None:
    290         "Handle end of training, `exception` is an `Exception` or False if no exceptions during training."
--> 291         self('train_end', exception=exception)
    292 
    293 class AverageMetric(Callback):

~/.conda/envs/fellowship_env/lib/python3.7/site-packages/fastai/callback.py in __call__(self, cb_name, call_mets, **kwargs)
    212         "Call through to all of the `CallbakHandler` functions."
    213         if call_mets: [getattr(met, f'on_{cb_name}')(**self.state_dict, **kwargs) for met in self.metrics]
--> 214         return [getattr(cb, f'on_{cb_name}')(**self.state_dict, **kwargs) for cb in self.callbacks]
    215 
    216     def set_dl(self, dl:DataLoader):

~/.conda/envs/fellowship_env/lib/python3.7/site-packages/fastai/callback.py in <listcomp>(.0)
    212         "Call through to all of the `CallbakHandler` functions."
    213         if call_mets: [getattr(met, f'on_{cb_name}')(**self.state_dict, **kwargs) for met in self.metrics]
--> 214         return [getattr(cb, f'on_{cb_name}')(**self.state_dict, **kwargs) for cb in self.callbacks]
    215 
    216     def set_dl(self, dl:DataLoader):

~/.conda/envs/fellowship_env/lib/python3.7/site-packages/fastai/callbacks/lr_finder.py in on_train_end(self, **kwargs)
     43         # restore the valid_dl we turned off on `__init__`
     44         self.data.valid_dl = self.valid_dl
---> 45         self.learn.load('tmp')
     46         if hasattr(self.learn.model, 'reset'): self.learn.model.reset()
     47         print('LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.')

~/.conda/envs/fellowship_env/lib/python3.7/site-packages/fastai/basic_train.py in load(self, name, device, strict, with_opt, purge)
    241     def load(self, name:PathOrStr, device:torch.device=None, strict:bool=True, with_opt:bool=None, purge:bool=True):
    242         "Load model and optimizer state (if `with_opt`) `name` from `self.model_dir` using `device`."
--> 243         if purge: self.purge(clear_opt=ifnone(with_opt, False))
    244         if device is None: device = self.data.device
    245         state = torch.load(self.path/self.model_dir/f'{name}.pth', map_location=device)

~/.conda/envs/fellowship_env/lib/python3.7/site-packages/fastai/basic_train.py in purge(self, clear_opt)
    287         state['cb_state'] = {cb.__class__:cb.get_state() for cb in self.callbacks}
    288         if hasattr(self, 'opt'): state['opt'] = self.opt.get_state()
--> 289         torch.save(state, open(tmp_file, 'wb'))
    290         for a in attrs_del: delattr(self, a)
    291         gc.collect()

~/.conda/envs/fellowship_env/lib/python3.7/site-packages/torch/serialization.py in save(obj, f, pickle_module, pickle_protocol)
    217         >>> torch.save(x, buffer)
    218     """
--> 219     return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
    220 
    221 

~/.conda/envs/fellowship_env/lib/python3.7/site-packages/torch/serialization.py in _with_file_like(f, mode, body)
    142         f = open(f, mode)
    143     try:
--> 144         return body(f)
    145     finally:
    146         if new_fd:

~/.conda/envs/fellowship_env/lib/python3.7/site-packages/torch/serialization.py in <lambda>(f)
    217         >>> torch.save(x, buffer)
    218     """
--> 219     return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
    220 
    221 

~/.conda/envs/fellowship_env/lib/python3.7/site-packages/torch/serialization.py in _save(obj, f, pickle_module, pickle_protocol)
    290     pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
    291     pickler.persistent_id = persistent_id
--> 292     pickler.dump(obj)
    293 
    294     serialized_storage_keys = sorted(serialized_storages.keys())

AttributeError: Can't pickle local object 'Flatten.<locals>.<lambda>'

Never mind, just found fastay.layers that has everything I need.

For future reference, this comes from purge that tries to save your model to free GPU RAM for you. I have disabled the default for now in master, but in general, you should avoid lambda functions because of this (you can use the Lambda layer, but put a function that has a name inside and not a lambda).

1 Like

Gotcha, thanks for the clarification.

@sgugger, perhaps the other approach is to have try_purge instead of purge? so if it succeeds then great and if doesn’t - it just remains as is. Otherwise we are impacting 99% of users, because 1% uses something that is not pickle-able?

And of course an even better transparent, but more complex, solution is to test for pickle-ability each segment, and only pickle/restore what’s pickle-able and leaving the rest as is.

1 Like

Seems reasonable for the first part, I’ll do it later today. Second part is way more trickier.