Hand key points prediction

I am trying to train a model to predict hand key points using fastai
My code is:

!wget http://domedb.perception.cs.cmu.edu/panopticDB/hands/hand_labels.zip

%reload_ext autoreload

%autoreload 2

%matplotlib inline

from fastai import *

from fastai.vision import *

from fastai.vision import image as im

import torch.nn as nn

from torch.nn.functional import mse_loss

import json

import re

path= Path(’./drive/My Drive/hand/hand_labels’)

path.ls()

transforms = get_transforms(do_flip=False, max_zoom=1.1, max_warp=0.01,max_rotate=45)

def get_y_func(x):

 pre, ext = os.path.splitext(x)

 img = open_image(x)

 coords = []

 for k in json.load(open(pre + '.json'))["hand_pts"]:

    coords.append([k[1],k[0]]) # inverting x and y

 return torch.tensor(coords,dtype=torch.float)

def get_y_func(x):

 pre, ext = os.path.splitext(x)

 img = open_image(x)

 coords = []

 for k in json.load(open(pre + '.json'))["hand_pts"]:

    coords.append([k[1],k[0]]) # inverting x and y

 return torch.tensor(coords,dtype=torch.float)

data = (PointsItemList.from_folder(path=path, extensions=[’.jpg’])

    .split_by_folder(train='manual_train', valid='manual_test')   

    .label_from_func(get_y_func) 
    
    .transform(transforms,size=224, tfm_y=True, remove_out=False,

    .databunch(bs=8)  

    .normalize(imagenet_stats))

class Reshape(nn.Module):

def __init__(self, *args):

    super(Reshape, self).__init__()

    self.shape = args

def forward(self, x):

    return x.view(self.shape)

head_reg = nn.Sequential(

Flatten(), 

nn.ReLU(),

nn.Dropout(0.5),

nn.Linear(512*7*7, 256),

nn.ReLU(),

nn.BatchNorm1d(256),

nn.Dropout(0.5),

nn.Linear(256, 42),

Reshape(-1,21,2),

nn.Tanh())

class MSELossFlat(nn.MSELoss):

def forward(self, input:Tensor, target:Tensor):

 return super().forward(input.view(-1), target.view(-1)) 

mse_loss_flat = MSELossFlat() #very important!!:initialazing the class

learn = cnn_learner(data, models.resnet34,custom_head=head_reg,loss_func=mse_loss_flat)

on training it this output on validation data I get is

What am i doing wrong?