It would be nice to be able to get predictions in the same order as they are in my test dataframe.
Code for reference:
import pandas as pd
from fastai import *
from fastai.tabular import *
path = './'
train_df = pd.read_csv('./train.csv')
split = 40000
valid_idx = range(len(train_df)-split, len(train_df))
test_df = pd.read_csv('./test.csv')
dep_var = 'target'
data = TabularDataBunch.from_df(path, train_df, dep_var, valid_idx=valid_idx, test_df=test_df)
learn = tabular_learner(data, layers=[200,20], metrics=accuracy)
learn.data.show_batch()
learn.lr_find()
learn.recorder.plot()
learn.fit_one_cycle(10, 1e-2)
preds, y = learn.get_preds(DatasetType.Test) # <-- I think these are in a different order than test_0, test_1, etc.
As a workaround right now I am just iterating through the test dataframe and predicting each row one at a time, but this is just slow and seems wrong
That is weird, they normally are in the same order. Can you check that data.show_batch(ds_type=DatasetType.Test) returns the same things as your first rows?
Note that test sets are unlabeled, so if you say this because your ys are 0, this isn’t a good check.
Interesting, I see that now. However another issue, when I pass in a databunch to get preds, they do seem out of order, or the accuracy drops dramatically. When I do a learn.predict() vs learn.get_preds(), predict returns ~97% accuracy whereas get_preds when comparing with the actual truth only gives me ~50%. Is order being lost in get_preds?
Ah I think I see my issue now. get_preds at location 1 (getpreds()[1]) returns the LOCATION of the category on the list, not the category itself. Apologies! Has there been thought to include the predicted category for situations in tabular regression?
Yeah, is there any follow up on this issue? I am experiencing the same thing. When I use get_pred(DatasetType.Test) it returns all the rows out of order. Currently, I am iterating through each row in my test set and using the predict method, but this is painfully slow. Would love to know if anyone has solved this problem! Thanks!