Can we use PyTorch samplers in "DataLoaders.from_dsets"?

Here’s what I’m trying to do …

train_sampler = RandomSampler(train_ds)
valid_sampler = SequentialSampler(valid_ds)

collate_fn = partial(hft_sum_collate, tokenizer=hft_tokenizer, block_size=max_seq_len)

dls = DataLoaders.from_dsets(train_ds, valid_ds, path=PATH, dl_type=SortedDL,
                             batch_size=(bsz,bsz*2), 
                             sampler=(train_sampler, valid_sampler), #????
                             create_batch=collate_fn)

I’m not sure how to pass the samplers in … or even if we can in v2. Either way, what is the v2 approach?

Also, is there a way to specify a different dl_type for the training vs. the validation dataloader here?

Thanks

1 Like

fastai overwrites the PyTorch DataLoader, so you can’t use them with PyTorch samplers. But you can probably do the same thing with the hooks in the fastai DataLoader since we have many more :). I think the function you want is get_idxs.

As for specifying a different dl_type for train and valid, this is where you should create your DataLoaders manually and use the init.

3 Likes