Retrieve hook.stored tensor for Model Parallelism

Hello, I have searched but I think I might be the first person trying to use PyTorch RPC module for having Model Parallelism with FastAI :sweat_smile:

At work I’m doing a Python module to simplify data and model parallelism with PyTorch, as some of my coworkers use FastAI., I thought “Ok, let’s try to include FastAI too”.

So far, a small change in SequentialEx is needed to be able to split the model for Model Parallelism (Add x_orig param in SequentialEx to allow split models by Patataman · Pull Request #4042 · fastai/fastai · GitHub) However, when I was finishing testing it I discovered that hooks are another problem for this.

Until now I have been testing with Unet because it’s the model we usually train and, for example, using the same code as in the PR:

class SimplifiedModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        layers = [m for m in model.layers]
        m_len = len(model.layers)
        self.layer1 = SequentialEx(*layers[:m_len//3])
        self.layer2 = SequentialEx(*layers[m_len//3:m_len//3*2])
        self.layer3 = SequentialEx(*layers[m_len//3*2:])

    def forward(self, x):
        _x = self.layer1(x)
        _x = self.layer2(_x, x_orig=x)
        return self.layer3(_x, x_orig=x)

model = resnet34
learn = unet_learner(dls, model, loss_func, [...])
newmodel = SimplifiedModel(learn.model)

I have noticed that only the first 4 layers in the Unet (layers (sequentials with BasicBlocks), BatchNorm2D, ReLU, and another sequential with Conv2D and ReLU) trigger the hook, and then finally used in the first UnetBlock, and the hook is no longer triggered.

Therefore, if I split the model as:

  • Node1: Layer1
  • Node2: Layer2, Layer3

Only the hooks are triggered in the layer1, and if I send that layer to another node (physically separated), no update occurs in layer2 and therefore, the data retrieved from hook.stored in the first UnetBlock is wrong (link to code)

I have tried to get, somehow, the data stored at the hooks before entering the first UnetBlock, but despite triggering the hook, I found impossible to get anything using hook_output or similar functions.

TL;DR
Is there a way to get the tensor stored in the hooks so I can send it to the other node and (somehow, need to figure out this) update the stored tensor in the hook before continue?

PS: There is a workaround, but I don’t like it because it doesn’t solve the real problem, which is, in this very specific case, split the model as:

  • Node1: Layer1, Layer2 (This way, the tensor stored in the hook is shared when needed)
  • Node2: Layer3