Why do I need to use wandb during inference? How do I get rid of this?

I have trained a model and used the wandb callback with fastai. In this way-

import wandb
from fastai.callback.wandb import *
wandb.init(project="my-project")

And then created the learner like this-

learn = cnn_learner(dls=dls,
                    arch=ARCH,
                    loss_func=CrossEntropyLossFlat(),
                    metrics=accuracy,
                    cbs=WandbCallback())
learn.model.cuda()

And then fine-tuned the model like this-

learn.fine_tune(5, 3e-3, cbs=WandbCallback())

The model finished training as expected and I got what I wanted from wandb.


The problem is not during the training part but during the inference part.

I am loading the learner like this-

learn_ = load_learner(MODEL_PATH)

I get an error that wandb is not installed.

I had no idea that I would even need wandb to do inference.

I installed and imported it anyway, but then during running the inference like this-

N_IMAGES = 5 # the averaging happens with N number of images for one image

dl = learn_.dls.test_dl([file_path])
pred, _ = learn_.tta(dl=dl,
                     n=N_IMAGES)
cat = learn_.dls.vocab[torch.argmax(pred).item()].lstrip()
cat

I get this error-

Could not gather input dimensions

/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /pytorch/c10/core/TensorImpl.h:1156.)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-11-210159cdeb3c> in <module>()
      5 dl = learn_.dls.test_dl([file_path])
      6 pred, _ = learn_.tta(dl=dl,
----> 7                      n=N_IMAGES)
      8 tea_cat = learn_.dls.vocab[torch.argmax(pred).item()].lstrip()
      9 tea_cat
Full error message
Could not gather input dimensions

/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /pytorch/c10/core/TensorImpl.h:1156.)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-11-210159cdeb3c> in <module>()
      5 dl = learn_.dls.test_dl([file_path])
      6 pred, _ = learn_.tta(dl=dl,
----> 7                      n=N_IMAGES)
      8 cat = learn_.dls.vocab[torch.argmax(pred).item()].lstrip()
      9 cat

23 frames
/usr/local/lib/python3.7/dist-packages/fastai/learner.py in tta(self, ds_idx, dl, n, item_tfms, batch_tfms, beta, use_max)
    589             for i in self.progress.mbar if hasattr(self,'progress') else range(n):
    590                 self.epoch = i #To keep track of progress on mbar since the progress callback will use self.epoch
--> 591                 aug_preds.append(self.get_preds(dl=dl, inner=True)[0][None])
    592         aug_preds = torch.cat(aug_preds)
    593         aug_preds = aug_preds.max(0)[0] if use_max else aug_preds.mean(0)

/usr/local/lib/python3.7/dist-packages/fastai/learner.py in get_preds(self, ds_idx, dl, with_input, with_decoded, with_loss, act, inner, reorder, cbs, **kwargs)
    251         if with_loss: ctx_mgrs.append(self.loss_not_reduced())
    252         with ContextManagers(ctx_mgrs):
--> 253             self._do_epoch_validate(dl=dl)
    254             if act is None: act = getattr(self.loss_func, 'activation', noop)
    255             res = cb.all_tensors()

/usr/local/lib/python3.7/dist-packages/fastai/learner.py in _do_epoch_validate(self, ds_idx, dl)
    201         if dl is None: dl = self.dls[ds_idx]
    202         self.dl = dl
--> 203         with torch.no_grad(): self._with_events(self.all_batches, 'validate', CancelValidException)
    204 
    205     def _do_epoch(self):

/usr/local/lib/python3.7/dist-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    163         try: self(f'before_{event_type}');  f()
    164         except ex: self(f'after_cancel_{event_type}')
--> 165         self(f'after_{event_type}');  final()
    166 
    167     def all_batches(self):

/usr/local/lib/python3.7/dist-packages/fastai/learner.py in __call__(self, event_name)
    139 
    140     def ordered_cbs(self, event): return [cb for cb in self.cbs.sorted('order') if hasattr(cb, event)]
