Semantic segmentation: Train on Crop, Predict on Full - Howto?

This Kaggle discussion post shows that it is possible to train on small images (or “crops”) from a training dataset but predict (i.e. create prediction masks) on arbitrary shaped images in a test dataset.

In Keras the training looks like this for (256x256) pixel images:

from segmentation_models import FPN
from segmentation_models.losses import bce_jaccard_loss
model = FPN('inceptionv3', input_shape=(256, 256, 3), classes=4, activation='sigmoid')
model.compile(optimizer='adam', loss=bce_jaccard_loss, metrics=[dice_coef])

And the prediction works for (256x1600) pixel images by using the same weights:

model2 = FPN('inceptionv3', input_shape=(256, 1600, 3), classes=4, activation='sigmoid')
model2.compile(optimizer='adam', loss=bce_jaccard_loss, metrics=[dice_coef])

I wonder how this could be achieved in ? I think the input shape lives in the dataloaders, but I guess simply providing a test dataloader with differently shaped input images would not do the trick. Instead, how could one change the CNN input accordingly in