Is there a way to specify a different "collate_fn" for your DataLoader?

I see this in the DataLoader code …

create_batch(self, b): return (fa_collate,fa_convert)[self.prebatched](b)

… but wondering if there is a way to pass our own collate function?

If not, would that right approach then be to simply pass a create_batch function of our own? Something like:

dls = dblock.dataloaders(imdb_df, bs=bsz, create_batch=my_collate_fn)

What is your collate function doing? In most cases, you can use before_batch instead of the collate function (for instance all the padding). Otherwise, you need to pass your own create_batch indeed.


Actually you’re right … was able to do with with before_batch (had a bug in my code).

Kinda related but is there a way to tell my Learner that there is only ONE input even if my DataBlock is setup so that there are TWO? I’m actually merging the two inputs from my DataBlock into a single input in a function I pass to before_batch. Unfortunately, Learner still expects a nn.Module with a forward() function that includes TWO arguments.

You can set learn.n_inp=1 to force the behavior, but it will probably mess with the show methods, as everything is setup to have the number of inputs of your data be the same for the Learner.

I don’t see a learn.n_inp=1 available … is this something recent?

UPDATE: It’s learn.dls.n_inp=1 … its there.

I imagine that won’t be a problem if I return a custom Tuple and include my own typedispatched show_batch and show_results?

… and the learn.predict method (as I just found out).

Maybe the easiest way is to just duplicate the first input and just not use it in my forward() function???

Or just have your model take two inputs and only use the first.

1 Like

I tried implementing something using Dataloader but it was very unefficient, especially the execution collate_fn.