Histogram equalization in fastai

Does fastai2 has integrated transformation for histogram equalization? If no, what would be the easiest way to implement it as custom transformation and pass it to batch_tfms in ImageLoader?

Since nobody offered any solution here is mine. It is slow, but works:

# Parts taken from: https://medium.com/hackernoon/histogram-equalization-in-python-from-scratch-ebb9c8aa3f23

# create our own histogram function
def get_histogram(image, bins):
    # array with size of bins, set to zeros
    histogram = np.zeros(bins)
    # loop through pixels and sum up counts of pixels
    for pixel in image:
        histogram[pixel] += 1
    # return our final result
    return histogram

# create our cumulative sum function
def cumsum(a):
    a = iter(a)
    b = [next(a)]
    for i in a:
        b.append(b[-1] + i)
    return np.array(b)

class HistogramEqualization(Transform):
    def __init__(self):
    def encodes(self, o):
        # convert our image into a numpy array
        imgnp = o.cpu().numpy()  # np.asarray(o[0]) # .permute(1, 2, 0))

        # put pixels in a 1D array by flattening out img array
        flat = imgnp.flatten()
        # execute our histogram function
        hist = get_histogram(flat, 256)

        # execute the fn
        cs = cumsum(hist)
        # numerator & denomenator
        nj = (cs - cs.min()) * 255
        N = cs.max() - cs.min()

        # re-normalize the cumsum
        cs = nj / N

        # cast it back to uint8 since we can't use floating point values in images
        cs = cs.astype('uint8')

        # get the value from cumulative sum for every index in flat, and set that as img_new
        img_new = cs[flat]

        # put array back into original shape since we flattened it
        img_new = np.reshape(img_new, o.shape)

        ret = TensorImage(img_new) if (type(o) == TensorImage) else o
        return ret
    def decodes(self, o):
        return o
1 Like

Faster solution is with item_tfms instead od batch_tfms:
from fastai.vision.core import PILImage
class HistogramEqualization_item(Transform):
def init(self, prefix=None):
self.prefix = prefix or “”
def encodes(self, o):
if type(o) == PILImage:
ret = PIL.ImageOps.equalize(o)
ret = o
return ret
def decodes(self, o):
return o
data = ImageDataLoaders.from_folder(path, train=‘train’, valid=‘test’,
item_tfms=[Resize(size=384)], HistogramEqualization_item()],
batch_tfms=[*aug_transforms(min_scale=0.98, do_flip=False)],
max_zoom=1.1, max_lighting=0.2, bs=16, num_workers=8)