Custom callback_fn is crashing lr_find

I’m training a model using SageMaker. SageMaker requires defining metric definitions as regex patterns so that training outputs like loss/accuracy/etc can be tracked and graphed. The way that these stats are printed by default makes it difficult to do this. I discovered a custom callback_fn written by someone on the AWS SageMaker team which prints the metrics after each epoch in a way that is easily regex’able. The problem with this callback_fn is that it seems to break learn.lr_find().

Any help understanding why these two things aren’t compatible as-is, and how to fix it, would be greatly appreciated!

Here’s the callback_fn:

@dataclass
class MetricsLogger(LearnerCallback):
    """print metrics in a less pretty, but easier way for them to be picked up with regex"""
    def __init__(self, learn:Learner):
        super().__init__(learn)    
    # call when each epoch finishes to print the metrics
    def on_epoch_end(self, epoch: int, smooth_loss: Tensor, last_metrics: MetricsList, **kwargs: Any) -> bool:
        last_metrics = ifnone(last_metrics, [])
        stats = [(name, str(stat)) if isinstance(stat, int) else (name, f'{stat:.6f}')
                 for name, stat in zip(self.learn.recorder.names[1:], [smooth_loss] + last_metrics)]
        for m in stats:
            print(f'#quality_metric: host={HOSTNAME}, epoch={epoch}, {m[0]}={m[1]}')

And here’s the error:

algo-1-4b5cs_1 | LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
algo-1-4b5cs_1 | 2019-06-06 23:40:14,424 sagemaker-containers ERROR ExecuteUserScriptError:
algo-1-4b5cs_1 | Command “/usr/bin/python -m train --batch_size 64 --classify_thresh 0.2 --csv_labels_filename labels_1200vio_600novio.csv --diff_lr True --diff_scale 0 --epochs 3 --frozen_cycles 3 --frozen_lr 0.001 --image_size 224 --instance_type testing --lr 1e-06 --lr_find True --min_delta 0.0001 --mixed_precision True --mixup False --model_name resnet18 --patience 5 --pretrained True --save_results False --save_type jit --val_percent 0.2 --verbose True --verify_images True”
algo-1-4b5cs_1 | Traceback (most recent call last):
algo-1-4b5cs_1 | File “/usr/lib/python3.6/runpy.py”, line 193, in _run_module_as_main
algo-1-4b5cs_1 | “main”, mod_spec)
algo-1-4b5cs_1 | File “/usr/lib/python3.6/runpy.py”, line 85, in _run_code
algo-1-4b5cs_1 | exec(code, run_globals)
algo-1-4b5cs_1 | File “/opt/ml/code/train.py”, line 449, in
algo-1-4b5cs_1 | _train(parser.parse_args())
algo-1-4b5cs_1 | File “/opt/ml/code/train.py”, line 190, in _train
algo-1-4b5cs_1 | learn.lr_find()
algo-1-4b5cs_1 | File “/usr/local/lib/python3.6/dist-packages/fastai/train.py”, line 32, in lr_find
algo-1-4b5cs_1 | learn.fit(epochs, start_lr, callbacks=[cb], wd=wd)
algo-1-4b5cs_1 | File “/usr/local/lib/python3.6/dist-packages/fastai/basic_train.py”, line 200, in fit
algo-1-4b5cs_1 | fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks)
algo-1-4b5cs_1 | File “/usr/local/lib/python3.6/dist-packages/fastai/basic_train.py”, line 108, in fit
algo-1-4b5cs_1 | if cb_handler.on_epoch_end(val_loss): break
algo-1-4b5cs_1 | File “/usr/local/lib/python3.6/dist-packages/fastai/callback.py”, line 317, in on_epoch_end
algo-1-4b5cs_1 | self(‘epoch_end’, call_mets = val_loss is not None)
algo-1-4b5cs_1 | File “/usr/local/lib/python3.6/dist-packages/fastai/callback.py”, line 251, in call
algo-1-4b5cs_1 | for cb in self.callbacks: self._call_and_update(cb, cb_name, **kwargs)
algo-1-4b5cs_1 | File “/usr/local/lib/python3.6/dist-packages/fastai/callback.py”, line 241, in call_and_update
algo-1-4b5cs_1 | new = ifnone(getattr(cb, f’on
{cb_name}’)(**self.state_dict, **kwargs), dict())
algo-1-4b5cs_1 | File “/opt/ml/code/train.py”, line 76, in on_epoch_end
algo-1-4b5cs_1 | for name, stat in zip(self.learn.recorder.names[1:], [smooth_loss] + last_metrics)]
algo-1-4b5cs_1 | File “/opt/ml/code/train.py”, line 76, in
algo-1-4b5cs_1 | for name, stat in zip(self.learn.recorder.names[1:], [smooth_loss] + last_metrics)]
algo-1-4b5cs_1 | TypeError: unsupported format string passed to NoneType.format

Edit: Solved it. Instead of initializing learn with MetricsLogger as a callback_fn I should have added MetricsLogger(learn) to the callbacks list in learn.fit.

1 Like