Yeah the autoreload is quite buggy with fastai, I don’t know why. Always restart your ntoebook/reload in case of doubt.
Adding a new transform is as simple as what you coded, so it should work perfectly.
Ah, yes. Your y would have two dimensions then so it’s a bug. Thanks for the fix, feel free to put it in a PR, otherwise I’ll add it tonight or tomorrow/
So I wanted to implement AutoAugment transforms and am looking for some advice.
AA is working with PIL images using 14 basic functions implemented via Image.transform(Image.AFFINE), PIL.ImageEnhance and PIL.ImageOps. While you can find a few of the functions among Fastai transforms, most are not yet in the library, so the question would be does it make sense to just wrap original AA with transforms tensor to PIL image and back and just use it that way, especially since it can be used as the only augmentation, or splitting AA into basic functions and then assembling it back into whole transform is the only reasonable way?
The best would be to add the functions that are missing in fastai v1: if you look at the source code of vision.transform, it’s really easy to code new transforms (though maybe a few of those functions are difficult).
Otherwise, you should just wrap original AA with transforms tensor to PIL image and back and just use it that way as you said.
The slight fix for the mixup model is more than welcome in a direct PR to fastai.
For the senet file, let me check if the 0.7 version can be refactored with our new tools first, and once I’ve added it to the library, your PR with the metadata will be more than welcome too (note that my refactoring might potentially change the indexes you have).
I did a bit of experimenting(mostly fruitless), and for now settled on
policy = ImageNetPolicy()
def autoaugment(x):
pil_img = PIL.Image.fromarray(image2np(x*255).astype(‘uint8’))
x = policy(pil_img)
x = pil2tensor(x,np.float32)
x.div(255)
return x
autoaugment = TfmPixel(_autoaugment)
It does the trick, in fact, I played around with training Imagenet and started the training with standard get_transforms(). Then I changed transforms to