How to convert multi-class to multi-label classification?

Hello all,

Based on the information from lesson 9, it would be beneficial to train models with loss function that uses binary classification. This can be a way to detect ambiguous predictions (eg 1 that looks like 7).

In a dataset with one-hot encoded labels,

  1. Is it possible to train the model using BCEWithLigitsFlat instead of CrossEntropyFlat?
  2. How would you convert a dataset from multi-class to multi-label?

For example, using MNIST with data.c = 10,

from import *
mnist = untar_data(URLs.MNIST)
data = (ImageList.from_folder(mnist/"training")
        .transform(tfms, size=32)
learn = cnn_learner(data, models.resnet18, metrics=accuracy)

Returns with ValueError: Target size (torch.Size([64])) must be the same as input size (torch.Size([640])).

It seems the issue resides in the last FC layer’s out_features.

# learn.layer_groups[-1]
  (0): AdaptiveAvgPool2d(output_size=1)
  (1): AdaptiveMaxPool2d(output_size=1)
  (2): Flatten()
  (3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (4): Dropout(p=0.25, inplace=False)
  (5): Linear(in_features=1024, out_features=512, bias=True)
  (6): ReLU(inplace=True)
  (7): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (8): Dropout(p=0.5, inplace=False)
  (9): Linear(in_features=512, out_features=10, bias=True)


The objective is to obtain the benefits of sigmoid as fully-connected layer over softmax, to show possibility of input belonging to multiple classes, or to detect if the is confused.

I was able to force the Databunch label as MultiCategoryList by label_cls=MultiCategoryList, and the resultant training reflects as multi-label classification.

However, the dataset is designed to be one-hot encoded for multi-class classification. Would it be more pythonic to train with multi-class labeling using softmax, and use sigmoid during inference?

Also, how does one change the fully-connected layer in FastAI? I understand Pytorch combines FC and loss function into one object.