Hi all - wasn’t sure if this was useful enough to contribute or trivial enough that it wasn’t I’ve been training language models and find it useful/interesting to see how the text generated evolves over time, so I created the following callback I use often. I’d clean it up a bit before creating a PR, largely the actual printing/writing sections.
Nothing super novel, just repacking existing callback/code from the notebooks:
class PrintSentenceEveryNBatches(LearnerCallback):
"Generate a sentence after n batches."
def __init__(self, learn:Learner, n_batches:int=100, seed_text:str='', n_words:int=100, n_sentences:int=2, temperature:float=0.75):
super().__init__(learn)
self.n_batches = n_batches-1 # iteration starts from 0
self.seed_text = seed_text
self.n_words = n_words
self.n_sentences = n_sentences
self.temperature = temperature
def generate_sentence(self, iteration):
generated = "\n".join(self.learn.predict(self.seed_text, self.n_words, self.temperature) for _ in range(self.n_sentences))
self.learn.model.train() # learn.predict toggles eval mode, change it back to train
print(f'--------------------- Sentence generated at batch {iteration} ----------------------------')
print(generated)
with open('./generated_sentences.txt', 'a') as f:
f.write(f'--------------------- Sentence generated at batch {iteration} ----------------------------\n')
f.write(generated + '\n')
def on_batch_end(self, iteration, train, **kwargs):
if iteration % self.n_batches == 0:
self.generate_sentence(iteration)
How to use it:
learn = language_model_learner(
data,
AWD_LSTM,
drop_mult=0.3,
pretrained=False
)
print_callback = PrintSentenceEveryNBatches(learn, seed_text='pt is', n_batches=1000, temperature=1)
learn.fit_one_cycle(1, 1e-2, moms=(0.8,0.7), callbacks=print_callback)