Can’t say that I see any issues with your code. It may be a Pytorch issue. What version of Pytorch are you running?
check this post right at the end:
hope that helps