I found an elegant solution to do normalization and prediction on a single image (see below). However, it relies on tfms_from_stats to do the resize and normalization which is sadly deprecated in V1, it seems. How might the same thing be done in V1? I’ve search the docs and forums for a few hrs and didn’t come across anything yet. Specifically, I’m not seeing anything in the current tfms method that would allow for using image_stats.
environment:
BUCKET_NAME: pytorch-serverless
STATE_DICT_NAME: dogscats-resnext50.h5
IMAGE_SIZE: 224
IMAGE_STATS: ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
LABELS_PATH: lib/labels.txt
STATS = A(*eval(os.environ['IMAGE_STATS']))
SZ = int(os.environ.get('IMAGE_SIZE', '224'))
TFMS = tfms_from_stats(STATS, SZ)[-1]
class SetupModel(object):
model = classification_model()
labels = get_labels(os.environ['LABELS_PATH'])
def __init__(self, f):
self.f = f
file_path = f'/tmp/{STATE_DICT_NAME}'
download_file(BUCKET_NAME, STATE_DICT_NAME, file_path)
state_dict = torch.load(file_path, map_location=lambda storage, loc: storage)
self.model.load_state_dict(state_dict), self.model.eval()
os.remove(file_path)
def __call__(self, *args, **kwargs):
return self.f(*args, **kwargs)
def build_pred(label_idx, log, prob):
label = SetupModel.labels[label_idx]
return dict(label=label, log=float(log), prob=float(prob))
def parse_params(params):
image_url = urllib.parse.unquote_plus(params.get('image_url', ''))
n_labels = len(SetupModel.labels)
top_k = int(params.get('top_k', 3))
if top_k < 1: top_k = n_labels
return dict(image_url=image_url, top_k=min(top_k, n_labels))
def predict(img):
batch = [T(TFMS(img))]
inp = VV_(torch.stack(batch))
return SetupModel.model(inp).mean(0)