Hi,
I’ve been working in the lesson 4 further research section (at the end of the lesson) and I’ve hit a wall trying to implement the first task which is implement the Learner fastai class.
Now for this I tried to create everything by myself, which includes creating a model class for my model, an optimization function, … Only thing I’m using from fast ai is really the Dataloaders class.
Now, my problem seems to be that my gradients are always zero (or really close to zero) which means the weights never actualize, which in turn means that my model never improves. Been trying to debug the error all day but haven’t been able to.
Code below:
def init_params(size):
return torch.rand(size).requires_grad_()
class Model:
def __init__(self):
self.w1 = init_params((28*28, 1))
self.b1 = init_params(1)
self.params = self.w1, self.b1
def predict(self, x):
predictions = x@self.w1 + self.b1
return predictions
def loss_function(labels, preds):
preds = preds.sigmoid()
return torch.where(labels==1, 1-preds, preds).mean()
def calculate_gradient(loss, learning_rate, params):
loss.backward()
for weight in params:
weight.data -= learning_rate * weight.grad.data
weight.grad.zero_()
def accuracy(label, pred):
pred = pred.sigmoid() > 0.5
return (label == pred).float().mean()
class Tr_loop:
def __init__(self, dls, model, opt_func, loss_func, metric):
self.dls = dls
self.model = model
self.opt_func = opt_func
self.loss_func = loss_func
self.metric = metric
def train(self, learning_rate):
for x, y in self.dls[0]:
preds = self.model.predict(x)
loss = self.loss_func(y, preds)
self.opt_func(loss, learning_rate, self.model.params)
def validation(self):
accs_list = [self.metric(y, self.model.predict(x)) for x, y in self.dls[1]]
return torch.FloatTensor(accs_list).mean()
def fit(self, epochs, learning_rate):
for _ in range(epochs):
self.train(learning_rate)
metric = self.validation()
print(metric, end=' ')
model = Model()
learner = Tr_loop(dls, model, calculate_gradient, loss_function, accuracy)
Where dls is a data loader object.
Current functions are kind of fixed in the sense that they only work for a two class classfication problem, in this case is the same as the lesson 4 which is classifying a number as either a 3 or a 7.