I have some raw data like this:
date | open | high | low | close | volume | |
---|---|---|---|---|---|---|
0 | 2005-01-04 | 5.18 | 5.22 | 5.02 | 5.15 | 346551.0 |
1 | 2005-01-05 | 5.08 | 5.38 | 5.08 | 5.26 | 605342.0 |
2 | 2005-01-06 | 5.38 | 5.50 | 5.25 | 5.28 | 582162.0 |
3 | 2005-01-07 | 5.19 | 5.45 | 5.18 | 5.26 | 437115.0 |
4 | 2005-01-10 | 5.15 | 5.34 | 5.15 | 5.29 | 387913.0 |
5 | 2005-01-11 | 5.31 | 5.57 | 5.27 | 5.51 | 1209972.0 |
6 | 2005-01-12 | 5.50 | 5.51 | 5.37 | 5.43 | 536727.0 |
7 | 2005-01-13 | 5.46 | 5.62 | 5.43 | 5.57 | 1164331.0 |
8 | 2005-01-14 | 5.62 | 5.63 | 5.42 | 5.44 | 1004858.0 |
9 | 2005-01-17 | 5.40 | 5.40 | 5.16 | 5.27 | 935954.0 |
after some feature engineer, it becomes (I don’t think it cause the exception, just show the data process procedure):
date | open | high | low | close | volume | cma5 | cma10 | cma20 | cma60 | cma120 | cma250 | vma5 | vma10 | vma20 | vma60 | vma120 | vma250 | target | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
249 | 2006-01-13 | 3.58 | 3.62 | 3.53 | 3.60 | 1226823.0 | 3.606 | 3.465 | 3.3305 | 3.454333 | 3.531333 | 3.862895 | 2237795.8 | 1832271.3 | 1520951.65 | 9.338882e+05 | 1.072407e+06 | 8.186527e+05 | 3.58 |
250 | 2006-01-16 | 3.56 | 3.60 | 3.48 | 3.50 | 578826.0 | 3.562 | 3.490 | 3.3415 | 3.448500 | 3.537333 | 3.856295 | 1225406.0 | 1799309.3 | 1528996.55 | 9.287565e+05 | 1.074718e+06 | 8.195818e+05 | 3.60 |
251 | 2006-01-17 | 3.48 | 3.50 | 3.40 | 3.47 | 624485.0 | 3.536 | 3.519 | 3.3515 | 3.438833 | 3.542000 | 3.849135 | 1025033.6 | 1797463.3 | 1315508.90 | 9.046011e+05 | 1.075360e+06 | 8.196584e+05 | 3.50 |
252 | 2006-01-18 | 3.47 | 3.57 | 3.47 | 3.53 | 740180.0 | 3.536 | 3.550 | 3.3690 | 3.433000 | 3.547833 | 3.842135 | 883772.8 | 1819289.1 | 1244595.10 | 8.485467e+05 | 1.079651e+06 | 8.202904e+05 | 3.47 |
253 | 2006-01-19 | 3.54 | 3.58 | 3.46 | 3.55 | 1107908.0 | 3.530 | 3.573 | 3.3875 | 3.430000 | 3.552833 | 3.835295 | 855644.4 | 1777605.9 | 1262985.00 | 8.440694e+05 | 1.084204e+06 | 8.229736e+05 | 3.53 |
target here is just as same as df.close.shift(1)
then I build the databunch, and put it into training:
dep_var = 'target'
cat_names = []
cont_names = df.columns.drop('date').tolist()
procs = [FillMissing, Categorify, Normalize]
train = (TabularList.from_df(df, path=os.path.dirname(csv_file),
cat_names=cat_names,
cont_names=cont_names,
procs=procs)
.split_by_idx(list(range(len(df)-200, len(df))))
.label_from_df(cols=dep_var, label_cls=FloatList, log=True)
.databunch())
train.show_batch(rows=10)
learn = tabular_learner(data, layers=[1000,100], metrics=mse)
lr = 1e-2
callbacks = [
SaveModelCallback(learn, every='improvement', monitor='mse', name='best'),
EarlyStoppingCallback(learn, monitor='mse', min_delta=0.1, patience=3)
]
learn.callbacks= callbacks
learn.fit(10, lr)
but when I try to make some predict:
learn.predict(train.valid_ds[0])
or
learn.predict(df.iloc[0]) # as the document shows
both will raise same exception as the following:
KeyError Traceback (most recent call last)
~/miniconda3/envs/lynch/lib/python3.8/site-packages/pandas/core/indexes/base.py in get_loc(self, key, method, tolerance)
2601 try:
-> 2602 return self._engine.get_loc(key)
2603 except KeyError:
pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()
pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()
pandas/_libs/index_class_helper.pxi in pandas._libs.index.Int64Engine._check_type()
KeyError: 'o'
During handling of the above exception, another exception occurred:
KeyError Traceback (most recent call last)
<ipython-input-225-876bf4b11578> in <module>
----> 1 learn.predict(train.valid_ds[0])
~/miniconda3/envs/lynch/lib/python3.8/site-packages/fastai/basic_train.py in predict(self, item, return_x, batch_first, with_dropout, **kwargs)
370 def predict(self, item:ItemBase, return_x:bool=False, batch_first:bool=True, with_dropout:bool=False, **kwargs):
371 "Return predicted class, label and probabilities for `item`."
--> 372 batch = self.data.one_item(item)
373 res = self.pred_batch(batch=batch, with_dropout=with_dropout)
374 raw_pred,x = grab_idx(res,0,batch_first=batch_first),batch[0]
~/miniconda3/envs/lynch/lib/python3.8/site-packages/fastai/basic_data.py in one_item(self, item, detach, denorm, cpu)
179 "Get `item` into a batch. Optionally `detach` and `denorm`."
180 ds = self.single_ds
--> 181 with ds.set_item(item):
182 return self.one_batch(ds_type=DatasetType.Single, detach=detach, denorm=denorm, cpu=cpu)
183
~/miniconda3/envs/lynch/lib/python3.8/contextlib.py in __enter__(self)
111 del self.args, self.kwds, self.func
112 try:
--> 113 return next(self.gen)
114 except StopIteration:
115 raise RuntimeError("generator didn't yield") from None
~/miniconda3/envs/lynch/lib/python3.8/site-packages/fastai/data_block.py in set_item(self, item)
609 def set_item(self,item):
610 "For inference, will briefly replace the dataset with one that only contains `item`."
--> 611 self.item = self.x.process_one(item)
612 yield None
613 self.item = None
~/miniconda3/envs/lynch/lib/python3.8/site-packages/fastai/data_block.py in process_one(self, item, processor)
89 if processor is not None: self.processor = processor
90 self.processor = listify(self.processor)
---> 91 for p in self.processor: item = p.process_one(item)
92 return item
93
~/miniconda3/envs/lynch/lib/python3.8/site-packages/fastai/tabular/data.py in process_one(self, item)
42 def process_one(self, item):
43 df = pd.DataFrame([item,item])
---> 44 for proc in self.procs: proc(df, test=True)
45 if len(self.cat_names) != 0:
46 codes = np.stack([c.cat.codes.values for n,c in df[self.cat_names].items()], 1).astype(np.int64) + 1
~/miniconda3/envs/lynch/lib/python3.8/site-packages/fastai/tabular/transform.py in __call__(self, df, test)
122 "Apply the correct function to `df` depending on `test`."
123 func = self.apply_test if test else self.apply_train
--> 124 func(df)
125
126 def apply_train(self, df:DataFrame):
~/miniconda3/envs/lynch/lib/python3.8/site-packages/fastai/tabular/transform.py in apply_test(self, df)
175 if name+'_na' not in self.cat_names: self.cat_names.append(name+'_na')
176 df[name] = df[name].fillna(self.na_dict[name])
--> 177 elif pd.isnull(df[name]).sum() != 0:
178 raise Exception(f"""There are nan values in field {name} but there were none in the training set.
179 Please fix those manually.""")
~/miniconda3/envs/lynch/lib/python3.8/site-packages/pandas/core/frame.py in __getitem__(self, key)
2915 if self.columns.nlevels > 1:
2916 return self._getitem_multilevel(key)
-> 2917 indexer = self.columns.get_loc(key)
2918 if is_integer(indexer):
2919 indexer = [indexer]
~/miniconda3/envs/lynch/lib/python3.8/site-packages/pandas/core/indexes/base.py in get_loc(self, key, method, tolerance)
2602 return self._engine.get_loc(key)
2603 except KeyError:
-> 2604 return self._engine.get_loc(self._maybe_cast_indexer(key))
2605 indexer = self.get_indexer([key], method=method, tolerance=tolerance)
2606 if indexer.ndim > 1 or indexer.size > 1:
pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()
pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()
pandas/_libs/index_class_helper.pxi in pandas._libs.index.Int64Engine._check_type()
KeyError: 'o'
Any idea?