I just finished lesson 10 and I am trying to review the callbacks logic in CycleGANTrainer (it’s from CycleGAN v3 notebook). I am a bit confused on how I should define the arguments for each methods (e.g. on_batch_begin
, on_backward_begin
… etc) in the callback.
For example, I notice that the callback’s method CycleGANTrainer.en_batch_end()
is defined to have 2 arguments (i.e. last_input
, last_output
).
def on_batch_end(self, last_input, last_output, **kwargs):
self.G_A.zero_grad(); self.G_B.zero_grad()
fake_A, fake_B = last_output[0].detach(), last_output[1].detach()
real_A, real_B = last_input
self._set_trainable(D_A=True)
...
However, when I reviewed the place where this method is called (it is called by a cb_handler
in Learner
if I didn’t get it wrong), it actually only requires loss
to be the arguments:
(extracted from here, look at the last line of the snippet)
def fit(epochs:int, learn:BasicLearner, callbacks:Optional[CallbackList]=None, metrics:OptMetrics=None)->None:
"Fit the `model` on `data` and learn using `loss_func` and `opt`."
assert len(learn.data.train_dl) != 0, f"""Your training dataloader is empty, can't train a model.
Use a smaller batch size (batch size={learn.data.train_dl.batch_size} for {len(learn.data.train_dl.dataset)} elements)."""
cb_handler = CallbackHandler(callbacks, metrics)
pbar = master_bar(range(epochs))
cb_handler.on_train_begin(epochs, pbar=pbar, metrics=metrics)
exception=False
try:
for epoch in pbar:
learn.model.train()
cb_handler.set_dl(learn.data.train_dl)
cb_handler.on_epoch_begin()
for xb,yb in progress_bar(learn.data.train_dl, parent=pbar):
xb, yb = cb_handler.on_batch_begin(xb, yb)
loss = loss_batch(learn.model, xb, yb, learn.loss_func, learn.opt, cb_handler)
if cb_handler.on_batch_end(loss): break
Why it is so? Perhaps there are somethings that I have missed here, could anyone give me a guide?