Navigating the fastai2 codebase

Hi,

I have some trouble navigating the fastai2 codebase. I tried to do this using a simple editor instead of nbdev so I wonder if nbdev is required. For example I wanted to know which loss function is used for simple image classification (X-entropy loss with weights for unbalanced classes I guess, but how are the weights calculated)?

Started with cnn_learner(...) from vision/learner.py. This function has an argument called loss_func which is set to None by default and passed to Learner. In the constructor I find the following code:

if loss_func is None:
    loss_func = getattr(dls.train_ds, 'loss_func', None)

So dls has to have it by default. After some more digging I find that the type of dls is DataLoaders but I could neither find it in DataLoaders or in DataSet itself. I also found the class CrossEntropyLossFlat in layers.py but could not figure out if and this is instantiated when calling cnn_learner(...).

Please enlighten me!

If you wanted to know which loss_func is used by a learner object. learn.loss_func attribute should give you the loss func used

usually when learner object is called like this
learn = cnn_learner(dls, models.resnet34, loss_func=CrossEntropyLossFlat(), config=cnn_config(ps=0.25)) over here loss_func is passed.

Whenever I have doubts regarding fastai code. I just run the jupyter notebooks. they have in the repo . It helps a lot. You can clone the entire repo and run whichever notebook you want to run . Cloning the entire repo ensures dependent files are cloned

Thank you for your answer. I know how to pass it to the learner, my problem is in fully understanding the code. But since it is all developed in notebooks I might have a look at those as opposed to the Python code files.

@jc-denton
By default, fastai2 has a mechanism to infer the loss function from the data you feed in if you don’t provide a loss function. I guess your question is (i) how the mechanism works, and (ii) at which level of API does such inference occur.

When I search the keywords “loss_func” in the source code, I found two scripts that could possibly answer the questions:

1.data/transforms.py

It is the module controlling the low-level data transform. I noticed for some transform classes, loss function is attached

Example

class Categorize(Transform):
    "Reversible transform of category string to `vocab` id"
    loss_func,order=CrossEntropyLossFlat(),1
    def __init__(self, vocab=None, add_na=False):
        self.add_na = add_na
        self.vocab = None if vocab is None else CategoryMap(vocab, add_na=add_na)

Remarks: Related Keyword Search on data/transforms.py

data/transforms.py:    loss_func,order=CrossEntropyLossFlat(),1
data/transforms.py:    loss_func,order=BCEWithLogitsLossFlat(),1
data/transforms.py:    loss_func,order=BCEWithLogitsLossFlat(),1

2.vision/core.py

It is a module defining the low-level base data classes (e.g. PILImage, PILMask). I see for some base data classes have loss function attached:

Example

OpenMask = Transform(PILMask.create)
OpenMask.loss_func = CrossEntropyLossFlat(axis=1)
PILMask.create = OpenMask

Remarks: Related Keyword Search on data/transforms.py

vision/core.py:OpenMask.loss_func = CrossEntropyLossFlat(axis=1)
vision/core.py:TensorPointCreate.loss_func = MSELossFlat()

TBH, I am still grasping the whole picture of how the loss function is finally propagated from those transforms class / base data class to dls.train_ds, but I hope the above insight could partially help.

Hi thanks for your answer. Yes my question is how the mechanism works and where it is instantiated. The loss functions are defined in 01_layers.ipynb

@log_args
@delegates(keep=True)
class CrossEntropyLossFlat(BaseLoss):
    "Same as `nn.CrossEntropyLoss`, but flattens input and target."
    y_int = True
    def __init__(self, *args, axis=-1, **kwargs): super().__init__(nn.CrossEntropyLoss, *args, axis=axis, **kwargs)
    def decodes(self, x):    return x.argmax(dim=self.axis)
    def activation(self, x): return F.softmax(x, dim=self.axis)

So it is just using nn.CrossEntropyLoss in the case of cross entropy loss. There are two things that I could not figure out:

  • How does it account for class imbalance I would expect it to set the weight parameter for nn.CrossEntropyLoss but I could not find that or is that something you have to do manually?
  • Where is it attached to dls.train_ds?

@jc-denton

  1. class imbalance can be handled in CrossEntropyFlat. As you can see in the code snippet of CrossEntropyLossFlat, it subclass from BaseLoss and it makes use of a pytorch nn.Module, aka nn.CrossEntropyLoss. You can specify the weight in nn.CrossEntropyLoss's argument weight. You can refer to pytorch doc of nn.CrossEntropyLoss for details. Going back to CrossEntropyLossFlat, if you instantiate it with weight argument, the argument can be propagated to nn.CrossEntropyLoss thanks to BaseLoss's __init__:
@log_args
class BaseLoss():
    "Same as `loss_cls`, but flattens input and target."
    activation=decodes=noops
    def __init__(self, loss_cls, *args, axis=-1, flatten=True, floatify=False, is_2d=True, **kwargs):
        store_attr(self, "axis,flatten,floatify,is_2d")
        self.func = loss_cls(*args,**kwargs) # << __init__ arguments propagated to __init__ of loss_cls
        functools.update_wrapper(self, self.func)
  1. I am still finding the answer, perhaps other fellows could help answer that.

dls.train_ds is either a Datasets or a TfmdLists if you use fastai to gather your data. By default, they each get attached the attributes of all the underlying transforms (and if several transforms have the same attribute, they get attached a list with those) on read access. For instance, Categorize has a vocab attribute, that is then readable with dataset.vocab, which is what then allows us to know the number of classes when we set the number of outputs of our model.

It’s the same for the default loss function: it comes from which type of target you have, which you can identify by the transform used. So Categorize sets a default loss function of CrossEntropyFlat, MultiCategorize sets the BCEWithLogitsLoss etc…

2 Likes

Sure I understand the confusion. I used to have a lot of it myself initially. With practice it’s going away

Thanks for these words, it is already going away!

1 Like