The intended workflow is as follows:
- get_tabular_learner and use it with a train and a validation dataframe.
- store the model
- at a later point, load the model and use it on a new dataframe (prediction only)
This does work if i run it in one script:
data, embeddings = prepare_df()
learner = get_tabular_learner(data, emb_szs=embeddings, layers=[1000, 10], y_range=(-1., 1.), metrics=F.mse_loss)
learner.loss_fn = F.l1_loss
cbs = [EarlyStopping(), SaveBest(learner), OneCycleScheduler(learner, 1e-3)]
learner.fit(1000, 1e-3, callbacks=cbs)
learner.load(cbs[1].best_model_name) # this works
model = TabularModel(emb_szs=data.get_emb_szs(embeddings), n_cont=len(data.train_ds.cont_names),
out_sz=1, layers=[1000, 10], y_range=(-1., 1.))
model.load_state_dict(torch.load("models/{}.pth".format(cbs[1].best_model_name))) # this works as well
So both options, using learner.load and model.load_state dict work if they are called with the same DataBunch data. However, if i restart the script and only use
learner = get_tabular_learner(data, emb_szs=embeddings, layers=[1000, 10], y_range=(-1., 1.), metrics=F.mse_loss)
learner.load("best_model_learner")
or if i use this
model = TabularModel(emb_szs=data.get_emb_szs(embeddings), n_cont=len(data.train_ds.cont_names),
out_sz=1, layers=[1000, 10], y_range=(-1., 1.))
model.load_state_dict(torch.load("models/{}.pth".format("best_torch_model")))
it will fail. The only thing that will change in the new script run is the train test split, which happens in the function prepare_df.
StackTrace:
RuntimeError: Error(s) in loading state_dict for TabularModel:
size mismatch for embeds.15.weight: copying a param of torch.Size([3, 2]) from checkpoint, where the shape is torch.Size([4, 2]) in current model.
size mismatch for embeds.17.weight: copying a param of torch.Size([3, 8]) from checkpoint, where the shape is torch.Size([4, 8]) in current model.
size mismatch for embeds.26.weight: copying a param of torch.Size([3, 2]) from checkpoint, where the shape is torch.Size([4, 2]) in current model.
size mismatch for embeds.29.weight: copying a param of torch.Size([5, 4]) from checkpoint, where the shape is torch.Size([6, 4]) in current model.
size mismatch for embeds.34.weight: copying a param of torch.Size([6, 4]) from checkpoint, where the shape is torch.Size([5, 4]) in current model.
size mismatch for embeds.44.weight: copying a param of torch.Size([6, 4]) from checkpoint, where the shape is torch.Size([5, 4]) in current model.
size mismatch for embeds.47.weight: copying a param of torch.Size([11, 8]) from checkpoint, where the shape is torch.Size([8, 8]) in current model.
size mismatch for embeds.62.weight: copying a param of torch.Size([11, 8]) from checkpoint, where the shape is torch.Size([12, 8]) in current model.
size mismatch for embeds.67.weight: copying a param of torch.Size([15, 8]) from checkpoint, where the shape is torch.Size([14, 8]) in current model.
size mismatch for embeds.71.weight: copying a param of torch.Size([4, 2]) from checkpoint, where the shape is torch.Size([3, 2]) in current model.
size mismatch for embeds.77.weight: copying a param of torch.Size([13, 8]) from checkpoint, where the shape is torch.Size([12, 8]) in current model.
size mismatch for embeds.87.weight: copying a param of torch.Size([11, 8]) from checkpoint, where the shape is torch.Size([12, 8]) in current model.
size mismatch for embeds.92.weight: copying a param of torch.Size([13, 8]) from checkpoint, where the shape is torch.Size([11, 8]) in current model.
size mismatch for embeds.97.weight: copying a param of torch.Size([13, 8]) from checkpoint, where the shape is torch.Size([12, 8]) in current model.
size mismatch for embeds.102.weight: copying a param of torch.Size([11, 8]) from checkpoint, where the shape is torch.Size([12, 8]) in current model.
size mismatch for embeds.112.weight: copying a param of torch.Size([10, 8]) from checkpoint, where the shape is torch.Size([11, 8]) in current model.
size mismatch for embeds.115.weight: copying a param of torch.Size([3, 2]) from checkpoint, where the shape is torch.Size([4, 2]) in current model.
size mismatch for embeds.124.weight: copying a param of torch.Size([7, 4]) from checkpoint, where the shape is torch.Size([6, 4]) in current model.
size mismatch for embeds.125.weight: copying a param of torch.Size([4, 2]) from checkpoint, where the shape is torch.Size([3, 2]) in current model.
size mismatch for embeds.127.weight: copying a param of torch.Size([12, 8]) from checkpoint, where the shape is torch.Size([14, 8]) in current model.
size mismatch for embeds.131.weight: copying a param of torch.Size([3, 2]) from checkpoint, where the shape is torch.Size([4, 2]) in current model.
size mismatch for embeds.132.weight: copying a param of torch.Size([13, 8]) from checkpoint, where the shape is torch.Size([14, 8]) in current model.
size mismatch for embeds.140.weight: copying a param of torch.Size([6, 3]) from checkpoint, where the shape is torch.Size([5, 3]) in current model.
It looks like the embeddings are a problem. The embeddings variable already contains appropriate sizes for all possible categories for the categorial variables.
Am i using the functions wrong or is this a bug? Similar threads do not really help as they are all related to image classification as far as i have seen.
Using load_state_dict(…, strict=False) does not change anything.