Custom Architecture

Is it possible to pass my own simple pytorch model to cnn_learner?

Sure! Learner(databunch, myModel, metrics= someAccuracyFunction)
For instance:
Learner(myDatabunch, myModel, metrics= accuracy)

See here on the docs for all the custom parameters you can include:



How does that work exactly? I’m struggling to implement my own version of embeddingDotBias. It is basically the same code but I’m getting a “NotImplemmented” error when I run fit.

How are you setting it up? (can I see some code?) :slight_smile: Any pytorch architecture can be passed into your model, though it may need some modification for it to work with fastai (I saw this occasionally when I tried to do NTS-Net)

Sure. Here it is:

class CollabChEMBL(nn.Module):
    def __init__(self,n_mols,n_target,n_factors,max_act=12,min_act=1):
        self.m_weight = nn.Embedding(n_mols, n_factors)
        self.t_weight = nn.Embedding(n_target, n_factors)
        self.lin1 = nn.Linear(n_factors*2,10)
        self.lin2 = nn.Linear(10,1)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.drop1 = nn.Dropout()
        self.drop2 = nn.Dropout()
def forward(self, mols,targets):
        dot = self.m_weight(mols)@self.t_weight(targets)
        x = self.relu1(self.lin1(dot)) # First non-linearity
        x = self.drop1(x)
        x = self.relu2(self.lin2(x))
        x = self.drop2(x)
        out = nn.Sigmoid(x) * (self.max_act-self.min_act+1) + self.min_act-0.5
        return out

I’m showing a bigger arch because this was what I first tried, then I decided to try something smaller but it also didnt work. I’m trying to implement my own collaborative filtering NN.

wd = 1e-2
criterion = nn.MSELoss() # NLL + Log_softmax layer = multi-class Cross-entropy
model = CollabChEMBL(n_mols, n_targets, 50,12,1)
opt = optim.Adam(model.parameters(), 1e-3, weight_decay=wd)
print(model) # This runs OK and show my arch

learn = Learner(data,model) # Doesn't give me any error
1 Like

And how are you creating your learner and what is the stack trace of errors it’s giving you? Eg Learner(data=data, arch=CollabChEMBL())?

1 Like

Like this

n_mols = 40000

model = CollabChEMBL(n_mols, n_targets, 50,12,1)
learn = Learner(data,model) # Doesn't give me any error

And when I run

NotImplementedError Traceback (most recent call last)
----> 1

/opt/conda/lib/python3.6/site-packages/fastai/ in fit(self, epochs, lr, wd, callbacks)
200 callbacks = [cb(self) for cb in self.callback_fns + listify(defaults.extra_callback_fns)] + listify(callbacks)
201 self.cb_fns_registered = True
–> 202 fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks)
204 def create_opt(self, lr:Floats, wd:Floats=0.)->None:

/opt/conda/lib/python3.6/site-packages/fastai/ in fit(epochs, learn, callbacks, metrics)
99 for xb,yb in progress_bar(, parent=pbar):
100 xb, yb = cb_handler.on_batch_begin(xb, yb)
–> 101 loss = loss_batch(learn.model, xb, yb, learn.loss_func, learn.opt, cb_handler)
102 if cb_handler.on_batch_end(loss): break

/opt/conda/lib/python3.6/site-packages/fastai/ in loss_batch(model, xb, yb, loss_func, opt, cb_handler)
24 if not is_listy(xb): xb = [xb]
25 if not is_listy(yb): yb = [yb]
—> 26 out = model(*xb)
27 out = cb_handler.on_loss_begin(out)

/opt/conda/lib/python3.6/site-packages/torch/nn/modules/ in call(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
–> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)

/opt/conda/lib/python3.6/site-packages/torch/nn/modules/ in forward(self, *input)
86 registered hooks while the latter silently ignores them.
87 “”"
—> 88 raise NotImplementedError
90 def register_buffer(self, name, tensor):


I created the databunch using:

data = CollabDataBunch(train,user_name=userid,
                           test=valid, seed=42)

Oh no hahahaha. I found out what was wrong! My forward method wasnt indented! Now it’s running ok ahhaha.

1 Like

Perfect! :slight_smile: Glad to see you got it working :slight_smile: