Loading saved TabularModel fails due to embeddings

The intended workflow is as follows:

  • get_tabular_learner and use it with a train and a validation dataframe.
  • store the model
  • at a later point, load the model and use it on a new dataframe (prediction only)

This does work if i run it in one script:

    data, embeddings = prepare_df()
    learner = get_tabular_learner(data, emb_szs=embeddings, layers=[1000, 10], y_range=(-1., 1.), metrics=F.mse_loss)
    learner.loss_fn = F.l1_loss
    cbs = [EarlyStopping(), SaveBest(learner), OneCycleScheduler(learner, 1e-3)]
    learner.fit(1000, 1e-3, callbacks=cbs)
    learner.load(cbs[1].best_model_name)  # this works
    model = TabularModel(emb_szs=data.get_emb_szs(embeddings), n_cont=len(data.train_ds.cont_names),
                         out_sz=1, layers=[1000, 10], y_range=(-1., 1.))
    model.load_state_dict(torch.load("models/{}.pth".format(cbs[1].best_model_name)))  # this works as well

So both options, using learner.load and model.load_state dict work if they are called with the same DataBunch data. However, if i restart the script and only use

learner = get_tabular_learner(data, emb_szs=embeddings, layers=[1000, 10], y_range=(-1., 1.), metrics=F.mse_loss)
learner.load("best_model_learner")

or if i use this

model = TabularModel(emb_szs=data.get_emb_szs(embeddings), n_cont=len(data.train_ds.cont_names),
                     out_sz=1, layers=[1000, 10], y_range=(-1., 1.))
model.load_state_dict(torch.load("models/{}.pth".format("best_torch_model")))

it will fail. The only thing that will change in the new script run is the train test split, which happens in the function prepare_df.

StackTrace:

RuntimeError: Error(s) in loading state_dict for TabularModel:
size mismatch for embeds.15.weight: copying a param of torch.Size([3, 2]) from checkpoint, where the shape is torch.Size([4, 2]) in current model.
size mismatch for embeds.17.weight: copying a param of torch.Size([3, 8]) from checkpoint, where the shape is torch.Size([4, 8]) in current model.
size mismatch for embeds.26.weight: copying a param of torch.Size([3, 2]) from checkpoint, where the shape is torch.Size([4, 2]) in current model.
size mismatch for embeds.29.weight: copying a param of torch.Size([5, 4]) from checkpoint, where the shape is torch.Size([6, 4]) in current model.
size mismatch for embeds.34.weight: copying a param of torch.Size([6, 4]) from checkpoint, where the shape is torch.Size([5, 4]) in current model.
size mismatch for embeds.44.weight: copying a param of torch.Size([6, 4]) from checkpoint, where the shape is torch.Size([5, 4]) in current model.
size mismatch for embeds.47.weight: copying a param of torch.Size([11, 8]) from checkpoint, where the shape is torch.Size([8, 8]) in current model.
size mismatch for embeds.62.weight: copying a param of torch.Size([11, 8]) from checkpoint, where the shape is torch.Size([12, 8]) in current model.
size mismatch for embeds.67.weight: copying a param of torch.Size([15, 8]) from checkpoint, where the shape is torch.Size([14, 8]) in current model.
size mismatch for embeds.71.weight: copying a param of torch.Size([4, 2]) from checkpoint, where the shape is torch.Size([3, 2]) in current model.
size mismatch for embeds.77.weight: copying a param of torch.Size([13, 8]) from checkpoint, where the shape is torch.Size([12, 8]) in current model.
size mismatch for embeds.87.weight: copying a param of torch.Size([11, 8]) from checkpoint, where the shape is torch.Size([12, 8]) in current model.
size mismatch for embeds.92.weight: copying a param of torch.Size([13, 8]) from checkpoint, where the shape is torch.Size([11, 8]) in current model.
size mismatch for embeds.97.weight: copying a param of torch.Size([13, 8]) from checkpoint, where the shape is torch.Size([12, 8]) in current model.
size mismatch for embeds.102.weight: copying a param of torch.Size([11, 8]) from checkpoint, where the shape is torch.Size([12, 8]) in current model.
size mismatch for embeds.112.weight: copying a param of torch.Size([10, 8]) from checkpoint, where the shape is torch.Size([11, 8]) in current model.
size mismatch for embeds.115.weight: copying a param of torch.Size([3, 2]) from checkpoint, where the shape is torch.Size([4, 2]) in current model.
size mismatch for embeds.124.weight: copying a param of torch.Size([7, 4]) from checkpoint, where the shape is torch.Size([6, 4]) in current model.
size mismatch for embeds.125.weight: copying a param of torch.Size([4, 2]) from checkpoint, where the shape is torch.Size([3, 2]) in current model.
size mismatch for embeds.127.weight: copying a param of torch.Size([12, 8]) from checkpoint, where the shape is torch.Size([14, 8]) in current model.
size mismatch for embeds.131.weight: copying a param of torch.Size([3, 2]) from checkpoint, where the shape is torch.Size([4, 2]) in current model.
size mismatch for embeds.132.weight: copying a param of torch.Size([13, 8]) from checkpoint, where the shape is torch.Size([14, 8]) in current model.
size mismatch for embeds.140.weight: copying a param of torch.Size([6, 3]) from checkpoint, where the shape is torch.Size([5, 3]) in current model.

