Can't Pickle TabNet Models Inside Optuna Optimizer

Hi all,

I implemented TabNet in my Optuna optimizer to find the best hyperparameters. The same was done for the tabular_learner from fastai, and everything worked fine. This time, however, I am getting an error when attempting to export my model.

AttributeError: Can’t pickle local object ‘fit_with..TabNetModel’

The weirdest part is that if I setup everything the same outside of my Optuna optimizer, then I am able to export the model with

net = TabNetModel(emb_szs, len(to_nn.cont_names), dls.c, n_d=n_d, n_a=n_a, n_steps=n_steps,gamma=gamma,n_independent=n_independent,n_shared=num_shared,epsilon=epsilon,momentum=mom);
learn = Learner(dls, net, CrossEntropyLossFlat(), metrics=accuracy, opt_func=Adam)
learn.fit_one_cycle(3, 1e-1)

fn = ‘Pkl {}’.format(‘Ignore’)
fn = datapath+‘Models/’+fn

I will post the function where it is trying to export below, but please note that optunalearn is simply the name for “learn” in the case above. If somebody would like all the optuna functions, I can provide them, but perhaps that is for another forum since this appears to be a pytorch or FastAI issue.

def FitAndEval(trial,optunalearn,roundepochs,lr,wd):


fn = 'Pkl {}'.format(trial.number)
fn = datapath+'Models/'+fn


acc = optunalearn.validate(dl=test_dl)

return 1-acc[1]


AttributeError Traceback (most recent call last)
7 study = optuna.create_study(study_name=‘TabNetStudy2’, storage=‘sqlite:///F:/Linux Desktop/My Notebooks/CURRENT_OPTUNA_STUDY.db’, load_if_exists=True) #Loads the study
----> 8 study.optimize(objective,n_trials=10) #Runs optimization

~\anaconda3\envs\FastTab\lib\site-packages\optuna\ in optimize(self, func, n_trials, timeout, n_jobs, catch, callbacks, gc_after_trial, show_progress_bar)
374 If nested invocation of this method occurs.
375 “”"
–> 376 _optimize(
377 study=self,
378 func=func,

~\anaconda3\envs\FastTab\lib\site-packages\ in _optimize(study, func, n_trials, timeout, n_jobs, catch, callbacks, gc_after_trial, show_progress_bar)
61 try:
62 if n_jobs == 1:
—> 63 _optimize_sequential(
64 study,
65 func,

~\anaconda3\envs\FastTab\lib\site-packages\ in _optimize_sequential(study, func, n_trials, timeout, catch, callbacks, gc_after_trial, reseed_sampler_rng, time_start, progress_bar)
163 try:
–> 164 trial = _run_trial(study, func, catch)
165 except Exception:
166 raise

~\anaconda3\envs\FastTab\lib\site-packages\ in _run_trial(study, func, catch)
261 if state == TrialState.FAIL and func_err is not None and not isinstance(func_err, catch):
–> 262 raise func_err
263 return trial

~\anaconda3\envs\FastTab\lib\site-packages\ in _run_trial(study, func, catch)
210 try:
–> 211 value_or_values = func(trial)
212 except exceptions.TrialPruned as e:
213 # TODO(mamu): Handle multi-objective cases.

in objective(trial)
17 wd = trial.suggest_loguniform(‘wd’,.6,.999999999)
—> 19 return fit_with(trial,lr,mom,n_d,n_a,n_steps,gamma,n_independent,num_shared,epsilon,epochs,wd)

in fit_with(trial, lr, mom, n_d, n_a, n_steps, gamma, n_independent, num_shared, epsilon, epochs, wd)
43 roundepochs = round(epochs)
—> 45 return FitAndEval(trial,optunalearn,roundepochs,lr,wd)

in FitAndEval(trial, optunalearn, roundepochs, lr, wd)
8 # (datapath,’/Models/’,fn).save(bayeslearn)
9 # learnerout = optunalearn
—> 10 optunalearn.export(fn)
11 # preds,targs = bayeslearn.get_preds(dl=dls.valid)
12 # acc = Accuracy(preds,targs)

~\anaconda3\envs\FastTab\lib\site-packages\fastai\ in export(self, fname, pickle_module, pickle_protocol)
543 after_validate = “Log loss and metric values on the validation set”,
544 after_cancel_train = “Ignore training metrics for this epoch”,
–> 545 after_cancel_validate = “Ignore validation metrics for this epoch”,
546 plot_loss = “Plot the losses from skip_start and onward”)

~\anaconda3\envs\FastTab\lib\site-packages\torch\ in save(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization)
370 if _use_new_zipfile_serialization:
371 with _open_zipfile_writer(opened_file) as opened_zipfile:
–> 372 _save(obj, opened_zipfile, pickle_module, pickle_protocol)
373 return
374 _legacy_save(obj, opened_file, pickle_module, pickle_protocol)

~\anaconda3\envs\FastTab\lib\site-packages\torch\ in _save(obj, zip_file, pickle_module, pickle_protocol)
474 pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
475 pickler.persistent_id = persistent_id
–> 476 pickler.dump(obj)
477 data_value = data_buf.getvalue()
478 zip_file.write_record(‘data.pkl’, data_value, len(data_value))

AttributeError: Can’t pickle local object ‘fit_with..TabNetModel’

I think this is likely a limitation of Python pickling rather than a specific issue with fastai.

Are you trying to save the fastai model at the end of each Optuna trial?

Hey Dan,

Yes that’s what I’m trying to do. And like I mentioned, it worked fine for my fastai models before, but now that I’m bringing in some fastai mixed with pytorch, it seems to be throwing a fit.

I suspect the issue is with pickle, there are certain things that pickle can’t pickle e.g. lambdas and I suspect that this might be because your class is instantiated outside of your function. One thing that might be worth trying it to instantiate your TabNetModel inside of your FitAndEval function and see if that helps? I could definitely be off with what is happening here though…

I just gave that a try but no luck. Same error