Adding (image-meta-data) Input-Layer as intermediate layer in CNN

(Will Sutton) #1

So I have a set of images I’m training for an ImagePoints task. Each image also has associated meta-data in boolean, continuous, and categorical form, e.g.:

  • label_is_dirty: true or false
  • img_source: ‘google-image’, ‘pinterest’, ‘reddit’, etc…
  • shutter_capture_time: 0.0185

To train the images, I’m obviously using a cnn_learner. But there’s no sensible (?) reason to pass the meta-data-variable into convolutions. And there’s no {x,y,channel}-location to place the data in the traditional input-activation layer.

I’m wondering if it’s best to add this data after flatten()-stage into the input for the first fully-connected layer by modifying my architecture? Clearly this is a common enough desire that there must be a name for this concept, what is it? Also where in fastai courses / library can I learn and utilize this capability?

0 Likes

(Malcolm McLean) #2

Hi Will. I see that no one has responded to your questions. Though I am by no means an expert, I may be able to help you get started, having struggled through a similar task. Hopefully, experts will correct and add to these points.

Your idea to include metadata with the image is reasonable. And yes, you would add it into the FC (Linear) layer, rather than try to send it through the CNN. I assume you also want to use a pretrained CNN. The issue of course is how.

To do this you need to understand a number of topics: the datablock API, how fastai creates a CNN, PyTorch Modules, and methods that manipulate Modules.

The steps are…

  1. Decide how to preprocess and encode the metadata.

capture_time is already a number, so Linear can inherently understand it.

Source is categorical. With a small number of categories, the simplest approach is to use one hot encoding directly. Later, you may want to use and train an embedding, which converts a one-hot category into a smaller vector of real numbers.

is_dirty is just another one-hot item.

You’ll need to decide what to do about any missing metadata. Source may need an “unknown” category. capture_time is tricky because Linear will try to fit any special missing value as if it were continuous with the others. The standard solution is to add yet another categorical boolean that says whether the capture-time is valid.

So now you have additional input data, which you can think of as a minibatch of size bs x (1 + #sourcecategories + 1 + 1). This will eventually be concatenated onto the input of the Linear layer.

  1. Next, use the datablock API to have your input x batches include both the image batch and the above metadata batch as a tuple. I think the API can help you with the categorical encoding, decoding, and display. But usually when I try to do anything fancy with that API, I have to ask for help. Advising you is out of my league.

  2. Now you get to slice and dice the image model. Fortunately, you can append, prepend, delete, and replace modules as you wish, and PyTorch magically keeps the gradients flowing.

First, let fastai create the pretrained image model.

Prepend a custom Module at the very beginning which extracts and outputs just the image batch. The image gets passed along to the image model. The metadata batch gets extracted and dumped into a global variable.

Next, find the Linear layer that fastai has built as part of the head of resnet. Replace it with another custom Module. This Module receives a tensor of image features of size bs x #imageFeatures (512?). It contains a Linear layer like the original one, but instead taking an input of #imageFeatures+#metadataFeatures. In forward(), use torch.cat to join the metadata features from the global to the image features. The result is a tensor of size bs x (#imageFeatures + (1 + #sourcecategories + 1 + 1)). This tensor gets passed to the new Linear, which learns from both the image and metadata features. Linear’s output get passed on to the rest of the head. As far it the rest knows, nothing special has happened, and training proceeds on the image points.

Using a global variable this way is a disapproved hack. So mea culpa. But it gets the job done, building on existing good work, with minimal opportunities for coding errors. The right way would be to make a custom module that itself loads and contains the image model, keeping the metadata internal. Alternatively, you could copy parts of the fastai code to build the desired model directly. There are lots of ways to get the task done.

Good luck! I would be interested to know whether including the metadata increases accuracy.

P.S. I don’t know of any official name for this process, but I am going to start calling it model wrangling.

1 Like

(Will Sutton) #3

Big thank you for a detailed and helpful answer. Curious if you can share the exercise where you did this?

0 Likes

(Malcolm McLean) #4

It was for a Kaggle competition, Histopathologic Cancer Detection.

Here is an example…

1 Like

#5

Just a thought. Do you want the image pixels to influence the decision boundary for the image metadata and vice-versa? The network might pickup spurious correlations between the two which could influence the classifier. You may be better off training two independent models and use a voting scheme to combine the predictions.

0 Likes

(Will Sutton) #6

Interesting, I hadn’t thought about that. I’m doing ‘ImagePoints’ tasks, where I’m finding the corners of particular objects in the picture. So in the middle of my net, I assume I have lots of candidate corners, and the meta-data is supposed to help select one of candidate corners. In the case of classification, I could see (e.g. say pinterest is usually cat, reddit is usually dog), but I’m not sure how a separate model of the meta-data, without the image, can select an (x,y) point in the image or helpfully modify the (x,y) output of the imagenet.

0 Likes

#7

Right I missed the use case in the original post. I agree with you. Using separate nets won’t help in this case. So you already have a training set with annotated imagepoints (or you annotated them yourself) which came from google, pintrest, etc and you want to see if adding the metadata would improve the learner?

0 Likes