Is it possible to implement cross-validation in fastai?

I tried to implement cross-validation in fastai, but I failed.


I have used sklearn to help generate the folds, then train a fastai classifier on each train and validation dataset combination.

Can you tell me in more detail? I divided the data set into 5 folds to train, and when I did the second iteration, I got an error.

Here is how I iterate over my folds

I save off the results after each learner is trained


This my way to implement Stratified K Fold cross validation with and scikitlearn

I think that is ok and easy but I’m a newbie


have u ever raise an error : out of memory? thanks for advance!

I guess probably you can reduce the batch_size to avoid that error.

1 Like

I was just asking myself: which is the purpose of cross-validation in this situation? My guess: to get a better estimate of the accuracy of the model in the test_set?

Also what about early stopping if error starts increasing in the validation set, how can we do it?


Regarding your second question you should probably have a look at the EarlyStoppingCallback in the docs.

1 Like

Is the y in df[‘y’] the tags/label column in the df?

Yes, it is :slight_smile:

1 Like

Thank you muellerzr!

Hi all, how would you adapt @farconada’s method to fastai v2?

I tried playing with @muellerzr’s notebook but can’t seem to get it to work in the case of tabular data. I also only have 1 dataframe to work with (no separate validation set), so I’m finding it difficult to adapt his code. Would love to get some help from you @muellerzr!

Otherwise there is this notebook:

(old notebook, but only change required: from fastai.x rather than from fastai2)

Hi all,

I’m going a bit nuts with implementing KFold CV in FastAI, so I hope someone can help on this slightly older thread!

Essentially I am using FastAI V2 and am using a Tabular Learner with cont & cat cols to build a NN. But before I do that I create a robust train & unseeen test dataset:

X_train, X_test  = train_test_split(all_train_data, test_size = 0.15,random_state=1)

Having followed: I then am trying to create a stratified kfold split and pass through the folds in a loop that I load train data, train the model and then aim to predict against the unseen test data (X_test) above.

However, the fun/annoyance begins when I call get_preds(dl=test_dl) and I get:

 ~/opt/anaconda3/envs/capstone_python/lib/python3.8/site-packages/torch/nn/ in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
           1812         # remove once script supports set_grad_enabled
           1813         _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
        -> 1814     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)

IndexError: index out of range in self

Originally I thought that this was due to the random split nature of SKLearn’s train_test_split, however, I’ve reset the index for both train and test so can’t figure out why I’m getting this error. Is it because the learn.get_preds() can’t find an embedding element or something?

The code that runs the loop is as follows:

val_pct = L()
test_pct = L()
test_preds = L()
f1_results = []
auc_results = []
mcc_results = []
bal_acc_results = []

strat_kfold = StratifiedKFold(n_splits=5, shuffle=True)
res = strat_kfold.split(X_train.index,X_train['macro_LGA'])

tab_config = tabular_config(embed_p=0.25,use_bn=True,ps=[0.025],act_cls=Swish())

for x, y in res:
    ix = (L(list(x)),L(list(y)))

    train_dl = TabularDataLoaders.from_df(df=X_train,cat_names=cat_cols,
    tab_config = tabular_config(embed_p=0.25,use_bn=True,ps=[0.025,0.015],act_cls=Mish())
    learn = tabular_learner(train_dl,config=tab_config,loss_func=loss_func, 
                    metrics=[accuracy,roc_auc],lr=0.00005,emb_szs=emb_szs, layers=[100,50],
                            opt_func=RAdam), cbs=SaveModelCallback(monitor='valid_loss'))
    test_dl = TabularPandas(X_test, procs=procs, cat_names=cat_cols, 
                        cont_names=cont_cols, y_names="macro_LGA")
    test_dl = TabDataLoader(test_dl)




1 Like