How to load wgan model

Hello all,

I have trained wgan network as in lesson7-wgan.ipynb using colab.

I saved the model using:

learn.save('stage-2')

but when I try to load the model back using

learn = GANLearner.wgan(data, generator, critic, switch_eval=False,
                   opt_func = partial(optim.Adam, betas = (0.,0.99)), wd=0.)
learn.load('stage-2')

I get the following error:

--------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-60-78cd3789730e> in <module>()
----> 1 learn.load('stage-2')

/usr/local/lib/python3.6/dist-packages/fastai/basic_train.py in load(self, name, device, strict, with_opt, purge)
    238     def load(self, name:PathOrStr, device:torch.device=None, strict:bool=True, with_opt:bool=None, purge:bool=True):
    239         "Load model and optimizer state (if `with_opt`) `name` from `self.model_dir` using `device`."
--> 240         if purge: self.purge(clear_opt=ifnone(with_opt, False))
    241         if device is None: device = self.data.device
    242         state = torch.load(self.path/self.model_dir/f'{name}.pth', map_location=device)

/usr/local/lib/python3.6/dist-packages/fastai/basic_train.py in purge(self, clear_opt)
    284         state['cb_state'] = {cb.__class__:cb.get_state() for cb in self.callbacks}
    285         if hasattr(self, 'opt'): state['opt'] = self.opt.get_state()
--> 286         torch.save(state, open(tmp_file, 'wb'))
    287         for a in attrs_del: delattr(self, a)
    288         gc.collect()

/usr/local/lib/python3.6/dist-packages/torch/serialization.py in save(obj, f, pickle_module, pickle_protocol)
    217         >>> torch.save(x, buffer)
    218     """
--> 219     return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
    220 
    221 

/usr/local/lib/python3.6/dist-packages/torch/serialization.py in _with_file_like(f, mode, body)
    142         f = open(f, mode)
    143     try:
--> 144         return body(f)
    145     finally:
    146         if new_fd:

/usr/local/lib/python3.6/dist-packages/torch/serialization.py in <lambda>(f)
    217         >>> torch.save(x, buffer)
    218     """
--> 219     return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
    220 
    221 

/usr/local/lib/python3.6/dist-packages/torch/serialization.py in _save(obj, f, pickle_module, pickle_protocol)
    290     pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
    291     pickler.persistent_id = persistent_id
--> 292     pickler.dump(obj)
    293 
    294     serialized_storage_keys = sorted(serialized_storages.keys())

AttributeError: Can't pickle local object 'AvgFlatten.<locals>.<lambda>'

Could you please advise on have to save and load the model properly so I can continue training after period of time?

Thank you.

2 Likes

Ran into the exact same problem when trying to load my WGAN on a different machine.

1 Like

Took a dive into the logs and it turns out that the problem comes from running the purge function. If you run load('stage-2', purge=False), it should work, but I’m not sure what purge does exactly.

4 Likes

Thanks for the answer. Now it works like a charm…

1 Like

Didn’t work. I get the error AttributeError: 'GANLearner' object has no attribute 'gen_mode' on learn.show_results() or any other inference method.

Hi this forum post may be of use to you. It seems you are having the same problem as the user on the post. Try
learn_gen.predict(img)
I haven’t tested this myself though, so no guarantees.