Siamese network with triple loss

Hello. I’m trying to implement a siamese network with triplet loss for facial recognition using fastai-v1. I could figure out how to give a custom head for a pre-trained resnet network to get an “embedding” of sorts for the images. But I’m unable to figure out how to pass the loss function. I need to pass the loss function involving 3 outputs from the same network. I’m clueless as to how to proceed.

Any help / direction will be much appreciated, thank you!

Disclaimer: I haven’t implemented this myself
As I understand, it goes this way:

  1. Pass Anchor image through all layers
  2. Pass Positive image through same layers
  3. Pass Negative image through same layers
  4. Compute Loss: L(A,P,N) = max(|| f(A) - f( P) ||2 - || f(A) - f(N) ||2 + alpha, 0)