Sentence Generation Callback

Hi all - wasn’t sure if this was useful enough to contribute or trivial enough that it wasn’t :slight_smile: I’ve been training language models and find it useful/interesting to see how the text generated evolves over time, so I created the following callback I use often. I’d clean it up a bit before creating a PR, largely the actual printing/writing sections.

Nothing super novel, just repacking existing callback/code from the notebooks:

class PrintSentenceEveryNBatches(LearnerCallback):
    "Generate a sentence after n batches."
    def __init__(self, learn:Learner, n_batches:int=100, seed_text:str='', n_words:int=100, n_sentences:int=2, temperature:float=0.75):
        self.n_batches = n_batches-1 # iteration starts from 0
        self.seed_text = seed_text
        self.n_words = n_words
        self.n_sentences = n_sentences
        self.temperature = temperature
    def generate_sentence(self, iteration):
        generated = "\n".join(self.learn.predict(self.seed_text, self.n_words, self.temperature) for _ in range(self.n_sentences))
        self.learn.model.train() # learn.predict toggles eval mode, change it back to train
        print(f'--------------------- Sentence generated at batch {iteration} ----------------------------')
        with open('./generated_sentences.txt', 'a') as f:
            f.write(f'--------------------- Sentence generated at batch {iteration} ----------------------------\n')
            f.write(generated + '\n')

    def on_batch_end(self, iteration, train, **kwargs):
        if iteration % self.n_batches == 0:

How to use it:

learn = language_model_learner(

print_callback = PrintSentenceEveryNBatches(learn, seed_text='pt is', n_batches=1000, temperature=1)
learn.fit_one_cycle(1, 1e-2, moms=(0.8,0.7), callbacks=print_callback)