Lesson 9 Discussion & Wiki (2019)

As I mentioned above please do not submit PRs with links to unlisted videos, because the fastai_docs repo is public. Until MOOC part 2 is released, the links can only go into the special section of the forum “Part 2 (2019)” or the notes - see the other threads with the notes. You can also share the link to your repo here. Thank you for understanding.

Thank you for clarification.
I have another question. I spoke with Jeremy Howard about the annotated notebooks and Jeremy asked if these annotations could be integrated with the video viewer. I suppose the integration might be possible after the course ends and public release is in the works, but it would be nice to discuss it with someone working on the viewer ahead of time

I was reviewing the lecture and came across something I was not sure about.

In 04_callbacks.ipynb, there is a TestCallback class that looks like the following:

class TestCallback(Callback):
    _order=1
    def after_step(self):
        if self.n_iter>=10: return True

Returning True here means “stop”. But Runner’s functions look like:

    def one_batch(self, xb, yb):
        self.xb,self.yb = xb,yb
        if self('begin_batch'): return
        self.pred = self.model(self.xb)
        if self('after_pred'): return
        self.loss = self.loss_func(self.pred, self.yb)
        if self('after_loss') or not self.in_train: return
        self.loss.backward()
        if self('after_backward'): return
        self.opt.step()
        if self('after_step'): return # <<<<<<<<<<<<<< HERE
        self.opt.zero_grad()

    def all_batches(self, dl):
        self.iters = len(dl)
        for xb,yb in dl:
            if self.stop: break
            self.one_batch(xb, yb)
            self('after_batch')
        self.stop=False

So after n_iter reaches 10, TestCallback continue to return True, but the loop in all_batches keeps going and the only thing it is achieving is not setting the gradients back to zero.

I decided to add TestCallback and experiment.

stats = [TestCallback(), AvgStatsCallback([accuracy])]
run = Runner(cbs=stats)

Trial 1

The first thing I changed was to put back self.run.stop=True in TestCallback. This causes the for loop inside of all_batches to break after 10 iterations.

Then I noticed that n_iter gets set to zero at the beginning of fit function. So during the second epoch, the TestCallback sees that n_iter > 10 right away, and all_batches loop terminates (gist).

Trial 2

I thought “maybe self.run.n_iter=0 should happen at the beginning of epoch”. I tried that and now training loop exits after 10 iterations every epoch while running the complete validation loop (gist).

Trial 3

I tried resetting the n_iter and incrementing it during the validation phase, but because of this line in one_batch function:

if self('after_loss') or not self.in_train: return

It never gets to TestCallback's after_step during the validation phase(gist).

Question

  • Is it okay to reset n_iter at the beginning of epoch?
  • Do we want TestCallback to terminate the validation loop as it does for training?
2 Likes

Hey @amanmadaan, I think that’s not quite right. I would expand this code out to:

for cb in self.cbs:
   if res:
      res = cb.something(learn)
   else:
      res = False
return res     

That is, if res is already false then it won’t call cb.something(learn) at all!

3 Likes

@gietema, do you feel that this use of callbacks returning True might be a code smell?

I know that I need to think extra hard when I consider what the right callbacks return value should be!

It might make things more cumbersome, but I wonder if we would benefit from being explicit about what we mean when the callback returns True.

Maybe we could make a silly class that behaves like True/False but gives the return value a semantic meaning:

class ContinueCallback():
    def __bool__(self):
        return True
class StopCallback():
    def __bool__(self):
        return False

if ContinueCallback(): print('go')
if not StopCallback(): print('stop')

which returns

go
stop

May be to encode the possible states, which would also allow future flexibility beyond stop and go?

 from enum import Enum
 class CallbackState(Enum):
     CONT = 1
     STOP = 2
     ...
 class MyCallback():
     ...
     return CallbackState.STOP # or CallbackState.CONT

Then it’s much more intuitive. And of course could alias them to something shorter…

1 Like

i believe that very few callbacks would stop the loop. I can think of callbacks for

  • early stopping
  • pausing the training
  • List item

Couldn’t they send a message to runner to stop it instead of everybody having to call back ?

2 Likes

Yeah, I do think I’m reaching for something like an Enum.
But it would need to have some kind of boolish property, otherwise we would need to refactor out all the places that look like this:

for cb in self.cbs: res = res and cb.begin_fit(learn)

To do something special depending upon the callback.

Something like

for cb in self.cbs: special_handler (cb.begin_fit(learn))
1 Like

Couldn’t they send a message to runner to stop it instead of everybody having to call back ?

