Fine-Tuning ULMFiT

(Youcef Djeddar) #1

Hi,

Is it possible to fine-tune ULMFiT on my own dataset and then extract embedding vectors from it?

Thanks

0 Likes

(Abhinav Verma) #2

Yes. That’s what it’s designed for. One of the fast.ai alumni @muellerzr has done so in one of his posts.
https://muellerzr.github.io/Suspecto/. You can check his blog. I then think since embedding layer is the first layer of the architecture. You can extract weights from them.

1 Like

(Sooraj Mangalath Subrahmannian) #3

In the link given above, it tells you about how to fine tune on your own dataset. To extract embeddings, you could tweak the API in the following ways.

Solution 1:

import torch
import fastai
from fastai.text import AWD_LSTM,load_data,MultiBatchEncoder, RNNLearner

lm_data_path = ‘/home/ubuntu/efs/corpus/fine_tuning_corpus/ft_corpus_250K.csv’
output_path = ‘/home/ubuntu/efs/fp16_ulmfit/vsz_100k_dp_0.75/’

Encoder initialization

bptt=70
max_len=20*70
vocab_sz=100000
drop_mult = 1

AWD-LSTM config

config = dict(emb_sz=400, n_hid=1152, n_layers=3, pad_token=1, qrnn=False, bidir=False, output_p=0.4,
hidden_p=0.3, input_p=0.4, embed_p=0.05, weight_p=0.5)
for k in config.keys():
if k.endswith(’_p’): config[k] *= drop_mult
ps = [config.pop(‘output_p’)]

Get the AWD-LSTM encoder

encoder = MultiBatchEncoder(bptt, max_len, AWD_LSTM(vocab_sz, **config), pad_idx=1)
learn = RNNLearner(dl, encoder)
learn.model.reset()

Load the fine tuned AWD-LSTM

learn.model = learn.model.module
learn.model.load_state_dict(torch.load(f’{output_path}models/fwd_enc.pth’, map_location=‘cuda’))

Get the embeddings

text = ‘APPLE is the world ’s greenest tech company’
batch = learn.data.one_item(text)
learn.model(batch[0])

Solution 2:

Use hook on the classifier model made using the encoder

from fastai.callbacks.hooks import hook_output
layer = learn.model[:2][1].layers[0] # the layer from which you want to extract the output
input_ds = learn.data.one_item(text)
with hook_output(layer) as hook_forward:
preds = learn.model(input_ds[0])
print(hook_forward.stored)

executed in 3ms, finished 12:51:36 2019-12-10

0 Likes

(Rishabh Choudhary) #4

In this example post, how is the csv file formatted. Is each article in one cell of the csv?

0 Likes

(Abhinav Verma) #5

yes, each article is one row

1 Like