Init()/train()/eval() custom functionality - I need a programmer's help!

Hi All,

I’d like to program an initialization stage for DL models with LSUV and batchnorm, dropout, etc. In the init stage some of the functions/layers of the model will behave differently, for example, I want batchnorm to not do anything at init.

Currently there is the train()/eval() flags of pytorch which alter the functionality of the bn layer, but not in the way I need it to change for init.

What I need is some flag changing function on the general model, which will be accessible to all its sub modules and determine a new behavior to some of them. Something like:

model = model.init()
init(model)
model = model.train()
fit(model)
model = model.eval()

where under the init() regime bn is not active yet.
model.train() brings bn back to its normal training functionality (which later changes again to using stored stats when setting model.eval()).

Any advice about how to do it??

For the time being, I solved it using:

def init_flag(self, mode=True):
    r"""Sets the module in init mode.       """
    self.initializing = mode
    for module in self.children():
        module.init_flag(mode)
    return self
torch.nn.Module.init_flag = init_flag

If anyone has a better / more elegant way to do it, please tell!