Focal Loss in Pytorch [Code Review]

Hey Everyone, I have written a custom loss function (Focal Loss) in Pytorch which I am going to use in my SSD code.

Here is the colab URL for the loss function:-

I need peer reviews on my code regarding its correctness, I have done certain tests which are successful but not so sure whether it will work as intended for object detection is SSD.

Any feedback or suggestions for the code is extremely Appreciated. It would really help me out.

Dear Aman,
Thank you for publishing FocalLoss loss function!
I wanna try if for semantic segmentation for Unet.
But I met some problems while using it. Could you please clarify how I can solve it
First problem is the following:
weights = [15./np.log(np.sum(y==0)),15./np.log(np.sum(y==1)),15./np.log(np.sum(y==2))]
weights = [15./np.log(np.sum(y==0)),15./np.log(np.sum(y>0))]
x, y = generator() #batch generator
vx = torch.from_numpy(x).float().to(device)
vy =torch.from_numpy(y.astype(np.int64)).to(device)
res = net(vx)
loss = binary_focal_loss(res,vy)
And after training I got the following picture

Could you help me?