What do you see when you run the following in a separate terminal (run this then start your jupyter terminal and watch the memory useage as you step through the notebook)?
watch -n 1 nvidia-smi
Reducing batch size you should see decrease in mem useage-make sure your dataloader is actually told of the batch size - from memory run (if not correct you may need to do some digging for it) after you set the bs and check if correct: