Hi everyone!
For some project of mine, I am currently trying to create a dynamic sampler for my data loader, i.e. a sampler whose weights evolve depending on the loss (an item with a higher loss will then get picked more often). The idea I had was then to create a callback which updates the weights after loss is computed. To do that, I created a loss class that stores the unreduced loss and yields the reduce version when called:
class URLoss(nn.Module):
def __init__(self, loss_func, reduction='mean'):
super().__init__()
self.loss_func = loss_func
self.reduction = reduction
self.loss = None
def forward(self, input, target, **kwargs):
self.loss = self.loss_func(input, target, reduction='none', **kwargs)
return self.loss.mean() if self.reduction == 'mean' else self.loss.sum()
The sampler then simply contains weights and yields the result of a call of torch.multinomial
when iter
is called:
class RandomSampler(Sampler):
def __init__(self, num_samples, weights = None):
self.weights = weights if weights is not None else torch.ones(num_samples)
self.num_samples = num_samples
def __iter__(self):
return iter(torch.multinomial(self.weights, self.num_samples, True).tolist())
def __len__(self):
return self.num_samples
I then wan to write something like:
class UpdateSamplerCallback(LearnerCallback):
_order = 0
def on_backward_begin(self, last_input, **kwargs):
loss = self.learn.loss.fn.loss
for k, x in enumerate(last_input):
self.learn.data.train_dl.sampler.weights[index_of_x] = loss[k]
The problem is that I have no way to access index_of_x
(the index in the item list). I had a couple of ideas, but none seems to work:
- Directly add the index as an attribute to Tensor. Problem: pytorch creates a new tensor for many basic operations (functions
to
,clone
,detach
and others), which means I would need to overwrite all of them if I want the index to get passed. - Access the current batch of indices using the
sample_iter
attribute from pytorch’s_DataLoaderIter
. However, there is no way I can currently do that.
I then made some tests, and finally found a monkey patch that enables me to accomplish that:
from itertools import tee
def new_iter(self):
dl = iter(self.dl)
dl.sample_iter, self.sample_iter = tee(dl.sample_iter)
for b in dl:
yield self.proc_batch(b)
from fastai.basic_data import DeviceDataLoader
DeviceDataLoader.__iter__= new_iter
I can then get the batch indexes in on_backward_begin
by calling next(self.learn.data.train_dl.sample_iter)
.
The few tests I made make me believe it works fine, though it is quite slow (probably because of tee
). Do you think there is a better way to accomplish what I am trying to do ? And would it be a functionality that is worth implementing on fastai ?
Thanks!