I have been trying to export a tabular model trained with fast.ai to the onnx format (needed for deployment). I have tried various versions of this:
target = torch.empty(1,512, dtype=torch.long, requires_grad=False).random_(28).cuda()
torch.onnx.export(learner.model, target, "/home/scott/model.onnx", verbose=True, export_params=True, operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
some also have:
learner.model.eval()
But all of the variations I try get the same error:
TypeError Traceback (most recent call last)
in
1 target = torch.empty(1,512, dtype=torch.long, requires_grad=False).random_(28).cuda()
----> 2 torch.onnx.export(learner.model, target, “/home/scott/model.onnx”, verbose=True, export_params=True, operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
~/system/anaconda3/lib/python3.7/site-packages/torch/onnx/init.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
141 operator_export_type, opset_version, _retain_param_name,
142 do_constant_folding, example_outputs,
–> 143 strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
144
145
~/system/anaconda3/lib/python3.7/site-packages/torch/onnx/utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
64 _retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding,
65 example_outputs=example_outputs, strip_doc_string=strip_doc_string,
—> 66 dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs)
67
68
~/system/anaconda3/lib/python3.7/site-packages/torch/onnx/utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, propagate, opset_version, _retain_param_name, do_constant_folding, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size)
380 example_outputs, propagate,
381 _retain_param_name, do_constant_folding,
–> 382 fixed_batch_size=fixed_batch_size)
383
384 # TODO: Don’t allocate a in-memory string for the protobuf
~/system/anaconda3/lib/python3.7/site-packages/torch/onnx/utils.py in _model_to_graph(model, args, verbose, training, input_names, output_names, operator_export_type, example_outputs, propagate, _retain_param_name, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size)
247 model.graph, tuple(in_vars), False, propagate)
248 else:
–> 249 graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
250 state_dict = _unique_state_dict(model)
251 params = list(state_dict.values())
~/system/anaconda3/lib/python3.7/site-packages/torch/onnx/utils.py in _trace_and_get_graph_from_model(model, args, training)
204 # training mode was.)
205 with set_training(model, training):
–> 206 trace, torch_out, inputs_states = torch.jit.get_trace_graph(model, args, _force_outplace=True, _return_inputs_states=True)
207 warn_on_static_input_change(inputs_states)
208
~/system/anaconda3/lib/python3.7/site-packages/torch/jit/init.py in get_trace_graph(f, args, kwargs, _force_outplace, return_inputs, _return_inputs_states)
273 if not isinstance(args, tuple):
274 args = (args,)
–> 275 return LegacyTracedModule(f, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
276
277
~/system/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
539 result = self._slow_forward(*input, **kwargs)
540 else:
–> 541 result = self.forward(*input, **kwargs)
542 for hook in self._forward_hooks.values():
543 hook_result = hook(self, input, result)
~/system/anaconda3/lib/python3.7/site-packages/torch/jit/init.py in forward(self, *args)
350 # for now.
351 inputs_states = _unflatten(all_trace_inputs[:len(in_vars)], in_desc)
–> 352 out = self.inner(*trace_inputs)
353 if self._return_inputs_states:
354 inputs_states = (inputs_states, trace_inputs)
~/system/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
537 input = result
538 if torch._C._get_tracing_state():
–> 539 result = self._slow_forward(*input, **kwargs)
540 else:
541 result = self.forward(*input, **kwargs)
~/system/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in _slow_forward(self, *input, **kwargs)
523 tracing_state._traced_module_stack.append(self)
524 try:
–> 525 result = self.forward(*input, **kwargs)
526 finally:
527 tracing_state.pop_scope()
TypeError: forward() missing 1 required positional argument: ‘x_cont’
Has anyone got any thoughts on how I might be able to do this?