[SOLVED] tabular_learner returning NaNs in training

I spent a lot of time trying to solve a problem, where training would start, but all the losses would be NaNs.

The solution turned out to be quite simple:

cast all numeric columns in your dataframe to float32 before initializing it to TabularDataBunch


I have this same issue but casting numeric columns in my dataframe to float32 did not fix it :weary:

Hi @jgilbert2000,
I just started wrestling with this problem. Casting to floats did not work for me either.
Did you ever find a solution?

Hi @TimK, I don’t really remember exactly what fixed it because it was awhile ago, but here is a sample from some of my old code from that time.

df = pd.read_csv(path/filename, usecols=desired_cols)

# Randomize data
df = df.iloc[np.random.permutation(len(df))]

# Convert dataframe types to float32 for fastai's tabular learner

dep_var = 'lab_result' # target / dependent variable
cat_names = [] # categorical variables
cont_names = ['age','Temp','exam_Plt','exam_WBC'] # continuous variables
# procs = [FillMissing, Categorify, Normalize] # procs didn't work because no categorical variables were used

# Percent of original dataframe
test_pct = 0
valid_pct = 0.2

# Masks for separating dataframe sets
cut_test = int(test_pct * len(df))
cut_valid = int(valid_pct*len(df))+cut_test

valid_indx = range(cut_test,cut_valid) # range of validation indices, used for fastai
dep_var = 'lab_result'

test = TabularList.from_df(df.iloc[cut_test:cut_valid].copy(), cat_names=cat_names, cont_names=cont_names)

data = (TabularList.from_df(df=df, path=path, cat_names=cat_names, cont_names=cont_names)

# Tabular Classifier
learn = tabular_learner(data, layers=[200,100],metrics=accuracy) 


preds, targets = learn.get_preds(DatasetType.Valid)
labels = np.argmax(preds, 1)

validation_accuracy = (targets == labels).type(torch.FloatTensor).mean().item()

I was using the following environment at the time:

python 3.7.4
fastai 1.0.57
torch 1.2.0
pandas 0.25.0
numpy 1.16.4

Hope this helps.

Thanks @jgilbert2000