Lesson 3 official topic

This is an interesting problem, I haven’t found a solution but here’s what I’ve found.

Here is the source code for learn.predict:

def predict(self, item, rm_type_tfms=None, with_input=False):
        dl = self.dls.test_dl([item], rm_type_tfms=rm_type_tfms, num_workers=0)
        inp,preds,_,dec_preds = self.get_preds(dl=dl, with_input=True, with_decoded=True)
        i = getattr(self.dls, 'n_inp', -1)
        inp = (inp,) if i==1 else tuplify(inp)
        dec = self.dls.decode_batch(inp + tuplify(dec_preds))[0]
        dec_inp,dec_targ = map(detuplify, [dec[:i],dec[i:]])
        res = dec_targ,dec_preds[0],preds[0]
        if with_input: res = (dec_inp,) + res
        return res

I took that source code, pasted it into a cell and ran the following (I ran your provided code first which gives me train_x):

item = train_x[0]

dl = learn.dls.test_dl([item], rm_type_tfms=None, num_workers=0)
inp,preds,_,dec_preds = learn.get_preds(dl=dl, with_input=True, with_decoded=True)
i = getattr(learn.dls, 'n_inp', -1)
inp = (inp,) if i==1 else tuplify(inp)
dec = learn.dls.decode_batch(inp + tuplify(dec_preds))[0]
dec_inp,dec_targ = map(detuplify, [dec[:i],dec[i:]])
res = dec_targ,dec_preds[0],preds[0]

This gives the following error:

IndexError: too many indices for tensor of dimension 0

for the following line of code:

inp,preds,_,dec_preds = learn.get_preds(dl=dl, with_input=True, with_decoded=True)

I found this Forums post which although isn’t technically related to your situation, I thought I would give it a try (unsqueeze the train_x[0] value):

item = train_x[0].unsqueeze(dim=0)

dl = learn.dls.test_dl([item], rm_type_tfms=None, num_workers=0)
inp,preds,_,dec_preds = learn.get_preds(dl=dl, with_input=True, with_decoded=True)
i = getattr(learn.dls, 'n_inp', -1)
inp = (inp,) if i==1 else tuplify(inp)
dec = learn.dls.decode_batch(inp + tuplify(dec_preds))[0]
dec_inp,dec_targ = map(detuplify, [dec[:i],dec[i:]])
res = dec_targ,dec_preds[0],preds[0]

This resolved the initial error but gave a new error:

AttributeError: 'list' object has no attribute 'decode_batch'

Caused by the following line:

dec = learn.dls.decode_batch(inp + tuplify(dec_preds))[0]

Looking at your DataLoaders, it doesn’t have a decode_batch attribute (I’m not sure why):

image

Here is a colab notebook with the code.

Dropping a few more related links that I found—note the last link where they have had the same issue but no resolution:

Not sure if any of this helps but hopefully you can find a solution to this.

2 Likes