Comparison of new activation functions on ImageNette. New Winner = TRelu

Edit and update: Based on a question by @yonatan365, a new activation function was made, TRelu, and that is now the new winner of this initial runoff. Originally it was FTSwishPlus…details in the thread but wanted to clarify why early posts talk about FTSwishPlus as the winner…details in the thread!).

There were two new papers on activation functions on Arxiv recently, and I thought it would be interesting to put them to the test along with GeneralRelu and of course ReLU.

In brief, the two new activation functions:

1 - Flattened threshold swish: A swish-like activation with near equivalent value for positive values (x*sigmoid(x) + threshold), and a fixed (flat) threshold for negative values. (x = threshold). Their recommended default for this threshold is -.2.
Paper: https://arxiv.org/abs/1812.06247

2 - Linear Scaled Hyperbolic Tangent - LiSHT - provides a parabola shape activation where negative values are scaled as their absolute value, and thus become positive. Code is simple: x = x *torch.tanh(x)
Paper: https://arxiv.org/abs/1901.05894#

3 - General ReLU - you already know this one, but Relu + mean shift + leak + clamping.

4 - ReLU - used as the ‘baseline’ for comparison.

Results: After testing all four on ImageNette, (XResNet18, cyclical learning with lr=1e-2) for up to 40 epochs each, General Relu was the winner. However, I then decided to add mean shift + clamping to both LiSHT and FTSWish to make LiSHT+ and FTSwish+.

Clamping stabilized the training path (i.e. epoch accuracy gains were steadier) but ultimately ended up short in terms of final accuracy.

However - the mean shift on FTSwish+ turned out to be a very big improvement.
The mean shift was surprisingly, - .1.
After that, FTSwish+ took top place with - .2 threshold.

I then tested - .3 and - .4 which did well but not as well as - .2…
Finally, I tested with - .25 (as - .3 had done quite well and -.2 was the current winner), and that proved to be the final winner. The original paper found -.2 as the optimal threshold, but since it was for MNIST, I was not surprised to see a slightly larger threshold improve it for 3 channel images.

Here are the results:

One other bonus for FTSwish+ was the path per epoch was very smooth.
ReLU and GRelu for example both had epochs that were worse than a previous epoch…by contrast FTSwish+ was very smooth with progress every epoch and worst case the same - never a step backward. This was also true with LiSHT+, except it was never able to arrive at a competitive ending accuracy (I did let it run additional epochs to see what would happen, but it tended to stall out).

Here’s the code for FTSwish+ if you would like to try it:

class FTSwishPlus(nn.Module):
    def __init__(self, threshold=-.25, mean_shift=-.1):
        super().__init__()
        self.threshold = threshold
        self.mean_shift = mean_shift
        #warning - does not handle multi-gpu case below
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 


    def forward(self, x): 
        
        #FTSwish+ for positive values
        pos_value = (x*torch.sigmoid(x)) + self.threshold
        
        #FTSwish+ for negative values
        tval = torch.tensor([self.threshold],device=self.device)
        
        #apply to x tensor based on positive or negative value
        x = torch.where(x>=0, pos_value, tval)
        
        #apply mean shift to drive mean to 0. -.1 was tested as optimal for kaiming init
        if self.mean_shift is not None:
            x.sub_(self.mean_shift)

        return x
17 Likes

This is really impressive and promising. Thanks for sharing!

1 Like

Thanks, a very interesting experiment!

I’d be curious to know if a straight lined function that will resemble the FTSwish (i.e. a ReLU that was shifted down by 0.1 for example) would get a similar score, as it is easier to calculate and looks almost identical. If its not similar, it’s interesting to think why this slight curve makes a difference…

2 Likes

Nice work!

I’m probably missing something here but, isn’t it possible to implement it using relu(...)? I mean something like:

x = torch.nn.functional.relu(x) * torch.sigmoid(x) + self.threshold

(Actually, there’s a little mention of this in the paper.)

Very cool! Can you try repeating your experiments for 20 epochs and 80 epochs, using rn50, for imagenette and imagewoof, so you can compare to the leaderboard? Otherwise it’s a little hard to interpret experiments like this because we can’t tell if they’re being trained well.

2 Likes

Hi Herchu,
You are absolutely correct here. In fact I originally tried to use Relu * Sigmoid via nn.Relu which didn’t work as it was returning a Relu object instead of tensor and that’s why I went into the code above.
(ala: TypeError: mul(): argument ‘other’ (position 1) must be Tensor, not ReLU)

