Confused about whether we need to call
model.reset() before each individual prediction … or if we only need to call it once after calling
model.eval() since we aren’t updating any of the weights.
Consider this snippet of code:
# set model for evaluation (turn off dropout)
# reset hidden state here just once?
preds = 
for i in range(len(examples)):
# ... or ... reset hidden state prior to each individual prediction?
pred = m(inp)
Again, this is for running predictions for examples once example at a time.
This is a really good question. From experimentation, it doesn’t look like reset() is needed if doing predictions on a row by row basis as in the code above.
However, I have got some odd results doing batch prediction where maybe the order of rows affects the predictions. But that might be a bug in my code rather than anything to do with reset() - I’m still trying to isolate it.
I haven’t had time to dig into this deeply, but if anyone is reading this it would be great to confirm the correct answer, and to check if the ways predictions can be made behave the same way.
FYI, if you are using
model.predict_batch, it appears as if
reset() is called for every batch …
def predict_batch(m, x):
if hasattr(m, 'reset'): m.reset()
Not that I can share.
I was manually using the PyTorch model.
This is even more confusing. The reset will only apply to the first row in the batch, right?
Maybe there is some guidance on the PyTorch forums, because I guess this is really pure PyTorch?
No, it will be called before the entire batch is ran through the model. Remember that when you call this method you are running batch_size number of examples through your model simultaneously.
Yeah, and I find this somewhat confusing. The model is - in effect - a number of operations to apply (and associated hidden state), and these operations need to be applied to each row.
I understand that multiple rows can be done simultaneously for performance.
But I don’t understand how that is done. I assumed that multiple copies of the same set of operations are created and that one row doesn’t effect the hidden state from another row in the same batch. But thinking that though I’m not really sure about it. So how does it work?