Struggling to understand SequentialEx and Merge_layer

Hi,

I can understand the functionality of both the classes. However, I’m struggling in figuring out the details.

Can anyone please explain what is going on in the forward method:
especially this line res.orig = x

class SequentialEx(Module):
    "Like `nn.Sequential`, but with ModuleList semantics, and can access module input"
    def __init__(self, *layers): self.layers = nn.ModuleList(layers)

    def forward(self, x):
        res = x
        for l in self.layers:
            res.orig = x
            nres = l(res)
            # We have to remove res.orig to avoid hanging refs and therefore memory leaks
            res.orig = None
            res = nres
        return res
class MergeLayer(Module):
    "Merge a shortcut with the result of the module by adding them or concatenating thme if `dense=True`."
    def __init__(self, dense:bool=False): self.dense=dense
    def forward(self, x): return torch.cat([x,x.orig], dim=1) if self.dense else (x+x.orig)

from where does x.orig come?
is there a reference for this .orig trick?

3 Likes

You can see where x.orig comes form in SequentialEx. It’s the value of x before passing through the layers.

1 Like

I think the @sgugger is clear. I can explain a little bit more about the scenario.
When calling the SequentialEx for the Resblock code is

SequentialEx(conv_layer(nf, nf_inner, norm_type=norm_type, **conv_kwargs),
                      conv_layer(nf_inner, nf, norm_type=norm2, **conv_kwargs),
                      MergeLayer(dense))

Then you have 3 layers into the Sequential. For example: If I call res_block(8), it is

SequentialEx(
  (layers): ModuleList(
    (0): Sequential(
      (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): ReLU(inplace)
      (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): Sequential(
      (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): ReLU(inplace)
      (2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): MergeLayer()
  )
)

In the SequentialEx , you can see 3 input layers into the self.layers

def __init__(self, *layers): self.layers = nn.ModuleList(layers)

In the forward function, there is a loop that go over these input layers

 def forward(self, x):
        res = x
        for l in self.layers:
            res.orig = x
            nres = l(res)

Every loop take the input data : x , assigned to res.orig and pass to layer module l the res as the output of the last layer or x for the first loop.

Then to the 3rd times of the loop, in the forward function of MergeLayer get x as input or res now. Then x.orig is res.orig which assigned as the original x from the Sequentialex’s forward function.

In Python you can do this: assigning a tensor to a variable (just variable reference, no copy). Then you make a new attribute to this new variable => the tensor also contains this new attribute. You can test by use id() function to get the memory address. For example

ta = torch.ones(2,2)
tb = ta
tb.xx = 1

and check id(ta), id(tb), id(ta.xx), id(tb.xx) and the value of ta.xx will be 1

3 Likes

Thanks to both of you.

Actually, this is exactly what I was looking for.
I was not aware of this possibility.

Currently, I’m going through part 2 of the MOOC and I noticed that this is used in many places. I assumed that Tensors are like generic data types. But it is being handled like a class.