Problems: 1. creating mini batches. 2. CUDA memory error

Hey folks,
I’m already despaired.

I’ve got 61 pairs of images [input: dark image (fast shutter), target: bright image (slow shutter)] of the same objects. Each image is in size of (NxM=4272x2848 pixels).

I’m asking your help with two issues:

  1. Problem 1: Configure the mini-batches / patches.
  2. Problem 2: Find a solution to the CUDA memory problem.

So:

Problem 1:
I want to split the whole dataset into about roughly 10 batches. That’s why I set up bs=6 (which creates batches with a size of 6 each).
But for each batch, I’d want to split every images-pair into 10-200 smaller mini-batches/patches. This means, that for every image, there would be created many patches in size of WxD=256x256 (out of the NxM=4272x2848 pixels), which would be fed every epoch.
The goal here is to use less memory in GPU since it has only available 12GB at a time.
The patches shouldn’t be squished, rescaled, padded, or anything. They shall only be smaller crops of the original size.

Which shall I choose?

  1. item_tfms=RandomCrop(256)
  2. batch_tfms=RandomCrop(256)

How can I tell that it creates 10 batches of 6 images of size MxN, with 10-200 patches of size WxD?
Because now it looks like that it creates 10 batches of 6 images of size WxD. Right?

Problem 2:
As I tried to start training that model, I got this issue:

I don’t get it. What is actually happening behind?

I tried to research the forum here, and found somebody’s recommending on his function that prints out all of the data that was created in the GPU RAM.

def pretty_size(size):
    """Pretty prints a torch.Size object"""
    assert(isinstance(size, torch.Size))
    return " × ".join(map(str, size))

def dump_tensors(gpu_only=True):
    """Prints a list of the Tensors being tracked by the garbage collector."""
    import gc
    total_size = 0
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj):
                if not gpu_only or obj.is_cuda:
                    print("%s:%s%s %s" % (type(obj).__name__, 
                                          " GPU" if obj.is_cuda else "",
                                          " pinned" if obj.is_pinned else "",
                                          pretty_size(obj.size())))
                    total_size += obj.numel()
            elif hasattr(obj, "data") and torch.is_tensor(obj.data):
                if not gpu_only or obj.is_cuda:
                    print("%s → %s:%s%s%s%s %s" % (type(obj).__name__, 
                                                   type(obj.data).__name__, 
                                                   " GPU" if obj.is_cuda else "",
                                                   " pinned" if obj.data.is_pinned else "",
                                                   " grad" if obj.requires_grad else "", 
                                                   " volatile" if obj.volatile else "",
                                                   pretty_size(obj.data.size())))
                    total_size += obj.data.numel()
        except Exception as e:
            pass        
    print("Total size:", total_size)

Here is the output: What am I actually seeing here?

