Should I be posting to issues/bugs?
Not sure hot to handle this. I would like to use QRNN so setup my LM. That seems to work fine. But then using it for a classifier seems to fail. Full example:
from fastai import *
from fastai.text import *
path = untar_data(URLs.IMDB_SAMPLE)
data = TextLMDataBunch.from_csv(path, 'texts.csv')
data_cls = TextClasDataBunch.from_csv(path, 'texts.csv')
config = awd_lstm_lm_config.copy()
config['qrnn'] = True
learn = language_model_learner(data, AWD_LSTM, drop_mult=0.5,config=config,pretrained=False)
learn.freeze()
learn.fit(1)
##classifier
data_cls = TextClasDataBunch.from_csv(path, 'texts.csv')
config = awd_lstm_lm_config.copy()
config['qrnn'] = True
learn = text_classifier_learner(data_cls, AWD_LSTM, drop_mult=0.5,config=config)
learn.freeze()
learn.fit(1)
I pass that in, no problem. Then train my LM, then try to use in a classifier, I get this error:
TypeError Traceback (most recent call last)
<ipython-input-3-23c1feca5f7b> in <module>
4 config['qrnn'] = True
5
----> 6 learn = text_classifier_learner(data_cls, AWD_LSTM, drop_mult=0.5,config=config)
7 learn.freeze()
8 learn.fit(1)
~/fast_ai/fastai-fork/fastai/text/learner.py in text_classifier_learner(data, arch, bptt, max_len, config, pretrained, drop_mult, lin_ftrs, ps, **learn_kwargs)
285 "Create a `Learner` with a text classifier from `data` and `arch`."
286 model = get_text_classifier(arch, len(data.vocab.itos), data.c, bptt=bptt, max_len=max_len,
--> 287 config=config, drop_mult=drop_mult, lin_ftrs=lin_ftrs, ps=ps)
288 meta = _model_meta[arch]
289 learn = RNNLearner(data, model, split_func=meta['split_clas'], **learn_kwargs)
~/fast_ai/fastai-fork/fastai/text/learner.py in get_text_classifier(arch, vocab_sz, n_class, bptt, max_len, config, drop_mult, lin_ftrs, ps, pad_idx)
276 ps = [config.pop('output_p')] + ps
277 init = config.pop('init') if 'init' in config else None
--> 278 encoder = MultiBatchEncoder(bptt, max_len, arch(vocab_sz, **config), pad_idx=pad_idx)
279 model = SequentialRNN(encoder, PoolingLinearClassifier(layers, ps))
280 return model if init is None else model.apply(init)
TypeError: __init__() got an unexpected keyword argument 'tie_weights'
Appears that tie_weights
and out_bias
are unused (which makes sense)
This snippet works:
config = awd_lstm_lm_config.copy()
config['qrnn'] = True
del config['tie_weights']
del config['out_bias']