Change activation function in ResNet model

This function can change all the activation in any model. I recommend Mish as of course it’s new and does awesome :wink:

def convert_act_cls(model, layer_type_old, layer_type_new):
    conversion_count = 0
    for name, module in reversed(model._modules.items()):
        if len(list(module.children())) > 0:
            # recurse
            model._modules[name] = convert_act_cls(module, layer_type_old, layer_type_new)

        if type(module) == layer_type_old:
            layer_old = module
            layer_new = layer_type_new
            model._modules[name] = layer_new

    return model

This was taken from the convert_MP_to_blurMP in ImageWoof/Nette and converted over for replacing any layer. For instance:

learn.model = convert_act_cls(learn.model, nn.ReLU, Mish())

Technically we can pass an act_cls to unet_learner however that will only change the unet part, not the encoder. This will do both

5 Likes