For torch.jit.trace - be extra skeptical about it actually working. I’ve run into numerical issues that distort the colors noticeably in DeOldify. It’s also harder to make sure it’s hardware agnostic (running trace on CPU gets different results from GPU). So I’ve abandoned that and haven’t had those same issues with torch.jit.script .
For torch.jit.script, if you’re using DynamicUnet you’ll be out of luck- it won’t be able to deal with things like hooks. You’ll need to basically rewrite the Unet to be more friendly for Torchscript, something more like this: https://github.com/dana-kelley/DeOldify/blob/master/fasterai/generators.py
That’s the original DeOldify from 2 years ago. Notice that in AbstractUnet it doesn’t use hooks. That’s not the end of the story though- you can’t do the array access at runtime in the forward on the encoder either. So you have to extract those parts in the constructor of the Unet instead.
And wait- but that’s not all (LOL). You’ll also want to make the model backwards compatible with the fastai learner code so you’ll still want to present something like a Sequential model to the Learner. So what I’ve done is made a wrapper Module just for this purpose, that just pretends that the model is sequential by implementing getitem . Like this:
class SequentialScriptUnet(Module):
def __init__(self, encoder:nn.Module, nf_factor:int=4, leaky:float=0.01):
self.core = ScriptUnet(encoder=encoder, nf_factor=nf_factor, leaky=leaky)
def forward(self, x):
return self.core(x)
def __getitem__(self,i):
switcher = {
0: self.core.encoder,
1: self.core.middle_conv,
2: self.core.unetblock1,
3: self.core.unetblock2,
4: self.core.unetblock3,
5: self.core.unetblock4,
6: self.core.output
}
return switcher.get(i)
def append(self,l):
raise 'this aint right'
def extend(self,l):
raise 'this aint right'
def insert(self,i,l):
raise 'this aint right'
So basically, you’re creating a Unet model and training from scratch just to get something that’s compatible with torch.jit.script at the end. You extract the self.core model as your actual torchscript model.
It’s super hacky. I know. But it turns out straight up torchscript (as opposed to tracing) is really restrictive, yet it’s the way to go if you want something that actually works.
edit I forgot to mention that torch.jit.script will also complain about x.orig being accessed at MergeLayer, so you’ll need to implement those concats/residuals on your own rather than using SequentialEx.