I am having a related issue.
Check this out. Wiki: Lesson 1 - #183 by nikhil_no_1
I am using accuracy function.
This is the source for it.
def accuracy(preds, targs):
preds = np.argmax(preds, axis=1)
return (preds==targs).mean()
Moreover I replaced probs with log_preds where ever it gave axis error. Just wanted to run through it once. Will now go through each line step by step to understand if that is the right thing to do.
-Nikhil