Hi @prashanth
Yes and no. I don’t think it’s possible to run the get_predictions
function directly on a test set. However, I found a workaround that allows you to get predictions on an unlabeled test set.
First of all, you need to use an explicit validation set during training so that you can drop it later on. You can simply shuffle your whole dataframe, add a new column that indicates if a row is in the validation set or not and then use that column to split your data.
df = df.sample(frac=1).reset_index(drop=True)
df['is_valid'] = None
df.iloc[:int(len(df)*0.8),2] = False
df.iloc[int(len(df)*0.8):,2] = True
src = Seq2SeqTextList.from_df(df, path = path, cols='fr').split_from_df(col='is_valid').label_from_df(cols='en', label_cls=TextList)
Then you continue the usual way, train the model on the training set and evaluate it on the validation set. After that, you create a new dataframe by concatenating the training dataframe and the test dataframe. The test dataframe needs to have the same columns (source language, target language and is_valid, which has to be set to True
). The target language column can be empty, but you have to use an empty string instead of None
. Finally you use the new dataframe to create a new databunch using the vocab from the old databunch.
df_new = pd.concat([df[df.is_valid==False], df_test])
df_new.reset_index(drop=True, inplace=True)
src_new = Seq2SeqTextList.from_df(df_new, path = path, cols='fr', vocab=data.vocab).split_from_df(col='is_valid').label_from_df(cols='en', label_cls=TextList)
data_new = src_new.databunch()
Now you have to slightly modify the get_predictions
function. Also, for some reason reconstruct
didn’t work for me anymore, so I rewrote it (however in a much slower way).
def get_predictions_test(learn, ds_type=DatasetType.Valid):
learn.model.eval()
inputs, outputs = [],[]
with torch.no_grad():
for xb,_ in progress_bar(data_new.dl(ds_type)):
out = learn.model(xb)
for x,z in zip(xb,out):
inputs.append(' '.join([learn.data.x.vocab.itos[i] for i in x if not i in list(range(1,9))]))
outputs.append(' '.join([learn.data.y.vocab.itos[i] for i in z.argmax(1) if not i in list(range(1,9))]))
return inputs, outputs
And then you can get your predictions on the test set, which technically is the validation set of the new databunch - therefore ds_type=DatasetType.Valid
stays.
inputs, outputs = get_predictions_test(learn)
Certainly not the most elegant way, but it works 