LiSHT (linear scaled Hyperbolic Tangent) - better than ReLU? - testing it out

I found a very interesting paper last night about a new activation function called LiSHT. The authors show superior performance over ReLU, Swish, etc. and thanks to this course, I was able to get up and jump into putting it into code and start doing some initial testing.
Short summary = impressive so far, but just used MNIST so far…ImageNette coming up.

Details - here’s the link to the paper:

What is it?
In math, it’s this:

Which I coded up as:
x = x *torch.tanh(x)

It behaves like a symmetric parabola:

I tested it initially with one of our intial notebooks and promptly got NaN gradients…possibly using BN will suppress this but to keep testing with the low level framework I modifed the GeneralRelu to add a mean shift and clamping. I called it LightRelu b/c
a)typing Lisht felt odd, light was easier and
b)concatenating relu on the end made it clears it’s an activation function.

class LightRelu(nn.Module):
#.46 was found to shift the mean to 0 on a random distribution test
# maxv of 7.5 was from initial testing on MNIST.
#Important - cut your learning rates in half with this…

def __init__(self,sub=.46,maxv=7.5):
    super().__init__()
    self.sub=sub
    self.maxv=maxv

def forward(self,x):
    #change to lisht
    
    x = x *torch.tanh(x)
    
    if self.sub is not None:
        x.sub_(self.sub)
    if self.maxv is not None: 
        x.clamp_max_(self.maxv)
    return x

Similar to Relu, it shifts the mean and I found .46 adjustment drives to an ideal state. I put the clamping in for now b/c it explodes 5-8 runs in if you don’t.
The other big thing is to cut the learning rate in half vs normal. It does learn quite fast!
Some comparisons on MNIST with the basic conv framework in notebook 6:

Regular ReLU:

LightRelu(with shift and clamp):

More interesting is the comparison with General Relu and LightRelu (both thus have mean shift and clamp):
General:

Light:

and min activations - LightRelu was nearly double by the fourth layer:
General:

LightRelu:

I’m going to test it with ImageNette next as that way it can be used with BatchNorm. The authors only used it that way and perhaps that removes the need for clamping.
Overall I did get consistenly higher accuracy on MNIST with the LightRelu. It also learns rapidly so it’s easy to blow up - learning rate finder actually doesn’t work well with it b/c it blows up so fast.
More to come…

7 Likes

Just a quick update - it’s looking even better now with BN.
I am running with the new XResNet34 on ImageNette and using the LightRelu. It’s extremely smooth in terms of a silky increase in accuracy.
It does seem BN has eliminated all the exploding gradient issue I was seeing with testing it raw…though I still have it clamped at 8.
More testing to go, but I am getting more encouraged this might be a real improvement over ReLU.

1 Like

After another day of testing:
1 - LightRelu (clamped and meanshift LiSHT) had the smoothest training curve, but …
2 - ReLU still beat it out after the same # of epochs. For some reason ReLU get’s a fast start on the first epoch (50% vs 18%) and while it takes a wilder up and down ride, ends up a bit higher.

I also tested LiSHT with no mean shift and no clamping, which is how the paper was done, but that was not nearly as good as the modified with meanshift and clamp (so I applied some useful info from our course :slight_smile: on a paper and improved it at least.
I’m testing double the clamping range (16) now and that’s looking reasonable…
and finally I want to throw the General ReLU from the notebooks in since that one kind of disappeared here with XResNET model.

Sounds like you need a higher LR for LiSHT. Not sure why that would be - but try it!

Edit: actually too much clamping would explain that…

Thanks Jeremy. I will test that.
I doubled the clamp from 8 to 16 and that was the best performance so far - about matched ReLU with a much smoother path.

I’m trying to test one more activation though and stumped - it’s Fixed Threshold Swish… but how do I code this:
fts-math

or proto-type code:

def forward(self, x): 
    if x >= 0:
        x = (x*torch.sigmoid(x)) + self.threshold
    else:
        x = self.threshold

I’m unclear how to do an element-wise check if >=0 and apply a function, vs do the fixed value if negative?

but how to do the dynamic checking and apply the respective function is not clear yet.

Anyway, any help on how to write this would be greatly appreciated!
Here’s the paper for reference:

https://pytorch.org/docs/master/torch.html#torch.where

1 Like

Thanks a ton for this answer - it was driving me nuts trying to figure it out…now I can get some sleep.

(Sorry about the other thread…this thread isn’t highly viewed lol, so I figured I should make it it’s own topic …anyway I can’t remove that other thread but wont’ repeat this mistake of cross posting).

Anyway, thanks again for the answer on this!

1 Like

This is what I ended up with - the cuda hack is not good, but I was getting an error about it being on the CPU vs CUDA:

def forward(self, x): 
    
    pos_value = (x*torch.sigmoid(x)) + self.threshold
    
    cuda0 = torch.device('cuda:0')
    
    tval = torch.tensor([self.threshold],device=cuda0)
    
    x = torch.where(x>=0,pos_value, tval)

We work around that in quite a few notebooks in this course - see if you can see how we do it, and let us know what you find! :slight_smile:

will do :slight_smile: I know offhand exactly what you are talking about (recalling the flexible and less flex comments for torch.device), but need to find it in the notebooks…and will update.
For now I just wanted to get this running so it could go overnight while I’m sleeping…so far, it’s impressive but it stalled out. Testing a higher learning rate on it…here’s how it went with same lr as the others: