A little more long winded, but would eliminate any dependencies on the fast.ai framework in production (which might prove helpful if trying to put this in an Android or iOS app at some point):
torch_model = learn.models.model
img = Image.open(f'{PATH}/valid/in-n-out/2.jpg')
normalize = torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
preprocess = torchvision.transforms.Compose([
torchvision.transforms.Scale(256),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
normalize
])
img_tensor = preprocess(img).unsqueeze_(0)
img_variable = Variable(img_tensor.cuda())
log_probs = torch_model(img_variable)
preds = np.argmax(log_probs.cpu().data.numpy(), axis=1)
print(preds)