Tensor: GPU pinned
Tensor: GPU pinned 6 × 64 × 512 × 512
TensorImage: GPU pinned 6 × 64 × 512 × 512
Tensor: GPU pinned 6 × 64 × 512 × 512
TensorImage: GPU pinned 6 × 64 × 512 × 512
Tensor: GPU pinned
Tensor: GPU pinned 6 × 64 × 512 × 512
TensorImage: GPU pinned 6 × 64 × 512 × 512
Tensor: GPU pinned 6 × 64 × 512 × 512
TensorImage: GPU pinned 6 × 64 × 512 × 512
Tensor: GPU pinned
Tensor: GPU pinned 6 × 64 × 512 × 512
TensorImage: GPU pinned 6 × 64 × 512 × 512
Tensor: GPU pinned 6 × 128 × 256 × 256
TensorImage: GPU pinned 6 × 128 × 256 × 256
Tensor: GPU pinned
Tensor: GPU pinned 6 × 128 × 256 × 256
TensorImage: GPU pinned 6 × 128 × 256 × 256
Tensor: GPU pinned 6 × 128 × 256 × 256
TensorImage: GPU pinned 6 × 128 × 256 × 256
Tensor: GPU pinned
Tensor: GPU pinned 6 × 128 × 256 × 256
TensorImage: GPU pinned 6 × 128 × 256 × 256
Tensor: GPU pinned 6 × 128 × 256 × 256
TensorImage: GPU pinned 6 × 128 × 256 × 256
Tensor: GPU pinned
Tensor: GPU pinned 64
Tensor: GPU pinned 64
Tensor: GPU pinned 64
Tensor: GPU pinned 64
Tensor: GPU pinned 64
Tensor: GPU pinned 64
Tensor: GPU pinned 64
Tensor: GPU pinned 64
Tensor: GPU pinned 64
Tensor: GPU pinned 64
Tensor: GPU pinned 64
Tensor: GPU pinned 64
Tensor: GPU pinned 64
Tensor: GPU pinned 64
Tensor: GPU pinned 128
Tensor: GPU pinned 128
Tensor: GPU pinned 128
Tensor: GPU pinned 128
Tensor: GPU pinned 128
Tensor: GPU pinned 128
Tensor: GPU pinned 128
Tensor: GPU pinned 128
Tensor: GPU pinned
Tensor: GPU pinned 128
Tensor: GPU pinned 128
Tensor: GPU pinned
Tensor: GPU pinned 128
Tensor: GPU pinned 128
Tensor: GPU pinned
Tensor: GPU pinned 128
Tensor: GPU pinned 128
Tensor: GPU pinned
Tensor: GPU pinned 128
Tensor: GPU pinned 128
Tensor: GPU pinned
Tensor: GPU pinned 128
Tensor: GPU pinned 128
Tensor: GPU pinned
Tensor: GPU pinned 256
Tensor: GPU pinned 256
Tensor: GPU pinned
Tensor: GPU pinned 256
Tensor: GPU pinned 256
Tensor: GPU pinned
Tensor: GPU pinned 256
Tensor: GPU pinned 256
Tensor: GPU pinned
Tensor: GPU pinned 256
Tensor: GPU pinned 256
Tensor: GPU pinned
Tensor: GPU pinned 256
Tensor: GPU pinned 256
Tensor: GPU pinned
Tensor: GPU pinned 256
Tensor: GPU pinned 256
Tensor: GPU pinned
Tensor: GPU pinned 256
Tensor: GPU pinned 256
Tensor: GPU pinned
Tensor: GPU pinned 256
Tensor: GPU pinned 256
Tensor: GPU pinned
Tensor: GPU pinned 256
Tensor: GPU pinned 256
Tensor: GPU pinned
Tensor: GPU pinned 256
Tensor: GPU pinned 256
Tensor: GPU pinned
Tensor: GPU pinned 256
Tensor: GPU pinned 256
Tensor: GPU pinned
Tensor: GPU pinned 256
Tensor: GPU pinned 256
Tensor: GPU pinned
Tensor: GPU pinned 256
Tensor: GPU pinned 256
Tensor: GPU pinned
Tensor: GPU pinned 512
Tensor: GPU pinned 512
Tensor: GPU pinned
Tensor: GPU pinned 512
Tensor: GPU pinned 512
Tensor: GPU pinned
Tensor: GPU pinned 512
Tensor: GPU pinned 512
Tensor: GPU pinned
Tensor: GPU pinned 512
Tensor: GPU pinned 512
Tensor: GPU pinned
Tensor: GPU pinned 512
Tensor: GPU pinned 512
Tensor: GPU pinned
Tensor: GPU pinned 512
Tensor: GPU pinned 512
Tensor: GPU pinned
Tensor: GPU pinned 512
Tensor: GPU pinned 512
Tensor: GPU pinned
Tensor: GPU pinned 512
Tensor: GPU pinned 512
Tensor: GPU pinned
Tensor: GPU pinned 256
Tensor: GPU pinned 256
Tensor: GPU pinned
Tensor: GPU pinned 128
Tensor: GPU pinned 128
Tensor: GPU pinned
Tensor: GPU pinned 64
Tensor: GPU pinned 64
Tensor: GPU pinned
Tensor: GPU pinned 64
Tensor: GPU pinned 64
Tensor: GPU pinned
Tensor: GPU pinned 6 × 3 × 2048 × 2048
TensorImage: GPU pinned 6 × 3 × 2048 × 2048
Tensor: GPU pinned 6 × 3 × 2048 × 2048
Tensor: GPU pinned 6 × 64 × 1024 × 1024
TensorImage: GPU pinned 6 × 64 × 1024 × 1024
Tensor: GPU pinned
Tensor: GPU pinned 6 × 64 × 1024 × 1024
TensorImage: GPU pinned 6 × 64 × 1024 × 1024
Tensor: GPU pinned 6 × 64 × 512 × 512
TensorImage: GPU pinned 6 × 64 × 512 × 512
Tensor: GPU pinned 6 × 64 × 512 × 512
TensorImage: GPU pinned 6 × 64 × 512 × 512
Tensor: GPU pinned
Tensor: GPU pinned 6 × 64 × 512 × 512
TensorImage: GPU pinned 6 × 64 × 512 × 512
Tensor: GPU pinned 6 × 64 × 512 × 512
TensorImage: GPU pinned 6 × 64 × 512 × 512
Tensor: GPU pinned
Tensor: GPU pinned 6 × 64 × 512 × 512
TensorImage: GPU pinned 6 × 64 × 512 × 512
Tensor: GPU pinned 6 × 64 × 512 × 512
TensorImage: GPU pinned 6 × 64 × 512 × 512
Tensor: GPU pinned
Tensor: GPU pinned 6 × 64 × 512 × 512
TensorImage: GPU pinned 6 × 64 × 512 × 512
Tensor: GPU pinned 6 × 64 × 512 × 512
TensorImage: GPU pinned 6 × 64 × 512 × 512
TensorImage: GPU pinned 6 × 3 × 2048 × 2048
Tensor: GPU pinned 1 × 3 × 1 × 1
Tensor: GPU pinned 1 × 3 × 1 × 1
Parameter: GPU pinned 256 × 256 × 3 × 3
Parameter: GPU pinned 256
Parameter: GPU pinned 3
Parameter: GPU pinned 3 × 99 × 1 × 1
Parameter: GPU pinned 99
Parameter: GPU pinned 99 × 99 × 3 × 3
Parameter: GPU pinned 99
Parameter: GPU pinned 99 × 99 × 3 × 3
Parameter: GPU pinned 384
Parameter: GPU pinned 384 × 96 × 1 × 1
Parameter: GPU pinned 96
Parameter: GPU pinned 96 × 96 × 3 × 3
Parameter: GPU pinned 96
Parameter: GPU pinned 96 × 192 × 3 × 3
Parameter: GPU pinned 64
Parameter: GPU pinned 64
Parameter: GPU pinned 512
Parameter: GPU pinned 512 × 256 × 1 × 1
Parameter: GPU pinned 64
Parameter: GPU pinned 64
Parameter: GPU pinned 64
Parameter: GPU pinned 64
Parameter: GPU pinned 64
Parameter: GPU pinned 64
Parameter: GPU pinned 64
Parameter: GPU pinned 64
Parameter: GPU pinned 64
Parameter: GPU pinned 64
Parameter: GPU pinned 64
Parameter: GPU pinned 64
Parameter: GPU pinned 64
Parameter: GPU pinned 64
Parameter: GPU pinned 128
Parameter: GPU pinned 128
Parameter: GPU pinned 128
Parameter: GPU pinned 128
Parameter: GPU pinned 128
Parameter: GPU pinned 128
Parameter: GPU pinned 128
Parameter: GPU pinned 128
Parameter: GPU pinned 128
Parameter: GPU pinned 128
Parameter: GPU pinned 128
Parameter: GPU pinned 128
Parameter: GPU pinned 128
Parameter: GPU pinned 128
Parameter: GPU pinned 128
Parameter: GPU pinned 128
Parameter: GPU pinned 128
Parameter: GPU pinned 128
Parameter: GPU pinned 256
Parameter: GPU pinned 256
Parameter: GPU pinned 256
Parameter: GPU pinned 256
Parameter: GPU pinned 256
Parameter: GPU pinned 256
Parameter: GPU pinned 256
Parameter: GPU pinned 256
Parameter: GPU pinned 256
Parameter: GPU pinned 256
Parameter: GPU pinned 256
Parameter: GPU pinned 256
Parameter: GPU pinned 256
Parameter: GPU pinned 256
Parameter: GPU pinned 256
Parameter: GPU pinned 256
Parameter: GPU pinned 256
Parameter: GPU pinned 256
Parameter: GPU pinned 256
Parameter: GPU pinned 256
Parameter: GPU pinned 256
Parameter: GPU pinned 256
Parameter: GPU pinned 256
Parameter: GPU pinned 256
Parameter: GPU pinned 256
Parameter: GPU pinned 256
Parameter: GPU pinned 512
Parameter: GPU pinned 512
Parameter: GPU pinned 512
Parameter: GPU pinned 512
Parameter: GPU pinned 512
Parameter: GPU pinned 512
Parameter: GPU pinned 512
Parameter: GPU pinned 512
Parameter: GPU pinned 512
Parameter: GPU pinned 512
Parameter: GPU pinned 512
Parameter: GPU pinned 512
Parameter: GPU pinned 512
Parameter: GPU pinned 512
Parameter: GPU pinned 512
Parameter: GPU pinned 512
Parameter: GPU pinned 1024
Parameter: GPU pinned 512
Parameter: GPU pinned 1024
Parameter: GPU pinned 256
Parameter: GPU pinned 256
Parameter: GPU pinned 512
Parameter: GPU pinned 512
Parameter: GPU pinned 1024
Parameter: GPU pinned 128
Parameter: GPU pinned 128
Parameter: GPU pinned 384
Parameter: GPU pinned 384
Parameter: GPU pinned 768
Parameter: GPU pinned 64
Parameter: GPU pinned 64
Parameter: GPU pinned 256
Parameter: GPU pinned 1024 × 512 × 3 × 3
Parameter: GPU pinned 512 × 1024 × 3 × 3
Parameter: GPU pinned 1024 × 512 × 1 × 1
Parameter: GPU pinned 512 × 512 × 3 × 3
Parameter: GPU pinned 512 × 512 × 3 × 3
Parameter: GPU pinned 1024 × 512 × 1 × 1
Parameter: GPU pinned 384 × 384 × 3 × 3
Parameter: GPU pinned 384 × 384 × 3 × 3
Parameter: GPU pinned 768 × 384 × 1 × 1
Parameter: GPU pinned 256 × 256 × 3 × 3
Parameter: GPU pinned 64 × 3 × 7 × 7
Parameter: GPU pinned 64 × 64 × 3 × 3
Parameter: GPU pinned 64 × 64 × 3 × 3
Parameter: GPU pinned 64 × 64 × 3 × 3
Parameter: GPU pinned 64 × 64 × 3 × 3
Parameter: GPU pinned 64 × 64 × 3 × 3
Parameter: GPU pinned 64 × 64 × 3 × 3
Parameter: GPU pinned 128 × 64 × 3 × 3
Parameter: GPU pinned 128 × 128 × 3 × 3
Parameter: GPU pinned 128 × 64 × 1 × 1
Parameter: GPU pinned 128 × 128 × 3 × 3
Parameter: GPU pinned 128 × 128 × 3 × 3
Parameter: GPU pinned 128 × 128 × 3 × 3
Parameter: GPU pinned 128 × 128 × 3 × 3
Parameter: GPU pinned 128 × 128 × 3 × 3
Parameter: GPU pinned 128 × 128 × 3 × 3
Parameter: GPU pinned 256 × 128 × 3 × 3
Parameter: GPU pinned 256 × 256 × 3 × 3
Parameter: GPU pinned 256 × 128 × 1 × 1
Parameter: GPU pinned 256 × 256 × 3 × 3
Parameter: GPU pinned 256 × 256 × 3 × 3
Parameter: GPU pinned 256 × 256 × 3 × 3
Parameter: GPU pinned 256 × 256 × 3 × 3
Parameter: GPU pinned 256 × 256 × 3 × 3
Parameter: GPU pinned 256 × 256 × 3 × 3
Parameter: GPU pinned 256 × 256 × 3 × 3
Parameter: GPU pinned 256 × 256 × 3 × 3
Parameter: GPU pinned 256 × 256 × 3 × 3
Parameter: GPU pinned 256 × 256 × 3 × 3
Parameter: GPU pinned 512 × 256 × 3 × 3
Parameter: GPU pinned 512 × 512 × 3 × 3
Parameter: GPU pinned 512 × 256 × 1 × 1
Parameter: GPU pinned 512 × 512 × 3 × 3
Parameter: GPU pinned 512 × 512 × 3 × 3
Parameter: GPU pinned 512 × 512 × 3 × 3
Parameter: GPU pinned 512 × 512 × 3 × 3
Total size: 5074405187