--> 141     def __call__(self, event_name): L(event_name).map(self._call_one)
    142 
    143     def _call_one(self, event_name):

/usr/local/lib/python3.7/dist-packages/fastcore/foundation.py in map(self, f, gen, *args, **kwargs)
    152     def range(cls, a, b=None, step=None): return cls(range_of(a, b=b, step=step))
    153 
--> 154     def map(self, f, *args, gen=False, **kwargs): return self._new(map_ex(self, f, *args, gen=gen, **kwargs))
    155     def argwhere(self, f, negate=False, **kwargs): return self._new(argwhere(self, f, negate, **kwargs))
    156     def argfirst(self, f, negate=False): return first(i for i,o in self.enumerate() if f(o))

/usr/local/lib/python3.7/dist-packages/fastcore/basics.py in map_ex(iterable, f, gen, *args, **kwargs)
    664     res = map(g, iterable)
    665     if gen: return res
--> 666     return list(res)
    667 
    668 # Cell

/usr/local/lib/python3.7/dist-packages/fastcore/basics.py in __call__(self, *args, **kwargs)
    649             if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)
    650         fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[self.maxi+1:]
--> 651         return self.func(*fargs, **kwargs)
    652 
    653 # Cell

/usr/local/lib/python3.7/dist-packages/fastai/learner.py in _call_one(self, event_name)
    143     def _call_one(self, event_name):
    144         if not hasattr(event, event_name): raise Exception(f'missing {event_name}')
--> 145         for cb in self.cbs.sorted('order'): cb(event_name)
    146 
    147     def _bn_bias_state(self, with_bias): return norm_bias_params(self.model, with_bias).map(self.opt.state)

/usr/local/lib/python3.7/dist-packages/fastai/callback/core.py in __call__(self, event_name)
     43                (self.run_valid and not getattr(self, 'training', False)))
     44         res = None
---> 45         if self.run and _run: res = getattr(self, event_name, noop)()
     46         if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
     47         return res

/usr/local/lib/python3.7/dist-packages/fastai/callback/core.py in after_validate(self)
    158         with self.learn.removed_cbs(to_rm + self.cbs) as learn:
    159             self.preds = learn.get_preds(ds_idx=self.ds_idx, dl=self.dl,
--> 160                 with_input=self.with_input, with_decoded=self.with_decoded, inner=True, reorder=self.reorder)

/usr/local/lib/python3.7/dist-packages/fastai/learner.py in get_preds(self, ds_idx, dl, with_input, with_decoded, with_loss, act, inner, reorder, cbs, **kwargs)
    251         if with_loss: ctx_mgrs.append(self.loss_not_reduced())
    252         with ContextManagers(ctx_mgrs):
--> 253             self._do_epoch_validate(dl=dl)
    254             if act is None: act = getattr(self.loss_func, 'activation', noop)
    255             res = cb.all_tensors()

/usr/local/lib/python3.7/dist-packages/fastai/learner.py in _do_epoch_validate(self, ds_idx, dl)
    201         if dl is None: dl = self.dls[ds_idx]
    202         self.dl = dl
--> 203         with torch.no_grad(): self._with_events(self.all_batches, 'validate', CancelValidException)
    204 
    205     def _do_epoch(self):

/usr/local/lib/python3.7/dist-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    161 
    162     def _with_events(self, f, event_type, ex, final=noop):
--> 163         try: self(f'before_{event_type}');  f()
    164         except ex: self(f'after_cancel_{event_type}')
    165         self(f'after_{event_type}');  final()

/usr/local/lib/python3.7/dist-packages/fastai/learner.py in all_batches(self)
    167     def all_batches(self):
    168         self.n_iter = len(self.dl)
--> 169         for o in enumerate(self.dl): self.one_batch(*o)
    170 
    171     def _do_one_batch(self):

/usr/local/lib/python3.7/dist-packages/fastai/data/load.py in __iter__(self)
    108         self.__idxs=self.get_idxs() # called in context of main process (not workers/subprocesses)
    109         for b in _loaders[self.fake_l.num_workers==0](self.fake_l):
