Inconsistency between get_preds

(jswong) #1

I am facing some inconsistencies using the get_preds function.

_, y_train = learn.get_preds(ds_type=src.train_ds) # 7897 good
_, y_train = learn.get_preds(ds_type=DatasetType.Train) # 7818 bad

_, y_valid = learn.get_preds(ds_type=src.valid_ds) # 7897 bad
_, y_valid = learn.get_preds(ds_type=DatasetType.Valid) # 1974 good

The number after # refers to the len. I wonder why is this the case. Any help is appreciated. Thank you.

0 Likes

#2

src.valid_ds or src.train_ds are invalid keysfor ds_type, so you shouldn’t use them.
Note that the training dataloader drops the last batch if it doesn’t have batch size items, so that’s why you have the wrong length. ds_type = DatasetType.Fix will give you the training set un-shuffled and with that last batch.

1 Like