Progressive resizing for segmentation

I am using the following code to perform progressive resizing for a semantic segmentation problem using UNET.

Problem:
I don’t think the resizing part actually works because the time that each epoch takes to train for learn_224 is exactly the same as that of learn_112. When create a learn_224 using unet_learn() each epoch takes much longer to train.

Am I doing this incorrectly?

    size = 460 // 2
    half= (224 //2, 224 // 2) 
    dblock_112 = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
                       get_items=get_relevant_images,
                       splitter=splitter,
                       get_y=get_mask, 
                       item_tfms=Resize((size,size)),
                       batch_tfms=[*aug_transforms(size=half), Normalize.from_stats(*imagenet_stats)])

    dls_112 = dblock_2.dataloaders(path/'images', bs=16)

    learn_112 = unet_learner(dls_112, resnet34,
                         cbs=callbacks, 
                         self_attention=True,
                         metrics=[Dice()]
                        )
    learn_112.fine_tune(50)


    # Let's trinf using double the size (224)
    size = 460 
    half= (224, 224 ) 

    dblock_2 = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),
                       get_items=get_relevant_images,
                       splitter=splitter,
                       get_y=get_mask, 
                       item_tfms=Resize((size,size)),
                       batch_tfms=[*aug_transforms(size=half), Normalize.from_stats(*imagenet_stats)])

    dls_224 = dblock_2.dataloaders(path/'images', bs=16)

    learn_224 = learn_112
    learn_224.data = dls_224
    x = next(iter(learn_224.data.train))
    x[0].size()    # prints (as expected): torch.Size([16, 3, 224, 224])

    learn_224.unfreeze()
    learn_224.fit_one_cycle(30, lr_max=slice(1e-4, 1e-3))
    ```

Your issue is you’re confusing fastai v1 with v2 notation.

Try:

learn_224.dls = dls_224

after increasing the size

2 Likes

Thanks you! It fixed the problem. Mixing v1 and v2 API has been an ongoing issue.