This makes sense.

… otherwise we would need to refactor out all the places that look like this

Oh, for sure. there is only one training loop, and many callbacks, so it should not be a problem to write a wrapper to handle those.

1 Like

I guess a deeper question here is what other point is there to the return value of the callback. Is there any other signal that needs to be sent?
I’m not sure how else to send a signal to the runner besides a return value? I guess we could mutate the state of the runner but that would lead to some tough to maintain code, I don’t think we want to setup a kind of message queue either…

I’m not sure how else to send a signal to the runner besides a return value?

Just have a cb-level state flag, so no need to return anything and the callback handler checks the flag.

This fits well with @Kaspar’s suggestion since most of the time there nothing to communicate back. So this should be an exception and not the rule.

Changing a flag would also lend itself well for parallel training as it’d be easier to communicate to all parallel processes from a central loop, rather than from the callback, no? Say if one worker wants to flag all other workers to stop. And here we would need a richer set of states, rather than stop/cont.

1 Like

Yeah that could work.

Looking back at the usage of the return value, it seems that returning True merely aborts the current batch, while if the Runner.stop property is set, then the runner stops completely.

So I can identify 3 states:
1, Continue
2, Abort current batch
3, Abort training

Thanks for expanding on this important fact! Note that the execution will not stop, since the loop goes like for cb in self.cbs .

Ah good point! Sorry @sergeman I should have thought of this when we discussed it.

Let’s talk to @zachcaceres about this - he wrote the video viewer :slight_smile:

Zach maybe we can discuss on Monday?

1 Like

You’re right - that should be there!

Not needed. You can use the non-integer part of n_epoch for that if needed (since it’s a floating point value). Although this callback is meant to totally stop after 10 batches, so (with your change above) it’s already doing the right thing.

Yes I think so.

Good morning!

I think what I am seeing is somewhat related to what @amanmadaan, @johnnyv, and @stas are discussing about stopping. I am finding it difficult for TestCallback to:

What happens is this (please refer to the code block below for exactly where each item happens).

  1. Training loop starts
    – After 10 batches, TestCallback sets self.run.stop=True
  2. self.run.Stop=True will make the training loop to break
  3. self.run.Stop gets set to False
  4. Validation loop starts
    – At this point, there is nothing that would stop this from running.
    def all_batches(self, dl):
        self.iters = len(dl)
        for xb,yb in dl:
            if self.stop: break         # 2. self.run.stop=True will make the training loop to break
            self.one_batch(xb, yb)
            self('after_batch')
        self.stop=False                 # 3. self.run.stop gets set to False

    def fit(self, epochs, learn):
        self.epochs,self.learn,self.loss = epochs,learn,tensor(0.)

        try:
            for cb in self.cbs: cb.set_runner(self)
            if self('begin_fit'): return
            for epoch in range(epochs):
                self.epoch = epoch
                if not self('begin_epoch'): 
                    self.all_batches(self.data.train_dl)  # 1 Training loop starts 

                with torch.no_grad(): 
                    if not self('begin_validate'): 
                        self.all_batches(self.data.valid_dl) # 4. Validation loop starts
                if self('after_epoch'): break
            
        finally:
            self('after_fit')
            self.learn = None

The only idea I had was to modify TestCallback like this and tell it to stop (again) after_epoch like below. But the validation loop will still run because the flow terminates before the backward path (in one_batch function).

class TestCallback(Callback):
    _order=1
    def after_step(self):
        if self.n_iter>=10:
            self.run.stop=True  # CHANGED
            return True
    def after_epoch(self):      # ADDED
        return True

Full gist

I am not sure how to tell it to “please don’t run anything else!” because self.run.stop gets reset as soon as it breaks one loop. Maybe we could throw like an exception, say, Terminate and add expect Terminate: block to the try block in fit function?

i.e.

    def fit(self, epochs, learn):
        try:
            ...
        except Terminate:
            print('One of the callbacks requested fit to terminate')     
        finally:
            ...

But I am not crazy about this idea either.

I’m sorry about scatter brained comment. I will think about this a little more and see if I can come up with something better.

2 Likes

No problem. I did not think of it either :slight_smile:

Yeah I’ve been wondering about that too. For now maybe something like:

class TestCallback(Callback):
    _order=1
    def check_stop(self):
        if self.n_iter<10: return False
        self.run.stop=True
        return True

    def after_step(self): return self.check_stop()
    def after_epoch(self): return self.check_stop()
    def begin_validate(self): return self.check_stop()
3 Likes

Sure, let’s chat about it Monday.