Considering that I use Tesla T4 (GPU 11.17GB), and 61 images x 4MB each = 244MB. How did it turn out to store already 10.47GB? What settings or configurations should I do to fix it?

Thanks

Link to my script:

1 Like

I feel like the forum is kinda dead, so I’m hereby trying to stimulate it! :wink:

As you can see from the GPU RAM chart in the screenshot, the CUDA out of memory error is because the Tesla T4 ran out of memory. Restart the env and reduce the batch size.

Assuming your dataloaders is called dls,

x,y = dls.one_batch()
x.shape

will generate one batch and then you can verify the size and shape of your batches.

RandomCrop is an item transform, and the documentation has an example of its output (under OldRandomCrop, which does the same thing).

You will need to write a custom preprocessor as fastai (or any other framework) doesn’t support this use case out of the box. At an image size of 256x256 and with 16GB of GPU RAM, a ResNet50 class model won’t be able to train on more than 64 images at a time. Give or take. So you won’t be able to fit the entire image into one batch if it requires 200 patches.

Hey man, thank you very much!

So yeah, something about allocating memory isn’t properly used by me. If the GPU has 16GB of RAM, does it mean that he will save on it all of the 61 pairs of images? Or does it copy every pair of images at a time to the RAM of the GPU?

Can it do something like this:
0. Begin the first epoch.

  1. Create 61 random crops (let’s call them patches (WxD)) from the images (NxM).
  2. Load them to the GPU.
  3. Finish the epoch.
  4. Go to 0, to start another epoch but with other new patches).

