[CycleGAN v3] Confused on CycleGANTrainer.on_batch_end() Arguments

I just finished lesson 10 and I am trying to review the callbacks logic in CycleGANTrainer (it’s from CycleGAN v3 notebook). I am a bit confused on how I should define the arguments for each methods (e.g. on_batch_begin, on_backward_begin … etc) in the callback.

For example, I notice that the callback’s method CycleGANTrainer.en_batch_end() is defined to have 2 arguments (i.e. last_input, last_output).

 def on_batch_end(self, last_input, last_output, **kwargs):
        self.G_A.zero_grad(); self.G_B.zero_grad()
        fake_A, fake_B = last_output[0].detach(), last_output[1].detach()
        real_A, real_B = last_input


However, when I reviewed the place where this method is called (it is called by a cb_handler in Learner if I didn’t get it wrong), it actually only requires loss to be the arguments:

(extracted from here, look at the last line of the snippet)

def fit(epochs:int, learn:BasicLearner, callbacks:Optional[CallbackList]=None, metrics:OptMetrics=None)->None:
    "Fit the `model` on `data` and learn using `loss_func` and `opt`."
    assert len(learn.data.train_dl) != 0, f"""Your training dataloader is empty, can't train a model.
        Use a smaller batch size (batch size={learn.data.train_dl.batch_size} for {len(learn.data.train_dl.dataset)} elements)."""
    cb_handler = CallbackHandler(callbacks, metrics)
    pbar = master_bar(range(epochs))
    cb_handler.on_train_begin(epochs, pbar=pbar, metrics=metrics)

        for epoch in pbar:
            for xb,yb in progress_bar(learn.data.train_dl, parent=pbar):
                xb, yb = cb_handler.on_batch_begin(xb, yb)
                loss = loss_batch(learn.model, xb, yb, learn.loss_func, learn.opt, cb_handler)
                if cb_handler.on_batch_end(loss): break

Why it is so? Perhaps there are somethings that I have missed here, could anyone give me a guide?

After a few inspection with Python debugger, it seems I (partially) understand why it is so. The reason I got stuck is because I have ignored the role of CallbackHandler. It actually plays an important role here to propagate useful objects between CycleGANTrainer and Learner. As an analogy, it serves as a bridge to exchange useful objects between learner and different callbacks.

For those who have the same struggle, let me try to explain here. (It’s after all limited by my current own understanding, any correction is welcome.)

Question at Stake

How Learner propagate the required arguments (i.e. last_input, last_output) to CycleGANTrainer.on_batch_end()?
And why Learner takes loss as only argument of on_batch_end?

Step 1: Learner delegate the work to CallbackHandler

method Learner.fit() will eventually call a standalone function fit(). Inside that function, it arranges CallbackHandler to call its on_batch_end method, as follows. Note that only loss is used as an argument here.

if cb_handler.on_batch_end(loss): break

(extracted from here)

Step 2: CallbackHandler trigger on_batch_end and update its state

CallbackHandler.on_batch_end intake loss as argument and update its state (aka self.state_dict) with the loss:

def on_batch_end(self, loss:Tensor)->Any:
        "Handle end of processing one batch with `loss`."
        self.state_dict['last_loss'] = loss
        self('batch_end', call_mets = not self.state_dict['train'])

Step 3: Trigger _call_and_update in CallbackHandler

The last line of the above snippet subsequently calls _call_and_update_:

for cb in self.callbacks: self._call_and_update(cb, cb_name, **kwargs)

_call_and_update_ is used to trigger the target callback (aka cb) and call its target method (aka cb_name). In this case, cb is CycleGANTrainer. cb_name is batch_end.

Step 4: CallbackHandler propagate its state to on_batch_end

The following line is called in _call_and_update_:

new = ifnone(getattr(cb, f'on_{cb_name}')(**self.state_dict, **kwargs), dict())

note that self.state_dict is passed as an argument to CycleGANTrainer.on_batch_end (i.e. getattr(cb, f'on_{cb_name}')).
self.state_dict is a dict bookmarking all the useful objects during training, among them are last_input and last_output! Below are all keys of self.state_dict:

'epoch', 'iteration', 'num_batch', 
'skip_validate', 'n_epochs', 'pbar', 
'metrics', 'stop_training', 
'last_input', 'last_target', 
'train', 'stop_epoch', 
'skip_step', 'skip_zero', 'skip_bwd', 
'last_output', 'last_loss', 'smooth_loss'

I guess now you realize why CycleGANTrainer.on_batch_end could take last_input, last_output as argument — It’s because they are in the keys of self.state_dict!

CycleGANTrainer.on_batch_end smartly parses from self.state_dict the values of last_input and last_output.