Here is the full code, some of the comments are garbage so ignore those. Also my gist-it isn’t working properly, so that’s why I’m posting it instead of gisting it:
import fastai
from fastai import learner
from fastai import dataset
from fastai import model
from fastai.model import resnet34
from pathlib import Path
from fastai.text import *
import pandas as pd
import numpy as np
import spacy
import json
import re
import html
Creating Data
Usually I see people use completely random data when they don’t have a dataset to show a concept. Instead, I’m going to use a counting dataset that starts at a random number and then counts up 10, wrapping around from “nine” to “zero”.
numbers = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
myData = pd.DataFrame()
def DataGenerator():
numlist = ""
starting_num = random.randint(0,9)
for i in range(10):
if i==0:
numlist = str(numbers[(starting_num+i)%10])
else:
numlist = numlist + " " + str(numbers[(starting_num+i)%10])
return numlist
DataGenerator()
'four five six seven eight nine zero one two three'
for i in range(1000):
myData = myData.append(pd.Series(DataGenerator()), ignore_index=True)
tok = Tokenizer()
texts = myData[0].astype(str)
texts.values.astype(str)[0:10]
array(['one two three four five six seven eight nine zero',
'one two three four five six seven eight nine zero',
'eight nine zero one two three four five six seven',
'nine zero one two three four five six seven eight',
'seven eight nine zero one two three four five six',
'four five six seven eight nine zero one two three',
'seven eight nine zero one two three four five six',
'nine zero one two three four five six seven eight',
'zero one two three four five six seven eight nine',
'seven eight nine zero one two three four five six'],
dtype='<U49')
texts.values.astype(str)[0]
'one two three four five six seven eight nine zero'
data = tok.proc_all_mp(partition_by_cores(texts.values.astype(str)))
Now that we have the data, let’s build a frequency map. This will tell us how many times each word was seen. Since we have 5 numbers (0-9) and are doing 100 sequences, all of these will be close to 500. Usually this will not be the case and this part will help filter out any words that are only seen a low number of time.
freq = Counter(p for o in data for p in o)
itos will be used to translate the number back into a string
itos = [o for o,c in freq.most_common(10) if c > 2]
itos.insert(0, '_pad_')
itos.insert(0, '_unk_')
freq.most_common(10)
[('one', 1000),
('two', 1000),
('three', 1000),
('four', 1000),
('five', 1000),
('six', 1000),
('seven', 1000),
('eight', 1000),
('nine', 1000),
('zero', 1000)]
itos
['_unk_',
'_pad_',
'one',
'two',
'three',
'four',
'five',
'six',
'seven',
'eight',
'nine',
'zero']
len(itos)
12
The stoi variable will create the translator from the strings to the int versions.
stoi = collections.defaultdict(lambda:0, {v:k for k,v in enumerate(itos)})
len(itos)
12
The lambda:0 is telling this that if you don’t know what the word is, give it a value of “0” which we know is tied to ‘unk’ so translating it back, would replace that word with ‘unk’
unknownNumber = stoi["ten"];print("unknownNumber idx: " + str(unknownNumber))
knownNumber = stoi["nine"];print("knownNumber idx: " + str(knownNumber))
print(itos[unknownNumber])
print(itos[knownNumber])
unknownNumber idx: 0
knownNumber idx: 10
_unk_
nine
itos[unknownNumber]
'_unk_'
stoi["nine"]
10
stoi
defaultdict(<function __main__.<lambda>>,
{'_pad_': 1,
'_unk_': 0,
'eight': 9,
'five': 6,
'four': 5,
'nine': 10,
'one': 2,
'seven': 8,
'six': 7,
'ten': 0,
'three': 4,
'two': 3,
'zero': 11})
All I’m doing here is feeding each of my numbers through and turning the string into an int using stoi[wordtotokenize]
tokenized_data = [[stoi[o] for o in i] for i in data]
tokenized_data[0:10]
[[2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
[2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
[9, 10, 11, 2, 3, 4, 5, 6, 7, 8],
[10, 11, 2, 3, 4, 5, 6, 7, 8, 9],
[8, 9, 10, 11, 2, 3, 4, 5, 6, 7],
[5, 6, 7, 8, 9, 10, 11, 2, 3, 4],
[8, 9, 10, 11, 2, 3, 4, 5, 6, 7],
[10, 11, 2, 3, 4, 5, 6, 7, 8, 9],
[11, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[8, 9, 10, 11, 2, 3, 4, 5, 6, 7]]
PATH = Path("data/counterExample/")
em_sz,nh,nl = 8,500,3
wd=1e-7
bptt=10
bs=16
opt_fn = partial(optim.Adam, betas=(0.8, 0.99))
combined_tokenized_data = np.concatenate(tokenized_data)
dataloader = LanguageModelLoader(combined_tokenized_data, bs, bptt)
#PATH
#Pad_Idx
#Number of tokens
#dataloader - Training
#dataloader - Validation (Should be different from Training)
modeldata = LanguageModelData(PATH,1,len(stoi), dataloader, dataloader)
modeldata.nt
13
??modeldata.get_model
drops=np.array([0.25, 0.1, 0.2, 0.02, 0.15])*0.7
learner = modeldata.get_model(opt_fn, em_sz, nh, nl,dropouti=drops[0], dropout=drops[1], wdrop=drops[2], dropoute=drops[3], dropouth=drops[4])
learner.metrics = [accuracy]
learner.unfreeze()
learner
SequentialRNN(
(0): RNN_Encoder(
(encoder): Embedding(13, 8, padding_idx=1)
(encoder_with_dropout): EmbeddingDropout(
(embed): Embedding(13, 8, padding_idx=1)
)
(rnns): ModuleList(
(0): WeightDrop(
(module): LSTM(8, 500, dropout=0.105)
)
(1): WeightDrop(
(module): LSTM(500, 500, dropout=0.105)
)
(2): WeightDrop(
(module): LSTM(500, 8, dropout=0.105)
)
)
(dropouti): LockedDropout(
)
(dropouths): ModuleList(
(0): LockedDropout(
)
(1): LockedDropout(
)
(2): LockedDropout(
)
)
)
(1): LinearDecoder(
(decoder): Linear(in_features=8, out_features=13)
(dropout): LockedDropout(
)
)
)
learner.lr_find()
Failed to display Jupyter Widget of type HBox
.
If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean
that the widgets JavaScript is still loading. If this message persists, it
likely means that the widgets JavaScript library is either not installed or
not enabled. See the Jupyter
Widgets Documentation for setup instructions.
If you're reading this message in another frontend (for example, a static
rendering on GitHub or NBViewer),
it may mean that your frontend doesn't currently support widgets.
97%|█████████▋| 59/61 [00:01<00:00, 58.82it/s, loss=10.9]
learner.sched.plot()
lr = 10e-1
learner.fit(lr,10)
Failed to display Jupyter Widget of type HBox
.
If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean
that the widgets JavaScript is still loading. If this message persists, it
likely means that the widgets JavaScript library is either not installed or
not enabled. See the Jupyter
Widgets Documentation for setup instructions.
If you're reading this message in another frontend (for example, a static
rendering on GitHub or NBViewer),
it may mean that your frontend doesn't currently support widgets.
epoch trn_loss val_loss accuracy
0 2.501293 2.401719 0.100253
1 2.450239 2.784427 0.099447
2 2.462207 2.512367 0.100442
3 2.454979 2.717493 0.100404
4 2.436725 2.509735 0.101221
5 2.441772 2.537813 0.09924
6 2.421832 2.389191 0.100319
7 2.455708 2.511775 0.101774
8 2.476664 2.857218 0.100419
9 2.465733 2.50652 0.099912
[2.5065198, 0.099911526761582645]
needPrediction = np.array([[5]])
probs = learner.model(V(needPrediction))
probs[0][-1].exp()
Variable containing:
1.0839e-07
1.1849e-07
3.1182e+02
3.0495e+02
4.0501e+02
5.6112e+02
4.1499e+02
6.6163e+02
2.3161e+02
1.2453e+02
2.1783e+02
1.6167e+02
1.1372e-07
[torch.cuda.FloatTensor of size 13 (GPU 0)]
itos[5]
'four'
itos[to_np(probs[0][-1].exp()).argmax()]
'six'
for i in range(2,12):
needPrediction = np.array([[i]])
probs = learner.model(V(needPrediction))
print(itos[i] + "---->" + itos[to_np(probs[0][-1].exp()).argmax()])
one---->six
two---->six
three---->six
four---->six
five---->six
six---->six
seven---->six
eight---->six
nine---->six
zero---->six
At this point, still not perfect, but getting there, let’s try adding another 100 iterations of fit
learner.fit(lr,10)
Failed to display Jupyter Widget of type HBox
.
If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean
that the widgets JavaScript is still loading. If this message persists, it
likely means that the widgets JavaScript library is either not installed or
not enabled. See the Jupyter
Widgets Documentation for setup instructions.
If you're reading this message in another frontend (for example, a static
rendering on GitHub or NBViewer),
it may mean that your frontend doesn't currently support widgets.
epoch trn_loss val_loss accuracy
0 2.507211 2.473511 0.099996
1 2.567874 2.484493 0.098951
2 2.447682 2.400187 0.099441
3 2.447006 2.650691 0.100153
4 2.449038 2.62591 0.100343
5 2.485107 2.414987 0.098826
6 2.438728 2.444442 0.100094
7 2.42563 2.494963 0.09899
8 2.445234 2.356856 0.099618
9 2.447071 2.997203 0.100324
[2.9972031, 0.10032353916617691]
for i in range(2,12):
needPrediction = np.array([[i]])
probs = learner.model(V(needPrediction))
print(itos[i] + "---->" + itos[to_np(probs[0][-1].exp()).argmax()])
one---->one
two---->one
three---->one
four---->one
five---->one
six---->one
seven---->one
eight---->one
nine---->one
zero---->one
learner.fit(lr,10)
Failed to display Jupyter Widget of type HBox
.
If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean
that the widgets JavaScript is still loading. If this message persists, it
likely means that the widgets JavaScript library is either not installed or
not enabled. See the Jupyter
Widgets Documentation for setup instructions.
If you're reading this message in another frontend (for example, a static
rendering on GitHub or NBViewer),
it may mean that your frontend doesn't currently support widgets.
epoch trn_loss val_loss accuracy
0 2.516302 2.440026 0.098669
1 2.439618 2.33569 0.10036
2 2.465316 2.439928 0.099074
3 2.486601 2.420878 0.10029
4 2.452492 2.521931 0.100927
5 2.491448 2.379579 0.099763
6 2.420096 2.424615 0.102314
7 2.45222 2.546835 0.099964
8 2.514125 2.333285 0.101252
9 2.445579 2.431223 0.100947
[2.4312229, 0.10094661253397583]
for i in range(2,12):
needPrediction = np.array([[i]])
probs = learner.model(V(needPrediction))
print(itos[i] + "---->" + itos[to_np(probs[0][-1].exp()).argmax()])
one---->six
two---->six
three---->six
four---->six
five---->six
six---->six
seven---->six
eight---->six
nine---->six
zero---->six
So more training doesn’t seem to be helping, we still have multiple that are going to the wrong next word
learner.unfreeze()
learner.fit(lr,100)
Failed to display Jupyter Widget of type HBox
.
If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean
that the widgets JavaScript is still loading. If this message persists, it
likely means that the widgets JavaScript library is either not installed or
not enabled. See the Jupyter
Widgets Documentation for setup instructions.
If you're reading this message in another frontend (for example, a static
rendering on GitHub or NBViewer),
it may mean that your frontend doesn't currently support widgets.
epoch trn_loss val_loss accuracy
0 2.532187 2.391899 0.100669
1 2.449502 2.592566 0.100048
2 2.422548 2.481155 0.098664
3 2.433634 2.374926 0.098504
4 2.4846 2.433929 0.099653
5 2.443925 2.489157 0.100965
6 2.458176 2.478071 0.100885
7 2.446534 2.415808 0.102158
8 2.448918 2.361977 0.098269
9 2.477083 2.536442 0.099198
10 2.452269 2.374832 0.099516
11 2.463058 2.559443 0.098423
12 2.464275 2.445957 0.101949
13 2.503255 2.701074 0.100223
14 2.462109 2.351269 0.099997
15 2.452649 2.449541 0.099165
16 2.468501 2.332346 0.100533
17 2.444887 2.346866 0.100175
18 2.42937 2.394043 0.100376
19 2.476886 2.41362 0.099247
20 2.449308 2.459924 0.099959
21 2.434841 2.576449 0.098805
22 2.454262 2.583148 0.09825
23 2.478229 2.414344 0.098201
24 2.443477 2.332206 0.099706
25 2.44875 2.507034 0.100083
26 2.456731 3.052977 0.100177
27 2.473659 2.569483 0.100308
28 2.488046 2.37466 0.100972
29 2.43887 2.727457 0.100739
30 2.463462 2.530313 0.100128
31 2.449768 2.480434 0.099945
32 2.458727 2.518376 0.100101
33 2.459107 2.493311 0.099963
34 2.461484 2.409652 0.100445
35 2.443723 2.419954 0.09979
36 2.427422 2.460212 0.099989
37 2.509038 2.615475 0.100165
38 2.468286 2.453477 0.100413
39 2.43953 2.438079 0.099731
40 2.459077 2.461557 0.09994
41 2.483411 2.4851 0.102056
42 2.458116 2.600188 0.101106
43 2.456078 2.450119 0.100664
44 2.452124 2.348065 0.099687
45 2.48183 2.583075 0.099432
46 2.454912 2.363219 0.099733
47 2.51616 2.376456 0.098833
48 2.434357 2.698936 0.100779
49 2.452154 2.353041 0.100286
50 2.437518 2.416678 0.101911
51 2.524313 2.55522 0.098992
52 2.425512 3.116068 0.0989
53 2.439224 2.339689 0.100026
54 2.444603 2.582581 0.099018
55 2.47307 2.664025 0.097857
56 2.444368 2.633893 0.10105
57 2.436004 2.380004 0.100453
58 2.452654 2.574329 0.10044
59 2.497562 2.333667 0.101639
60 2.471379 2.450145 0.099848
61 2.460531 2.35693 0.099431
62 2.417477 2.712232 0.100774
63 2.469832 2.370589 0.099859
64 2.48184 2.737447 0.099237
65 2.487076 2.45893 0.09972
66 2.434812 2.357924 0.099824
67 2.434282 2.461235 0.099532
68 2.458523 2.419769 0.100679
69 2.444596 2.474225 0.097741
70 2.467982 2.712098 0.10113
71 2.477893 2.455158 0.098918
72 2.442475 2.399636 0.100612
73 2.418096 2.622461 0.101848
74 2.49339 2.422304 0.099105
75 2.418169 2.568468 0.100021
76 2.468581 2.547715 0.100119
77 2.456119 2.536897 0.10083
78 2.434664 2.65585 0.100217
79 2.522499 2.657638 0.100374
80 2.452929 2.461939 0.099831
81 2.426431 2.458987 0.098729
82 2.450306 2.359966 0.10028
83 2.425657 2.353513 0.099618
84 2.503894 2.822604 0.099164
85 2.431455 2.324247 0.100215
86 2.44774 2.777127 0.100655
87 2.447497 2.568796 0.099632
88 2.433973 2.342007 0.09897
89 2.450539 2.323313 0.100807
90 2.434953 2.340156 0.099623
91 2.463191 2.557094 0.099799
92 2.462035 2.458475 0.099565
93 2.484958 2.443302 0.099806
94 2.472418 2.513814 0.099521
95 2.434136 2.383336 0.098641
96 2.453682 2.482803 0.100369
97 2.499124 2.485369 0.100844
98 2.450978 2.352065 0.101033
99 2.426416 2.594633 0.098279
[2.5946326, 0.09827934010107009]
for i in range(2,12):
needPrediction = np.array([[i]])
probs = learner.model(V(needPrediction))
print(itos[i] + "---->" + itos[to_np(probs[0][-1].exp()).argmax()])
one---->two
two---->two
three---->two
four---->two
five---->two
six---->two
seven---->two
eight---->two
nine---->two
zero---->two
learner.fit(lr, 1, wds=wd, use_clr=(20,10), cycle_len=15)
Failed to display Jupyter Widget of type HBox
.
If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean
that the widgets JavaScript is still loading. If this message persists, it
likely means that the widgets JavaScript library is either not installed or
not enabled. See the Jupyter
Widgets Documentation for setup instructions.
If you're reading this message in another frontend (for example, a static
rendering on GitHub or NBViewer),
it may mean that your frontend doesn't currently support widgets.
epoch trn_loss val_loss accuracy
0 2.504131 2.497203 0.101079
1 2.410001 2.320999 0.100073
2 2.35829 2.310918 0.099568
3 2.344692 2.310223 0.100768
4 2.344776 2.306048 0.101331
5 2.341698 2.56541 0.0
6 2.350182 2.30824 0.101532
7 2.339936 2.304673 0.099802
8 2.336478 2.306644 0.100232
9 2.332428 2.30723 0.100003
10 2.332684 2.303474 0.099913
11 2.326382 2.307298 0.099052
12 2.328791 2.303727 0.099682
13 2.327139 2.303137 0.099768
14 2.323202 2.302859 0.100209
[2.3028586, 0.10020882822573185]
for i in range(2,10):
needPrediction = np.array([[i]])
probs = learner.model(V(needPrediction))
print(itos[i] + "---->" + itos[to_np(probs[0][-1].exp()).argmax()])
one---->eight
two---->eight
three---->eight
four---->eight
five---->eight
six---->eight
seven---->eight
eight---->eight