The last can also be done with a callback since you can skip the normal step in the training loop by returning the proper flag and do it in the Callback instead,
In order to properly use all 8 cores of the TPU, you need to use the multiprocessing API which PyTorch XLA provides. It uses their own DataLoader for putting the data on the 8 cores of the TPU, and their own optimizer step to sync everything from the 8 cores during training. Unfortunately, it is not as simple as just setting the device parameter.