SOURCE CODE: Mid-Level API

Also, a few days ago I went ahead and commented every line of code in DataLoader and I’m just posting it here. This will probably be slightly out of sync with the current version of the library and will keep getting out of sync moving ahead, but I guess the ideas will remain pretty much close to what they are now. Here it is:
(it’s a Wiki so anyone can edit this with clearer explanations)
(also, you might have to do a bit of homework before completely understanding this!)

@funcs_kwargs # Make delegation work
class DataLoader(GetAttr):
    _noop_methods = 'wif before_iter after_item before_batch after_batch after_iter'.split()
    for o in _noop_methods:
        exec(f"def {o}(self, x=None, *args, **kwargs): return x")
    # Define each of the _noop_methods as identity transforms
    _methods = _noop_methods + 'create_batches create_item create_batch retain \
        get_idxs sample shuffle_fn do_batch create_batch'.split()
    _default = 'dataset'
    def __init__(self, dataset=None, bs=None, num_workers=0, pin_memory=False, timeout=0, batch_size=None,
                 shuffle=False, drop_last=False, indexed=None, n=None, device=None, **kwargs):
        if batch_size is not None: bs = batch_size # PyTorch compatibility
        assert not (bs is None and drop_last)
        if indexed is None: indexed = dataset is not None and hasattr(dataset,'__getitem__')
        # indexed will be true if the dataset exists and can be indexed into
        if n is None:
            try: n = len(dataset)
            except TypeError: pass
        # n signifies the length of the dataset. This can be set to be smaller than the actual length
        # This is probably to allow conveniently using a subset of the dataset
        store_attr(self, 'dataset,bs,shuffle,drop_last,indexed,n,pin_memory,timeout,device')
        # convenient way to do `self.x = x` for all x in the above string
        self.rng,self.nw,self.offs = random.Random(),1,0
        # The RNG will be used later in the module
        self.fake_l = _FakeLoader(self, pin_memory, num_workers, timeout)
        # The source code for _FakeLoader is a bit complicated. 
        # The useful bit is the __iter__ function:
        # return iter(self.create_batches(self.sample()))

    def __len__(self):
        if self.n is None: raise TypeError
        if self.bs is None: return self.n
        return self.n//self.bs + (0 if self.drop_last or self.n%self.bs==0 else 1)
        # pretty self explanatory
        # the length gets divided by batch_size because that's the length of 
        # the dataloader as opposed to the dataset

    def get_idxs(self):
        idxs = Inf.count if self.indexed else Inf.nones
        # Inf.count = itertools.count(0) : a counter (iterator) starting from 0
        # Inf.nones = itertools.cycle([None]): indefinitely returns None on each next() call
        if self.n is not None: idxs = list(itertools.islice(idxs, self.n))
        # list(itertools.islice(idxs, self.n)) ~ list(range(0, self.n))
        if self.shuffle: idxs = self.shuffle_fn(idxs)
        # basic shuffling of indexes
        return idxs

    def sample(self):
        idxs = self.get_idxs()
        return (b for i,b in enumerate(idxs) if i//(self.bs or 1)%self.nw==self.offs)
        # current implementation means this will always send back 
        # a generator idxs pretty much as it is.
        # this is puzzling

    def __iter__(self):
        self.randomize()
        # Reseed random number generator
        self.before_iter()
        for b in _loaders[self.fake_l.num_workers==0](self.fake_l):
            # _loaders = _loaders = (_MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter)
            # These are imported from torch.utils.dataloader
            # They have no documentation, but I guess this just creates an iterator
            if self.device is not None: b = to_device(b, self.device)
            yield self.after_batch(b)
            # Above 2 lines are pretty standard:
            # Get the batch, process it, yield
        self.after_iter()
        if hasattr(self, 'it'): delattr(self, 'it')
        # Cleanup

    def create_batches(self, samps):
        # `samps` is a list of indexes to the batches in the dataset 
        # which may or may not be in shuffled
        self.it = iter(self.dataset) if self.dataset is not None else None
        res = filter(lambda o:o is not None, map(self.do_item, samps))
        # res is a generator of training samples in the order described by `samps`
        # res is also careful not to return a None value
        yield from map(self.do_batch, self.chunkify(res))
        # Returns a batch from res with appropriate processing
        # and while trying to retain the original type where it can

    def new(self, dataset=None, cls=None, **kwargs):
        # Create a copy of the Dataloader with possibly a new dataset
        # and return
        if dataset is None: dataset = self.dataset
        if cls is None: cls = type(self)
        cur_kwargs = dict(dataset=dataset, num_workers=self.fake_l.num_workers, pin_memory=self.pin_memory, timeout=self.timeout,
                          bs=self.bs, shuffle=self.shuffle, drop_last=self.drop_last, indexed=self.indexed, device=self.device)
        for n in self._methods: cur_kwargs[n] = getattr(self, n)
        return cls(**merge(cur_kwargs, kwargs))

    @property
    def prebatched(self): return self.bs is None
    # prebatched will probably be true when our dataset returns items in batches
    # in which case we specify bs=None when creating the dataloader
    def do_item(self, s):
        try: return self.after_item(self.create_item(s))
        # Process and return an item indexed at s
        except SkipItemException: return None
    def chunkify(self, b): return b if self.prebatched else chunked(b, self.bs, self.drop_last)
    # return a batch of samples from the iterator b
    # chunked is pretty straightforward
    def shuffle_fn(self, idxs): return self.rng.sample(idxs, len(idxs))
    # Simply shuffle idxs and return
    def randomize(self): self.rng = random.Random(self.rng.randint(0,2**32-1))
    # reseed RNG
    def retain(self, res, b):  return retain_types(res, b[0] if is_listy(b) else b)
    # retain_types tries to retain the type of each elemenet in `res` to that of `b`
    def create_item(self, s):  return next(self.it) if s is None else self.dataset[s]
    # Index into dataset at `s`
    def create_batch(self, b): return (fa_collate,fa_convert)[self.prebatched](b)
    # If it's not prebatched then collate, else convert
    # fa_collate tries and uses PyTorch's default_collate if b has array like elements
    # fa_convert simply converts `b` into a Tensor using PyTorch's default_convert
    def do_batch(self, b): return self.retain(self.create_batch(self.before_batch(b)), b)
    # process elements in batch, collate/convert them, retain the type along the way, and return.
    def to(self, device): self.device = device
    # change default device of self
    def one_batch(self):
        # Gets the first batch from `self`.
        if self.n is not None and len(self)==0: raise ValueError(f'This DataLoader does not contain any batches')
        with self.fake_l.no_multiproc(): res = first(self)
        # first just gets the first item from any iterator or `None` if there is no such item
        if hasattr(self, 'it'): delattr(self, 'it')
        return res
4 Likes