Segmentation nan loss debugging

I am using segmentation for a data set with 5 classes + background.

Loss examples with different validation splits are below but keep encountering nan losses, occasionally with a specific validation split I can get one epoch to run with good train and val loss but then typically get nan for the loss on second epoch.

Loss function is FlattenedLoss of CrossEntropyLoss()



Depending on the run, loss vs learning rate from lr_find() can also appear pretty normal looking eg:

I also tried instead of using the 1 cycle policy, but similar result.


Using a really low learning rate below, works for one epoch then get nan on next:

When I train on a very small subset of the dataset, training works fine. Perhaps I need to split the training set into n parts, train on each, see if I can identify a problem in the data that is throwing the dataset off - or even better - to do thsi programmatically by using a callback to track the image names when nan is encountered - has anyone done something similar?

Has anyone encountered similar nan issues when training segmentation problems and have any other suggestions to try to debug?

I reran workflows to process the training data all works ok now.

I have found this in the past too: if you get errors the first place to check is your input data

This happens sometimes when running the training in f16 mode, never happened in f32 mode for me atleast