CNN Regression

As per the assertion error, I’m interested in creating a regression CNN.

I’m of course willing to contribute to/make the initial solution, but it’s not clear to me exactly what additions need to be made. It looks like the only changes that would need to be made to get it working is to change the create_head function and make an RegressionLearner equivalent to the ClassificationLearner.

The latter may be straight-forward, but the former (i.e., creating a default create_head function) is not so clear. I use regression CNNs to learn image-to-image regressions (example). But regression could also mean: take an image to some arbitrarily-sized set of real numbers. Any thoughts on what the default behavior should be? And what else should be included in the create_head function?

1 Like

I believe you can still make it work, let’s Imagine you wanted to predict 4 bounding boxes and 6 categories. You could create a head for your cnn like this

head_reg4 = nn.Sequential(
Flatten(),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(25088,256),
nn.ReLU(),
nn.BatchNorm1d(256),
nn.Dropout(0.5),
nn.Linear(256,4(bounding boxes)+6(categories)),
)

Add your custom loss function to your Dataset class before passing it to your DataBunch, Tada you got regression working

1 Like

Hey, thanks for the response. I was able to get regression working previously (you can see in this silly example), where I am basically doing what you suggest in creating a custom head. I was more directly addressing the request to “bug us on the forums” comment in the assertion error, since more native support for regression in the fastai library would be nice.

I also think this may not work as expected with the create_cnn function, since the latest behavior of create_cnn instantiates ClassificationLearner whose predict function is specifically for classification. Point being, even with a custom head not all functionality would be totally supported.

Your comment also alludes to the issue I was bringing up about what the default output of a regression CNN should be, which is very case dependent. So perhaps a generic regression CNN would not be all that useful (?), since custom heads would be have to be made for each use-case anyways. Not sure though.

2 Likes

I’ve got a dataset downloaded now. Will try to do this soon.

7 Likes

I’m just here to nag FASTAI to create CNN Regression!! :smiley:

1 Like