aw1231
(Alan Williams)
August 21, 2019, 1:35pm
1
I am trying to collect the metrics outputted by fit_one_cycle
in order to make a while
loop that runs until the valid_loss
and train_loss
hits a certain value. Is there a way to do this?
muellerzr
(Zachary Mueller)
August 21, 2019, 1:36pm
2
Iād recommend looking at the source code for EarlyStoppingCallback and adjust it to search for your specified value probably.
https://docs.fast.ai/callbacks.tracker.html#Tracking-Callbacks
# Contribution from @fredguth, https://github.com/fredguth/fastai_playground.
from fastai.torch_core import *
from fastai.callback import *
from fastai.basic_train import *
__all__ = ['TerminateOnNaNCallback', 'EarlyStoppingCallback', 'SaveModelCallback', 'TrackerCallback',
'ReduceLROnPlateauCallback', 'TrackEpochCallback' ]
class TerminateOnNaNCallback(Callback):
"A `Callback` that terminates training if loss is NaN."
def __init__(self):
self.stop = False
def on_batch_end(self, last_loss, epoch, num_batch, **kwargs:Any)->None:
"Test if `last_loss` is NaN and interrupts training."
if self.stop: return True #to skip validation after stopping during training
if torch.isnan(last_loss):
print (f'Epoch/Batch ({epoch}/{num_batch}): Invalid loss, terminating training.')
2 Likes
chess
April 27, 2020, 8:05pm
3
In case someone runs across this question, this has been implemented with TrackerCallback and get_monitor_value()
https://docs.fast.ai/callbacks.tracker.html#TrackerCallback
Fastai2 callbacks can be found here:
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#default_exp callback.tracker"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"from fastai2.basics import *\n",
"from fastai2.callback.progress import *\n",
This file has been truncated. show original