Reducing VRAM usage

Hello, so I’m pretty new to this but I’ve just started using @ilovescience’s UPIT library (UPIT/upit at master · tmabraham/UPIT · GitHub) on AWS Sagemaker and would appreciate some help figuring out my VRAM woes.

Basically I’m running into very high VRAM usage with what I had thought wouldn’t be a particularly intensive usecase (I may be wrong). My dataset is about 6000 lots of 400*400 images. I’ve found that 64 features trains OK on a 16gb card at a batch size of 1, and I can just about sneak 128 features onto a 16gb card but it can spill over, again with batch size of 1. In order to go any higher than this or to increase my batch size, I end up having to use very expensive machines with 24/30gb of VRAM.

One potentially relevant detail is I’m loading my dataset from a pickled file as a Pandas df, I’m then casting these to tensors. I wasn’t sure if I’m able to load this data as a generator and if that would keep VRAM lower? (i.e. if that can load and unload data as needed?) I’m also slightly resizing on the fly, would this but increase VRAM usage? (i.e. does it keep original size in memory?). I’m also assuming that parallelisation would only help me increase batch size/training speed, rather than help with my network size. I’m interested to know if this is just a fundamental issue and as my network becomes bigger I have to scale, but it feels like I’m doing something wrong?

Any help really appreciated, thank you