I am trying to use the method tfms_from_model
defined in /fastai/transforms.py. I have no issues when using 3 channel data, but things break down when I use 1 channel data. This function is defined as:
imagenet_stats = A([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
def tfms_from_model(f_model, sz, aug_tfms=None, max_zoom=None, pad=0, crop_type=CropType.RANDOM,
tfm_y=None, sz_y=None, pad_mode=cv2.BORDER_REFLECT, norm_y=True, scale=None):
stats = inception_stats if f_model in inception_models else imagenet_stats
return tfms_from_stats(stats, sz, aug_tfms, max_zoom=max_zoom, pad=pad, crop_type=crop_type,
tfm_y=tfm_y, sz_y=sz_y, pad_mode=pad_mode, norm_y=norm_y, scale=scale)
Since f_model
is my own model (therefore it is not in inception_models
), stats=imagenet_stats
. The two arguments of the function A
are the mean and standard deviation of imagenet.
Now my dataset is dsprites, which contains single channel data. I thought that defining
stats = A([0.5], [0.5]) (or whatever statistics of dsprites)
inside tfms_from_model
would be the adequate change, but I get the following error:
sz = 64
img = np.uint8(np.random.uniform(150, 180, (sz, sz, 1)))/255
train_tfms, val_tfms = tfms_from_model(<my_model>, sz)
val_tfms(img)
---------------------------------------------------------------------------
AxisError Traceback (most recent call last)
<ipython-input-12-fe09d5fb79ff> in <module>
2 img = np.uint8(np.random.uniform(150, 180, (sz, sz, 1)))/255
3 train_tfms, val_tfms = tfms_from_model(model.module.encoder, sz)
----> 4 val_tfms(img)
<ipython-input-7-830c40232b51> in __call__(self, im, y)
630 self.tfms.append(ChannelOrder(tfm_y))
631
--> 632 def __call__(self, im, y=None): return compose(im, y, self.tfms)
633 def __repr__(self): return str(self.tfms)
634
<ipython-input-7-830c40232b51> in compose(im, y, fns)
605 for fn in fns:
606 #pdb.set_trace()
--> 607 im, y =fn(im, y)
608 return im if y is None else (im, y)
609
<ipython-input-7-830c40232b51> in __call__(self, x, y)
175
176 def __call__(self, x, y):
--> 177 x = np.rollaxis(x, 2)
178 #if isinstance(y,np.ndarray) and (len(y.shape)==3):
179 if self.tfm_y==TfmType.PIXEL: y = np.rollaxis(y, 2)
~/anaconda3/envs/filter_visualizer_env/lib/python3.6/site-packages/numpy/core/numeric.py in
rollaxis(a, axis, start)
1449 """
1450 n = a.ndim
-> 1451 axis = normalize_axis_index(axis, n)
1452 if start < 0:
1453 start += n
AxisError: axis 2 is out of bounds for array of dimension 2
How should I change the code to be able to use it on single channel data?