Hi, I’m getting an error when I use TabularProc
s with a constructor. I’m not sure if it’s a bug or I’m using them wrong.
If I use FillMissing
, it works fine. If I use FillMissing(cat_names, cont_names)
I get an error suggesting that FillMissing.apply_test()
is being called before FillMissing.apply_train()
. (I want to use the constructor so I can set parameters such as fill_strategy
.)
Reproducible example:
from fastai.tabular import *
path = untar_data(URLs.ADULT_SAMPLE)
# Version of learn() from
# https://github.com/fastai/fastai/blob/1a914b365720aa216a4a377705fe11707a665fe9/tests/test_tabular_train.py
# Only modification is extraction of 'procs' as parameter
def learn(procs=[FillMissing, Categorify, Normalize]):
df = pd.read_csv(path/'adult.csv')
# procs = [FillMissing, Categorify, Normalize]
dep_var = 'salary'
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country']
cont_names = ['age', 'fnlwgt', 'education-num']
test = TabularList.from_df(df.iloc[800:1000].copy(), path=path, cat_names=cat_names, cont_names=cont_names)
data = (TabularList.from_df(df, path=path, cat_names=cat_names, cont_names=cont_names, procs=procs)
.split_by_idx(list(range(800,1000)))
.label_from_df(cols=dep_var)
.add_test(test)
.databunch(num_workers=1))
learn = tabular_learner(data, layers=[200,100], emb_szs={'native-country': 10}, metrics=accuracy)
learn.fit_one_cycle(2, 1e-2)
return learn
# cat_names, cont_names copied from learn()
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country']
cont_names = ['age', 'fnlwgt', 'education-num']
# Succeeds as expected
learn()
# Fails (uses FillMissing constructor)
learn([FillMissing(cat_names=cat_names, cont_names=cont_names), Categorify, Normalize])
# AttributeError: 'FillMissing' object has no attribute 'na_dict'
na_dict
is set in FillMissing.apply_train()
, but I don’t think this method is being called.
Using Categorify(cat_names, cont_names)
also fails.
We can make a small class Report
to check whether apply_train()
or apply_test()
are being called:
# Make our own TabularProc subclass to check calls to apply_train and apply_test
class Report(TabularProc):
def apply_train(self, df:DataFrame):
print("Hello from TRAIN")
def apply_test(self, df:DataFrame):
print("Hello from TEST")
# Prints "Hello from TRAIN"
learn([Report, FillMissing, Categorify, Normalize])
# Hello from TRAIN
# Hello from TEST
# Hello from TEST
# Never prints "Hello from TRAIN"
learn([Report(cat_names=cat_names, cont_names=cont_names), FillMissing, Categorify, Normalize])
# Hello from TEST
# Hello from TEST
# Hello from TEST
Am I using FillMissing
etc. wrong?
Thank you!