Update: the way to do I think is to define the operation inside a class as follows:
class BiasLayer(torch.nn.Module):
def __init__(self, subtract):
'''Subtract a learnable bias or multiply a number by the negative bias
https://academic.oup.com/bioinformatics/article/32/12/i52/2288769 '''
super().__init__()
self.bias = nn.Parameter(torch.rand(n_channels), requires_grad=True)
self.subtract = subtract
def forward(self, x):
bias_clamped = self.bias.clamp(min=0, max=1).unsqueeze(0)
if self.subtract:
return x - bias_clamped
else:
return -x*bias_clamped