Data collate function always uses default function

This line redefines the collate function of the dataloader. However, if the dataloader has its own collate function, it is being re-written. It should check if there is no existing collate function or a collate function should be passed in the databunch

You’re missing the default defined just at the top: That’s the magic of dataclasses, you can type-annote and define defaults super easily.

DataBunch.create also passes this default_function (which is incidentally the pytorch default that applied .data to all the elements of the input).

Suppose I have data loader created say trn_dl, val_dl. To create the databunch I would need to do db = DataBunch(trn_dl, val_dl, pth). However, DataBunch creates the dataloader in line 62 of

Here it doesn’t take collate function as an input, therefore the self.train_dl always reassigns the collate function.

Please correct me if I am mis-understanding anything.

Ah, I see where this comes from. There is two levels of dataloaders in a DataBunch: a pytorch dataloader (which takes a collate_fn) wrapped inside a DeviceDataLoader (which is responsible for putting the batches on a given device and applying batch transforms). When you create a DataBunch, you pass it pytorch dataloaders with the collate_fn you have decided for yourself (which probably is the default from pytorch) and it’s not overloaded.

When you create a DataBunch via the .create function, the default collate_fn is used.

Exactly. However, in my case I am using a different collate function. Temporary hack that I am using
data.train_dl.dl.collate_fn = collater where collater is my collate function.

I think you’re right. Also, I notice **kwargs there isn’t being used. I suspect it should be passed to DataBunch.__init__, which in turn should pass it to the DeviceDataLoader, but I’ll leave it to @sgugger to confirm.

Ah, now I understand, sorry for being slow. Will correct this sometime today, thanks for pointing it out!

Should be fixed in this commit.