Using wandb.fastai with gan_critic() learner results in an error

Hello friends!

I was wondering if anyone could provide some advice on how to correctly use wandb.fastai with a Learner that uses gan_critic(). I’ve used wandb to visualize my networks in the case of unet_learner and cnn_learner successfully. Now I am trying to recreate the GAN example from Lesson 7 and I am adding the wandb callback as shown in this example. So my final critic code is the following:

critic =  Learner(data_critic, gan_critic(), metrics=metrics, loss_func=loss_critic, wd=wd)
critic.fit_one_cycle(5, 1e-3, callbacks=[WandbCallback(critic, input_type='images', log='all', monitor='accuracy_thresh_expand')])

Everything else is the same as in Lesson 7. When the first epoch is completed, I get a list index out of range error. I tried using %debug, but that just prints the whole stack trace, which I already had. I’ve included the stack-trace below.

Stack trace
/usr/local/lib/python3.6/dist-packages/fastai/train.py in fit_one_cycle(learn, cyc_len, max_lr, moms, div_factor, pct_start, final_div, wd, callbacks, tot_epochs, start_epoch)
     21     callbacks.append(OneCycleScheduler(learn, max_lr, moms=moms, div_factor=div_factor, pct_start=pct_start,
     22                                        final_div=final_div, tot_epochs=tot_epochs, start_epoch=start_epoch))
---> 23     learn.fit(cyc_len, max_lr, wd=wd, callbacks=callbacks)
     24 
     25 def fit_fc(learn:Learner, tot_epochs:int=1, lr:float=defaults.lr,  moms:Tuple[float,float]=(0.95,0.85), start_pct:float=0.72,

/usr/local/lib/python3.6/dist-packages/fastai/basic_train.py in fit(self, epochs, lr, wd, callbacks)
    198         else: self.opt.lr,self.opt.wd = lr,wd
    199         callbacks = [cb(self) for cb in self.callback_fns + listify(defaults.extra_callback_fns)] + listify(callbacks)
--> 200         fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks)
    201 
    202     def create_opt(self, lr:Floats, wd:Floats=0.)->None:

/usr/local/lib/python3.6/dist-packages/fastai/basic_train.py in fit(epochs, learn, callbacks, metrics)
    106                                        cb_handler=cb_handler, pbar=pbar)
    107             else: val_loss=None
--> 108             if cb_handler.on_epoch_end(val_loss): break
    109     except Exception as e:
    110         exception = e

/usr/local/lib/python3.6/dist-packages/fastai/callback.py in on_epoch_end(self, val_loss)
    315         "Epoch is done, process `val_loss`."
    316         self.state_dict['last_metrics'] = [val_loss] if val_loss is not None else [None]
--> 317         self('epoch_end', call_mets = val_loss is not None)
    318         self.state_dict['epoch'] += 1
    319         return self.state_dict['stop_training']

/usr/local/lib/python3.6/dist-packages/fastai/callback.py in __call__(self, cb_name, call_mets, **kwargs)
    249         if call_mets:
    250             for met in self.metrics: self._call_and_update(met, cb_name, **kwargs)
--> 251         for cb in self.callbacks: self._call_and_update(cb, cb_name, **kwargs)
    252 
    253     def set_dl(self, dl:DataLoader):

/usr/local/lib/python3.6/dist-packages/fastai/callback.py in _call_and_update(self, cb, cb_name, **kwargs)
    239     def _call_and_update(self, cb, cb_name, **kwargs)->None:
    240         "Call `cb_name` on `cb` and update the inner state."
--> 241         new = ifnone(getattr(cb, f'on_{cb_name}')(**self.state_dict, **kwargs), dict())
    242         for k,v in new.items():
    243             if k not in self.state_dict:

/usr/local/lib/python3.6/dist-packages/wandb/fastai/__init__.py in on_epoch_end(self, epoch, smooth_loss, last_metrics, **kwargs)
    142 
    143             for x, y in self.validation_data:
--> 144                 pred = self.learn.predict(x)
    145 
    146                 # scalar -> likely to be a category

/usr/local/lib/python3.6/dist-packages/fastai/basic_train.py in predict(self, item, return_x, batch_first, with_dropout, **kwargs)
    380         pred = ds.y.analyze_pred(raw_pred, **kwargs)
    381         x = ds.x.reconstruct(grab_idx(x, 0))
--> 382         y = ds.y.reconstruct(pred, x) if has_arg(ds.y.reconstruct, 'x') else ds.y.reconstruct(pred)
    383         return (x, y, pred, raw_pred) if return_x else (y, pred, raw_pred)
    384 

/usr/local/lib/python3.6/dist-packages/fastai/data_block.py in reconstruct(self, t)
    384 
    385     def reconstruct(self, t):
--> 386         return Category(t, self.classes[t])
    387 
    388 class MultiCategoryProcessor(CategoryProcessor):

IndexError: list index out of range

Any help is appreciated - thanks in advance!

At the moment there are a few cases where logging images doesn’t work as it depends on the predict method. It should now show you more details when it fails. See this issue

You could still use the callback but need to remove the input_type arg.