It seems like this is probably an error with PyTorch and that the actual code might be fine.
See here and here.