I’ve been working through the metrics documentation and this post started out as me asking for help with errors and questions about the relationship between metrics classes, metrics functions, and callbacks. In writing the post and doing research for it, I ended up answering my questions for myself after a bit of trial and error. I thought rather than delete the post, I’d leave those answers here in case others have similar questions. (Or if I’ve gotten anything wrong, I’d welcome corrections instead.)
The docs read:
This why in fastai, every metric is implemented as a callback.
If you pass a regular function, the library trnasforms it to a proper
callback calledAverageCallback
. The callback metrics are only called
during the validation pahse, and only for the following events:
…
As an example, is here the exact implementation of theAverageMetric
callback that transforms a function likeaccuracy
into a metric callback.
…
Discovery 1: AverageCallback
is not an implemented Callback
, as best I can tell. (Running a search for this string in the fast.ai Github repo turns up no results.) I think this is supposed to read AverageMetric
instead, and this class is implemented here.
(Also minor Discovery 2: the error_rate
metric used for ConvLearner
in Lesson 1 is not listed in the docs, but is implemented in the source code, here.
I didn’t quite grok this immediately, so I thought maybe there’d be an example I could work off of in the source code. Here I ran into another point of confusion: there is both an fbeta()
function, and an FBeta
class defined in metrics.py
. The FBeta
class does inherit from Callback
, but I thought it was somehow supposed to relate to fbeta()
, and I couldn’t quite piece that together. After a bit of fiddling around, I am pretty sure that fbeta()
is just defined such that it computes the F_beta average for each epoch - since it is a “regular function”, it is implemented as an instance of the AverageMetric
class, here. By comparison, FBeta
(the class) is implemented so that it can be tracked per-batch (I think?).
(Minor discovery 3: The metrics
argument in various Learner
initializations only takes functions or class instances, not class definitions. This seems kind of obvious in hindsight, but at first I was trying to pass in metrics=[error_rate, Precision]
and getting a TypeError
. This is resolved by making an instance first, e.g. precision = Precision()
, and then passing precision
instead of Precision
into the metrics
argument.)
It was really helpful for me to track through the source code to see how Learner
implements this “let’s see if your metric is a regular function or a Callback
instance.” There’s a line in the validate()
method of the Learner
class that starts this process:
cb_handler = CallbackHandler(self.callbacks + ifnone(callbacks, []), metrics)
Which, in turn, leads to the basic decision point in the CallbackHandler
__post_init__
:
self.metrics = [(met if isinstance(met, Callback) else AverageMetric(met)) for met in self.metrics]
In other words, if the metric you passed in inherits from Callback
, CallbackHandler
will just leave it be. If not, it’s treated as an AverageMetric
.
So, in summary, if you make your own metrics function, make sure it makes sense to have it treated as an AverageMetric
. If not, define your own class which inherits from Callback
, going by the rules outlined in the docs, and then pass an instance of that class into the metrics
argument for your Learner
of choice.
(Note: Nothing really new here above and beyond what’s already in the documentation - I just needed to go through all of the above myself first in order to be able to understand what they were saying.)