Hello.
I need to add the functionality of getting a probability of a word given a context.
I’ve found this post that is exactly what I need:
I want to get P(“red”|“The car is”).
It means that I need a function that receives the context: “The car is”, and the next word candidate: “red”, and returns the probability for this event.
The answer for the post added here wasn’t very clear, and I’m having trouble understanding how to implement this function.
I think I solved it. I’m pasting my function below the predict (my reference):
def predict(self, text:str, n_words:int=1, no_unk:bool=True, temperature:float=1., min_p:float=None, sep:str=' ',
decoder=decode_spec_tokens):
"Return `text` and the `n_words` that come after"
self.model.reset()
xb,yb = self.data.one_item(text)
new_idx = []
for _ in range(n_words): #progress_bar(range(n_words), leave=False):
res = self.pred_batch(batch=(xb,yb))[0][-1]
#if len(new_idx) == 0: self.model[0].select_hidden([0])
if no_unk: res[self.data.vocab.stoi[UNK]] = 0.
if min_p is not None:
if (res >= min_p).float().sum() == 0:
warn(f"There is no item with probability >= {min_p}, try a lower value.")
else: res[res < min_p] = 0.
if temperature != 1.: res.pow_(1 / temperature)
idx = torch.multinomial(res, 1).item()
new_idx.append(idx)
xb = xb.new_tensor([idx])[None]
return text + sep + sep.join(decoder(self.data.vocab.textify(new_idx, sep=None)))
def get_prob_of_word_in_context(self, context: str, word: str):
self.model.reset()
xb,yb = self.data.one_item(context)
res = self.pred_batch(batch=(xb, yb))[0][-1]
normalized_scores = F.softmax(res)
index_of_word = self.data.vocab.stoi[word]
prob_of_word_given_context = normalized_scores[index_of_word]
return prob_of_word_given_context
first we reset the model, feed it with the context (like the predict part), and pred_batch the same way.
res is a vector in the dimension of the number of words (total vocab.stoi). Meaning it contains a distribution of score for each possible word. If we normalize it by softmax, we’ll get a probability distribution. Next, we search for the word we want, getting its index with stoi, and that’s how we get the probability for it.