This is an interesting problem, I haven’t found a solution but here’s what I’ve found.
Here is the source code for learn.predict
:
def predict(self, item, rm_type_tfms=None, with_input=False):
dl = self.dls.test_dl([item], rm_type_tfms=rm_type_tfms, num_workers=0)
inp,preds,_,dec_preds = self.get_preds(dl=dl, with_input=True, with_decoded=True)
i = getattr(self.dls, 'n_inp', -1)
inp = (inp,) if i==1 else tuplify(inp)
dec = self.dls.decode_batch(inp + tuplify(dec_preds))[0]
dec_inp,dec_targ = map(detuplify, [dec[:i],dec[i:]])
res = dec_targ,dec_preds[0],preds[0]
if with_input: res = (dec_inp,) + res
return res
I took that source code, pasted it into a cell and ran the following (I ran your provided code first which gives me train_x
):
item = train_x[0]
dl = learn.dls.test_dl([item], rm_type_tfms=None, num_workers=0)
inp,preds,_,dec_preds = learn.get_preds(dl=dl, with_input=True, with_decoded=True)
i = getattr(learn.dls, 'n_inp', -1)
inp = (inp,) if i==1 else tuplify(inp)
dec = learn.dls.decode_batch(inp + tuplify(dec_preds))[0]
dec_inp,dec_targ = map(detuplify, [dec[:i],dec[i:]])
res = dec_targ,dec_preds[0],preds[0]
This gives the following error:
IndexError: too many indices for tensor of dimension 0
for the following line of code:
inp,preds,_,dec_preds = learn.get_preds(dl=dl, with_input=True, with_decoded=True)
I found this Forums post which although isn’t technically related to your situation, I thought I would give it a try (unsqueeze
the train_x[0]
value):
item = train_x[0].unsqueeze(dim=0)
dl = learn.dls.test_dl([item], rm_type_tfms=None, num_workers=0)
inp,preds,_,dec_preds = learn.get_preds(dl=dl, with_input=True, with_decoded=True)
i = getattr(learn.dls, 'n_inp', -1)
inp = (inp,) if i==1 else tuplify(inp)
dec = learn.dls.decode_batch(inp + tuplify(dec_preds))[0]
dec_inp,dec_targ = map(detuplify, [dec[:i],dec[i:]])
res = dec_targ,dec_preds[0],preds[0]
This resolved the initial error but gave a new error:
AttributeError: 'list' object has no attribute 'decode_batch'
Caused by the following line:
dec = learn.dls.decode_batch(inp + tuplify(dec_preds))[0]
Looking at your DataLoaders
, it doesn’t have a decode_batch
attribute (I’m not sure why):
Here is a colab notebook with the code.
Dropping a few more related links that I found—note the last link where they have had the same issue but no resolution:
- IndexError with Focal Loss
- Learn.model input? (Tabular)
- Chapter 4 - Further research MNIST
- Beginners guide to MNIST with fast.ai | Kaggle
- MNIST image learner predict with fastai
Not sure if any of this helps but hopefully you can find a solution to this.