I want to build a CNN that will accept standard continuous/categorical variables after the corresponding image data have passed through the convolutional layers.
In Keras, I’ve seen this done by concatenating the given variables to the output of the flattened layer coming out of the convolutional layers. I’d imagine that a concat could be performed in the forward method in PyTorch, but I’m wondering if fastai has prebuilt implementations of this.
In short: how would I pass data to a fastai learner, like ConvLearner.from_model_data(…), if the data are not just images?
Thank you in advance.
You’ll have to write your own architecture, but there’s no reason why you can’t do this. Take a look at the fast.ai source for inspiration and add the appropriate functions.
I’m working on something similar if you want a hand, but it’ll take a few days before I’ve got it complete.
I was just going to pose the same question tonight too! I also did it before pretty easily w/Keras (The Part1 lecture last year showed pretty clearly how to do this) but was much more confused as to how to fit it into the new fastai/pytorch framework. Would love to use the automatic embeddings + combine that with image layer data.
I’m guessing the answer is to first lop off a later layer from a resnet and store that as an intermediate state and then concat them with the column data (either by flattening it and appending it to column data before using the fastai api or building a new function that takes both the column and resnet data and concatenates them). I am probably going to try the first technique in the coming week or two.
Even - if you get something that works, would you be willing to publish a version of what you come up with either here or on github?
Actually - just realized my last comment won’t work if you want to retrain the resnet. So we’ll probably have to pass along the images and combine the models to get that added benefit (although, for our image data, retraining the resnet even a little bit it did not have much of an impact on the model).