As expected, the memory footprint is now double the size (~6.5GB). However, when I load the model using the code below, I still get a memory footprint of around 3.3GB and no errors in the image generation process:
Yes, I know. They will also run a bit faster. My question is, why do I need to set the revision to fp16 and not just specify the torch_dtype attribute to be torch.float16, like I do in the last code snippet? What’s special with the fp16 revision?
Sure, but if I then specify the torch_dtype attribute to be torch.float16 I’m overriding the default dtype of the model to be 16-bit precision floats. At least this is what the docs say:
torch_dtype (str or torch.dtype, optional) — Override the default torch.dtype and load the model under this dtype. If "auto" is passed the dtype will be automatically derived from the model’s weights.
If you do that you are still downloading a model that is double the size of what you needed. With models that can take several Gb this makes a difference.