I’m trying to add Arc Face as the final layer to an efficientNet model in the following way:
ArcNet is the efficientNet model with the ArcMarginProduct layer appended to it.
from torch.nn import Parameter
class ArcNet(nn.Module):
def __init__(self,n_cls,model_name='efficientnet_b0',s=30.0,margin=0.4,ls_eps=0.0,theta_zero=0.785,pretrained=True):
super(ArcNet, self).__init__()
self.backbone = timm.create_model(model_name, pretrained=pretrained)
final_in_feat = self.backbone.classifier.in_features
self.backbone.classifier = nn.Identity()
self.backbone.global_pool = nn.Identity()
self.pooling = nn.AdaptiveAvgPool2d(1)
self.final = ArcMarginProduct(final_in_feat, n_cls, s=s, m=margin, easy_margin=False)
def forward(self, x, label):
batch_size = x.shape[0]
x = self.backbone(x)
feature = self.pooling(x).view(batch_size, -1)
logits = self.final(feature, label)
return logits
class ArcMarginProduct(nn.Module):
def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False, ls_eps=0.0):
super(ArcMarginProduct, self).__init__()
self.weight = Parameter(torch.FloatTensor(out_features, in_features))
nn.init.xavier_uniform_(self.weight)
self.easy_margin = easy_margin
self.m = m
self.s = s
self.cos_m = math.cos(self.m)
self.sin_m = math.sin(self.m)
self.th = math.cos(math.pi - self.m)
self.mm = math.sin(math.pi - self.m) * self.m
# self.register_buffer('phi',torch.tensor(0, dtype=torch.float16))
def forward(self, input, label):
x = F.normalize(input)
W = F.normalize(self.weight)
cosine = F.linear(x, W)
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m)
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine)
else:
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
one_hot = torch.zeros(cosine.size(), device='cuda')
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
output *= self.s
return output
I can train without converting to fp_16 but when i convert it to fp_16 I get the following error:
RuntimeError: expected scalar type float but found c10::Half
triggered by the line phi = torch.where(cosine > self.th, phi, cosine - self.mm). The variable cosine is float16 but phi is still a float32.
here is a link to a minimal colab notebook to reproduce.
Thanks for the help in advancce 