The error happens since you are passing an instance of AWD_LSTM
instead of the class. text_classifier_learner
calls get_text_classifier
which calls meta = _model_meta[arch]
. That dict expects a class (not an instance of that class) hence the KeyError.
You can also see in the code of the two function that the architecture gets its vocab_sz
from the dataloaders that you pass:
def text_classifier_learner(dls, arch, …)
…
vocab = _get_text_vocab(dls)
…
model = get_text_classifier(arch, len(vocab), …)
…
def get_text_classifier(arch, vocab_sz, …
…
encoder = SentenceEncoder(seq_len, arch(vocab_sz, **config), pad_idx=pad_idx, max_len=max_len)
…
so the place to change the vocab size would be the dataloaders object (this makes sense since your data should know if it has more or less tokens at hand).
If you use the DataBlock API you can simply add a max_vocab
parameter to the TextBlock
. If you use a higher level function to create your dls
you need to recreate that manually, I did that (simplified, using IMDB) for TextDataLoaders.from_folder
below:
dblock = DataBlock(
blocks=(TextBlock.from_folder(path, vocab=None, is_lm=False, max_vocab=10000), CategoryBlock(vocab=None)),
get_items=partial(get_text_files, folders=['train','test']),
splitter=GrandparentSplitter(train_name='train', valid_name='test'),
get_y=parent_label
)
dls = TextDataLoaders.from_dblock(dblock, path, path=path, seq_len=72)
print(len(dls.vocab[0]))
>> 10008
To do any other customization on AWD_LSTM
, as the instantiation in your code suggests, you can pass a config
parameter to text_classifier_learner
that resembles the default one:
from fastai.text.models.core import _model_meta
my_config = _model_meta[AWD_LSTM]['config_clas'].copy()
my_config['emb_sz'] = 20
…
learn = text_classifier_learner(dls, AWD_LSTM, metrics=accuracy, config=my_config, pretrained=False).to_fp16()
Note that you can not used the pretrained models, since the configs don’t match.
Hope that helps 