How to use tfms_from_model for 1 channel data

I am trying to use the method tfms_from_model defined in /fastai/transforms.py. I have no issues when using 3 channel data, but things break down when I use 1 channel data. This function is defined as:

imagenet_stats = A([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
def tfms_from_model(f_model, sz, aug_tfms=None, max_zoom=None, pad=0, crop_type=CropType.RANDOM,
                tfm_y=None, sz_y=None, pad_mode=cv2.BORDER_REFLECT, norm_y=True, scale=None):

    stats = inception_stats if f_model in inception_models else imagenet_stats
    return tfms_from_stats(stats, sz, aug_tfms, max_zoom=max_zoom, pad=pad, crop_type=crop_type,
                       tfm_y=tfm_y, sz_y=sz_y, pad_mode=pad_mode, norm_y=norm_y, scale=scale)

Since f_model is my own model (therefore it is not in inception_models), stats=imagenet_stats. The two arguments of the function A are the mean and standard deviation of imagenet.

Now my dataset is dsprites, which contains single channel data. I thought that defining

stats = A([0.5], [0.5]) (or whatever statistics of dsprites)

inside tfms_from_model would be the adequate change, but I get the following error:

sz = 64
img = np.uint8(np.random.uniform(150, 180, (sz, sz, 1)))/255
train_tfms, val_tfms = tfms_from_model(<my_model>, sz)
val_tfms(img)

---------------------------------------------------------------------------
AxisError                                 Traceback (most recent call last)
<ipython-input-12-fe09d5fb79ff> in <module>
  2 img = np.uint8(np.random.uniform(150, 180, (sz, sz, 1)))/255
  3 train_tfms, val_tfms = tfms_from_model(model.module.encoder, sz)
----> 4 val_tfms(img)

<ipython-input-7-830c40232b51> in __call__(self, im, y)
630         self.tfms.append(ChannelOrder(tfm_y))
631 
--> 632     def __call__(self, im, y=None): return compose(im, y, self.tfms)
633     def __repr__(self): return str(self.tfms)
634 

<ipython-input-7-830c40232b51> in compose(im, y, fns)
605     for fn in fns:
606         #pdb.set_trace()
--> 607         im, y =fn(im, y)
608     return im if y is None else (im, y)
609 

<ipython-input-7-830c40232b51> in __call__(self, x, y)
175 
176     def __call__(self, x, y):
--> 177         x = np.rollaxis(x, 2)
178         #if isinstance(y,np.ndarray) and (len(y.shape)==3):
179         if self.tfm_y==TfmType.PIXEL: y = np.rollaxis(y, 2)

~/anaconda3/envs/filter_visualizer_env/lib/python3.6/site-packages/numpy/core/numeric.py in 
rollaxis(a, axis, start)
1449     """
1450     n = a.ndim
-> 1451     axis = normalize_axis_index(axis, n)
 1452     if start < 0:
 1453         start += n

AxisError: axis 2 is out of bounds for array of dimension 2 

How should I change the code to be able to use it on single channel data?