Here is the full code snippet how I did it, but if someone knows a better way, im happy to hear it.

```
def get_data(sz, f_model, transforms, val_idxs, bs=64):
transforms=transforms_top_down
tfms = tfms_from_model(f_model, sz, aug_tfms=transforms, max_zoom=1.1)
return ImageClassifierData.from_csv(PATH, 'newtrain', label_csv, val_idxs=val_idxs, test_name='test',
tfms=tfms, bs=bs)
#get the full dataset first and then use that to split
data = get_data(sz, [0])
skf = StratifiedKFold(n_splits=4, random_state=seed, shuffle=True)
splits = skf.split(np.zeros(len(data.trn_y)), data.trn_y)
datas = []
for train_index, val_index in splits:
datas.append(get_data(sz, val_idxs=val_index))
learn = ConvLearner.pretrained(f_model, datas[0],precompute=False)
lrf=learn.lr_find()
learn.sched.plot()
#loop through each fold and train for a bit
for data in datas:
learn.set_data(data)
learn.fit(lr, 3, cycle_len=1, cycle_mult=2)
```