It looks like the embeddings are a problem. The embeddings variable already contains appropriate sizes for all possible categories for the categorial variables.

Am i using the functions wrong or is this a bug? Similar threads do not really help as they are all related to image classification as far as i have seen.

Using load_state_dict(…, strict=False) does not change anything.

1 Like

I have traced it to the function

def def_emb_sz(df, n, sz_dict):
    col = df[n]
    n_cat = len(col.cat.categories)+1  # extra cat for NA
    sz = sz_dict.get(n, min(50, (n_cat//2)+1))  # rule of thumb
    return n_cat,sz

in tabular/data.py
n_cat varies with each random dataset, since not all possible categorical values may be present. however, i am unsure how to best proceed.

Ah, interesting problem. We had automated the creation of embeddings to make the user’s life easier but it looks like some bad side effect.
To solve the problem for now, you should load your model with your original data then change the data to your new one (not ideal I know, we’ll try to fix this).

learner = get_tabular_learner(old_data, emb_szs=embeddings, layers=[1000, 10], y_range=(-1., 1.), metrics=F.mse_loss)
learner.load("my_awesome_model")
learner.data = new_data
1 Like

Thanks for the quick reply! Unfortunately, trying to create a learner with the training data and then using it to predict another dataset yields a different error:

Traceback (most recent call last):
  File "/home/jan/PythonProjects/analysis/fastai_analysis.py", line 163, in <module>
    prediction(best_model_name, pred_data)
  File "/home/jan/PythonProjects/analysis/fastai_analysis.py", line 148, in prediction
    out, reference = predict(model, dl)
  File "/home/jan/PythonProjects/analysis/fastai_analysis.py", line 72, in predict
    res.append(model(*xb).numpy())
  File "/home/jan/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/jan/PythonProjects/fastai/fastai/tabular/models.py", line 27, in forward
    x = [e(x_cat[:,i]) for i,e in enumerate(self.embeds)]
  File "/home/jan/PythonProjects/fastai/fastai/tabular/models.py", line 27, in <listcomp>
    x = [e(x_cat[:,i]) for i,e in enumerate(self.embeds)]
  File "/home/jan/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/jan/anaconda3/lib/python3.6/site-packages/torch/nn/modules/sparse.py", line 113, in forward
    self.norm_type, self.scale_grad_by_freq, self.sparse)
  File "/home/jan/anaconda3/lib/python3.6/site-packages/torch/nn/functional.py", line 1205, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: index out of range at /opt/conda/conda-bld/pytorch-nightly-cpu_1539155206632/work/aten/src/TH/generic/THTensorEvenMoreMath.cpp:191

I had another idea: Write a custom embedding size function with the maximum number of categories for each categorical column, not calculated from the training dataset, but from the maximum possible variations per category. sz_dict is a dict of cat_names to maximum number of categories regardless of the datasets in data.

def my_get_emb_szs(data, sz_dict):
    emb = []
    for c in data.cat_names:
        assert c in sz_dict
        sz_with_nan = sz_dict[c] + 1
        emb.append((sz_with_nan, min(50, (sz_with_nan//2)+1)))
    return emb

This yields the same error when creating a torch.embedding:

...
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: index out of range at /opt/conda/conda-bld/pytorch-nightly-cpu_1539155206632/work/aten/src/TH/generic/THTensorEvenMoreMath.cpp:191

It looks like you are trying to use your model with a new_data that has new categories. Be sure to put the new categories to nan otherwise you will have those errors.

Any update on the best approach for this? Have not been able to successfully train/save/reload/predict.

I’m unsure how to apply sgugger’s suggestn to replace unknown categorical values with nan, so i’m using the workaround to one-hot-encode my categorical variables. e.g. with dummyPy or scikit-learn, which set unknown categories to 0 automatically.

any updates for loading Tabular models without loading the old data again?

You should check the inference tutorial.

1 Like

thank you so much! worked like a charm

1 Like

I’m trying to create embeddings to represent airports for the airline domain, and most of the actual problems I want to solve don’t have data for every single airport so I can’t use the pretrained embeddings.

Any progress on a more general solution for this? I’ve tried just using the Pytorch primitives to save/load the state-dict on the specific embedding and it fails on the size mismatch (despite using strict=False).

Did you find a solution for this? I’m also having the same issue…

Actually yes - but I had to write my own code to save the embedding data and reload it. I posted my code here: Tabular Transfer Learning and/or retraining with fastai

1 Like

Thank you!