Dynamic UNet not compatible with torch.jit.script

I tried to transform UNet into ScriptModule using torch.jit.script.

However, I am getting this error:

/home/david/anaconda3/envs/pro/lib/python3.7/site-packages/torch/jit/__init__.py:1266: UserWarning: `optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead
  warnings.warn("`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead")
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-7-7f319ecfa8dd> in <module>
      6 with torch.no_grad():
      7     print(model(img))
----> 8     traced_cell = torch.jit.script(model, (img))
      9 traced_cell.save(traced_cell, modelname+"-torchscript.pth")
     10 

~/anaconda3/envs/pro/lib/python3.7/site-packages/torch/jit/__init__.py in script(obj, optimize, _frames_up, _rcb)
   1269 
   1270     if isinstance(obj, torch.nn.Module):
-> 1271         return torch.jit._recursive.create_script_module(obj, torch.jit._recursive.infer_methods_to_compile)
   1272 
   1273     qualified_name = _qualified_name(obj)

~/anaconda3/envs/pro/lib/python3.7/site-packages/torch/jit/_recursive.py in create_script_module(nn_module, stubs_fn, share_types)
    303         concrete_type = concrete_type_builder.build()
    304 
--> 305     return create_script_module_impl(nn_module, concrete_type, stubs_fn)
    306 
    307 def create_script_module_impl(nn_module, concrete_type, stubs_fn):

~/anaconda3/envs/pro/lib/python3.7/site-packages/torch/jit/_recursive.py in create_script_module_impl(nn_module, concrete_type, stubs_fn)
    357 
    358     # Actually create the ScriptModule, initializing it with the function we just defined
--> 359     script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
    360 
    361     # Compile methods if necessary

~/anaconda3/envs/pro/lib/python3.7/site-packages/torch/jit/__init__.py in _construct(cpp_module, init_fn)
   1642             """
   1643             script_module = RecursiveScriptModule(cpp_module)
-> 1644             init_fn(script_module)
   1645 
   1646             # Finalize the ScriptModule: replace the nn.Module state with our

~/anaconda3/envs/pro/lib/python3.7/site-packages/torch/jit/_recursive.py in init_fn(script_module)
    338             else:
    339                 # use the default recursive rule to compile the module
--> 340                 scripted = create_script_module_impl(orig_value, sub_concrete_type, infer_methods_to_compile)
    341             cpp_module.setattr(name, scripted)
    342             script_module._modules[name] = scripted

~/anaconda3/envs/pro/lib/python3.7/site-packages/torch/jit/_recursive.py in create_script_module_impl(nn_module, concrete_type, stubs_fn)
    357 
    358     # Actually create the ScriptModule, initializing it with the function we just defined
--> 359     script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
    360 
    361     # Compile methods if necessary

~/anaconda3/envs/pro/lib/python3.7/site-packages/torch/jit/__init__.py in _construct(cpp_module, init_fn)
   1642             """
   1643             script_module = RecursiveScriptModule(cpp_module)
-> 1644             init_fn(script_module)
   1645 
   1646             # Finalize the ScriptModule: replace the nn.Module state with our

~/anaconda3/envs/pro/lib/python3.7/site-packages/torch/jit/_recursive.py in init_fn(script_module)
    338             else:
    339                 # use the default recursive rule to compile the module
--> 340                 scripted = create_script_module_impl(orig_value, sub_concrete_type, infer_methods_to_compile)
    341             cpp_module.setattr(name, scripted)
    342             script_module._modules[name] = scripted

~/anaconda3/envs/pro/lib/python3.7/site-packages/torch/jit/_recursive.py in create_script_module_impl(nn_module, concrete_type, stubs_fn)
    361     # Compile methods if necessary
    362     if concrete_type not in concrete_type_store.methods_compiled:
--> 363         create_methods_from_stubs(concrete_type, stubs)
    364         torch._C._run_emit_module_hook(cpp_module)
    365         concrete_type_store.methods_compiled.add(concrete_type)

~/anaconda3/envs/pro/lib/python3.7/site-packages/torch/jit/_recursive.py in create_methods_from_stubs(concrete_type, stubs)
    277     rcbs = [m.resolution_callback for m in stubs]
    278     defaults = [get_default_args(m.original_method) for m in stubs]
--> 279     concrete_type._create_methods(defs, rcbs, defaults)
    280 
    281 def create_script_module(nn_module, stubs_fn, share_types=True):

RuntimeError: 
Module 'UnetBlock' has no attribute 'hook' (This attribute exists on the Python module, but we failed to convert Python type: 'Hook' to a TorchScript type.):
  File "/home/david/anaconda3/envs/pro/lib/python3.7/site-packages/fastai2/vision/models/unet.py", line 35
    def forward(self, up_in):
        s = self.hook.stored
            ~~~~~~~~~ <--- HERE
        up_out = self.shuf(up_in)
        ssh = s.shape[-2:]

Yes, this is known.

1 Like

Ah, okey. Good to know! Is a good idea to use instead torch.jit.trace?

I don’t think it’s going to worker either, because of the hooks.

Surprisingly, it is working. However, I don’t understand why!