Hello everyone,
New to fastai and really enjoying this library
Currently, I’m working on a segmentation problem and use the u-net model for training. I use two different augmentation methods: online and offline augmentation. Online augmentation augments images on the fly. Offline augmentation augments the images offline and then adds them to the dataset.
When training the model using online augmentation, the dice score is not very good. However, when using offline augmentation, the dice score is good. The dataset consists of 2500 images and there is roughly a 85/15 split between training and validation set. There is another dataset, which is the test set, but is not relevant for this topic.
For offline augmentation, I generate 40 augmented images for each images. Doing so, and starting with 2500 images, it is possible to obtain a dataset of ~100,000 images with offline augmentation.
For online augmentation, my idea was to have 40 times as many epochs as for the offline augmentation, since there are 40 times less images generated. However, the model eventually overfits and there is no sight on better performance by running for more epochs. A result can be seen here:
Model training with offline augmentation:
Model training with online augmentation:
We see that the offline starts to overfit earlier (which makes sense), but also see that the loss is lower. The loss used for both training regiments is CrossEntropyLossFlat
and the code to generate the augmentations is as follows (for offline augmentations, we set p=1 for each transform):
class SegmentationAlbumentationsTransform(ItemTransform):
def __init__(self, width, height):
self.width = width
self.height = height
self.affine = Affine(scale=(0.6,1.4), rotate=(-40,40))
self.pad = PadIfNeeded(min_height=self.height,min_width=self.width, border_mode=0)
self.crop = CenterCrop(self.height,self.width)
self.brightness = RandomBrightnessContrast()
self.transform = A.Compose([self.affine, self.pad, self.crop, self.brightness])
def encodes(self, x):
data = x[1]
img = data.image
mask = data.mask
aug = self.transform(image=img, mask=mask)
return PILImage.create(aug['image']), PILMask.create(aug['mask'])
I would actually prefer to use online augmentation, since it is not needed to store this many images on the disk. Does anyone have any ideas or pinpoints what could explain this large difference in dice score and loss?
Thanks in advance