I’m writing a little toy example to better understand custom callbacks and ran into a few questions. I wrote the following that saves model weights after each epoch.
@dataclass
class SaveModel(LearnerCallback):
"""Save Latest Model"""
def __init__(self, learn:Learner):
super().__init__(learn)
def on_epoch_end(self, epoch:int, **kwargs):
dt = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
self.learn.save('my_model_' + dt)
return False
learn = RNNLearner.language_model(data_lm, pretrained_model=URLs.WT103, callback_fns=SaveModel)
learn.fit(2, 1e-2)
I’m getting the expected output – which is the model weights to be saved to disk after each epoch. But my questions are:
- Is it possible to pass in function arguments to
SaveModel
when calling the callback? So as a trivial example, what if I wanted to pass a model name prefix into on_epoch_end
while creating learn
. Where would I do this?
- What is the difference between passing
callback_fns
and callbacks
into the learner object?
2 Likes
If you need to pass any parameter to your callback, you should do so in the init:
def __init__(self, learn:Learner, arg):
super().__init__(learn)
self.arg = arg
and then you can access this argument through self.arg
in every function.
A callback function is a Callback that takes a Learner
, it will be created when you call fit
and at the same time, become an attribute of the Learner
like learn.recorder
which can be useful if you want to access it later.
1 Like
Sorry, I think I’m still not quite understanding. So where would I pass in arg
? I’ve tried like this:
learn = RNNLearner.language_model(data_lm, pretrained_model=URLs.WT103, callback_fns=SaveModel, arg='arg')
which gives me TypeError: __init__() got an unexpected keyword argument 'model_name'
And like this:
learn.fit(2, 1e-2, arg='arg')
and I get TypeError: fit() got an unexpected keyword argument 'model_name'
Not sure if this is the best way to do this or not, but actually works if I do something like this instead:
learn = RNNLearner.language_model(data_lm, pretrained_model=URLs.WT103)
my_cb = SaveModel(learn, arg='some arg')
learn.fit(2, 1e-2, callbacks=my_cb)
1 Like
Yes, that is the right way to do it.
The docs mention that:
Every callback that is passed to Learner
with the callback_fns
parameter will be automatically stored as an attribute. The attribute name is snake-cased, so for instance ActivationStats
will appear as learn.activation_stats
(assuming your object is named learn
).
I created a simple custom callback like that:
@dataclass
class SimpleRecorder(LearnerCallback):
learn:Learner
def on_train_begin(self, **kwargs:Any)->None:
self.losses = []
def on_step_end(self, iteration: int, last_loss, **kwargs ):
self.losses.append(last_loss)
def on_epoch_end(self, last_loss, smooth_loss, **kwarg):
print('Epoch ended', last_loss, smooth_loss)
def plot(self, **kwargs):
losses = self.losses
iterations = range(len(losses))
fig, ax = plt.subplots(1,1)
ax.set_ylabel('Loss')
ax.set_xlabel('Iteration')
ax.plot(iterations, losses)
I add it via the callback_fns
parameter:
learn = cnn_learner(data,
models.resnet34,
metrics=error_rate,
callback_fns=[
ShowGraph,
SimpleRecorder
])
But can’t access it via learn.simple_recorder
. However, I can do learn.show_graph
. What am I missing?
LearnerCallback
is not a dataclass anymore, and by using this syntax you are bypassing LearnerCallback
init (where it sets the attribute). Just do
class SimpleRecorder(LearnerCallback):
def on_train_begin(self, **kwargs:Any)->None:
...
Ok. I removed the @dataclass
and the learner field but still receive an AttributeError.
I also tried subclassing the Recorder
:
class NewRecorder(Recorder):
def on_epoch_end(self, **kwargs):
print('Epoch ended')
But also can’t access the learn.new_recorder
attribute.