Split_idx based augmentation cannot be used with DataLoaders.from_dsets

Hi All,
I am working on a medical image segmentation problem where each call to the dataset returns a 2-tuple comprising an image and a mask. The training and validation datasets integrate with fastai dataloaders and train with any transforms as long as they apply to both. As soon as I start using split_idx=0 to selectively transform only the training dataset (e.g., warp / zoom), no transform is applied to either dataset.

The reason is that dataloaders.from_dset function is not creating a split_idx attribute when passed distinct train and valid dataloaders (with is_valid attribute). This behaviour can be confirmed in the Siamese tutorial: https://docs.fast.ai/tutorial.siamese.html#Using-the-mid-level-API. None of the dataloaders created in this tutorial - even those containing separate train and valid datasets - have split_idx, until the tutorial reaches the ‘Training and validation sets’ sections and passes a splitter directly to the TfmdLists function. I am not sure if this is intentional, but if so, it makes it impossible to use fastai selective augmentation on datasets imported via dataloaders.from_dsets function at least.
As a result of train and val datasets having no split_idx set, the dataloader init call fails this check in fastai library below:

def _call(self, fn, x, split_idx=None, **kwargs):
if split_idx!=self.split_idx and self.split_idx is not None: return x

My datasets declaration:

class LITS_dataset_fastai(torch.utils.data.Dataset):
def init(self, root_prefix,preds_prefix,casesDF, slices_with_mets_only=False, segment=‘liver’, outputChannels=3, is_valid=False):

def len(self):
-> returns length of dataset
def encodes(self,idx):
-> return __get__item(self,idx)
def getitem(self, idx):
-> return TensorImage(volume),TensorImage(mask)

The transforms:

class TrainingTransform(DisplayedTransform):
split_idx,order=0,2
def init(self,aug):
store_attr()
self.p=1.

def encodes(self,x):
    print("works")
    img,mask = x 
    aug = self.aug(image=np.array(img), mask=np.array(mask))
    return TensorImage(aug["image"]),TensorImage(aug["mask"])

tfms = [TrainingTransform(Resize(224))]

I create the datasets in a 2-list of train/valid:

dsets = [LITS_dataset_fastai(…, is_valid=valid) for valid in [False, True]]

The dataloaders line (here is the problem i am sure):

dls = DataLoaders.from_dsets(*dsets, after_batch=tfms,shuffle= True, bs=16)

Calling:

m,n=dls_m.train.one_batch()

returns untransformed images.

Usman Bashir