Fixed! Own fit implementation isn't training (.detach())

I am trying to implement a train function but it isn’t training my NN. It is running but doesn’t decrease the loss.

def train(epochs, data, m, opt, loss_func):
    for epoch in trange(epochs):

        m.train()
        data_iter = iter(data.trn_dl)
        i, n = 0, len(data.trn_dl)

        with tqdm(total=n) as pbar:

            while i < n:
                x, y = V(next(data_iter), requires_grad=True)

                set_trainable(m, True)
                m.zero_grad()
                y_hat = m(x).detach()
                loss = loss_func(y, y_hat)
                loss.backward()
                opt.step()
                i += 1

        print(f'Loss {to_np(loss)}')

model = SingleModel(Net().cuda()).model

train(5, data, model, optim.RMSprop(model.parameters(), lr=0.1), F.mse_loss)

By using the fit function I checked if the crit, optim, data and the model are all good for training. The fit function successfully reduced the loss of the NN.

model = SingleModel(Net().cuda()).model
learn = ConvLearner.from_model_data(model, data)
learn.crit = F.mse_loss
learn.optim = optim.RMSprop(model.parameters(), lr=0.1)

learn.fit(0.1, 5)

I have no idea what could be the problem as all the differences I see between what I do and the fit function is:

  1. Detaching y_hat (I get an error if I don’t do this)
  2. Setting the output of next(data_iter) to a variable with grad (I am not sure if the fit makes it have grad or not but without it, I get another error)

If what I am doing wrong is obvious then please just point me in the direction instead of just correcting me. Thanks!

The problem is detach(). You can’t propagate grads then.

2 Likes

I have added the detach() because, without it, I was getting RuntimeError: the derivative for 'target' is not implemented error. I found: https://github.com/pytorch/pytorch/issues/3933 It has the same error when the order of the arguments in the critic is wrong. I realized quickly that I have the same error in my code.

To correct:

loss = loss_func(y, y_hat) # before fix
loss = loss_func(y_hat, y) # after fix

This also means that I don’t need to detach and also that
x, y = V(next(data_iter), requires_grad=True)
shouldn’t require_grad.

(restart fixes this, I am only including this for people who experience the same thing)
Even with the fix, the training is bad. The loss is exploding then going down then exploding… By using the learning rate I used with fit (0.1), I was getting loss in the tens of millions. I reduced the learning rate to 0.001 and it is still extremely unstable:

# with Adam lr 0.001 the loss was continuously increasing
# RMSProp lr 0.001:
epoch      trn_loss                                      
    0      [3.436]   
    1      [1.419]                                                  
    2      [1.264]                                                
    3      [263.4]                                                   
    4      [67.033]

After a full system restart, this unstable training went away!

Train function after fix (working well):

def train(epochs, data, m, opt, loss_func):
    for epoch in trange(epochs):

        m.train()
        data_iter = iter(data.trn_dl)
        i, n = 0, len(data.trn_dl)

        with tqdm(total=n) as pbar:

            while i < n:
                x, y = V(next(data_iter))

                set_trainable(m, True)
                m.zero_grad()

                y_hat = m(x)
                loss = loss_func(y_hat, y)

                loss.backward()
                opt.step()

                i += 1

        print(f'Loss {to_np(loss)}')