So letās walk through predict
, specifically where it calls get_preds
:
inp,preds,_,dec_preds = self.get_preds(dl=dl, with_input=True, with_decoded=True)
We see here we get back our input passed in, and with_decoded
. This decoded is for the loss function only. decode_batch
then decodes from our DataBlock
.
For a different way to look at it, take my fastinference
library Iāve been building. I rebuilt the get_preds
function to make it a bit more efficient, but for all intensive purposes it still acts and behaves the same way the framework does:
for batch in dl:
with torch.no_grad():
...
if decoded_loss or fully_decoded:
out = x.model(*batch[:x.dls.n_inp])
raw.append(out)
dec_out.append(x.loss_func.decodes(out))
else:
raw.append(x.model(*batch[:x.dls.n_inp]))
This is how I get predictions (this is all hidden inside get_preds
and the GatherPreds
callback, so itās hard to figure out. Presume itās this with a bit more abstractness)
So we can see that if I want to decode my loss, I decode via the loss_func
. At the end Iāll go through and get a result similar to predict here:
if not raw_outs:
try: outs.insert(0, x.loss_func.activation(tensor(raw)).numpy())
except: outs.insert(0, dec_out)
else:
outs.insert(0, raw)
if fully_decoded: outs = _fully_decode(x.dls, inps, outs, dec_out, is_multi)
if decoded_loss: outs = _decode_loss(x.dls.vocab, dec_out, outs)
return outs
(And Iām going to do a video walkthrough of this whole thing this weekend too, that may help some.)
Thatās a long explanation but does this help? To see _fully_decode
and _decode_loss
, see here, Iām going to go through those in the video, just a lot to explain here (but itās good to keep in mind because the fastai framework operates in this same way!)
Also, my version has an option return a fully decoded output similar to learn.predict
, thatās why it will look different.