Batch prediction/ Inference for GANS - SuperRes models

How to perform a batch inference for the GANs model created for super resolution in lesson7-superres-gan.ipynb notebook? I am getting an error when I am trying to do the following

data_crit = get_crit_data(['images_data_200', 'images_data_600'], bs=bs, size=size)

learn_crit = create_critic_learner(data_crit, metrics=None).load('critic-pre2')
learn_gen = create_gen_learner().load('gen-pre2')

switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)
learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.,50.), show_img=False, switcher=switcher,opt_func=partial(optim.Adam, betas=(0.,0.99)), wd=wd)
learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))


p,img_hr,b = learn.predict(open_image('test_data/sample.jpg'))

I am getting the following error:

AttributeError: ‘GANLearner’ object has no attribute 'gen_mode’

AttributeError                            Traceback (most recent call last)
<ipython-input-84-36fff7c3f50b> in <module>
----> 1 p,img_hr,b = learn.predict(open_image('images_data_600/2642415632_2.jpg'))
      2 Image(img_hr).show(figsize=(10,10))

~/anaconda3/envs/fastai-latest/lib/python3.7/site-packages/fastai/ in predict(self, item, return_x, batch_first, with_dropout, **kwargs)
    371         "Return predicted class, label and probabilities for `item`."
    372         batch =
--> 373         res = self.pred_batch(batch=batch, with_dropout=with_dropout)
    374         raw_pred,x = grab_idx(res,0,batch_first=batch_first),batch[0]
    375         norm = getattr(,'norm',False)

~/anaconda3/envs/fastai-latest/lib/python3.7/site-packages/fastai/ in pred_batch(self, ds_type, batch, reconstruct, with_dropout, activ)
    347         else: xb,yb =, detach=False, denorm=False)
    348         cb_handler = CallbackHandler(self.callbacks)
--> 349         xb,yb = cb_handler.on_batch_begin(xb,yb, train=False)
    350         activ = ifnone(activ, _loss_func2activ(self.loss_func))
    351         with torch.no_grad():

~/anaconda3/envs/fastai-latest/lib/python3.7/site-packages/fastai/ in on_batch_begin(self, xb, yb, train)
    277         self.state_dict.update(dict(last_input=xb, last_target=yb, train=train, 
    278             stop_epoch=False, skip_step=False, skip_zero=False, skip_bwd=False))
--> 279         self('batch_begin', call_mets = not self.state_dict['train'])
    280         return self.state_dict['last_input'], self.state_dict['last_target']

~/anaconda3/envs/fastai-latest/lib/python3.7/site-packages/fastai/ in __call__(self, cb_name, call_mets, **kwargs)
    249         if call_mets:
    250             for met in self.metrics: self._call_and_update(met, cb_name, **kwargs)
--> 251         for cb in self.callbacks: self._call_and_update(cb, cb_name, **kwargs)
    253     def set_dl(self, dl:DataLoader):

~/anaconda3/envs/fastai-latest/lib/python3.7/site-packages/fastai/ in _call_and_update(self, cb, cb_name, **kwargs)
    239     def _call_and_update(self, cb, cb_name, **kwargs)->None:
    240         "Call `cb_name` on `cb` and update the inner state."
--> 241         new = ifnone(getattr(cb, f'on_{cb_name}')(**self.state_dict, **kwargs), dict())
    242         for k,v in new.items():
    243             if k not in self.state_dict:

~/anaconda3/envs/fastai-latest/lib/python3.7/site-packages/fastai/vision/ in on_batch_begin(self, last_input, last_target, **kwargs)
    112             for p in self.critic.parameters():, self.clip)
    113         if last_input.dtype == torch.float16: last_target = to_half(last_target)
--> 114         return {'last_input':last_input,'last_target':last_target} if self.gen_mode else {'last_input':last_target,'last_target':last_input}
    116     def on_backward_begin(self, last_loss, last_output, **kwargs):

~/anaconda3/envs/fastai-latest/lib/python3.7/site-packages/fastai/ in __getattr__(self, k)
    441         setattr(self.learn, self.cb_name, self)
--> 443     def __getattr__(self,k): return getattr(self.learn, k)
    444     def __setstate__(self,data:Any): self.__dict__.update(data)

AttributeError: 'GANLearner' object has no attribute 'gen_mode'

Can someone help?

Hey, Dude!!

You have to predict with the generator that the GAN uses! In this case, it would be learn_gen.predict()

Thank you @davilirio. But I have a question here. We are training the GANLearner here for around 40 epochs and then saving that particular learner object. So if I just load the saved GANLearner object the learn_gen will be updated?

So, in the process of loading the GAN in a new notebook , you have to define your learn_gen/learn_crit as if you where creating them, create the GAN as usual and then load the saved GAN model. The parameters of the generator will come in the loading process, and then it is as simple as learn_gen.predict() to get inferences on data!
Just be careful when exporting a GAN learners generator to put into production, because if you have a personalized loss function it will have to be modularized and called when opening the app!

Thanks for the response @davilirio. If that’s the case then how does the learn object of the GANLearner works after the initial training? only learn_gen should have worked right?

Screen Shot 2020-08-25 at 4.15.23 PM|509x500

When using learn.show_results(), the GANLearner uses the learn_gen passed in GANLearner.from_learners() to make those predictions! If you enter the source code for that learner it should be there!
Remember that the discriminator acts as a loss function to extract better results from the generator during the training process by obligating the learn_gen module to learn how to “trick” the discriminator into believing the generated images are the ground_truth images!

@davilirio Thanks for the highlight. So after the initial training, the “learn” object will use the learn_gen to make those predictions. But when we are loading the model for inference we have to directly use learn_gen for predictions? Did I get it right?

1 Like

Yes! That’s it! When showing results the model uses the learn_gen generated images and to make direct predictons after loading you should call learn_gen!

1 Like