Problems when exporting model by JIT tracing because of Mish

I created a Unet model using this line -

learn = unet_learner(dls, resnet18, act_cls=Mish, opt_func=ranger, n_out=2)

Then, I trained the model. Now, I am trying to export the model for inference using JIT trace -

torch.jit.save(torch.jit.trace(learn.model, trans(image).float().unsqueeze_(0)), 'model.pt').

But it is throwing the error below -

RuntimeError: 
Could not export Python function call 'MishJitAutoFn'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:

As per error, function should have @script annotation. I checked in fastai code, there already exists @script.

@script
def _mish_jit_fwd(x): return x.mul(torch.tanh(F.softplus(x)))

Can anybody tell me what’s wrong and how to solve this problem?

1 Like