I’m having and error when the model is trained with the callback OverSamplingCallback and when it is imported using load_learner(). Error occurs after calling learn.get_preds.
IndexError Traceback (most recent call last)
<ipython-input-12-5a5f3d8917f7> in <module>
13 ids_test.append(match.group(0))
14 learn = load_learner(learner_exported, 'final_vgg19_8.pkl', test=ImageList.from_folder(subdir)).to_fp16()
---> 15 preds, y = learn.get_preds(ds_type=DatasetType.Test)
16 preds_test.append(activation(preds))
17 preds_test_mxp.append(max_proba(preds))
/opt/anaconda3/lib/python3.7/site-packages/fastai/basic_train.py in get_preds(self, ds_type, activ, with_loss, n_batch, pbar)
340 if not getattr(self, 'opt', False): self.create_opt(lr, wd)
341 else: self.opt.lr,self.opt.wd = lr,wd
--> 342 self.callbacks = [cb(self) for cb in self.callback_fns + listify(defaults.extra_callback_fns)] + listify(self.callbacks)
343 self.cb_fns_registered = True
344 return get_preds(self.model, self.dl(ds_type), cb_handler=CallbackHandler(self.callbacks),
/opt/anaconda3/lib/python3.7/site-packages/fastai/basic_train.py in <listcomp>(.0)
340 if not getattr(self, 'opt', False): self.create_opt(lr, wd)
341 else: self.opt.lr,self.opt.wd = lr,wd
--> 342 self.callbacks = [cb(self) for cb in self.callback_fns + listify(defaults.extra_callback_fns)] + listify(self.callbacks)
343 self.cb_fns_registered = True
344 return get_preds(self.model, self.dl(ds_type), cb_handler=CallbackHandler(self.callbacks),
/opt/anaconda3/lib/python3.7/site-packages/fastai/callbacks/oversampling.py in __init__(self, learn, weights)
15 _, counts = np.unique(self.labels,return_counts=True)
16 self.weights = (weights if weights is not None else
---> 17 torch.DoubleTensor((1/counts)[self.labels]))
18 self.label_counts = np.bincount([self.learn.data.train_dl.dataset.y[i].data for i in range(len(self.learn.data.train_dl.dataset))])
19 self.total_len_oversample = int(self.learn.data.c*np.max(self.label_counts))
IndexError: arrays used as indices must be of integer (or boolean) type
Callback is used because I try to do the week 1 assignment with an unbalanced data set having x2 labels with 7 to 100 ratio.
Removing the OverSamplingCallback callback during the training solves the issue.
Also, this problem does not occur when I use Learner.save()->Learner.load() instead of Learner.export()->load_learner.
Thanks in advance for you help.