I am working on a project that requires large amounts of text generation. I found the predict function too slow, and worked on code to do this in batches this afternoon. I wanted to share this with the community, in hopes it may be helpful. Please point out any bugs you find! I haven’t done extensive tests to confirm this produces the same distribution as Jeremy’s predict, but early checks look good.
#—generate 64 independent articles of text
batch_sample=64 seqs_gen = ['']*batch_sample mask = [1] *batch_sample #---my GO character is GO. Need to find the corresponding int go_int = learner.data.train_ds.vocab.stoi['GO'] xb = np.array( [go_int]*batch_sample ) yb = np.array( [go_int]*batch_sample ) xb = torch.from_numpy(xb).to(device='cuda').unsqueeze(1) yb = torch.from_numpy(yb).to(device='cuda').unsqueeze(1) #--- want to generate text that is max 600 words long for i in range(0, 600): print(i) preds=learner.pred_batch( batch=(xb,yb) ) preds = preds[:,-1,:] preds = preds.squeeze() m = torch.distributions.Multinomial(probs=preds) actions = m.sample() idx = actions.nonzero()[:,1] for i in range(0, idx.size()[0] ): ints = idx[i].item() letter = learner.data.train_ds.x.vocab.itos[ints] if( letter != 'xxpad' and letter != 'GO' and letter != 'END' and letter != 'xxfake' and mask[i] == 1): seqs_gen[i] = seqs_gen[i] + letter elif( letter == 'END'): #---use mask to not append further mask[i] = 0 print('we have ended text') idx = idx.unsqueeze(1) idx = idx.to(device='cuda') xb = torch.cat( (xb,idx), 1) yb = torch.cat( (yb,idx), 1)