Learner.predict_array on a language model


(Kevin Bird) #1

I have built a language model that I’m hoping will print out sql-looking text, so I am feeding it the following:

ss="""select *"""
s = [tok.spacy_tok(ss)]
t = np.array([[stoi[o] for o in p] for p in s])
' '.join(s[0])
predictions = learner.predict_array(t)

This gives me the following list:

[array([[ -0.27024,  -8.19738,   0.62141, ...,  -8.40918,  -8.29776,  -9.02088],
        [  0.73325,  -9.69724,   1.94042, ...,  -9.71638,  -9.17598, -10.30418]], dtype=float32),
 [array([[[-0.02815,  0.01443, -0.21755, ..., -0.2932 , -0.83243, -0.68881],
          [-0.11399,  0.58334,  0.01465, ..., -0.42394,  0.0057 , -0.00011]]], dtype=float32),
  array([[[ 0.17183,  0.06068,  0.87668, ...,  0.09799,  0.00023, -0.0269 ],
          [-0.03981, -0.05539, -0.23291, ...,  0.01526,  0.95958,  0.20518]]], dtype=float32),
  array([[[ 0.67286, -0.62926, -0.46802,  0.70781, -0.1038 , -0.56664, -0.75002,  0.59732, -0.71267,  0.67986,
            0.29438, -0.32798, -0.11172, -0.10629, -0.22235,  0.02634, -0.12676,  0.00015, -0.2415 , -0.69331,
            0.1373 ,  0.6395 ,  0.54709, -0.18545,  0.73523, -0.07354,  0.72034, -0.69605, -0.83535, -0.0378 ,
            0.93838,  0.12184, -0.77273, -0.88967,  0.06678, -0.88061, -0.25478, -0.91705,  0.00156, -0.0026 ,
            0.12699, -0.09595,  0.36363, -0.02742, -0.72856, -0.46303,  0.29309,  0.84334,  0.13666, -0.08954,
            0.71309, -0.83131,  0.99123, -0.63005,  0.02297, -0.31556,  0.76686, -0.58607,  0.75435, -0.44936,
            0.05017, -0.02792,  0.34902, -0.32274,  0.18679, -0.20664,  0.03724, -0.51263,  0.88026,  0.00346,
            0.17272, -0.7053 ,  0.03951,  0.00038, -0.38001, -0.42956,  0.67778,  0.5839 , -0.06674, -0.00792,
           -0.34072,  0.27365,  0.58116,  0.00094, -0.06808,  0.90614,  0.5648 ,  0.41149,  0.00279, -0.8009 ,
            0.89737,  0.1683 , -0.47667, -0.60266,  0.00028,  0.00242,  0.50461, -0.00596, -0.39126, -0.00524,
            0.17619,  0.11177,  0.56797,  0.04151,  0.00092,  0.81613,  0.00019, -0.68058,  0.00011, -0.13884,
            0.2879 ,  0.08336,  0.03764, -0.16683, -0.12263,  0.99375,  0.77305, -0.72881, -0.51032, -0.00854,
            0.21516,  0.94627, -0.43949,  0.09651,  0.48232, -0.67047,  0.61994, -0.65298,  0.76164,  0.40813,
           -0.21207,  0.02539,  0.30622,  0.94719, -0.47456,  0.05289,  0.68243, -0.65845, -0.22452,  0.46537,
            0.17224, -0.37007, -0.01055,  0.31053, -0.08247,  0.17951,  0.66865,  0.00286,  0.10626, -0.58866,
           -0.65322, -0.0006 , -0.01518,  0.42873, -0.09785,  0.62467, -0.47047,  0.91923,  0.79256,  0.5828 ,
            0.62863,  0.03925, -0.69479, -0.55773, -0.19763,  0.03652, -0.17079, -0.85278, -0.46987, -0.81744,
            0.00205, -0.03472,  0.22838,  0.57121,  0.11225,  0.15275,  0.75706,  0.02474,  0.10318,  0.83198,
           -0.0023 , -0.44944, -0.45313, -0.00027,  0.53422, -0.0233 ,  0.06128, -0.02725, -0.06369,  0.20423,
           -0.85756, -0.33212, -0.06988,  0.14369,  0.78994, -0.04727,  0.02333, -0.31869,  0.10497, -0.94189,
           -0.06164,  0.31728,  0.10503,  0.00242, -0.30187,  0.16965, -0.59538,  0.0048 , -0.22711, -0.16922,
            0.00835,  0.55194, -0.99398, -0.02675, -0.12815, -0.02341, -0.52103, -0.10805,  0.19744, -0.64536,
           -0.55582, -0.13832, -0.12886,  0.06639,  0.00281, -0.05265,  0.65211,  0.2119 , -0.80961, -0.48294,
            0.09542, -0.85252,  0.20053,  0.03175, -0.97582, -0.23859, -0.25901,  0.77566,  0.04035, -0.73435,
            0.02944, -0.06317, -0.57585,  0.00446, -0.79466,  0.3611 ,  0.59581, -0.2554 ,  0.59862,  0.04779,
            0.16032,  0.03658,  0.03249, -0.00826,  0.5015 , -0.20112, -0.00579,  0.80953,  0.0095 , -0.30357,
           -0.075  ,  0.48389,  0.97771, -0.61917, -0.34016, -0.23435,  0.1816 , -0.61097, -0.00523,  0.81557,
           -0.65464, -0.73878,  0.04386,  0.55917, -0.36094,  0.31521, -0.91728,  0.01842,  0.34646, -0.95734,
           -0.39245, -0.07041, -0.58314, -0.00131, -0.01181, -0.901  ,  0.03733, -0.68288,  0.71937, -0.55193,
            0.05402, -0.99416, -0.5582 , -0.40096,  0.00934,  0.20287, -0.02709,  0.47579, -0.09131,  0.44647,
            0.15844, -0.15828,  0.4887 , -0.27968, -0.01158, -0.40339, -0.88158,  0.48028, -0.31001,  0.01998,
            0.008  ,  0.75216,  0.44623, -0.97026,  0.95654,  0.93987,  0.81009, -0.60017,  0.00245, -0.75555,
           -0.70842, -0.81023,  0.40101, -0.71133,  0.00094, -0.03572,  0.75106,  0.79254,  0.01317,  0.11687,
            0.03002,  0.53937,  0.30929, -0.09701, -0.57614,  0.58046,  0.52487,  0.52834, -0.47754,  0.07432,
            0.36766,  0.12985, -0.65396,  0.54573, -0.95443,  0.50693,  0.48865,  0.78406,  0.10212,  0.22122,
            0.84579, -0.66749, -0.80173, -0.00079, -0.01478,  0.91735, -0.38039, -0.03046, -0.14163,  0.12755,
           -0.02287, -0.2211 , -0.11576, -0.7463 , -0.24429,  0.12826, -0.80301, -0.57252, -0.06721, -0.99647,
           -0.45914, -0.2207 , -0.62894, -0.52069, -0.09315, -0.15881, -0.34433, -0.32729, -0.8492 ,  0.13056,
           -0.38323, -0.02844, -0.01491, -0.08301,  0.01411,  0.99509,  0.0015 ,  0.1794 , -0.21069, -0.00529,
           -0.00715, -0.01159,  0.18546,  0.68294,  0.00909, -0.03601,  0.00338,  0.71442,  0.73567,  0.15405],
          [-0.10745, -0.89728, -0.02773,  0.36846,  0.68909, -0.00238, -0.84575,  0.21863, -0.72444, -0.38227,
            0.92325,  0.82073, -0.09403, -0.02757,  0.76086,  0.86858, -0.04909, -0.12215, -0.42108, -0.25846,
            0.07183,  0.3582 ,  0.28667,  0.19003, -0.0887 , -0.40156,  0.09809,  0.0253 , -0.66314, -0.00102,
           -0.1837 , -0.06704, -0.07694, -0.00983,  0.0421 ,  0.37194, -0.58079, -0.25971, -0.00935, -0.02574,
           -0.00143,  0.0006 ,  0.54139, -0.43277, -0.06122,  0.71352, -0.33764,  0.26673,  0.15375, -0.08849,
           -0.79094,  0.37394, -0.15387,  0.41102,  0.03294, -0.37476,  0.34984,  0.57495,  0.81134, -0.23035,
            0.00533, -0.7472 ,  0.42648, -0.14137,  0.3784 ,  0.63215, -0.03785, -0.75149,  0.53347, -0.77514,
            0.55872, -0.3569 ,  0.00124,  0.93715, -0.39832, -0.40543,  0.71075,  0.00698, -0.56825, -0.01117,
           -0.05005, -0.01759,  0.25676,  0.44227,  0.26843, -0.04639,  0.81999,  0.03353,  0.00346,  0.23105,
            0.38465, -0.15921,  0.94683, -0.00414,  0.00014,  0.0221 , -0.31495, -0.37511,  0.78355, -0.20104,
            0.12596,  0.19275, -0.06768, -0.27088,  0.0865 , -0.84591, -0.09512, -0.92559,  0.10891, -0.02505,
            0.35839,  0.82832,  0.72355, -0.23236,  0.3319 ,  0.98149, -0.78959,  0.40203, -0.07513,  0.47152,
           -0.84441,  0.11772,  0.00124,  0.83422,  0.27929,  0.86386, -0.20588,  0.26875,  0.85652,  0.74349,
           -0.01994,  0.0347 , -0.12332,  0.00867,  0.13142,  0.54586, -0.18044,  0.79878,  0.46989,  0.8288 ,
            0.11206, -0.63212,  0.07159,  0.07711, -0.19824, -0.01445, -0.01615,  0.0213 ,  0.39786, -0.00181,
            0.65992, -0.2276 , -0.75137,  0.76128,  0.20748,  0.00872, -0.04719, -0.84117, -0.20077,  0.18145,
           -0.79947,  0.81305, -0.02757, -0.01545, -0.21821,  0.49826, -0.02486, -0.53021,  0.17162,  0.04734,
            0.17984, -0.84934,  0.08191, -0.3045 ,  0.13092,  0.01098, -0.05554, -0.09635, -0.13986, -0.18509,
           -0.57531,  0.01809, -0.00269,  0.52137, -0.06853, -0.02114,  0.62262, -0.76599,  0.01034,  0.16209,
           -0.0254 ,  0.02735, -0.02615,  0.65093, -0.85351, -0.89152,  0.00799, -0.59672, -0.48846, -0.59858,
           -0.15403,  0.65761,  0.16738,  0.0754 ,  0.52624,  0.82602, -0.02598,  0.40127,  0.45649,  0.11987,
            0.24958,  0.34019, -0.742  ,  0.04091, -0.02229,  0.74436,  0.12138,  0.71557,  0.01476, -0.25715,
            0.08925, -0.08362, -0.3    , -0.06476,  0.09035,  0.7435 , -0.07749, -0.8163 , -0.33656,  0.29714,
           -0.00522, -0.0003 ,  0.87023,  0.02558, -0.99604, -0.1197 , -0.06248, -0.00118,  0.10089, -0.72426,
            0.00437, -0.40108, -0.15225,  0.00475, -0.0592 ,  0.31623, -0.00579, -0.62022, -0.12784, -0.66596,
            0.30865,  0.32293,  0.00099, -0.02548, -0.181  , -0.57369,  0.46369, -0.97714, -0.00414,  0.69337,
           -0.97754, -0.03125,  0.63071,  0.04494, -0.71515, -0.22428, -0.11089, -0.7952 , -0.09471, -0.63099,
           -0.57655,  0.02376,  0.07157, -0.17595,  0.96694,  0.28698,  0.52431,  0.00224,  0.4979 , -0.15099,
           -0.92323, -0.35417, -0.0007 ,  0.03533, -0.98333,  0.80681,  0.14622,  0.80184,  0.63707, -0.34576,
           -0.04199, -0.58708, -0.01058,  0.21514, -0.00114,  0.14158, -0.00103,  0.00136,  0.00237,  0.03363,
            0.36892, -0.65307,  0.43041,  0.15427, -0.40644, -0.71505, -0.50262, -0.5281 ,  0.53987,  0.18279,
            0.10648,  0.15446,  0.01094, -0.00054,  0.0298 ,  0.10558,  0.46539,  0.84262, -0.90441, -0.32831,
           -0.97124,  0.00711,  0.30706, -0.02257,  0.82821,  0.16343, -0.17096,  0.23279,  0.94632,  0.89465,
            0.00422,  0.11399, -0.03034,  0.53426,  0.2359 , -0.72692,  0.97674, -0.29913,  0.09627,  0.47417,
           -0.87099, -0.37598, -0.00867, -0.70182, -0.81085,  0.57523, -0.12937, -0.40587,  0.23322,  0.01819,
           -0.58604,  0.50209,  0.69755, -0.71263, -0.22608,  0.82225,  0.49957,  0.73971,  0.06533, -0.79   ,
           -0.02314, -0.01394, -0.75689,  0.45218, -0.1325 ,  0.62047, -0.10407, -0.08093, -0.0114 , -0.39029,
           -0.59705, -0.94595,  0.44325, -0.37954,  0.00017, -0.00643, -0.02829, -0.97852,  0.00206, -0.39064,
           -0.00037, -0.18224, -0.25577,  0.95344,  0.61798,  0.95757, -0.20175,  0.11794, -0.71991, -0.09784,
           -0.27881, -0.19821, -0.04242, -0.8296 ,  0.00081,  0.60963,  0.62664,  0.72264,  0.01433,  0.02532]]], dtype=float32)],
 [array([[[-0.02815,  0.01443, -0.21755, ..., -0.2932 , -0.83243, -0.68881],
          [-0.11399,  0.58334,  0.01465, ..., -0.42394,  0.0057 , -0.00011]]], dtype=float32),
  array([[[ 0.17183,  0.06068,  0.87668, ...,  0.09799,  0.00023, -0.0269 ],
          [-0.03981, -0.05539, -0.23291, ...,  0.01526,  0.95958,  0.20518]]], dtype=float32),
  array([[[ 0.67286, -0.62926, -0.46802,  0.70781, -0.1038 , -0.56664, -0.75002,  0.59732, -0.71267,  0.67986,
            0.29438, -0.32798, -0.11172, -0.10629, -0.22235,  0.02634, -0.12676,  0.00015, -0.2415 , -0.69331,
            0.1373 ,  0.6395 ,  0.54709, -0.18545,  0.73523, -0.07354,  0.72034, -0.69605, -0.83535, -0.0378 ,
            0.93838,  0.12184, -0.77273, -0.88967,  0.06678, -0.88061, -0.25478, -0.91705,  0.00156, -0.0026 ,
            0.12699, -0.09595,  0.36363, -0.02742, -0.72856, -0.46303,  0.29309,  0.84334,  0.13666, -0.08954,
            0.71309, -0.83131,  0.99123, -0.63005,  0.02297, -0.31556,  0.76686, -0.58607,  0.75435, -0.44936,
            0.05017, -0.02792,  0.34902, -0.32274,  0.18679, -0.20664,  0.03724, -0.51263,  0.88026,  0.00346,
            0.17272, -0.7053 ,  0.03951,  0.00038, -0.38001, -0.42956,  0.67778,  0.5839 , -0.06674, -0.00792,
           -0.34072,  0.27365,  0.58116,  0.00094, -0.06808,  0.90614,  0.5648 ,  0.41149,  0.00279, -0.8009 ,
            0.89737,  0.1683 , -0.47667, -0.60266,  0.00028,  0.00242,  0.50461, -0.00596, -0.39126, -0.00524,
            0.17619,  0.11177,  0.56797,  0.04151,  0.00092,  0.81613,  0.00019, -0.68058,  0.00011, -0.13884,
            0.2879 ,  0.08336,  0.03764, -0.16683, -0.12263,  0.99375,  0.77305, -0.72881, -0.51032, -0.00854,
            0.21516,  0.94627, -0.43949,  0.09651,  0.48232, -0.67047,  0.61994, -0.65298,  0.76164,  0.40813,
           -0.21207,  0.02539,  0.30622,  0.94719, -0.47456,  0.05289,  0.68243, -0.65845, -0.22452,  0.46537,
            0.17224, -0.37007, -0.01055,  0.31053, -0.08247,  0.17951,  0.66865,  0.00286,  0.10626, -0.58866,
           -0.65322, -0.0006 , -0.01518,  0.42873, -0.09785,  0.62467, -0.47047,  0.91923,  0.79256,  0.5828 ,
            0.62863,  0.03925, -0.69479, -0.55773, -0.19763,  0.03652, -0.17079, -0.85278, -0.46987, -0.81744,
            0.00205, -0.03472,  0.22838,  0.57121,  0.11225,  0.15275,  0.75706,  0.02474,  0.10318,  0.83198,
           -0.0023 , -0.44944, -0.45313, -0.00027,  0.53422, -0.0233 ,  0.06128, -0.02725, -0.06369,  0.20423,
           -0.85756, -0.33212, -0.06988,  0.14369,  0.78994, -0.04727,  0.02333, -0.31869,  0.10497, -0.94189,
           -0.06164,  0.31728,  0.10503,  0.00242, -0.30187,  0.16965, -0.59538,  0.0048 , -0.22711, -0.16922,
            0.00835,  0.55194, -0.99398, -0.02675, -0.12815, -0.02341, -0.52103, -0.10805,  0.19744, -0.64536,
           -0.55582, -0.13832, -0.12886,  0.06639,  0.00281, -0.05265,  0.65211,  0.2119 , -0.80961, -0.48294,
            0.09542, -0.85252,  0.20053,  0.03175, -0.97582, -0.23859, -0.25901,  0.77566,  0.04035, -0.73435,
            0.02944, -0.06317, -0.57585,  0.00446, -0.79466,  0.3611 ,  0.59581, -0.2554 ,  0.59862,  0.04779,
            0.16032,  0.03658,  0.03249, -0.00826,  0.5015 , -0.20112, -0.00579,  0.80953,  0.0095 , -0.30357,
           -0.075  ,  0.48389,  0.97771, -0.61917, -0.34016, -0.23435,  0.1816 , -0.61097, -0.00523,  0.81557,
           -0.65464, -0.73878,  0.04386,  0.55917, -0.36094,  0.31521, -0.91728,  0.01842,  0.34646, -0.95734,
           -0.39245, -0.07041, -0.58314, -0.00131, -0.01181, -0.901  ,  0.03733, -0.68288,  0.71937, -0.55193,
            0.05402, -0.99416, -0.5582 , -0.40096,  0.00934,  0.20287, -0.02709,  0.47579, -0.09131,  0.44647,
            0.15844, -0.15828,  0.4887 , -0.27968, -0.01158, -0.40339, -0.88158,  0.48028, -0.31001,  0.01998,
            0.008  ,  0.75216,  0.44623, -0.97026,  0.95654,  0.93987,  0.81009, -0.60017,  0.00245, -0.75555,
           -0.70842, -0.81023,  0.40101, -0.71133,  0.00094, -0.03572,  0.75106,  0.79254,  0.01317,  0.11687,
            0.03002,  0.53937,  0.30929, -0.09701, -0.57614,  0.58046,  0.52487,  0.52834, -0.47754,  0.07432,
            0.36766,  0.12985, -0.65396,  0.54573, -0.95443,  0.50693,  0.48865,  0.78406,  0.10212,  0.22122,
            0.84579, -0.66749, -0.80173, -0.00079, -0.01478,  0.91735, -0.38039, -0.03046, -0.14163,  0.12755,
           -0.02287, -0.2211 , -0.11576, -0.7463 , -0.24429,  0.12826, -0.80301, -0.57252, -0.06721, -0.99647,
           -0.45914, -0.2207 , -0.62894, -0.52069, -0.09315, -0.15881, -0.34433, -0.32729, -0.8492 ,  0.13056,
           -0.38323, -0.02844, -0.01491, -0.08301,  0.01411,  0.99509,  0.0015 ,  0.1794 , -0.21069, -0.00529,
           -0.00715, -0.01159,  0.18546,  0.68294,  0.00909, -0.03601,  0.00338,  0.71442,  0.73567,  0.15405],
          [-0.10745, -0.89728, -0.02773,  0.36846,  0.68909, -0.00238, -0.84575,  0.21863, -0.72444, -0.38227,
            0.92325,  0.82073, -0.09403, -0.02757,  0.76086,  0.86858, -0.04909, -0.12215, -0.42108, -0.25846,
            0.07183,  0.3582 ,  0.28667,  0.19003, -0.0887 , -0.40156,  0.09809,  0.0253 , -0.66314, -0.00102,
           -0.1837 , -0.06704, -0.07694, -0.00983,  0.0421 ,  0.37194, -0.58079, -0.25971, -0.00935, -0.02574,
           -0.00143,  0.0006 ,  0.54139, -0.43277, -0.06122,  0.71352, -0.33764,  0.26673,  0.15375, -0.08849,
           -0.79094,  0.37394, -0.15387,  0.41102,  0.03294, -0.37476,  0.34984,  0.57495,  0.81134, -0.23035,
            0.00533, -0.7472 ,  0.42648, -0.14137,  0.3784 ,  0.63215, -0.03785, -0.75149,  0.53347, -0.77514,
            0.55872, -0.3569 ,  0.00124,  0.93715, -0.39832, -0.40543,  0.71075,  0.00698, -0.56825, -0.01117,
           -0.05005, -0.01759,  0.25676,  0.44227,  0.26843, -0.04639,  0.81999,  0.03353,  0.00346,  0.23105,
            0.38465, -0.15921,  0.94683, -0.00414,  0.00014,  0.0221 , -0.31495, -0.37511,  0.78355, -0.20104,
            0.12596,  0.19275, -0.06768, -0.27088,  0.0865 , -0.84591, -0.09512, -0.92559,  0.10891, -0.02505,
            0.35839,  0.82832,  0.72355, -0.23236,  0.3319 ,  0.98149, -0.78959,  0.40203, -0.07513,  0.47152,
           -0.84441,  0.11772,  0.00124,  0.83422,  0.27929,  0.86386, -0.20588,  0.26875,  0.85652,  0.74349,
           -0.01994,  0.0347 , -0.12332,  0.00867,  0.13142,  0.54586, -0.18044,  0.79878,  0.46989,  0.8288 ,
            0.11206, -0.63212,  0.07159,  0.07711, -0.19824, -0.01445, -0.01615,  0.0213 ,  0.39786, -0.00181,
            0.65992, -0.2276 , -0.75137,  0.76128,  0.20748,  0.00872, -0.04719, -0.84117, -0.20077,  0.18145,
           -0.79947,  0.81305, -0.02757, -0.01545, -0.21821,  0.49826, -0.02486, -0.53021,  0.17162,  0.04734,
            0.17984, -0.84934,  0.08191, -0.3045 ,  0.13092,  0.01098, -0.05554, -0.09635, -0.13986, -0.18509,
           -0.57531,  0.01809, -0.00269,  0.52137, -0.06853, -0.02114,  0.62262, -0.76599,  0.01034,  0.16209,
           -0.0254 ,  0.02735, -0.02615,  0.65093, -0.85351, -0.89152,  0.00799, -0.59672, -0.48846, -0.59858,
           -0.15403,  0.65761,  0.16738,  0.0754 ,  0.52624,  0.82602, -0.02598,  0.40127,  0.45649,  0.11987,
            0.24958,  0.34019, -0.742  ,  0.04091, -0.02229,  0.74436,  0.12138,  0.71557,  0.01476, -0.25715,
            0.08925, -0.08362, -0.3    , -0.06476,  0.09035,  0.7435 , -0.07749, -0.8163 , -0.33656,  0.29714,
           -0.00522, -0.0003 ,  0.87023,  0.02558, -0.99604, -0.1197 , -0.06248, -0.00118,  0.10089, -0.72426,
            0.00437, -0.40108, -0.15225,  0.00475, -0.0592 ,  0.31623, -0.00579, -0.62022, -0.12784, -0.66596,
            0.30865,  0.32293,  0.00099, -0.02548, -0.181  , -0.57369,  0.46369, -0.97714, -0.00414,  0.69337,
           -0.97754, -0.03125,  0.63071,  0.04494, -0.71515, -0.22428, -0.11089, -0.7952 , -0.09471, -0.63099,
           -0.57655,  0.02376,  0.07157, -0.17595,  0.96694,  0.28698,  0.52431,  0.00224,  0.4979 , -0.15099,
           -0.92323, -0.35417, -0.0007 ,  0.03533, -0.98333,  0.80681,  0.14622,  0.80184,  0.63707, -0.34576,
           -0.04199, -0.58708, -0.01058,  0.21514, -0.00114,  0.14158, -0.00103,  0.00136,  0.00237,  0.03363,
            0.36892, -0.65307,  0.43041,  0.15427, -0.40644, -0.71505, -0.50262, -0.5281 ,  0.53987,  0.18279,
            0.10648,  0.15446,  0.01094, -0.00054,  0.0298 ,  0.10558,  0.46539,  0.84262, -0.90441, -0.32831,
           -0.97124,  0.00711,  0.30706, -0.02257,  0.82821,  0.16343, -0.17096,  0.23279,  0.94632,  0.89465,
            0.00422,  0.11399, -0.03034,  0.53426,  0.2359 , -0.72692,  0.97674, -0.29913,  0.09627,  0.47417,
           -0.87099, -0.37598, -0.00867, -0.70182, -0.81085,  0.57523, -0.12937, -0.40587,  0.23322,  0.01819,
           -0.58604,  0.50209,  0.69755, -0.71263, -0.22608,  0.82225,  0.49957,  0.73971,  0.06533, -0.79   ,
           -0.02314, -0.01394, -0.75689,  0.45218, -0.1325 ,  0.62047, -0.10407, -0.08093, -0.0114 , -0.39029,
           -0.59705, -0.94595,  0.44325, -0.37954,  0.00017, -0.00643, -0.02829, -0.97852,  0.00206, -0.39064,
           -0.00037, -0.18224, -0.25577,  0.95344,  0.61798,  0.95757, -0.20175,  0.11794, -0.71991, -0.09784,
           -0.27881, -0.19821, -0.04242, -0.8296 ,  0.00081,  0.60963,  0.62664,  0.72264,  0.01433,  0.02532]]], dtype=float32)]]

Is there a way to interprete this information or should I be using another method to do my predictions. Any help or a point in the correct direction would be great.


#2

Have you made any progress on this ?
I´m also looking on how to interpret the output of predict_array.


(urmas pitsi) #3

Try taking element[0] from results, softmax each row: you’ll get probabilities what will be the next word (token idx), for each word (token) in vocabulary.

Then argmax the softmax result to get token idx, stoi[argmax result] will get you the word. This should be the next word that the model predicts for the particular input.