Filter at every epoch based on threshold value

Hello everyone,

I have to implement a functionality which does the following things:-

  1. At the start of every epoch go through every batch of the entire train_dataloader.
  2. For each of the batch, pass the corresponding input tensor to the model and compare the outputs with a threshold value. The indices of the tensors with output values greater than the threshold will be used to extract data from the batch.
    (For eg. For the batch_size of 32, if the one batch of the dataloader had a size of (32, 256), and values at index 0, 3, 9 satisfy the condition, the tensor data I need from this batch would be of size (3, 256))
  3. Repeat step 2 for every batch in the dataloader
  4. In the end, I would have a filtered subset of data I need from all the batches.
  5. Use this filtered subset to create a new dataloader with the same batch size (32) and use this for training the next epoch.
  6. Repeat until I get an empty dataloader

Codewise, I am doing the following things:-


class MyCallback(LearnerCallback):
   def on_epoch_end(self, **kwargs):
      thresholded_inputs = []
      thresholded_targets = []
      for i, (inputs, targets) in enumerate(
          outputs = learn.model(inputs)
          indices = getIndicesWithOutputValuesGreaterThanThreshold(outputs)
          thresholded_inputs = thresholded_inputs + inputs[indices].tolist()
          thresholded_targets = thresholded_targets + targets[indices].tolist()
      if len(thresholded_targets) == 0:  # takes care of empty batch
           return {'stop_training': True}
  = createNewDataloaderWithSameBatchSize(thresholded_inputs, thresholded_targets)
       return {'stop_training': False}

I want to know what would be the best way to write createNewDataloaderWithSameBatchSize method as I am facing some trouble in it. Currently I am trying to do it in the following way


def createNewDataloaderWithSameBatchSize
     aux_data_class = TextClasDataBunch.from_ids(dataset_location, train_ids=thresholded_inputs, valid_ids=thresholded_inputs, train_lbls=thresholded_targets, valid_lbls=thresholded_targets, vocab=data_lm.train_ds.vocab, bs=batch_size)
     return aux_data_class.train_dl

But turns out, this is not behaving the I expected it to.
Does anybody know what would be the right way to do it?

Also I would like to know about some other approach to this problem as I think there must exist a more elegant way to solve it.