- """Helper functions to get data in a `DataLoaders` in the vision application and higher class `ImageDataLoaders`"""
- # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/08_vision.data.ipynb.
- # %% ../../nbs/08_vision.data.ipynb 2
- from __future__ import annotations
- from ..torch_basics import *
- from ..data.all import *
- from .core import *
- import types
- # %% auto 0
- __all__ = ['PointBlock', 'BBoxBlock', 'get_grid', 'clip_remove_empty', 'bb_pad', 'show_batch', 'ImageBlock', 'MaskBlock',
- 'BBoxLblBlock', 'ImageDataLoaders', 'SegmentationDataLoaders']
- # %% ../../nbs/08_vision.data.ipynb 7
- @delegates(subplots)
- def get_grid(
- n:int, # Number of axes in the returned grid
- nrows:int=None, # Number of rows in the returned grid, defaulting to `int(math.sqrt(n))`
- ncols:int=None, # Number of columns in the returned grid, defaulting to `ceil(n/rows)`
- figsize:tuple=None, # Width, height in inches of the returned figure
- double:bool=False, # Whether to double the number of columns and `n`
- title:str=None, # If passed, title set to the figure
- return_fig:bool=False, # Whether to return the figure created by `subplots`
- flatten:bool=True, # Whether to flatten the matplot axes such that they can be iterated over with a single loop
- **kwargs,
- ) -> (plt.Figure, plt.Axes): # Returns just `axs` by default, and (`fig`, `axs`) if `return_fig` is set to True
- "Return a grid of `n` axes, `rows` by `cols`"
- if nrows:
- ncols = ncols or int(np.ceil(n/nrows))
- elif ncols:
- nrows = nrows or int(np.ceil(n/ncols))
- else:
- nrows = int(math.sqrt(n))
- ncols = int(np.ceil(n/nrows))
- if double: ncols*=2 ; n*=2
- fig,axs = subplots(nrows, ncols, figsize=figsize, **kwargs)
- if flatten: axs = [ax if i<n else ax.set_axis_off() for i, ax in enumerate(axs.flatten())][:n]
- if title is not None: fig.suptitle(title, weight='bold', size=14)
- return (fig,axs) if return_fig else axs
- # %% ../../nbs/08_vision.data.ipynb 9
- def clip_remove_empty(
- bbox:TensorBBox, # Coordinates of bounding boxes
- label:TensorMultiCategory # Labels of the bounding boxes
- ):
- "Clip bounding boxes with image border and remove empty boxes along with corresponding labels"
- bbox = torch.clamp(bbox, -1, 1)
- empty = ((bbox[...,2] - bbox[...,0])*(bbox[...,3] - bbox[...,1]) <= 0.)
- return (bbox[~empty], label[TensorBase(~empty)])
- # %% ../../nbs/08_vision.data.ipynb 12
- def bb_pad(
- samples:list, # List of 3-tuples like (image, bounding_boxes, labels)
- pad_idx=0 # Label that will be used to pad each list of labels
- ):
- "Function that collects `samples` of labelled bboxes and adds padding with `pad_idx`."
- samples = [(s[0], *clip_remove_empty(*s[1:])) for s in samples]
- max_len = max([len(s[2]) for s in samples])
- def _f(img,bbox,lbl):
- bbox = torch.cat([bbox,bbox.new_zeros(max_len-bbox.shape[0], 4)])
- lbl = torch.cat([lbl, lbl .new_zeros(max_len-lbl .shape[0])+pad_idx])
- return img,bbox,lbl
- return [_f(*s) for s in samples]
- # %% ../../nbs/08_vision.data.ipynb 16
- @typedispatch
- def show_batch(x:TensorImage, y, samples, ctxs=None, max_n=10, nrows=None, ncols=None, figsize=None, **kwargs):
- if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)
- ctxs = show_batch[object](x, y, samples, ctxs=ctxs, max_n=max_n, **kwargs)
- return ctxs
- # %% ../../nbs/08_vision.data.ipynb 17
- @typedispatch
- def show_batch(x:TensorImage, y:TensorImage, samples, ctxs=None, max_n=10, nrows=None, ncols=None, figsize=None, **kwargs):
- if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize, double=True)
- for i in range(2):
- ctxs[i::2] = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(i),ctxs[i::2],range(max_n))]
- return ctxs
- # %% ../../nbs/08_vision.data.ipynb 20
- def ImageBlock(cls:PILBase=PILImage):
- "A `TransformBlock` for images of `cls`"
- return TransformBlock(type_tfms=cls.create, batch_tfms=IntToFloatTensor)
- # %% ../../nbs/08_vision.data.ipynb 21
- def MaskBlock(
- codes:list=None # Vocab labels for segmentation masks
- ):
- "A `TransformBlock` for segmentation masks, potentially with `codes`"
- return TransformBlock(type_tfms=PILMask.create, item_tfms=AddMaskCodes(codes=codes), batch_tfms=IntToFloatTensor)
- # %% ../../nbs/08_vision.data.ipynb 22
- PointBlock = TransformBlock(type_tfms=TensorPoint.create, item_tfms=PointScaler)
- BBoxBlock = TransformBlock(type_tfms=TensorBBox.create, item_tfms=PointScaler, dls_kwargs = {'before_batch': bb_pad})
- PointBlock.__doc__ = "A `TransformBlock` for points in an image"
- BBoxBlock.__doc__ = "A `TransformBlock` for bounding boxes in an image"
- # %% ../../nbs/08_vision.data.ipynb 25
- def BBoxLblBlock(
- vocab:list=None, # Vocab labels for bounding boxes
- add_na:bool=True # Add NaN as a background class
- ):
- "A `TransformBlock` for labeled bounding boxes, potentially with `vocab`"
- return TransformBlock(type_tfms=MultiCategorize(vocab=vocab, add_na=add_na), item_tfms=BBoxLabeler)
- # %% ../../nbs/08_vision.data.ipynb 28
- class ImageDataLoaders(DataLoaders):
- "Basic wrapper around several `DataLoader`s with factory methods for computer vision problems"
- @classmethod
- @delegates(DataLoaders.from_dblock)
- def from_folder(cls, path, train='train', valid='valid', valid_pct=None, seed=None, vocab=None, item_tfms=None,
- batch_tfms=None, img_cls=PILImage, **kwargs):
- "Create from imagenet style dataset in `path` with `train` and `valid` subfolders (or provide `valid_pct`)"
- splitter = GrandparentSplitter(train_name=train, valid_name=valid) if valid_pct is None else RandomSplitter(valid_pct, seed=seed)
- get_items = get_image_files if valid_pct else partial(get_image_files, folders=[train, valid])
- dblock = DataBlock(blocks=(ImageBlock(img_cls), CategoryBlock(vocab=vocab)),
- get_items=get_items,
- splitter=splitter,
- get_y=parent_label,
- item_tfms=item_tfms,
- batch_tfms=batch_tfms)
- return cls.from_dblock(dblock, path, path=path, **kwargs)
- @classmethod
- @delegates(DataLoaders.from_dblock)
- def from_path_func(cls, path, fnames, label_func, valid_pct=0.2, seed=None, item_tfms=None, batch_tfms=None,
- img_cls=PILImage, **kwargs):
- "Create from list of `fnames` in `path`s with `label_func`"
- dblock = DataBlock(blocks=(ImageBlock(img_cls), CategoryBlock),
- splitter=RandomSplitter(valid_pct, seed=seed),
- get_y=label_func,
- item_tfms=item_tfms,
- batch_tfms=batch_tfms)
- return cls.from_dblock(dblock, fnames, path=path, **kwargs)
- @classmethod
- def from_name_func(cls,
- path:str|Path, # Set the default path to a directory that a `Learner` can use to save files like models
- fnames:list, # A list of `os.Pathlike`'s to individual image files
- label_func:callable, # A function that receives a string (the file name) and outputs a label
- **kwargs
- ) -> DataLoaders:
- "Create from the name attrs of `fnames` in `path`s with `label_func`"
- if sys.platform == 'win32' and isinstance(label_func, types.LambdaType) and label_func.__name__ == '<lambda>':
- # https://medium.com/@jwnx/multiprocessing-serialization-in-python-with-pickle-9844f6fa1812
- raise ValueError("label_func couldn't be lambda function on Windows")
- f = using_attr(label_func, 'name')
- return cls.from_path_func(path, fnames, f, **kwargs)
- @classmethod
- def from_path_re(cls, path, fnames, pat, **kwargs):
- "Create from list of `fnames` in `path`s with re expression `pat`"
- return cls.from_path_func(path, fnames, RegexLabeller(pat), **kwargs)
- @classmethod
- @delegates(DataLoaders.from_dblock)
- def from_name_re(cls, path, fnames, pat, **kwargs):
- "Create from the name attrs of `fnames` in `path`s with re expression `pat`"
- return cls.from_name_func(path, fnames, RegexLabeller(pat), **kwargs)
- @classmethod
- @delegates(DataLoaders.from_dblock)
- def from_df(cls, df, path='.', valid_pct=0.2, seed=None, fn_col=0, folder=None, suff='', label_col=1, label_delim=None,
- y_block=None, valid_col=None, item_tfms=None, batch_tfms=None, img_cls=PILImage, **kwargs):
- "Create from `df` using `fn_col` and `label_col`"
- pref = f'{Path(path) if folder is None else Path(path)/folder}{os.path.sep}'
- if y_block is None:
- is_multi = (is_listy(label_col) and len(label_col) > 1) or label_delim is not None
- y_block = MultiCategoryBlock if is_multi else CategoryBlock
- splitter = RandomSplitter(valid_pct, seed=seed) if valid_col is None else ColSplitter(valid_col)
- dblock = DataBlock(blocks=(ImageBlock(img_cls), y_block),
- get_x=ColReader(fn_col, pref=pref, suff=suff),
- get_y=ColReader(label_col, label_delim=label_delim),
- splitter=splitter,
- item_tfms=item_tfms,
- batch_tfms=batch_tfms)
- return cls.from_dblock(dblock, df, path=path, **kwargs)
- @classmethod
- def from_csv(cls, path, csv_fname='labels.csv', header='infer', delimiter=None, quoting=csv.QUOTE_MINIMAL, **kwargs):
- "Create from `path/csv_fname` using `fn_col` and `label_col`"
- df = pd.read_csv(Path(path)/csv_fname, header=header, delimiter=delimiter, quoting=quoting)
- return cls.from_df(df, path=path, **kwargs)
- @classmethod
- @delegates(DataLoaders.from_dblock)
- def from_lists(cls, path, fnames, labels, valid_pct=0.2, seed:int=None, y_block=None, item_tfms=None, batch_tfms=None,
- img_cls=PILImage, **kwargs):
- "Create from list of `fnames` and `labels` in `path`"
- if y_block is None:
- y_block = MultiCategoryBlock if is_listy(labels[0]) and len(labels[0]) > 1 else (
- RegressionBlock if isinstance(labels[0], float) else CategoryBlock)
- dblock = DataBlock.from_columns(blocks=(ImageBlock(img_cls), y_block),
- splitter=RandomSplitter(valid_pct, seed=seed),
- item_tfms=item_tfms,
- batch_tfms=batch_tfms)
- return cls.from_dblock(dblock, (fnames, labels), path=path, **kwargs)
- ImageDataLoaders.from_csv = delegates(to=ImageDataLoaders.from_df)(ImageDataLoaders.from_csv)
- ImageDataLoaders.from_name_func = delegates(to=ImageDataLoaders.from_path_func)(ImageDataLoaders.from_name_func)
- ImageDataLoaders.from_path_re = delegates(to=ImageDataLoaders.from_path_func)(ImageDataLoaders.from_path_re)
- ImageDataLoaders.from_name_re = delegates(to=ImageDataLoaders.from_name_func)(ImageDataLoaders.from_name_re)
- # %% ../../nbs/08_vision.data.ipynb 62
- class SegmentationDataLoaders(DataLoaders):
- "Basic wrapper around several `DataLoader`s with factory methods for segmentation problems"
- @classmethod
- @delegates(DataLoaders.from_dblock)
- def from_label_func(cls, path, fnames, label_func, valid_pct=0.2, seed=None, codes=None, item_tfms=None, batch_tfms=None,
- img_cls=PILImage, **kwargs):
- "Create from list of `fnames` in `path`s with `label_func`."
- dblock = DataBlock(blocks=(ImageBlock(img_cls), MaskBlock(codes=codes)),
- splitter=RandomSplitter(valid_pct, seed=seed),
- get_y=label_func,
- item_tfms=item_tfms,
- batch_tfms=batch_tfms)
- res = cls.from_dblock(dblock, fnames, path=path, **kwargs)
- return res
This file has been truncated. show original