I am trying to apply fastai2 framework on Kaggle Global Wheat Detection Challenge (object detection problem)
I build DataBlock
and DataLoaders
by the following:
def build_dblock(data_path, resize_sz, norm, rand_seed = 144, test_mode = False):
json_path = data_path / 'train_mini.json' if test_mode else data_path / 'train.json'
_, _, img2bbox = decode_coco_json(json_path)
blks = (ImageBlock, BBoxBlock, BBoxLblBlock)
get_ids_func = get_img_ids(json_path)
getters_func = [lambda o: data_path / 'train' / o,
lambda o: img2bbox[o][0],
lambda o: img2bbox[o][1]]
rand_splitter = RandomSplitter(valid_pct = 0.2, seed = rand_seed)
batch_tfms = aug_transforms(size = resize_sz, min_scale = 0.85, do_flip = True)
if norm:
batch_tfms += [Normalize.from_stats(*imagenet_stats)]
dblock = DataBlock(
blocks = blks, splitter = rand_splitter,
get_items = get_ids_func, getters = getters_func,
batch_tfms = batch_tfms, n_inp = 1
)
return dblock
def build_dataloaders(
data_path, bs, resize_sz = 256,
norm = False, rand_seed = 144, test_mode = False
):
"""
:param:
data_path : str/ Path, path to wheat datasets
resize_sz : int, length after resized (assume square)
rand_seed : int, andom seed id
"""
if isinstance(data_path, str):
data_path = Path(data_path)
dblk = build_dblock(data_path, resize_sz, norm = norm,
rand_seed = rand_seed, test_mode = test_mode)
dls = dblk.dataloaders(data_path / 'train', bs = bs)
dls.c = 2
return dls
But when I tried to show_batch
from the DataLoaders
built by the above function, I got the following. I think it has something to do with batch resizing in aug_transforms
(everything is alright if I instead used Resize
in item_tfms
), but so far can’t pinpoint which lines of code causing the problem. Anyone has idea how to fix this?