Found the answer. For those who have the same question, here is my notes:
1a. All the metrics calculation happen in Recorder
callback. By default, metrics function won’t be evaluated on training set and only evaluated on validation set, unless you set argument train_metrics = True
when instantiate the Recorder
object:
class Recorder(Callback):
"Callback that registers statistics (lr, loss and metrics) during training"
remove_on_fetch,run_after = True,TrainEvalCallback
def __init__(self, add_time=True, train_metrics=False, valid_metrics=True, beta=0.98):
store_attr(self, 'add_time,train_metrics,valid_metrics')
self.loss,self.smooth_loss = AvgLoss(),AvgSmoothLoss(beta=beta)
2a. Metrics function is stored and called in Recorder
callback. Specifically, related code snippets are in after_batch
method:
#(line 423) learner.Recorder.after_batch:
def after_batch(self):
"Update all metrics and records lr and smooth loss in training"
if len(self.yb) == 0: return
mets = self._train_mets if self.training else self._valid_mets
for met in mets: met.accumulate(self.learn) # <-- metrics function is called inside
if not self.training: return
self.lrs.append(self.opt.hypers[-1]['lr'])
self.losses.append(self.smooth_loss.value)
self.learn.smooth_loss = self.smooth_loss.value
2b. In the above snippet, metrics function is called inside met.accumulate
method, where met
is an AvgMetric
object. learn.pred
and learn.yb
serve as the inputs of the metrics function:
#(line 348) learner.AvgMetric.accumulate
class AvgMetric(Metric):
"Average the values of `func` taking into account potential different batch sizes"
def __init__(self, func): self.func = func
def reset(self): self.total,self.count = 0.,0
def accumulate(self, learn):
bs = find_bs(learn.yb)
self.total += to_detach(self.func(learn.pred, *learn.yb))*bs # <-- self.func is the metrics function
self.count += bs
2c. Therefore, you can change the flow of metrics calculation by monkey patching AvgMetric.accumulate
method.
(all happen in learner.py module)