Batch_size set to None after specifying batch_sampler

Hi!

After specifying a batch_sampler in the Dataloader for the training data my batch_size is set to None instead of the batch_size in the batch_sampler. This results in some downstream issues e.g. when using data.show_batch. (TypeError: ‘<’ not supported between instances of ‘NoneType’ and ‘int’)

I set the batch_sampler the following way, similar to @sgugger’s suggestion for setting sampler here: Adding a custom sampler to databunch

data.train_dl = data.train_dl.new(shuffle=False, drop_last=False, 
sampler=None, batch_sampler=mybatchsampler)

Just setting the batch_size after this like data.batch_size = mybatchsize raised the ValueError of batch_size mutually exclusive etc.

For now I’ve solved it by creating a custom DataLoader with a change in __init__ for the batch_sampler like this:

   		if batch_sampler is not None:
			if batch_size > 1 or shuffle or sampler is not None or drop_last:
				raise ValueError('batch_sampler is mutually exclusive with '
								 'batch_size, shuffle, sampler, and drop_last')
			else:
				self.batch_size=batch_sampler.batch_size

It’s not very nice and means that all other classes using the DataLoader now have to use my custom DataLoader instead, which means a lot of copy-pasting from the docs for some very minor tweaks. I was wondering if there is another way to get the batch_size in DataLoader take the value of the batch_size specified in batch_sampler. Any suggestions are much appreciated!

You can do something like DataLoader.__init__ = yourinit at the start of your code, which will replace the basic init function with yours. Besides, if your change makes sense you could suggest a PR to avoid this bug for future users.

1 Like

Thanks, that is a great idea!

Hey @florobax I did this -

class my_sampler(Sampler):
    def __init__(self, data_source):
        super(my_sampler, self).__init__(data_source)
        self.data_source = data_source
        self.batch_size=64

db.train_dl = db.train_dl.new(shuffle=False, sampler=my_sampler())

I am getting the following error -

TypeError: __init__() missing 1 required positional argument: 'data_source'

It works if I do :

db.train_dl = db.train_dl.new(shuffle=False, sampler=my_sampler(db.train_ds))

Yes because you need to give the argument data_source to my_sampler when creating it.