Also, a few days ago I went ahead and commented every line of code in DataLoader
and I’m just posting it here. This will probably be slightly out of sync with the current version of the library and will keep getting out of sync moving ahead, but I guess the ideas will remain pretty much close to what they are now. Here it is:
(it’s a Wiki so anyone can edit this with clearer explanations)
(also, you might have to do a bit of homework before completely understanding this!)
@funcs_kwargs # Make delegation work
class DataLoader(GetAttr):
_noop_methods = 'wif before_iter after_item before_batch after_batch after_iter'.split()
for o in _noop_methods:
exec(f"def {o}(self, x=None, *args, **kwargs): return x")
# Define each of the _noop_methods as identity transforms
_methods = _noop_methods + 'create_batches create_item create_batch retain \
get_idxs sample shuffle_fn do_batch create_batch'.split()
_default = 'dataset'
def __init__(self, dataset=None, bs=None, num_workers=0, pin_memory=False, timeout=0, batch_size=None,
shuffle=False, drop_last=False, indexed=None, n=None, device=None, **kwargs):
if batch_size is not None: bs = batch_size # PyTorch compatibility
assert not (bs is None and drop_last)
if indexed is None: indexed = dataset is not None and hasattr(dataset,'__getitem__')
# indexed will be true if the dataset exists and can be indexed into
if n is None:
try: n = len(dataset)
except TypeError: pass
# n signifies the length of the dataset. This can be set to be smaller than the actual length
# This is probably to allow conveniently using a subset of the dataset
store_attr(self, 'dataset,bs,shuffle,drop_last,indexed,n,pin_memory,timeout,device')
# convenient way to do `self.x = x` for all x in the above string
self.rng,self.nw,self.offs = random.Random(),1,0
# The RNG will be used later in the module
self.fake_l = _FakeLoader(self, pin_memory, num_workers, timeout)
# The source code for _FakeLoader is a bit complicated.
# The useful bit is the __iter__ function:
# return iter(self.create_batches(self.sample()))
def __len__(self):
if self.n is None: raise TypeError
if self.bs is None: return self.n
return self.n//self.bs + (0 if self.drop_last or self.n%self.bs==0 else 1)
# pretty self explanatory
# the length gets divided by batch_size because that's the length of
# the dataloader as opposed to the dataset
def get_idxs(self):
idxs = Inf.count if self.indexed else Inf.nones
# Inf.count = itertools.count(0) : a counter (iterator) starting from 0
# Inf.nones = itertools.cycle([None]): indefinitely returns None on each next() call
if self.n is not None: idxs = list(itertools.islice(idxs, self.n))
# list(itertools.islice(idxs, self.n)) ~ list(range(0, self.n))
if self.shuffle: idxs = self.shuffle_fn(idxs)
# basic shuffling of indexes
return idxs
def sample(self):
idxs = self.get_idxs()
return (b for i,b in enumerate(idxs) if i//(self.bs or 1)%self.nw==self.offs)
# current implementation means this will always send back
# a generator idxs pretty much as it is.
# this is puzzling
def __iter__(self):
self.randomize()
# Reseed random number generator
self.before_iter()
for b in _loaders[self.fake_l.num_workers==0](self.fake_l):
# _loaders = _loaders = (_MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter)
# These are imported from torch.utils.dataloader
# They have no documentation, but I guess this just creates an iterator
if self.device is not None: b = to_device(b, self.device)
yield self.after_batch(b)
# Above 2 lines are pretty standard:
# Get the batch, process it, yield
self.after_iter()
if hasattr(self, 'it'): delattr(self, 'it')
# Cleanup
def create_batches(self, samps):
# `samps` is a list of indexes to the batches in the dataset
# which may or may not be in shuffled
self.it = iter(self.dataset) if self.dataset is not None else None
res = filter(lambda o:o is not None, map(self.do_item, samps))
# res is a generator of training samples in the order described by `samps`
# res is also careful not to return a None value
yield from map(self.do_batch, self.chunkify(res))
# Returns a batch from res with appropriate processing
# and while trying to retain the original type where it can
def new(self, dataset=None, cls=None, **kwargs):
# Create a copy of the Dataloader with possibly a new dataset
# and return
if dataset is None: dataset = self.dataset
if cls is None: cls = type(self)
cur_kwargs = dict(dataset=dataset, num_workers=self.fake_l.num_workers, pin_memory=self.pin_memory, timeout=self.timeout,
bs=self.bs, shuffle=self.shuffle, drop_last=self.drop_last, indexed=self.indexed, device=self.device)
for n in self._methods: cur_kwargs[n] = getattr(self, n)
return cls(**merge(cur_kwargs, kwargs))
@property
def prebatched(self): return self.bs is None
# prebatched will probably be true when our dataset returns items in batches
# in which case we specify bs=None when creating the dataloader
def do_item(self, s):
try: return self.after_item(self.create_item(s))
# Process and return an item indexed at s
except SkipItemException: return None
def chunkify(self, b): return b if self.prebatched else chunked(b, self.bs, self.drop_last)
# return a batch of samples from the iterator b
# chunked is pretty straightforward
def shuffle_fn(self, idxs): return self.rng.sample(idxs, len(idxs))
# Simply shuffle idxs and return
def randomize(self): self.rng = random.Random(self.rng.randint(0,2**32-1))
# reseed RNG
def retain(self, res, b): return retain_types(res, b[0] if is_listy(b) else b)
# retain_types tries to retain the type of each elemenet in `res` to that of `b`
def create_item(self, s): return next(self.it) if s is None else self.dataset[s]
# Index into dataset at `s`
def create_batch(self, b): return (fa_collate,fa_convert)[self.prebatched](b)
# If it's not prebatched then collate, else convert
# fa_collate tries and uses PyTorch's default_collate if b has array like elements
# fa_convert simply converts `b` into a Tensor using PyTorch's default_convert
def do_batch(self, b): return self.retain(self.create_batch(self.before_batch(b)), b)
# process elements in batch, collate/convert them, retain the type along the way, and return.
def to(self, device): self.device = device
# change default device of self
def one_batch(self):
# Gets the first batch from `self`.
if self.n is not None and len(self)==0: raise ValueError(f'This DataLoader does not contain any batches')
with self.fake_l.no_multiproc(): res = first(self)
# first just gets the first item from any iterator or `None` if there is no such item
if hasattr(self, 'it'): delattr(self, 'it')
return res