Runnning get_pred on ds_type=DatasetType.Train returns less predictions than the number of images in the train ds?

I have a databunch of 1000 images, 800 in the train set, 200 in the valid set.
When I run get_preds on valid (without specifying the dataset type), I get 200 predictions. All good.

But When I run get_preds on the train part of the databunch (by .get_preds(ds_type=DatasetType.Train) ), I collect only 768 predictions. I don’t understand what I am doing wrong. What could make this happen?

I try with several different datasets, and always working on the valid, always few missing on train ds type.

Here is the simpliest version of the code:

np.random.seed(42)
mini_data_all = (ImageDataBunch.from_df("DATA/train_img/", mini_df, ds_tfms=tfms, size=64,label_col='label',fn_col='file_name',valid_pct=.2)).normalize(imagenet_stats)
mini_data_all

out: ImageDataBunch;

Train: LabelList (800 items)
x: ImageList
Image (3, 64, 64),Image (3, 64, 64),Image (3, 64, 64),Image (3, 64, 64),Image (3, 64, 64)
y: CategoryList
0,1,1,1,1
Path: DATA/train_img;

Valid: LabelList (200 items)
x: ImageList
Image (3, 64, 64),Image (3, 64, 64),Image (3, 64, 64),Image (3, 64, 64),Image (3, 64, 64)
y: CategoryList
0,0,0,0,1
Path: DATA/train_img;


learn_for_mini_data = create_cnn(mini_data_all, arch, metrics=[accuracy])
learn_for_mini_data.load('good_model')
preds_mini_all_val= learn_for_mini_data.get_preds()
classes_mini_val=np.argmax(preds_mini_all_val[0], axis=1)
print (len(classes_mini_val))

out: 200

preds_mini_all_train= learn_for_mini_data.get_preds(ds_type=DatasetType.Train)
classes_mini_train=np.argmax(preds_mini_all_train[0], axis=1)
print (len(classes_mini_train))

out:768

If I get the classes GT directly by preds_mini_all_train[1] it is exactly the same:

classes_mini_train=np.argmax(preds_mini_all_train[0], axis=1)
classes_mini_train=preds_mini_all_train[1]
print (len(classes_mini_train))

out: 768

If someone could help me to get the full number of predictions when running on the training set?
TIA
K.

It’s dropping the last batch. There may be way to tell it not to drop the last batch but I don’t know how off the top of my head.

2 Likes

You’re absolutely right. Thank you.
The solution given by @sgugger post

get_preds(ds_type= DatasetType.Fix )

2 Likes