Creating a dynamic sampler

Hi everyone!

For some project of mine, I am currently trying to create a dynamic sampler for my data loader, i.e. a sampler whose weights evolve depending on the loss (an item with a higher loss will then get picked more often). The idea I had was then to create a callback which updates the weights after loss is computed. To do that, I created a loss class that stores the unreduced loss and yields the reduce version when called:

class URLoss(nn.Module):
    def __init__(self, loss_func, reduction='mean'):
        super().__init__()
        self.loss_func = loss_func
        self.reduction = reduction
        self.loss = None
        
    def forward(self, input, target, **kwargs):
        self.loss = self.loss_func(input, target, reduction='none', **kwargs)
        return self.loss.mean() if self.reduction == 'mean' else self.loss.sum() 

The sampler then simply contains weights and yields the result of a call of torch.multinomial when iter is called:

class RandomSampler(Sampler):
    def __init__(self, num_samples, weights = None):
        self.weights = weights if weights is not None else torch.ones(num_samples)
        self.num_samples = num_samples

    def __iter__(self):
        return iter(torch.multinomial(self.weights, self.num_samples, True).tolist())

    def __len__(self):
        return self.num_samples

I then wan to write something like:

class UpdateSamplerCallback(LearnerCallback):
    _order = 0
    def on_backward_begin(self, last_input, **kwargs):
        loss = self.learn.loss.fn.loss
        for k, x in enumerate(last_input):
            self.learn.data.train_dl.sampler.weights[index_of_x] = loss[k]

The problem is that I have no way to access index_of_x (the index in the item list). I had a couple of ideas, but none seems to work:

  • Directly add the index as an attribute to Tensor. Problem: pytorch creates a new tensor for many basic operations (functions to, clone, detach and others), which means I would need to overwrite all of them if I want the index to get passed.
  • Access the current batch of indices using the sample_iter attribute from pytorch’s _DataLoaderIter. However, there is no way I can currently do that.

I then made some tests, and finally found a monkey patch that enables me to accomplish that:

from itertools import tee
def new_iter(self):
    dl = iter(self.dl)
    dl.sample_iter, self.sample_iter = tee(dl.sample_iter)
    for b in dl:
        yield self.proc_batch(b)

from fastai.basic_data import DeviceDataLoader
DeviceDataLoader.__iter__= new_iter

I can then get the batch indexes in on_backward_begin by calling next(self.learn.data.train_dl.sample_iter) .
The few tests I made make me believe it works fine, though it is quite slow (probably because of tee). Do you think there is a better way to accomplish what I am trying to do ? And would it be a functionality that is worth implementing on fastai ?

Thanks!

2 Likes

With the monkey patch I described, I managed to create a working version, which seems to work pretty well:

class RandomSampler(Sampler):
    def __init__(self, num_samples, weights = None):
        self.weights = weights if weights is not None else torch.ones(num_samples)
        self.to_update = torch.ones_like(self.weights, dtype=torch.bool)
        self.num_samples = num_samples

    def __iter__(self):
        return iter(torch.multinomial(self.weights, self.num_samples, True).tolist())

    def __len__(self):
        return self.num_samples


class UpdateSamplerCallback(LearnerCallback):
    _order = 0
    def on_epoch_end(self, **kwargs):
        sampler = self.learn.data.train_dl.sampler
        sampler.weights[sampler.to_update] += 0.2
        sampler.weights = torch.clamp(sampler.weights, 0., 1.)
        sampler.to_update = torch.ones_like(sampler.weights, dtype=torch.bool)
        
    def on_backward_begin(self, **kwargs):
        loss = self.learn.loss_func.loss
        dl = self.learn.data.train_dl
        idxs = next(dl.sample_iter)
        dl.sampler.weights[idxs] = 1-torch.exp(-loss.cpu().detach())
        dl.sampler.to_update[idxs] = False


@dataclass
class URLoss():
    func : nn.Module
    loss : torch.Tensor = None
    reduction : str = 'mean'
    
    def __post_init__(self):
        self.func.reduction = 'none'
        
    def __call__(self, input, target):
        self.func.reduction = 'none'
        self.loss = self.func(input, target)
        if self.reduction == 'mean': return self.loss.mean()
        elif self.reduction == 'sum': return self.loss.sum()
        else: return self.loss
        
    def __getattr__(self, name):
        return getattr(self.func, name)
    
    def __setstate__(self, data): 
        self.__dict__.update(data)
2 Likes

I have encountered that problems a few times (wanting to get the index of the things drawn in a batch) but didn’t find any satisfactory solution. Yours seem pretty good, the other one I had though of was to change the dataset to return idx,x,y, then your batches would look like sample_idxs, xs, ys.

1 Like

I thought about that as well but figured it would require many more changes so I tried to find something easier to monkey patch.