Help in coding a dynamic activation function (flattened T swish)

I’m trying to code and test a new activation function, but I’m stuck on how to implement it - can anyone assist?

Here’s the math, my prototype code and the paper link:

fts-math

and my initial proto-type code:

def forward(self, x): 
    if x >= 0:
        x = (x*torch.sigmoid(x)) + self.threshold
    else:
        x = self.threshold

I’m unclear how to do an element-wise check if >=0 and apply the function, vs do the fixed value if negative? (I looked at F.threshold, ,but how to do a check of <0 is also unclear…)

Anyway, any help on how to write this would be greatly appreciated!

Here’s the paper for reference:

I already answered that in your other post - please don’t cross post! :slight_smile:

sorry, I thought no one would see it in my other thread. I’ll remove this one now :slight_smile:
edit - seems I can’t delete the thread!

The answer, thanks to Jeremy, is torch.where()!