I can’t. It is not a docker container that I can concoct as I like. It is an inference machine upon which I can just use the stuff already installed.
Besides, you are teaching me things which are interesting per se.
Look at what happens. I slightly changed your code to adjust it to my needs.
from torch import Tensor
from torch import nn
import logging as log
class Lambda(nn.Module):
"An easy way to create a pytorch layer for a simple `func`."
def __init__(self, func):
"create a layer that simply calls `func` with `x`"
super().__init__()
self.func=func
def forward(self, x): return self.func(x)
def Flatten(full:bool=False)->Tensor:
"Flatten `x` to a single dimension, often used at the end of a model. `full` for rank-1 tensor"
func = (lambda x: x.view(-1)) if full else (lambda x: x.view(x.size(0), -1))
return Lambda(func)
def myhead(nf, nc):
return \
nn.Sequential(
nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.AdaptiveMaxPool2d(1)
),
Flatten(),
nn.BatchNorm1d(nf),
nn.Linear(nf, 512),
nn.ReLU(True),
nn.BatchNorm1d(512),
nn.Linear(512, nc),
)
Note that I didn’t make it a python module. I just wrote that in a notebook cell, for experimenting. nc
is the number of classes.
Then, I did:
import torchvision.models
mylearn=create_cnn(data,arch=torchvision.models.resnet50,
metrics=accuracy,
custom_head=myhead(4096, 3))
That created a resnet50 with a head identical to fastai’s.
Then:
modeltosave=mylearn.model
modeltosave.cpu()
torch.save(modeltosave, '/path/mymodel.pkl')
As you warned, it didn’t work: AttributeError: Can't pickle local object 'Flatten.<locals>.<lambda>'
.
But it serializes the fastai’s Flatten
which is identical to ours, so I cannot figure out why it doesn’t work for our Flatten
(maybe @sgugger could answer this).
However, I installed dill
, and then:
import dill
modeltosave=mylearn.model
modeltosave.cpu()
torch.save(modeltosave, '/path/mymodel.pkl', pickle_module=dill)
I received a warning: serialization.py:251: UserWarning: Couldn't retrieve source code for container of type Lambda. It won't be checked for correctness upon loading. "type " + obj.__name__ + ". It won't be checked "
And indeed, at inference time, it says:
path/site-packages/dill/_dill.py", line 474, in find_class
return StockUnpickler.find_class(self, module, name)
AttributeError: Can't get attribute 'Lambda' on <module '__main__' from 'predictor.py'>
Mmhh… It seems it cannot serialize it.
Any suggestion?