Fastai V1 specify loss function

I am currently using fastai v1 for an image segmentation (binary classification for now, but will eventually want to change it to multi-class classification) problem I’m doing at work. To accomplish this, I’m trying to create my own custom pipeline. Because of this, I’ve had to create my own “SegDataset” as I wanted to alter some of “SegmentationDataset” functionality. Here’s some code below:

X_train, X_val, y_train, y_val = train_test_split(x_names, y_names, test_size = 0.2, random_state=21)

data_trn = SegDataset(x=X_train, y=y_train, path=ENGINES, size=sz)
data_val = SegDataset(x=X_val, y=y_val, path=ENGINES, size=sz)

data_loader_trn = DataLoader(data_trn, batch_size = bs, num_workers = nw)
data_loader_val = DataLoader(data_val, batch_size = bs, num_workers = nw)

x, y = next(iter(data_loader_trn))

databunch = ImageDataBunch(train_dl=data_loader_trn, valid_dl=data_loader_val, path=ENGINES)

databunch.show_batch(rows=3, figsize=(6,6))

learn = ConvLearner(databunch, tvm.resnet34, metrics=[accuracy])

Through examining x and y, I receive the outputs that I expect. The code all works perfectly fine until “” where I receive the following error:

Does anyone know where I’m going wrong? Many thanks in advance! :slight_smile:

RuntimeError Traceback (most recent call last)
14 # learn.opt_fn = criterion(output, target)
—> 16

~/.conda/envs/fastaiv1/lib/python3.6/site-packages/fastai/ in fit(self, epochs, lr, wd, callbacks)
136 callbacks = [cb(self) for cb in self.callback_fns] + listify(callbacks)
137 fit(epochs, self.model, self.loss_fn, opt=self.opt,, metrics=self.metrics,
–> 138 callbacks=self.callbacks+callbacks)
140 def create_opt(self, lr:Floats, wd:Floats=0.)->None:

~/.conda/envs/fastaiv1/lib/python3.6/site-packages/fastai/ in fit(epochs, model, loss_fn, opt, data, callbacks, metrics)
89 except Exception as e:
90 exception = e
—> 91 raise e
92 finally: cb_handler.on_train_end(exception)

~/.conda/envs/fastaiv1/lib/python3.6/site-packages/fastai/ in fit(epochs, model, loss_fn, opt, data, callbacks, metrics)
79 for xb,yb in progress_bar(data.train_dl, parent=pbar):
80 xb, yb = cb_handler.on_batch_begin(xb, yb)
—> 81 loss = loss_batch(model, xb, yb, loss_fn, opt, cb_handler)[0]
82 if cb_handler.on_batch_end(loss): break

~/.conda/envs/fastaiv1/lib/python3.6/site-packages/fastai/ in loss_batch(model, xb, yb, loss_fn, opt, cb_handler, metrics)
22 if not loss_fn: return to_detach(out), yb[0].detach()
—> 23 loss = loss_fn(out, *yb)
24 mets = [f(out,*yb).detach().cpu() for f in metrics] if metrics is not None else []

~/.conda/envs/fastaiv1/lib/python3.6/site-packages/torch/nn/ in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
1663 if size_average is not None or reduce is not None:
1664 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 1665 return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)

~/.conda/envs/fastaiv1/lib/python3.6/site-packages/torch/nn/ in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
1520 .format(input.size(0), target.size(0)))
1521 if dim == 2:
-> 1522 return torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
1523 elif dim == 4:
1524 return torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)

RuntimeError: multi-target not supported at /opt/conda/conda-bld/pytorch-nightly_1539602533843/work/aten/src/THCUNN/generic/

You should use a different loss function (which you should specify in learn.loss_fn). In a segmentation task, CrossEntropyFlat() that is defined in should work.

And it sounds like soon, we’ll be able to set that in the data, if I understood you right: Is_multi/is_reg equivalent in fastaiv1

More like the library will pick it for you. CrossEntropyFlat() is now the default for a segmentation model.

Do we have it mentioned in docs ? … the default loss functions for specific tasks … couldn’t find it yet.

Not yet no. Jeremy will talk a bit more about it in the next lessons, then we’ll figure out how to put in the docs.
In case of doubt, just type learn.loss_func to see what is the default.