Hooks don't get reset after tensorboard add_graph call

Hi!

I tried using the tensorboard callback from the library with a dynamic unet, and it seems it doesn’t interact well with the hooks. In fact, tensorboard makes a sample forward pass through the model to evaluate the graph, and it of course hooks the outputs from the encoder. It seems that these outputs never get used or deleted, as when training start, I get an error telling me that shapes 1 and 4 in dimension 0 cannot get concatenated in this line from UnetBlock : cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1)).
s is a stored output as we can see in the full forward pass:

def forward(self, up_in:Tensor) -> Tensor:
    s = self.hook.stored
    up_out = self.shuf(up_in)
    ssh = s.shape[-2:]
    if ssh != up_out.shape[-2:]:
        up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
    cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))

    return self.conv2(self.conv1(cat_x))

I don’t know how we can prevent that, as the whole hook storage process is quite obscure to me. If anyone has an idea, I’d be happy to read it!

PS: Tensorboard also doesn’t work well with mixed precision training as for instance "norm_cpu" not implemented for 'Half' is raised when trying to store grads.
PPS: I don’t know why, but it actually doesn’t look like it is storing anything at all. I’m going to investigate that anyway.

2 Likes

There is (still experimental) tensorboard support in PyTorch now. So in the callback you might try replacing from tensorboardX import SummaryWriter with from torch.utils.tensorboard import SummaryWriter. Though the PyTorch support is a clone of tensorboardX so may not be any different.
If you haven’t, you might also ensure you have the latest tensorboardX, and perhaps try installing from the git, in case this has been fixed. The current code looks like it uses inbuilt pytorch ops which you’d hope would behave properly.

I already switched to pytorch’s tensorboard, which does basically the same. I just loaded latest release from pytorch and from tensorboard, guess we’ll see if it works. (for now I just commented out the add_graph line)

1 Like

Thinking a bit more it’s not especially clear how hooking during the model recording could interfere with fastai. Hooks in PyTorch are just functions that are called every forward/backward. The hook.stored is part of fastai (in fastai.callbacks.Hook), so not clear how other hooks could interfere there given the model recording is happening in tensorboard code not fastai.
I’d suspect it might in fact be the way LearnerTensorboardWriter is logging the model and specifically the timing of it. It looks like the forward is expecting the hook to already have been called. Not especially familiar with Unet but I gather that would be a hook on the encoder, that is now expected to have been called when doing a forward on the decoder. But the LearnerTensorboardWriter is calling forward on the decoder before the encoder has been run.
You might be able to move the model logging to the first batch of the first epoch, i.e. move it from on_train_begin to on_batch_begin in an if iteration == 0. The callback still might not play nicely with more custom models as they still might not like an extra forward, but it would also avoid all the stuff around getting a single batch as you can just reuse the first batch.

I think the main problem is that fastai’s callback writes on tensorboard asynchronously, which probably means that it send the request then goes on with the training. As it is probably a bit slow, it then runs through the model while train as begun, which causes a conflict on the value of self.hook.stored. I’ll try to make it synchronous to see if my intuition is correct, but for now I just asked it to not add the graph as I don’t really need it.

2 Likes