I wouldn’t want the GPU to run many epochs only on the very pre-randomly-chosen patches (WxD) from the original photos (NxM), and lose information.
On the other hand, the GPU RAM can’t store tensors of (NxM) photos even in one batch. What calculations am I missing here?

Thanks

I don’t have access to look at your colab notebook via the link you provided, but I have a few suggestions that may help. Like @bwarner mentioned, you can write a stand alone script that goes through your images and creates the smaller tiles and saves them to a separate folder. You would then just use the new images for training. This will probably make debugging easier for you as you’ll be able to look at the files directly without wondering if the dataloader is working as you want it to. This may also lead to better performance as loading large jpg’s is a slow process, so if you go with the RandomCrop method, your dataloader will have to re-load the images each time which is likely to be a bit slow. The disadvantage is that you will no longer have randomized crops which gives you an infinitely variable dataset. If you do want to go with the RandomCrop method you may still want to break them up into larger tiles, maybe ~512x512 first and use those with RandomCrop to avoid the overhead of trying to load in very large images of which you’re only using a tiny portion of for each epoch. If you have enough CPU Ram you could also load the images all into cpu ram before training as a np array or tensor so you can skip the jpg decode process which can be relatively slow for large images, but that may be a bit trickier to implement.

Starting off with pre-processing your images into 256x256 tiles is probably a good starting point so your dataloader will be less complex and likely faster and then maybe look into building a more sophisticated dataloader later.