--> 110             if self.device is not None: b = to_device(b, self.device)
    111             yield self.after_batch(b)
    112         self.after_iter()

/usr/local/lib/python3.7/dist-packages/fastai/torch_core.py in to_device(b, device, non_blocking)
    256 #         if hasattr(o, "to_device"): return o.to_device(device)
    257         return o
--> 258     return apply(_inner, b)
    259 
    260 # Cell

/usr/local/lib/python3.7/dist-packages/fastai/torch_core.py in apply(func, x, *args, **kwargs)
    201 def apply(func, x, *args, **kwargs):
    202     "Apply `func` recursively to `x`, passing on args"
--> 203     if is_listy(x): return type(x)([apply(func, o, *args, **kwargs) for o in x])
    204     if isinstance(x,dict):  return {k: apply(func, v, *args, **kwargs) for k,v in x.items()}
    205     res = func(x, *args, **kwargs)

/usr/local/lib/python3.7/dist-packages/fastai/torch_core.py in <listcomp>(.0)
    201 def apply(func, x, *args, **kwargs):
    202     "Apply `func` recursively to `x`, passing on args"
--> 203     if is_listy(x): return type(x)([apply(func, o, *args, **kwargs) for o in x])
    204     if isinstance(x,dict):  return {k: apply(func, v, *args, **kwargs) for k,v in x.items()}
    205     res = func(x, *args, **kwargs)

/usr/local/lib/python3.7/dist-packages/fastai/torch_core.py in apply(func, x, *args, **kwargs)
    203     if is_listy(x): return type(x)([apply(func, o, *args, **kwargs) for o in x])
    204     if isinstance(x,dict):  return {k: apply(func, v, *args, **kwargs) for k,v in x.items()}
--> 205     res = func(x, *args, **kwargs)
    206     return res if x is None else retain_type(res, x)
    207 

/usr/local/lib/python3.7/dist-packages/fastai/torch_core.py in _inner(o)
    253     elif device is None: device=default_device()
    254     def _inner(o):
--> 255         if isinstance(o,Tensor): return o.to(device, non_blocking=non_blocking)
    256 #         if hasattr(o, "to_device"): return o.to_device(device)
    257         return o

/usr/local/lib/python3.7/dist-packages/fastai/torch_core.py in __torch_function__(self, func, types, args, kwargs)
    338         convert=False
    339         if _torch_handled(args, self._opt, func): convert,types = type(self),(torch.Tensor,)
--> 340         res = super().__torch_function__(func, types, args=args, kwargs=kwargs)
    341         if convert: res = convert(res)
    342         if isinstance(res, TensorBase): res.set_meta(self, as_copy=True)

/usr/local/lib/python3.7/dist-packages/torch/_tensor.py in __torch_function__(cls, func, types, args, kwargs)
   1021 
   1022         with _C.DisableTorchFunction():
-> 1023             ret = func(*args, **kwargs)
   1024             return _convert(ret, cls)
   1025 

/usr/local/lib/python3.7/dist-packages/torch/cuda/__init__.py in _lazy_init()
    170         # This function throws if there's a driver initialization error, no GPUs
    171         # are found or any other error occurs
--> 172         torch._C._cuda_init()
    173         # Some of the queued calls may reentrantly call _lazy_init();
    174         # we need to just return without initializing in that case.

RuntimeError: No CUDA GPUs are available

Why is this happening, and how do I get rid of this?

I don’t have access to GPUs for inference during deployment.

1 Like

Assuming you are exporting the Learner using Learner.export, then load_learner is loading the original Learner object, including all callbacks, which in this case is WandbCallback.

If you remove the WandbCallback before exporting or after loading using either Learner.remove_cb or Learner.remove_cbs, this should resolve the wandb requirement during inference.

load_learner should set the model and dataloader to CPU by default.

5 Likes

How do I turn off wandb logging temporarily? If you’re testing code and want to disable wandb syncing, set the environment variable WANDB_MODE=offline .

1 Like