Implement own batch processing

I’ve been trying to implement in a new semi-supervised learning method called “MixMatch” described in a paper here: link.

Unfortunately, I’ve no idea how to handle batch processing of items when we have to divide our training set into two parts: one part labeled and the other - unlabeled.

According to the paper we have to divide the batch equally, so that precisely one half of the batch data is labeled (they have to come from the “labeled” folder, basically), and the other - unlabeled. It seems unclear to me, however, how to implement such a behaviour in library.

I know that it’s possible to implement a custom PyTorch data loader, but I’m not sure if that’s the right approach (and if it is, I’m unsure where should I put it).

I’d appreciate any help or feedback on this, as I’ve been trying to figure this out for a while together with bunch of my colleagues.

The first idea I’d try would be:

  • Create a custom ItemList where self.items contains for the first half labeled data and for the second half unlabeled data. You can store the index where it changes from labeled to unlabeled.
  • Create a custom sampler that yields 2 iterators: one over the labeled indices, one over the unlabeled ones.
  • Create a custom batch sampler that when given the previous 2 iterators, yields batches where the first part comes from the labeled iterator, the second part from the unlabeled one.
    I don’t know if you are familiar with PyTorch samplers though, feel free to ask if it is not clear for you. Another way to do this would be to have your item list storing whatever you want but returning pairs that contain one labeled item and one unlabeled. You then need to design your model and your loss to take this into account (and use a custom collate_fn function to specify how to create your batch from a list of pairs).