Callback for Drawing Training Information

Hello everyone,

I am new to fastai and currently working on training a VAE model using the MNIST dataset.

I have written a code snippet that generates decoded images and target images after training. Here’s the code I used:

probs, tars = learn.get_preds(dl = [dls.valid.one_batch()])
decoded_img, target_img = probs[0][:9].sigmoid(), tars[:9]
show_images(torch.cat([decoded_img, target_img], axis = 0), 
            nrows = 6, imsize=1.5, titles=None);

This code snippet worked correctly and produced the expected output.

Next, I attempted to create a callback to generate images after each epoch. Here’s the code for the callback:

class ShowImagesCallback(Callback):
    def __init__(self, n = 9): self.n = 9
        
    def after_epoch(self):
        print('before computing')
        # plt.clf()
        with torch.no_grad():
            probs, tars = self.learn.get_preds(
                dl = [self.learn.dls.valid.one_batch()],
                )
        print('after computing')
        decoded_img, target_img = probs[0][:9].sigmoid(), tars[:9]
        print('showing images')
        show_images(
            torch.cat([decoded_img, target_img], axis = 0).cpu(), 
            nrows = 6, imsize=1.5, titles=None)
        print('completed')
        plt.show()

However, this callback is not functioning correctly. It prints “before computing” twice in an epoch and gets stuck in an infinite loop.

Any tips or suggestions would be greatly appreciated.
Thank you.


edit:
The callback has been corrected:

class ShowImagesCallback(Callback):
    def __init__(self): self.n = 9
        
    def after_epoch(self):
        axs = subplots(nrows =2, ncols=5, imsize=1.5)[1].flat
        pred_img = self.pred[0][:5].sigmoid()
        target_img = self.yb[0][:5]
        imgs = torch.cat([pred_img, target_img], axis = 0).cpu()
        for im, ax in zip(imgs, axs): show_image(im, ax=ax)
        plt.show()

But I have no idea how to remove these images when the next callback is called.