I’ve been working through the metrics documentation and this post started out as me asking for help with errors and questions about the relationship between metrics classes, metrics functions, and callbacks. In writing the post and doing research for it, I ended up answering my questions for myself after a bit of trial and error. I thought rather than delete the post, I’d leave those answers here in case others have similar questions. (Or if I’ve gotten anything wrong, I’d welcome corrections instead.)
The docs read:
This why in fastai, every metric is implemented as a callback.
If you pass a regular function, the library trnasforms it to a proper
AverageCallback. The callback metrics are only called
during the validation pahse, and only for the following events:
As an example, is here the exact implementation of the
callback that transforms a function like
accuracyinto a metric callback.
AverageCallback is not an implemented
Callback, as best I can tell. (Running a search for this string in the fast.ai Github repo turns up no results.) I think this is supposed to read
AverageMetric instead, and this class is implemented here.
(Also minor Discovery 2: the
error_rate metric used for
ConvLearner in Lesson 1 is not listed in the docs, but is implemented in the source code, here.
I didn’t quite grok this immediately, so I thought maybe there’d be an example I could work off of in the source code. Here I ran into another point of confusion: there is both an
fbeta() function, and an
FBeta class defined in
FBeta class does inherit from
Callback, but I thought it was somehow supposed to relate to
fbeta(), and I couldn’t quite piece that together. After a bit of fiddling around, I am pretty sure that
fbeta() is just defined such that it computes the F_beta average for each epoch - since it is a “regular function”, it is implemented as an instance of the
AverageMetric class, here. By comparison,
FBeta (the class) is implemented so that it can be tracked per-batch (I think?).
(Minor discovery 3: The
metrics argument in various
Learner initializations only takes functions or class instances, not class definitions. This seems kind of obvious in hindsight, but at first I was trying to pass in
metrics=[error_rate, Precision] and getting a
TypeError. This is resolved by making an instance first, e.g.
precision = Precision(), and then passing
precision instead of
Precision into the
It was really helpful for me to track through the source code to see how
Learner implements this “let’s see if your metric is a regular function or a
Callback instance.” There’s a line in the
validate() method of the
Learner class that starts this process:
cb_handler = CallbackHandler(self.callbacks + ifnone(callbacks, ), metrics)
Which, in turn, leads to the basic decision point in the
self.metrics = [(met if isinstance(met, Callback) else AverageMetric(met)) for met in self.metrics]
In other words, if the metric you passed in inherits from
CallbackHandler will just leave it be. If not, it’s treated as an
So, in summary, if you make your own metrics function, make sure it makes sense to have it treated as an
AverageMetric. If not, define your own class which inherits from
Callback, going by the rules outlined in the docs, and then pass an instance of that class into the
metrics argument for your
Learner of choice.
(Note: Nothing really new here above and beyond what’s already in the documentation - I just needed to go through all of the above myself first in order to be able to understand what they were saying.)