Yep, I want to update some parameter weights with a callback.
For example, weight decay does that inside fastai, but it actually does it in the optimizer step, not a callback. I saw the code for weight decay inside fast.ai, and to be honest, I don’t understand why it works on a multi-GPU environment:
def step(self)->None:
"Set weight decay and step optimizer."
# weight decay outside of optimizer step (AdamW)
if self.true_wd:
for lr,wd,pg1,pg2 in zip(self._lr,self._wd,self.opt.param_groups[::2],self.opt.param_groups[1::2]):
for p in pg1['params']: p.data.mul_(1 - wd*lr)
if self.bn_wd:
for p in pg2['params']: p.data.mul_(1 - wd*lr)
self.set_val('weight_decay', listify(0, self._wd))
self.opt.step()
Now, I’m not doing weight decay, but something kind of like that: making sure the weights satisfy certain constraints I need (don’t worry about why).
I’ve tested the above code with a single thread and it works and does what I think it does. But on multiple threads, since I can’t use jupyter, it’s a bit harder to see if the models are maybe getting out of sync or whatever. I think it might be getting applied only to one of the two models. But then again, the code above (from fast.ai) would suffer from the same problem… wouldn’t it?
I also tried like this:
class MyCallback(LearnerCallback):
def __init__(self, learn:Learner):
super().__init__(learn)
def on_batch_end(self, **kwargs):
important_parameter = learn.model.module[17].weight # or whatever
with torch.no_grad():
modify_in_place(important_parameter)
But I’m seeing the same problem 
Thank you for your reply.
Notice that in this new version I had to do learn.model.module[17], because nn.DataParallel apparently adds a “module” and the model is within that… but I’m somewhat confused.