How Should I train a Language Model?

I have a set of around 100,000 unlabeled tweets that I want to train a language model on and a classifier. Right now, I am doing:

# language model learner

learn = language_model_learner(data_lm, arch=AWD_LSTM, drop_mult=0.5)

# Gradual unfreezing of lm

learn.freeze()

learn.fit_one_cycle(cyc_len=1, max_lr=1e-2, moms=(0.8, 0.7))

learn.freeze_to(-2)

learn.fit_one_cycle(cyc_len=1, max_lr=1e-2, moms=(0.8, 0.7))

learn.freeze_to(-3)

learn.fit_one_cycle(cyc_len=1, max_lr=1e-2, moms=(0.8, 0.7))

learn.unfreeze()

learn.fit_one_cycle(cyc_len=3, max_lr=1e-3, moms=(0.8, 0.7))

However, I’m not sure this is correct. I get an output like this:

epoch train_loss valid_loss accuracy time
0 3.560003 3.149161 0.465935 05:58
epoch train_loss valid_loss accuracy time
0 3.271437 3.006000 0.493857 05:58
epoch train_loss valid_loss accuracy time
0 3.111655 3.022310 0.506357 06:40
epoch train_loss valid_loss accuracy time
0 2.868488 3.012175 0.511724 07:21
1 2.733594 3.012394 0.517344 06:36
2 2.567876 3.043180 0.518394 06:15

Am I doing the gradual unfreezing correctly? Also, should I be optimizing for validation loss or accuracy? The validation loss starts going back up, while accuracy is improving, so I’m not sure what to do. Would having the lowest validation loss result in the best score for the final classifier, or would having the highest accuracy be the best? The classification problem has 3 different classes with a 40%/55%/5% split.

1 Like

In general you want your language model to be between 40-50% (if you can get above that that’s good but don’t spend forever training it) before moving onto the classification, so I’d say that’s pretty good :slight_smile:

The higher accuracy is usually better (I’ve seen Perplexity as a metric too), but it all depends on how the final classifier is too :slight_smile:

I did some further research, and it looks like the legit way to do it is

learn.freeze()

learn.fit_one_cycle(cyc_len=1, max_lr=1e-2, moms=(0.8, 0.7))

learn.unfreeze()

learn.fit_one_cycle(cyc_len=3, max_lr=1e-3, moms=(0.8, 0.7))

and do early stop whenever the validation loss stops improving/goes up

1 Like