- class ColumnarModelData(ModelData):
- def __init__(self, path, trn_ds, val_ds, bs, test_ds=None, shuffle=True):
- test_dl = DataLoader(test_ds, bs, shuffle=False, num_workers=1) if test_ds is not None else None
- super().__init__(path, DataLoader(trn_ds, bs, shuffle=shuffle, num_workers=1),
- DataLoader(val_ds, bs*2, shuffle=False, num_workers=1), test_dl)
-
- @classmethod
- def from_arrays(cls, path, val_idxs, xs, y, is_reg=True, is_multi=False, bs=64, test_xs=None, shuffle=True):
- ((val_xs, trn_xs), (val_y, trn_y)) = split_by_idx(val_idxs, xs, y)
- test_ds = PassthruDataset(*(test_xs.T), [0] * len(test_xs), is_reg=is_reg, is_multi=is_multi) if test_xs is not None else None
- return cls(path, PassthruDataset(*(trn_xs.T), trn_y, is_reg=is_reg, is_multi=is_multi),
- PassthruDataset(*(val_xs.T), val_y, is_reg=is_reg, is_multi=is_multi),
- bs=bs, shuffle=shuffle, test_ds=test_ds)
-
- @classmethod
- def from_data_frames(cls, path, trn_df, val_df, trn_y, val_y, cat_flds, bs, is_reg, is_multi, test_df=None, shuffle=True):
- trn_ds = ColumnarDataset.from_data_frame(trn_df, cat_flds, trn_y, is_reg, is_multi)
- val_ds = ColumnarDataset.from_data_frame(val_df, cat_flds, val_y, is_reg, is_multi)
- test_ds = ColumnarDataset.from_data_frame(test_df, cat_flds, None, is_reg, is_multi) if test_df is not None else None
- return cls(path, trn_ds, val_ds, bs, test_ds=test_ds, shuffle=shuffle)
-