How to write custom loss function?

Hello everyone ! I’m currently working on using Centernet: Object as point to detect object, but at first I have to make a keypoint estimator.The output of the model is [2,3,128,128] and target is [2,3,128,128].
The first dimension of it is the heatmap of the object, the second is offset for x coordinate, third is offset for y coordinate.
The problem of mine is how do I write a custom loss function for it?
image
If I use pr[0][0] this way isn’t it just compare the first item of each batch?
is that another way of writing custom loss function for mine problem?
Thanks.