Test Seq2Seq prediction after training

I’ve successfully trained a seq2seq model based on the instruction in lesson 11. However I struggle to find an example of using this trained model to predict new raw manual input. I can find examples of using language model and image classifier but the API don’t carry over. Appreciate it if anyone could point me in the right direction. Thank you

1 Like

what do you mean by “raw manual input”, SeqToSeq can be used for many things, in images, NLP. What problem are you trying to solve?

I have the same problem, it seems in the seq2seq collate function need both input and target and in the inference time we don’t have the targets.
here is my collate function:

def seq2seq_collate(samples:BatchSamples, pad_idx:int=1, pad_first:bool=True, backwards:bool=False) -> Tuple[LongTensor, LongTensor]:
“Function that collect samples and adds padding. Flips token order if needed”
samples = to_data(samples)
max_len_x,max_len_y = max([len(s[0]) for s in samples]),max([len(s[1]) if type(s[1])== str else 1 for s in samples])
res_x = torch.zeros(len(samples), max_len_x).long() + pad_idx
res_y = torch.zeros(len(samples), max_len_y).long() + pad_idx
if backwards: pad_first = not pad_first
for i,s in enumerate(samples):
if pad_first:
res_x[i,-len(s[0]):],res_y[i,-len(s[1]):] = LongTensor(s[0]),LongTensor(s[1])
else:
res_x[i,:len(s[0]):],res_y[i,:len(s[1]):] = LongTensor(s[0]),LongTensor(s[1])
if backwards: res_x,res_y = res_x.flip(1),res_y.flip(1)
return res_x, res_y

any help is appriciated.
thanks.

@darrenlmd @armheb Were you able to predict the Seq2Seq model on test data? I have also build a Seq2Seq model following the lecture from Fastai NLP course, the model is trained well, metrics are good but struggling to save the model and predict on text input. I am new to fastai and Pytorch, need some guidance on this.

I used this code:

def predict(text,learner = learn):
  x_test,y = data.one_item(text)
  out = learn.model(x_test[0],x_test[1])
  res_test = str(data.train_ds.y.reconstruct(out[0].argmax(1))).replace('xxmaj','').replace('xxunk','').replace('xxbos','').replace('xxup','').replace('xxeos','')
  print(res_test)

hope it helps.

Hey @armheb, I just tried out your example and I am still getting errors. Namely I get a TypeError: object of type 'int' has no len() error in my seq2seq_collate_fn functionwhen one_item(text) method is called. How did you overcome this issue?

Same error. Same problem here. I can’t figure it out how to infer a new prediction using a new raw input.

What should be the type for “text” argument?

I think it was string.

I worked on this with @arianpasquali and we found out that using

x = data.x.process_one("Sentence example")
x = torch.tensor([x]).to(device)
out = learn.model(x)

did the trick!

(Arian please correct me if I’m wrong with the extra square bracket in line 2)

1 Like