Hello.
I’m fooling around with the tabular data and gets some errors when trying to create a datablock with two columns as ‘y’.
I get the error: 'got an unexpected keyword argument ‘one_hot’
# Create dataframe
df=pd.DataFrame()
df['A']=np.arange(datalen)
df['B']=np.arange(datalen)*10
df['C']=np.arange(datalen)*100
df['D']=np.random.choice(a=[True,False], size=datalen)
df['out']=np.arange(datalen)*2
df['out2'] =np.arange(datalen)*2
# create datablock
start = int(0.8 * len(df))
idx = list(range(start, len(df)))
procs = [Categorify]
src = (TabularList.from_df(df, cat_names=['D'], cont_names=['A', 'B', 'C'], procs=procs)
.split_by_idx(idx)
.label_from_df(cols=['out', 'out2'], label_cls=FloatList, log=True)
.databunch(bs=8))
I’ve tracked down the error to the function: label_from_df in data_block.py
Could it be that this function should look like this instead?
def label_from_df(self, cols:IntsOrStrs=1, **kwargs):
"Label `self.items` from the values in `cols` in `self.xtra`."
labels = _maybe_squeeze(self.xtra.iloc[:,df_names_to_idx(cols, self.xtra)])
assert labels.isna().sum().sum() == 0, f"You have NaN values in column(s) {cols} of your dataframe, please fix it."
# Added check for 'log', If this key is in kwargs, the y should not be one-hot encoded
# using multi category
if is_listy(cols) and len(cols) > 1 and not 'log' in kwargs:
new_kwargs = dict(one_hot=True, label_cls=MultiCategoryList, classes= cols)
# swithced place on kwargs and new_kwargs since "MultyCategoryList should override
# existing "FloatList". FloatList don't take the 'one hot' key.
# org kwargs = {**new_kwargs, **kwargs}
kwargs = {**kwargs, **new_kwargs}
return self.label_from_list(labels, **kwargs)