However, using F.relu does return a tensor and thus the code is now greatly simplified, and gets away from GPU checking as well since no need to manually create a new tensor:

#import torch.nn.functional as F  (uncomment if needed,but you likely already have it)

class FTSwishPlus(nn.Module):
    def __init__(self, threshold=-.25, mean_shift=-.1):
        super().__init__()
        self.threshold = threshold
        self.mean_shift = mean_shift

    def forward(self, x): 

        x = F.relu(x) * torch.sigmoid(x) + self.threshold
        #note on above - why not F.sigmoid?: 
        #PyTorch docs - "nn.functional.sigmoid is deprecated. Use torch.sigmoid instead."

        #apply mean shift to drive mean to 0. -.1 was tested as optimal for kaiming init
        if self.mean_shift is not None:
            x.sub_(self.mean_shift)

        return x

Anyway, thanks for pointing this out - the code is much cleaner now :slight_smile:
I also put this as a repo on github to track any future changes:

1 Like

that’s a really great question regarding relu with a threshold. I’ve thus made TRelu and will throw it into the mix so we can find out. Most of the papers seem to feel the big advantage is from allowing some small expression of negative value to flow through on the forward pass…I haven’t seen much emphasis on that curve for the small positive values.

But we can see with TRelu if that curve matters:

class TRelu(nn.Module):
    def __init__(self, threshold= - .25, mean_shift=-.03):
        super().__init__()
        self.threshold = threshold
        self.mean_shift = mean_shift
    
    def forward(self,x):
        x = F.relu(x)+self.threshold
        
        if self.mean_shift is not None:
            x.sub_(self.mean_shift)
            
        return x   

I found -.03 centers the mean nicely, so will use that as default for testing. Updates to come!

1 Like

Yes, absolutely! I was thinking to move to ImageWoof next as I felt a more challenging dataset should further clarify any true distinctions.
I’ll have to setup on FloydHub for it as my windows laptop with numworkers=0 (to avoid the forkedpickle bug) won’t cut it with RN50.

Anyway, will run it with the 20 and 80 epochs, RN50, and update. I’ll also throw the new ReluT into the mix as well.

1 Like

So some exciting news:
1 - TRelu got off to the fastest start of all (.61, first epoch)
2 - TRelu also achieved highest single epoch score (.90)
and…
3 - TRelu is the new winner for 12 epoch Imagenette comparison!

To be consistent with GRelu, I changed the name to TRelu (first posting I called it ReluT…) for this threshold relu - here’s the final results and it’s training curve:

Tentatively, I think I can propose two hypotheses here:

1 - The curve in 0-1 section vs flatline for Relu is what creates the smoother training/learning curve. All the Relu’s take step backwards during training vs those with curves exhibit smoother paths with minimal steps back. TRelu took multiple steps back.
LiSHT, which has two curves (-1 to 0, 0 to 1) is in fact the smoothest of them all and has 2x the curvature…but ultimately LiSHT consistently underperformed in terms of final accuracy.

2 - The allowance for a small amount of negative weighting is what is driving the accuracy improvements. The curve on the positive side is not what is driving the accuracy gains. And it seems that capping or thresholding the negative value is what promotes higher (highest?) accuracy.

Will have to test deeper (20 or 80) and on ImageWoof next, but maybe we have the makings of a new top form activation function with TRelu!

6 Likes

This is a really interesting experiment, thank you for sharing your work and results.

1 Like

Hi Less. I hope you will clarify some of my questions and confusions about TRelu.

I have been trying TRelu with my own work but found no significant difference from ordinary Relu in that case. That observation generated my questions.

  1. When you say the mean of TRelu, what does this mean? Is it the mean over some particular distribution of activations, like Gaussian around zero?

  2. Looking at the code, mean_shift and threshold combine to shift Relu’s y-value by a constant. That’s the extent of their effects. Right? (I wonder if you meant for one of these parameters to shift the x-axis?)

  3. If all activations are shifted by a constant and then passed to a convolution or FC, it looks like the only effect would be to add a different constant to that layer’s output. Same for avepool and maxpool. That constant would even be “neutralized” by the next layer’s bias or batchnorm. So I have a hard time seeing how TRelu would behave any differently than vanilla Relu. It may create a different training path, but it seems the models would be equivalent.

The above argument is “in my head”. I have not actually tested it in a model.

Thanks for all your work comparing these activation functions, and for addressing my questions. I’ll be trying FTSwish+ next.