1 Like

Great answer.
I’m still analyzing your ideas, and I think that I found more ideas in the tutorials of Jeremy.
I’ll keep things updated here as soon as I’m done with it

1 Like

I’ve made some progress and made to fix the problem as well. I could also understand your message quite better.

So the difference between these two is basically:

  1. item_tfms=RandomCrop(256) - transformations on the images right before being sent to the GPU, by the CPU
  2. batch_tfms=RandomCrop(256) - transformations (Augmentations) on the images after being stored on the GPU, by the GPU itself

I’ve got a question for you here:

You will need to write a custom preprocessor as fastai (or any other framework) doesn’t support this use case out of the box. At an image size of 256x256 and with 16GB of GPU RAM, a ResNet50 class model won’t be able to train on more than 64 images at a time. Give or take. So you won’t be able to fit the entire image into one batch if it requires 200 patches.

Did you mean 64 images in one batch? (A batch of 64 images)

I was really thinking about this.
I’ve noticed that even if I resize the images into smaller ones, the results become significantly better. So even if the image isn’t sized as NxM=4272x2848, the NN would learn much more by even smaller resized image of that. So cropping isn’t enough, but it needs to have the whole picture.

I’m now considering adding way more images to the dataset. I’ve got only ~120 images. Maybe I’ll need something about at least a few thousand to make a progress?

So just to understand. Every epoch, the CPU copies a new batch of images to the GPU? Or, all of the images are being copied to the GPU before the training?

Hi Dan, the batch size is the number of images that get sent to the GPU at one time. After each batch, the model weights get recalculated. Increasing the batch size may shave time off training, but it also requires more VRAM. Most commonly, batches are set to a power of 2 (2 4 8 16 etc), I think this is due to how the computation units on GPU’s are structured but I am not totally familiar with the theory.

It’s important to realize that we can get this confused with our everyday idea of the word “batch.” So if you are just thinking of arranging groups of images for your own organizational purposes, that would not be related to the batch concept above, nor do you need to do that beforehand (unless you are making different categories, or different train/test/validation groups). You would just put all of your images into a single dataset, and the program will decide on its own batches, limited to whatever size you specify.

An epoch is a complete pass through the entire set of images. So if you increase the number of images, that will require more “batches” to cycle through before the epoch is complete. The hope is that you train for enough epochs to get the model to fit well without overfitting.

