Hi,
I can understand the functionality of both the classes. However, I’m struggling in figuring out the details.
Can anyone please explain what is going on in the forward method:
especially this line res.orig = x
class SequentialEx(Module):
"Like `nn.Sequential`, but with ModuleList semantics, and can access module input"
def __init__(self, *layers): self.layers = nn.ModuleList(layers)
def forward(self, x):
res = x
for l in self.layers:
res.orig = x
nres = l(res)
# We have to remove res.orig to avoid hanging refs and therefore memory leaks
res.orig = None
res = nres
return res
class MergeLayer(Module):
"Merge a shortcut with the result of the module by adding them or concatenating thme if `dense=True`."
def __init__(self, dense:bool=False): self.dense=dense
def forward(self, x): return torch.cat([x,x.orig], dim=1) if self.dense else (x+x.orig)
from where does x.orig
come?
is there a reference for this .orig
trick?