The callback also logs the epoch.
When you have items previously logged (maybe from a previous loop). It wants to make sure it continues to the next epoch, so tries to read what is the last epoch logged.
Maybe I could change the logic and not assume that there will always be an epoch logged.
There should be no issue if you do some manual logging after at least one point has been logged.
(I realize this is a year later, but I see no other related posts on the forum.)
When trying to call learn.fit with the WandBCallback for the GANLearner.wgan, I get the error, WandbCallback was not able to prepare a DataLoader for logging prediction samples -> list index out of range.
What exactly is it thatâs out of range? Iâm able run run learn.show_results() with no problem.
Update: Started writing my own callback for wandb but am confused about how to get the output of the generator for âpredsâ instead of just the output of the critic. Hereâs where Iâm at so far â it doesnât work â Iâd welcome suggestions!
from PIL import Image
class WandB_WGAN_Images(Callback):
"Progress-like callback: log WGAN predictions to WandB"
order = ProgressCallback.order+1
def __init__(self, n_preds=6):
store_attr()
def after_epoch(self):
if not self.learn.training:
with torch.no_grad():
self.learn.switch(gen_mode=True)
inp,preds,targs,out = self.learn.pred
b = tuplify(inp) + tuplify(targs)
self.dl.show_results(b, out, show=False, max_n=self.n_preds)
preds = preds.detach().permute(1, 2, 0).cpu().squeeze().numpy()
images = [Image.fromarray(image) for image in preds]
wandb.log({"examples": [wandb.Image(image) for image in images]})
self.learn.switch(gen_mode=False)
Currently fails at the inp,preds,targs... line with ValueError: too many values to unpack (expected 4)
I see that show_results() uses âsamplesâ and âoutsâ â but I canât figure out how to obtain samples & outs while inside a callback.
class WandB_WGAN_Images(Callback):
"Progress-like callback: log WGAN predictions to WandB"
order = ProgressCallback.order+1
def __init__(self, n_preds=10):
store_attr()
def after_epoch(self):
if self.gen_mode:
preds = learn.gan_trainer.last_gen.cpu()
img_grid = make_grid(preds[:self.n_preds], nrow=5)
img_grid = img_grid.permute(1, 2, 0).squeeze()
wandb.log({"examples": wandb.Image(img_grid)})
NB: This callback should be used in fit() but not in the definition of the learner. Otherwise youâll get an error if you call learn.show_results() after a wandb.finish().
thanks. you are right. works for simple example. not sure why it doesnât work with the add on library I was using. will do more investigation and report back if I figure it out