Also, those are some massive images! Most of the ones I’ve used in DL courses are much smaller.

2 Likes

Thank you for your explanation! Great input.

May I ask two things to clarify my understanding of what I can achieve?

  1. More images in the dataset, would mean more batches to be fed onto the GPU during one single epoch. Suppose that I want to achieve better results (let’s measure it as [K] parameter) with [X] images and [Y] epochs, what will make them better:
    Increasing the number of images (let’s say 100[X]) and leaving [Y] epochs; or:
    Increasing the number of epochs (let’s say 10[Y]) and leaving [X] images.
    Which of the ones above, would make [K] about bigger?
    (for instance, would the first option yield 4[K] but the second option would yield only 2[K]?)

  2. Is it possible for me to code some script in fastai where I first train a dataset A, and then start training another dataset B on the same model?

Thanks you very much

Hi Dan, for #1 you will usually be better off increasing the number of images in your training set. That almost always helps things along. Other things you may want to try are adjusting the number of layers or nodes in your neural network, augmenting the data (as discussed in what others have posted above) and increasing the dropout rate. I’m sure all of these are covered somewhere in Part 1 and/or 2, but I confess I am an old-school Andrew Ng/Coursera deep learner, and am less familiar with the fast.ai course progression.

The more data and/or complex features you add, the more epochs you will probably need to run, as the model will be slower to converge. But you are more likely to get better results. This is very different than just running for more epochs without changing anything else (your second suggestion). If you do that, the model will appear to keep getting better and better on your training set, but eventually it may overfit and your real world results will actually get worse. This is the difference between training loss and validation loss (the latter is what I would call your “K parameter”).

For #2, if your purpose is to simultaneously compare results based on two different training sets, I think you’d have to run that on two machines (or one machine with 2 GPU’s, using a different run on each GPU). There’s another concept called data parallelization, where you spread a single training set over multiple GPU’s, but that doesn’t sound like what you have in mind (and I’m not very familiar with it).

Thank you verily!

  1. Yes, you confirmed my thoughts about this too. You also made me smarter about the nodes and dropout rate. I will have to read a little more about these.

  2. I meant that I’d want the model to be trained with dataset A, then stop. Then save the model. Then load another dataset B, and train the previous model on that new dataset B. Not simultaneously, not parallelly

Got it. I think this thread will give you what you need to retrain your model:

Why do you want to do this? There are good reasons for doing this in certain circumstances such as progressive resizing, language model pre-training, etc. but it can also be the wrong approach in many circumstances as well. It just depends on what you’re trying to do. You may be better off just training on both dataset A and B at the same time.

Suppose that my model restores colors from black and white photos.
I’d want it first to be able to restore modest colors, and only after a while, to train it to restore very vivid colors as well.
If I train it all together, it might learn “more” how to restore just the vivid colors, rather than the modest ones.

Gotcha. I have not worked on any image colorization myself so I don’t have any direct experience with this to know what works best. I think you may be better off training with (A+B) first and then afterwards fine tune the model with just the (B) dataset with a lower learning rate. The reason I’m thinking this may be better is that when you do your initial training your model will start to learn how to re-color both vivid and modest colors, even if the vivid ones do happen to dominate the modest ones it will still gain some experience with the modest ones. Then afterwards you can apply additional training with only more modest colors so your model will bias more towards the more modest colors. I am assuming you’re using a unet like architecture with a pre-trained head originally trained on imagenet. It’s likely that you will be re-training with a much smaller dataset than the original imagenet dataset of ~1M images that was used to create the pre-trained model. When you retrain a pre-trained model using a smaller dataset it tends to lose some of the sophisticated abilities that it previously learned that are no longer pertinent to its new task and once those sophisticated abilities are gone through re-training, they don’t tend to return. Essentially when dealing with a smaller datasets that will likely be less sophisticated than the pre-trained model dataset you want to avoid ‘task switching’ (pre-trained > task A > task B) and instead do (pre-trained > task A & task B > task B). I’m assuming here that your dataset size for task B is much smaller than task A, and task A & B are very similar/complimentary.

My intuition could be wrong in this case, but this is generally a good approach when thinking about re-training models. I haven’t thought about how to explain this before so hopefully this makes sense. There are quite a few papers on colorization you may want to check out that would provide better insights than what I can provide: Colorization | Papers With Code

1 Like

Thank you so much for your additional input!
Great insights. I’m going to apply them to my work.

1 Like