Is it possible to pass my own simple pytorch model to cnn_learner?
Sure! Learner(databunch, myModel, metrics= someAccuracyFunction)
For instance:
Learner(myDatabunch, myModel, metrics= accuracy)
See here on the docs for all the custom parameters you can include:
https://docs.fast.ai/basic_train.html#Learner
Thankyou
How does that work exactly? I’m struggling to implement my own version of embeddingDotBias. It is basically the same code but I’m getting a “NotImplemmented” error when I run fit.
How are you setting it up? (can I see some code?) Any pytorch architecture can be passed into your model, though it may need some modification for it to work with fastai (I saw this occasionally when I tried to do NTS-Net)
Sure. Here it is:
class CollabChEMBL(nn.Module):
def __init__(self,n_mols,n_target,n_factors,max_act=12,min_act=1):
super().__init__()
self.m_weight = nn.Embedding(n_mols, n_factors)
self.t_weight = nn.Embedding(n_target, n_factors)
self.lin1 = nn.Linear(n_factors*2,10)
self.lin2 = nn.Linear(10,1)
self.relu1 = nn.ReLU()
self.relu2 = nn.ReLU()
self.drop1 = nn.Dropout()
self.drop2 = nn.Dropout()
def forward(self, mols,targets):
dot = self.m_weight(mols)@self.t_weight(targets)
x = self.relu1(self.lin1(dot)) # First non-linearity
x = self.drop1(x)
x = self.relu2(self.lin2(x))
x = self.drop2(x)
out = nn.Sigmoid(x) * (self.max_act-self.min_act+1) + self.min_act-0.5
return out
I’m showing a bigger arch because this was what I first tried, then I decided to try something smaller but it also didnt work. I’m trying to implement my own collaborative filtering NN.
wd = 1e-2
criterion = nn.MSELoss() # NLL + Log_softmax layer = multi-class Cross-entropy
model = CollabChEMBL(n_mols, n_targets, 50,12,1)
opt = optim.Adam(model.parameters(), 1e-3, weight_decay=wd)
print(model) # This runs OK and show my arch
learn = Learner(data,model) # Doesn't give me any error
And how are you creating your learner and what is the stack trace of errors it’s giving you? Eg Learner(data=data, arch=CollabChEMBL())
?
Like this
n_mols = 40000
n_targets=400
model = CollabChEMBL(n_mols, n_targets, 50,12,1)
learn = Learner(data,model) # Doesn't give me any error
And when I run
learn.fit(1)
NotImplementedError Traceback (most recent call last)
in
----> 1 learn.fit(epochs=1)
/opt/conda/lib/python3.6/site-packages/fastai/basic_train.py in fit(self, epochs, lr, wd, callbacks)
200 callbacks = [cb(self) for cb in self.callback_fns + listify(defaults.extra_callback_fns)] + listify(callbacks)
201 self.cb_fns_registered = True
–> 202 fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks)
203
204 def create_opt(self, lr:Floats, wd:Floats=0.)->None:
/opt/conda/lib/python3.6/site-packages/fastai/basic_train.py in fit(epochs, learn, callbacks, metrics)
99 for xb,yb in progress_bar(learn.data.train_dl, parent=pbar):
100 xb, yb = cb_handler.on_batch_begin(xb, yb)
–> 101 loss = loss_batch(learn.model, xb, yb, learn.loss_func, learn.opt, cb_handler)
102 if cb_handler.on_batch_end(loss): break
103
/opt/conda/lib/python3.6/site-packages/fastai/basic_train.py in loss_batch(model, xb, yb, loss_func, opt, cb_handler)
24 if not is_listy(xb): xb = [xb]
25 if not is_listy(yb): yb = [yb]
—> 26 out = model(*xb)
27 out = cb_handler.on_loss_begin(out)
28
/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
–> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)
/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in forward(self, *input)
86 registered hooks while the latter silently ignores them.
87 “”"
—> 88 raise NotImplementedError
89
90 def register_buffer(self, name, tensor):
NotImplementedError:
I created the databunch using:
data = CollabDataBunch(train,user_name=userid,
rating_name=ratings,item_name=targetid,
test=valid, seed=42)
Oh no hahahaha. I found out what was wrong! My forward method wasnt indented! Now it’s running ok ahhaha.
Perfect! Glad to see you got it working