Hello everyone,
I have to implement a functionality which does the following things:-
- At the start of every epoch go through every batch of the entire train_dataloader.
- For each of the batch, pass the corresponding input tensor to the model and compare the outputs with a threshold value. The indices of the tensors with output values greater than the threshold will be used to extract data from the batch.
(For eg. For the batch_size of 32, if the one batch of the dataloader had a size of (32, 256), and values at index 0, 3, 9 satisfy the condition, the tensor data I need from this batch would be of size (3, 256)) - Repeat step 2 for every batch in the dataloader
- In the end, I would have a filtered subset of data I need from all the batches.
- Use this filtered subset to create a new dataloader with the same batch size (32) and use this for training the next epoch.
- Repeat until I get an empty dataloader
Codewise, I am doing the following things:-
Code:
class MyCallback(LearnerCallback):
def on_epoch_end(self, **kwargs):
learn.model.eval()
thresholded_inputs = []
thresholded_targets = []
for i, (inputs, targets) in enumerate(learn.data.train_dl):
outputs = learn.model(inputs)
indices = getIndicesWithOutputValuesGreaterThanThreshold(outputs)
thresholded_inputs = thresholded_inputs + inputs[indices].tolist()
thresholded_targets = thresholded_targets + targets[indices].tolist()
if len(thresholded_targets) == 0: # takes care of empty batch
return {'stop_training': True}
else:
learn.data.train_dl = createNewDataloaderWithSameBatchSize(thresholded_inputs, thresholded_targets)
learn.model.train()
return {'stop_training': False}
I want to know what would be the best way to write createNewDataloaderWithSameBatchSize method as I am facing some trouble in it. Currently I am trying to do it in the following way
Code
def createNewDataloaderWithSameBatchSize
aux_data_class = TextClasDataBunch.from_ids(dataset_location, train_ids=thresholded_inputs, valid_ids=thresholded_inputs, train_lbls=thresholded_targets, valid_lbls=thresholded_targets, vocab=data_lm.train_ds.vocab, bs=batch_size)
return aux_data_class.train_dl
But turns out, this is not behaving the I expected it to.
Does anybody know what would be the right way to do it?
Also I would like to know about some other approach to this problem as I think there must exist a more elegant way to solve it.