Hi All,
I’d like to program an initialization stage for DL models with LSUV and batchnorm, dropout, etc. In the init stage some of the functions/layers of the model will behave differently, for example, I want batchnorm to not do anything at init.
Currently there is the train()/eval() flags of pytorch which alter the functionality of the bn layer, but not in the way I need it to change for init.
What I need is some flag changing function on the general model, which will be accessible to all its sub modules and determine a new behavior to some of them. Something like:
model = model.init()
init(model)
model = model.train()
fit(model)
model = model.eval()
where under the init() regime bn is not active yet.
model.train() brings bn back to its normal training functionality (which later changes again to using stored stats when setting model.eval()).
Any advice about how to do it??