How to load wgan model

Hello all,

I have trained wgan network as in lesson7-wgan.ipynb using colab.

I saved the model using:'stage-2')

but when I try to load the model back using

learn = GANLearner.wgan(data, generator, critic, switch_eval=False,
                   opt_func = partial(optim.Adam, betas = (0.,0.99)), wd=0.)

I get the following error:

AttributeError                            Traceback (most recent call last)
<ipython-input-60-78cd3789730e> in <module>()
----> 1 learn.load('stage-2')

/usr/local/lib/python3.6/dist-packages/fastai/ in load(self, name, device, strict, with_opt, purge)
    238     def load(self, name:PathOrStr, device:torch.device=None, strict:bool=True, with_opt:bool=None, purge:bool=True):
    239         "Load model and optimizer state (if `with_opt`) `name` from `self.model_dir` using `device`."
--> 240         if purge: self.purge(clear_opt=ifnone(with_opt, False))
    241         if device is None: device =
    242         state = torch.load(self.path/self.model_dir/f'{name}.pth', map_location=device)

/usr/local/lib/python3.6/dist-packages/fastai/ in purge(self, clear_opt)
    284         state['cb_state'] = {cb.__class__:cb.get_state() for cb in self.callbacks}
    285         if hasattr(self, 'opt'): state['opt'] = self.opt.get_state()
--> 286, open(tmp_file, 'wb'))
    287         for a in attrs_del: delattr(self, a)
    288         gc.collect()

/usr/local/lib/python3.6/dist-packages/torch/ in save(obj, f, pickle_module, pickle_protocol)
    217         >>>, buffer)
    218     """
--> 219     return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))

/usr/local/lib/python3.6/dist-packages/torch/ in _with_file_like(f, mode, body)
    142         f = open(f, mode)
    143     try:
--> 144         return body(f)
    145     finally:
    146         if new_fd:

/usr/local/lib/python3.6/dist-packages/torch/ in <lambda>(f)
    217         >>>, buffer)
    218     """
--> 219     return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))

/usr/local/lib/python3.6/dist-packages/torch/ in _save(obj, f, pickle_module, pickle_protocol)
    290     pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
    291     pickler.persistent_id = persistent_id
--> 292     pickler.dump(obj)
    294     serialized_storage_keys = sorted(serialized_storages.keys())

AttributeError: Can't pickle local object 'AvgFlatten.<locals>.<lambda>'

Could you please advise on have to save and load the model properly so I can continue training after period of time?

Thank you.


Ran into the exact same problem when trying to load my WGAN on a different machine.

1 Like

Took a dive into the logs and it turns out that the problem comes from running the purge function. If you run load('stage-2', purge=False), it should work, but I’m not sure what purge does exactly.


Thanks for the answer. Now it works like a charm…

1 Like

Didn’t work. I get the error AttributeError: 'GANLearner' object has no attribute 'gen_mode' on learn.show_results() or any other inference method.

Hi this forum post may be of use to you. It seems you are having the same problem as the user on the post. Try
I haven’t tested this myself though, so no guarantees.