Hello all,
I have trained wgan network as in lesson7-wgan.ipynb using colab.
I saved the model using:
learn.save('stage-2')
but when I try to load the model back using
learn = GANLearner.wgan(data, generator, critic, switch_eval=False, opt_func = partial(optim.Adam, betas = (0.,0.99)), wd=0.) learn.load('stage-2')
I get the following error:
--------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-60-78cd3789730e> in <module>()
----> 1 learn.load('stage-2')
/usr/local/lib/python3.6/dist-packages/fastai/basic_train.py in load(self, name, device, strict, with_opt, purge)
238 def load(self, name:PathOrStr, device:torch.device=None, strict:bool=True, with_opt:bool=None, purge:bool=True):
239 "Load model and optimizer state (if `with_opt`) `name` from `self.model_dir` using `device`."
--> 240 if purge: self.purge(clear_opt=ifnone(with_opt, False))
241 if device is None: device = self.data.device
242 state = torch.load(self.path/self.model_dir/f'{name}.pth', map_location=device)
/usr/local/lib/python3.6/dist-packages/fastai/basic_train.py in purge(self, clear_opt)
284 state['cb_state'] = {cb.__class__:cb.get_state() for cb in self.callbacks}
285 if hasattr(self, 'opt'): state['opt'] = self.opt.get_state()
--> 286 torch.save(state, open(tmp_file, 'wb'))
287 for a in attrs_del: delattr(self, a)
288 gc.collect()
/usr/local/lib/python3.6/dist-packages/torch/serialization.py in save(obj, f, pickle_module, pickle_protocol)
217 >>> torch.save(x, buffer)
218 """
--> 219 return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
220
221
/usr/local/lib/python3.6/dist-packages/torch/serialization.py in _with_file_like(f, mode, body)
142 f = open(f, mode)
143 try:
--> 144 return body(f)
145 finally:
146 if new_fd:
/usr/local/lib/python3.6/dist-packages/torch/serialization.py in <lambda>(f)
217 >>> torch.save(x, buffer)
218 """
--> 219 return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
220
221
/usr/local/lib/python3.6/dist-packages/torch/serialization.py in _save(obj, f, pickle_module, pickle_protocol)
290 pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
291 pickler.persistent_id = persistent_id
--> 292 pickler.dump(obj)
293
294 serialized_storage_keys = sorted(serialized_storages.keys())
AttributeError: Can't pickle local object 'AvgFlatten.<locals>.<lambda>'
Could you please advise on have to save and load the model properly so I can continue training after period of time?
Thank you.