How to debug an occasionally bad validation loss (CloudNet+)


I am trying to implement a somewhat convoluted segmentation model for cloud recognition, described in this article:
Cloud-Net+: A Cloud Segmentation CNN for Landsat 8 Remote Sensing Imagery Optimized with Filtered Jaccard Loss Function

The model is working, and even learning (getting 80% accuracy on the camvid example of Part1) and even seems fare better than UNET on the Kaggle’s cloud competition (though I really just fed it and UNET with the raw data without any analysis, so both results are quite bad).

The issue I am having is that at some epochs during trainig, the validation loss jumps to NaN, Inf, or just unreasonably high values, and then going back to normal on next epoch. My question is how to I check what is going on? Maybe there are any tools that or pytorch provide to tackle this? What should I be checking among the huge tables of numbers to understand what is going on?

The model code is here. Thanks for your time!