NameError: name 'flatten' is not defined

I was going through the ULMFiT section of the text tutorial, and converting a learner to use 16-bit floats is causing some trouble.

I suppose I could train models with 32-bit floats, but I would also like to be able to take advantage of 16-bit training and I imagine that the error is an issue of typing flatten vs. Flatten as Flatten is a function implemented in fastai.layers

Here is the code that is throwing the error:

from fastai.text.all import *
path = untar_data(URLs.IMDB)
dls_lm = TextDataLoaders.from_folder(path, is_lm=True, valid_pct=0.1)
learn = language_model_learner(
    dls_lm, AWD_LSTM, metrics=[accuracy, Perplexity()],
    path=path, wd=0.1).to_fp16()  # I don't get the same error when I don't call `.to_fp16`.
learn.fit_one_cycle(1, 1e-2)

Here is the traceback:

---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
/tmp/ipykernel_4051/3151969659.py in <module>
      5     dls_lm, AWD_LSTM, metrics=[accuracy, Perplexity()],
      6     path=path, wd=0.1).to_fp16()
----> 7 learn.fit_one_cycle(1, 1e-2)

~/anaconda3/lib/python3.9/site-packages/fastai/callback/schedule.py in fit_one_cycle(self, n_epoch, lr_max, div, div_final, pct_start, wd, moms, cbs, reset_opt)
    114     scheds = {'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final),
    115               'mom': combined_cos(pct_start, *(self.moms if moms is None else moms))}
--> 116     self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd)
    117 
    118 # Cell

~/anaconda3/lib/python3.9/site-packages/fastai/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
    220             self.opt.set_hypers(lr=self.lr if lr is None else lr)
    221             self.n_epoch = n_epoch
--> 222             self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
    223 
    224     def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None

~/anaconda3/lib/python3.9/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    162 
    163     def _with_events(self, f, event_type, ex, final=noop):
--> 164         try: self(f'before_{event_type}');  f()
    165         except ex: self(f'after_cancel_{event_type}')
    166         self(f'after_{event_type}');  final()

~/anaconda3/lib/python3.9/site-packages/fastai/learner.py in _do_fit(self)
    211         for epoch in range(self.n_epoch):
    212             self.epoch=epoch
--> 213             self._with_events(self._do_epoch, 'epoch', CancelEpochException)
    214 
    215     def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False):

~/anaconda3/lib/python3.9/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    162 
    163     def _with_events(self, f, event_type, ex, final=noop):
--> 164         try: self(f'before_{event_type}');  f()
    165         except ex: self(f'after_cancel_{event_type}')
    166         self(f'after_{event_type}');  final()

~/anaconda3/lib/python3.9/site-packages/fastai/learner.py in _do_epoch(self)
    205 
    206     def _do_epoch(self):
--> 207         self._do_epoch_train()
    208         self._do_epoch_validate()
    209 

~/anaconda3/lib/python3.9/site-packages/fastai/learner.py in _do_epoch_train(self)
    197     def _do_epoch_train(self):
    198         self.dl = self.dls.train
--> 199         self._with_events(self.all_batches, 'train', CancelTrainException)
    200 
    201     def _do_epoch_validate(self, ds_idx=1, dl=None):

~/anaconda3/lib/python3.9/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    162 
    163     def _with_events(self, f, event_type, ex, final=noop):
--> 164         try: self(f'before_{event_type}');  f()
    165         except ex: self(f'after_cancel_{event_type}')
    166         self(f'after_{event_type}');  final()

~/anaconda3/lib/python3.9/site-packages/fastai/learner.py in all_batches(self)
    168     def all_batches(self):
    169         self.n_iter = len(self.dl)
--> 170         for o in enumerate(self.dl): self.one_batch(*o)
    171 
    172     def _do_one_batch(self):

~/anaconda3/lib/python3.9/site-packages/fastai/learner.py in one_batch(self, i, b)
    193         b = self._set_device(b)
    194         self._split(b)
--> 195         self._with_events(self._do_one_batch, 'batch', CancelBatchException)
    196 
    197     def _do_epoch_train(self):

~/anaconda3/lib/python3.9/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    162 
    163     def _with_events(self, f, event_type, ex, final=noop):
--> 164         try: self(f'before_{event_type}');  f()
    165         except ex: self(f'after_cancel_{event_type}')
    166         self(f'after_{event_type}');  final()

~/anaconda3/lib/python3.9/site-packages/fastai/learner.py in _do_one_batch(self)
    172     def _do_one_batch(self):
    173         self.pred = self.model(*self.xb)
--> 174         self('after_pred')
    175         if len(self.yb):
    176             self.loss_grad = self.loss_func(self.pred, *self.yb)

~/anaconda3/lib/python3.9/site-packages/fastai/learner.py in __call__(self, event_name)
    140 
    141     def ordered_cbs(self, event): return [cb for cb in self.cbs.sorted('order') if hasattr(cb, event)]
--> 142     def __call__(self, event_name): L(event_name).map(self._call_one)
    143 
    144     def _call_one(self, event_name):

~/anaconda3/lib/python3.9/site-packages/fastcore/foundation.py in map(self, f, gen, *args, **kwargs)
    152     def range(cls, a, b=None, step=None): return cls(range_of(a, b=b, step=step))
    153 
--> 154     def map(self, f, *args, gen=False, **kwargs): return self._new(map_ex(self, f, *args, gen=gen, **kwargs))
    155     def argwhere(self, f, negate=False, **kwargs): return self._new(argwhere(self, f, negate, **kwargs))
    156     def argfirst(self, f, negate=False): return first(i for i,o in self.enumerate() if f(o))

~/anaconda3/lib/python3.9/site-packages/fastcore/basics.py in map_ex(iterable, f, gen, *args, **kwargs)
    664     res = map(g, iterable)
    665     if gen: return res
--> 666     return list(res)
    667 
    668 # Cell

~/anaconda3/lib/python3.9/site-packages/fastcore/basics.py in __call__(self, *args, **kwargs)
    649             if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    650         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 651         return self.func(*fargs, **kwargs)
    652 
    653 # Cell

~/anaconda3/lib/python3.9/site-packages/fastai/learner.py in _call_one(self, event_name)
    144     def _call_one(self, event_name):
    145         if not hasattr(event, event_name): raise Exception(f'missing {event_name}')
--> 146         for cb in self.cbs.sorted('order'): cb(event_name)
    147 
    148     def _bn_bias_state(self, with_bias): return norm_bias_params(self.model, with_bias).map(self.opt.state)

~/anaconda3/lib/python3.9/site-packages/fastai/callback/core.py in __call__(self, event_name)
     55         res = None
     56         if self.run and _run:
---> 57             try: res = getattr(self, event_name, noop)()
     58             except (CancelBatchException, CancelEpochException, CancelFitException, CancelStepException, CancelTrainException, CancelValidException): raise
     59             except Exception as e:

~/anaconda3/lib/python3.9/site-packages/fastai/callback/fp16.py in after_pred(self)
     21     def before_batch(self): self.autocast.__enter__()
     22     def after_pred(self):
---> 23         if next(flatten(self.pred)).dtype==torch.float16: self.learn.pred = to_float(self.pred)
     24     def after_loss(self): self.autocast.__exit__(None, None, None)
     25     def before_backward(self): self.learn.loss_grad = self.scaler.scale(self.loss_grad)

NameError: Exception occured in `MixedPrecision` when calling event `after_pred`:
	name 'flatten' is not defined