Epochs of arbitrary length

When training with a lot of data being able to partition the train set into epochs of arbitrary size can be quite useful (for instance, for better training monitoring, to save models, when using with reduce lr callbacks / stop training).

I have a good manual workaround but not sure how / if this should be integrated into the library. Here is the code that I use:

# https://stackoverflow.com/a/312464/1105837
def chunks(l, n):
    """Yield successive n-sized chunks from l."""
    for i in range(0, len(l), n):
        yield l[i:i + n]

class RandomSamplerWithEpochSize(Sampler):
    """Yields epochs of specified sizes. Iterates over all examples in a data_source in random
    order. Ensures (nearly) all examples have been trained on before beginning the next iteration
    over the data_source - drops the last epoch that would likely be smaller than epoch_size.
    def __init__(self, data_source, epoch_size):
        self.n = len(data_source)
        self.epoch_size = epoch_size
        self._epochs = []
    def __iter__(self):
        return iter(self.next_epoch)
    def next_epoch(self):
        if len(self._epochs) == 0: self.generate_epochs()
        return self._epochs.pop()
    def generate_epochs(self):
        idxs = [i for i in range(self.n)]
        self._epochs = list(chunks(idxs, self.epoch_size))[:-1]
    def __len__(self):
        return self.epoch_size

class DataBunch():
    "Bind `train_dl`,`valid_dl` and`test_dl` to `device`. tfms are DL tfms (normalize). `path` is for models."
    def __init__(self, train_dl:DataLoader, valid_dl:DataLoader, test_dl:Optional[DataLoader]=None,
                 device:torch.device=None, tfms:Optional[Collection[Callable]]=None, path:PathOrStr='.',
        "Bind `train_dl`,`valid_dl` and`test_dl` to `device`. tfms are DL tfms (normalize). `path` is for models."
        self.tfms = listify(tfms)
        self.device = defaults.device if device is None else device
        self.train_dl = DeviceDataLoader(train_dl, self.device, self.tfms, collate_fn)
        self.valid_dl = DeviceDataLoader(valid_dl, self.device, self.tfms, collate_fn)
        self.test_dl  = DeviceDataLoader(test_dl,  self.device, self.tfms, collate_fn) if test_dl else None
        self.path = Path(path)

    def create(cls, train_ds:Dataset, valid_ds:Dataset, test_ds:Dataset=None, path:PathOrStr='.', bs:int=64,
               num_workers:int=defaults.cpus, tfms:Optional[Collection[Callable]]=None, device:torch.device=None,
               collate_fn:Callable=data_collate, epoch_size:int=10_000)->'DataBunch':
        "`DataBunch` factory. `bs` batch size, `ds_tfms` for `Dataset`, `tfms` for `DataLoader`."
        datasets = [train_ds,valid_ds]
        if test_ds is not None: datasets.append(test_ds)
        dls = [DataLoader(*o, num_workers=num_workers) for o in
               zip(datasets, (bs,bs*2,bs*2), (True,False,False))]
        dls[0] = DataLoader(train_ds, num_workers=num_workers,
                    batch_sampler=BatchSampler(RandomSamplerWithEpochSize(train_ds, epoch_size), bs, False))
        return cls(*dls, path=path, device=device, tfms=tfms, collate_fn=collate_fn)

    def __getattr__(self,k:int)->Any: return getattr(self.train_ds, k)
    def holdout(self, is_test:bool=False)->DeviceDataLoader:
        "Returns correct holdout `Dataset` for test vs validation (`is_test`)."
        return self.test_dl if is_test else self.valid_dl

    def add_tfm(self,tfm:Callable)->None:
        if self.test_dl: self.test_dl.add_tfm(tfm)

    def train_ds(self)->Dataset: return self.train_dl.dl.dataset
    def valid_ds(self)->Dataset: return self.valid_dl.dl.dataset
    def loss_func(self)->Dataset: return self.train_ds.loss_func

This is a bit rough around the edges (drops the last epoch so that I don’t have to deal with smaller epochs). This functionality could be arrived at by passing a custom BatchSampler into DataBunch.create but not sure if adding this is a good idea. Also, this would go against the nice zipping mechanism we have for creating dataloaders.


Hi! I second this question.

Is specifying the epoch size out of fashion? (as in, it was proven as an useless parameter to be tuned?) When I was using keras it was something that had some impact on training, especially that the LR schedulers were activated at the end-of-epoch.

My use case: I extract some features (model’s intermediate layer) from a dataset. With the amazing augmentation capabilities one can generate a fairly large volume of data. These features then form a training set for some head. Iterating through it, is considered by default fastai routines as one epoch. Of course, one can augment 1, 100 or 1000x the original dataset.

Are the more advanced fastai callbacks somehow immune to this phenomena? (eg oneCycle). I am looking at SaveModelCallback and the statistics are checked on end of the epoch.

Thank you!
p.s. @radek thanks for the code, I will check it out!

1 Like

Yeah the real issue here is that the idea of “epoch” is kinda silly for large datasets. It’s more likely you want to do things every n batches, since you may well have so much data you can’t even run one full epoch.

But then the concept of epoch is somewhat convenient sometimes - for smaller datasets it’s useful to think about how many times you want to run through the dataset.

@sgugger should we make it so if pass a tuple to n_epoch then the first item is treated as total_batches and second item is epoch_length. So if you pass (500.100) then it does 5 “epochs” of length 100 batches.


We would need a callback to change the dataloaders then (a bit like waht @radek is doing up there), so that it works without impacting the training loop, but I don’t have any objection on this.

I was thinking we would change the training loop…

1 Like

The only easy way to do this would be to interrupt the train_dl, which means we might go over the same data.

Would it be a good idea to track this potential improvement as a github issue? I can create one if you want.

Issues are for known bugs only. Potential new features are discussed and tracked on the forum :slight_smile:

I’ve been modifying the callbacks to work at “every n” batch level instead of epochs, maybe that’s one possible avenue? So for example the SaveModelCallback I’ve modified it to save every n batches instead of a every epoch. A default setting could be n_batches = total number of batches in an epoch, which (i think) should then give the same behavior as what currently exists for the callbacks, while allowing for more fine grained control.

That being said I haven’t dug into the code base too much as of yet, so I may be missing a lot of the complexity involved.


I’ve been thinking about custom samplers and with a few changes in DataBunch.create I added the functionality to pass a list of samplers when calling .databunch().

My DataBunch looks like this:

class ImageDataBunch(ImageDataBunch):
    def create(cls, train_ds:Dataset, valid_ds:Dataset, test_ds:Optional[Dataset]=None, path:PathOrStr='.', bs:int=64,
               val_bs:int=None, num_workers:int=defaults.cpus, dl_tfms:Optional[Collection[Callable]]=None,
               device:torch.device=None, collate_fn:Callable=data_collate, no_check:bool=False, sampler=None, **dl_kwargs)->'DataBunch':
        "Create a `DataBunch` from `train_ds`, `valid_ds` and maybe `test_ds` with a batch size of `bs`. Passes `**dl_kwargs` to `DataLoader()`"
        datasets = cls._init_ds(train_ds, valid_ds, test_ds)
        val_bs = ifnone(val_bs, bs)
        if sampler is None: sampler = [RandomSampler] + 3*[SequentialSampler]
        dls = [DataLoader(d, b, sampler=s(d, bs=b), num_workers=num_workers, **dl_kwargs) for d,b,s in
               zip(datasets, (bs,val_bs,val_bs,val_bs), sampler) if d is not None]
        return cls(*dls, path=path, device=device, dl_tfms=dl_tfms, collate_fn=collate_fn, no_check=no_check)
class ImageList(ImageList):
    _bunch = ImageDataBunch

Then the custom samplers (in random and sequential samplers I just add the **kwargs in the init):

class SequentialSampler(SequentialSampler):
    def __init__(self, data_source, **kwargs):
        self.data_source = data_source
class RandomSampler(RandomSampler):
    def __init__(self, data_source, replacement=False, num_samples=None, **kwargs):
        self.data_source = data_source
        self.replacement = replacement
        self.num_samples = num_samples
class FixedLenRandomSampler(RandomSampler):
    def __init__(self, data_source, bs, epoch_size, *args, **kwargs):
        self.epoch_size = epoch_size*bs
    def __iter__(self):
        return iter(torch.randperm(len(self.data_source))[:len(self)].tolist())
    def __len__(self):
        return self.epoch_size

Then I create a list of samplers for train_ds, valid_ds, fix_ds and test_ds:

train_sampler = partial(FixedLenRandomSampler, epoch_size=100)
samplers = [train_sampler, SequentialSampler, SequentialSampler, SequentialSampler]

Finally the datablock and learner as usual:

data = (ImageList.from_folder(path) 
        .split_by_folder(train='training', valid='testing')            
        .transform(get_transforms(), size=64) 
        .databunch(sampler=samplers, bs=64))

learn = cnn_learner(data, models.densenet121, metrics=[accuracy])

Then calling fit it runs with the specified epoch_size :slight_smile:

Working example on colab: https://colab.research.google.com/drive/1k2Ut_ZINNSzYkJt2bjUPD9gxC_FGwzQd

To avoid repeating samples I guess we can modify the Sampler to remember the already sampled indices and sample only from the remaining until all have been sampled.

This has many other applications like episode sampling for few-shot learning. I will share a sampler for that soon!


Quick update… This should do the trick for fixed epoch length sampling without replacement.

class SequentialSampler(SequentialSampler):
    def __init__(self, data_source, **kwargs):
        self.data_source = data_source
class RandomSampler(RandomSampler):
    def __init__(self, data_source, replacement=False, num_samples=None, **kwargs):
        self.data_source = data_source
        self.replacement = replacement
        self.num_samples = num_samples
class FixedLenRandomSampler(RandomSampler):
    def __init__(self, data_source, bs, epoch_size, *args, **kwargs):
        self.epoch_size = epoch_size*bs
        self.not_sampled = np.array([True]*len(data_source))
    def _reset_state(self): self.not_sampled[:] = True
    def __iter__(self):
        ns = sum(self.not_sampled)
        idx_last = []
        if ns >= len(self):
            idx = np.random.choice(np.where(self.not_sampled)[0], size=len(self), replace=False).tolist()
            if ns == len(self): self._reset_state
            idx_last = np.where(self.not_sampled)[0].tolist()
            idx = np.random.choice(np.where(self.not_sampled)[0], size=len(self)-len(idx_last), replace=False).tolist()
        self.not_sampled[idx] = False
        idx = [*idx_last, *idx]
        # print(ns, len(idx), len(idx_last)) # debug
        return iter(idx)
    def __len__(self):
        return self.epoch_size

The idx_last is for when the remaining unused samples are not enough to make an epoch, so it uses the available ones, then resets the state and samples how many needed to complete the epoch.

However, when using lr_find for example, we are “wasting” samples. To correct that the following callback is needed to call _reset_state on_train_begin.

class ResetSamplers(LearnerCallback):
    def __init__(self, learn):
        self.dls = learn.data.dls
    def on_train_begin(self, **kwargs):
        for o in self.dls:
            if hasattr(o.dl.sampler, '_reset_state'):

A possible approach I thought of would be to modify on_batch_end and introduce return {"stop_epoch":True} depending on quantity of input data already processed.

Let me know if it seems like a reasonable solution and I can try to implement it.

Callbacks at “every n” batches is a great start! I would love to specify time based callbacks: “Stop after 4 hours, saving best model every ten minutes”. I’m forever doing 1 epoch runs and dividing time available to set number of epochs.

Edit: I see there is a StopAfterNBatches callback. Cool! On medium/large datasets I can use this to get an approximate estimate at the start of training and then auto set number of epochs (and soon, it sounds, number of batches) based on time I have to train.

While I am on the train of thought, I’ve always wished for a way to “slice” a 1cycle training cycle. So I can run “pieces” when convenient. Like ¼ now, ½ tonight, and ¼ Thursday. I suppose it’s possible to do if I get my head around the training scheduler enough, but a helper would be cool.

I like the ‘do every n batches’ better than the complex samplers — it seems simpler and cleaner to me. Of course, we can do that already, just by keeping a counter.

But it would make the common case simple to have a value of ‘n’ that was a common denominator that received special treatment in the fit loop, i.e. have a on_n_batch_end callback.

Along with that, it would make sense for validation to be part of this ‘every n’ cycle (in fact it could be a callback itself).

I am actually implementing this approach, as I need to see validation results more often than the end of every epoch.

I’m just adding a note that we can use BatchSampler which can let us create mini-batches easily.

For example BatchSampler(RandomSampler, batch_size=32, drop_last=False) will create a sampler that goes through the entire dataset and pick randomly only 32 samples without replacement at each epoch.
The idea is to iterate through them (for example with batch size of 4) until end of each mini-batch, which would be the end of an epoch.

Hi, is there a plan about implementing “epoch_size” in Fastai? Maybe in fastai-v2? Thanks.

I actually have an implementation here: Fastai v2 chat

In case it might be useful for someone, I patched it in fastai v1:

class MyDl:
    def __init__(self, dl, epoch_size):
        self.dl = dl
        self.iter = iter(dl)
        self.c = dl.c
        self.dataset = dl.dataset
        self.epoch_size = epoch_size
    def __iter__(self):
        for i in range(self.epoch_size):
                yield next(self.iter)
                # start from beginning if end of self.iter reached
                self.iter = iter(self.dl)
                yield next(self.iter)

    def __len__(self):
        return self.epoch_size

path = untar_data(URLs.MNIST)
data = (ImageList.from_folder(path) 
        .split_by_folder(train='training', valid='testing')            
        .transform(get_transforms(), size=16) 

data.train_dl = MyDl(data.train_dl, 4)

learn = cnn_learner(data, models.resnet18, metrics=[accuracy])

You can find the Colab notebook here https://colab.research.google.com/drive/1bkATL1uNyHOlB4DW4ImYvAseSAZKCeoj

I hope I didn’t break anything :slight_smile:


Thanks for your working example in collab. I’ll adapt it to my case and test it.
Have a good week-end. :slight_smile:

Just for info you can now use the method “partial_dataloaders”.