For some project of mine, I am currently trying to create a dynamic sampler for my data loader, i.e. a sampler whose weights evolve depending on the loss (an item with a higher loss will then get picked more often). The idea I had was then to create a callback which updates the weights after loss is computed. To do that, I created a loss class that stores the unreduced loss and yields the reduce version when called:
class URLoss(nn.Module): def __init__(self, loss_func, reduction='mean'): super().__init__() self.loss_func = loss_func self.reduction = reduction self.loss = None def forward(self, input, target, **kwargs): self.loss = self.loss_func(input, target, reduction='none', **kwargs) return self.loss.mean() if self.reduction == 'mean' else self.loss.sum()
The sampler then simply contains weights and yields the result of a call of
iter is called:
class RandomSampler(Sampler): def __init__(self, num_samples, weights = None): self.weights = weights if weights is not None else torch.ones(num_samples) self.num_samples = num_samples def __iter__(self): return iter(torch.multinomial(self.weights, self.num_samples, True).tolist()) def __len__(self): return self.num_samples
I then wan to write something like:
class UpdateSamplerCallback(LearnerCallback): _order = 0 def on_backward_begin(self, last_input, **kwargs): loss = self.learn.loss.fn.loss for k, x in enumerate(last_input): self.learn.data.train_dl.sampler.weights[index_of_x] = loss[k]
The problem is that I have no way to access
index_of_x (the index in the item list). I had a couple of ideas, but none seems to work:
- Directly add the index as an attribute to Tensor. Problem: pytorch creates a new tensor for many basic operations (functions
detachand others), which means I would need to overwrite all of them if I want the index to get passed.
- Access the current batch of indices using the
sample_iterattribute from pytorch’s
_DataLoaderIter. However, there is no way I can currently do that.
I then made some tests, and finally found a monkey patch that enables me to accomplish that:
from itertools import tee def new_iter(self): dl = iter(self.dl) dl.sample_iter, self.sample_iter = tee(dl.sample_iter) for b in dl: yield self.proc_batch(b) from fastai.basic_data import DeviceDataLoader DeviceDataLoader.__iter__= new_iter
I can then get the batch indexes in
on_backward_begin by calling
The few tests I made make me believe it works fine, though it is quite slow (probably because of
tee). Do you think there is a better way to accomplish what I am trying to do ? And would it be a functionality that is worth implementing on fastai ?