How to perform a batch inference for the GANs model created for super resolution in lesson7-superres-gan.ipynb notebook? I am getting an error when I am trying to do the following
data_crit = get_crit_data(['images_data_200', 'images_data_600'], bs=bs, size=size)
learn_crit = create_critic_learner(data_crit, metrics=None).load('critic-pre2')
learn_gen = create_gen_learner().load('gen-pre2')
switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)
learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.,50.), show_img=False, switcher=switcher,opt_func=partial(optim.Adam, betas=(0.,0.99)), wd=wd)
learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))
learn.load('gan-1d')
p,img_hr,b = learn.predict(open_image('test_data/sample.jpg'))
Image(img_hr).show(figsize=(10,10))
I am getting the following error:
AttributeError: ‘GANLearner’ object has no attribute 'gen_mode’
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-84-36fff7c3f50b> in <module>
----> 1 p,img_hr,b = learn.predict(open_image('images_data_600/2642415632_2.jpg'))
2 Image(img_hr).show(figsize=(10,10))
~/anaconda3/envs/fastai-latest/lib/python3.7/site-packages/fastai/basic_train.py in predict(self, item, return_x, batch_first, with_dropout, **kwargs)
371 "Return predicted class, label and probabilities for `item`."
372 batch = self.data.one_item(item)
--> 373 res = self.pred_batch(batch=batch, with_dropout=with_dropout)
374 raw_pred,x = grab_idx(res,0,batch_first=batch_first),batch[0]
375 norm = getattr(self.data,'norm',False)
~/anaconda3/envs/fastai-latest/lib/python3.7/site-packages/fastai/basic_train.py in pred_batch(self, ds_type, batch, reconstruct, with_dropout, activ)
347 else: xb,yb = self.data.one_batch(ds_type, detach=False, denorm=False)
348 cb_handler = CallbackHandler(self.callbacks)
--> 349 xb,yb = cb_handler.on_batch_begin(xb,yb, train=False)
350 activ = ifnone(activ, _loss_func2activ(self.loss_func))
351 with torch.no_grad():
~/anaconda3/envs/fastai-latest/lib/python3.7/site-packages/fastai/callback.py in on_batch_begin(self, xb, yb, train)
277 self.state_dict.update(dict(last_input=xb, last_target=yb, train=train,
278 stop_epoch=False, skip_step=False, skip_zero=False, skip_bwd=False))
--> 279 self('batch_begin', call_mets = not self.state_dict['train'])
280 return self.state_dict['last_input'], self.state_dict['last_target']
281
~/anaconda3/envs/fastai-latest/lib/python3.7/site-packages/fastai/callback.py in __call__(self, cb_name, call_mets, **kwargs)
249 if call_mets:
250 for met in self.metrics: self._call_and_update(met, cb_name, **kwargs)
--> 251 for cb in self.callbacks: self._call_and_update(cb, cb_name, **kwargs)
252
253 def set_dl(self, dl:DataLoader):
~/anaconda3/envs/fastai-latest/lib/python3.7/site-packages/fastai/callback.py in _call_and_update(self, cb, cb_name, **kwargs)
239 def _call_and_update(self, cb, cb_name, **kwargs)->None:
240 "Call `cb_name` on `cb` and update the inner state."
--> 241 new = ifnone(getattr(cb, f'on_{cb_name}')(**self.state_dict, **kwargs), dict())
242 for k,v in new.items():
243 if k not in self.state_dict:
~/anaconda3/envs/fastai-latest/lib/python3.7/site-packages/fastai/vision/gan.py in on_batch_begin(self, last_input, last_target, **kwargs)
112 for p in self.critic.parameters(): p.data.clamp_(-self.clip, self.clip)
113 if last_input.dtype == torch.float16: last_target = to_half(last_target)
--> 114 return {'last_input':last_input,'last_target':last_target} if self.gen_mode else {'last_input':last_target,'last_target':last_input}
115
116 def on_backward_begin(self, last_loss, last_output, **kwargs):
~/anaconda3/envs/fastai-latest/lib/python3.7/site-packages/fastai/basic_train.py in __getattr__(self, k)
441 setattr(self.learn, self.cb_name, self)
442
--> 443 def __getattr__(self,k): return getattr(self.learn, k)
444 def __setstate__(self,data:Any): self.__dict__.update(data)
445
AttributeError: 'GANLearner' object has no attribute 'gen_mode'
Can someone help?