I think it is normal when you plot a 2D image, I always see it with plt.imshow 2d image. You can, firstly, check the dimension of your data (Which I think is 2). Secondly, convert it to 3d img with np.stack((your-img)x3times), or you can check if there is a gray colormap for it. But this is just for visualizing the image, you can ignore it if what you concentrate is the model,