2D Tabular Input with Fastai DataBlock.dataloaders

Objective: Use the fastai DataBlock.dataloaders to process two dimensional tabular data.

Hi all, I have a two-dimensional time series that I’d like to work on with fastai. The time series is of the form (time lags, features). I believe that I could flatten the two-dimensional series and then use views to reassemble it after the dataloader. However, I’d really like to be able to directly load the 2D array.

I have some tentative code that seems to work, but I’m wondering (1) if there is a better solution and (2) if this is actually a solution. The code is as follows:

# create dummy data and labels
# dummy_data is a stacked list of numpy arrays
dummy_data = []
dummy_labels = []
for i in range(1000):
    dummy_data.append(np.random.rand(20,15))
    dummy_labels.append(np.random.rand(1))

# zip dummy_data with dummy_labels
# tensor transform seems to be necessary here
inputs = L(zip(tensor(dummy_data),dummy_labels))

# create a block for handeling the 2D input
def TabularBlock2D():
    return TransformBlock(type_tfms=ToTensor)

# create classes for get_x and get_y
class GetX(ItemTransform): 
    def encodes(self, x): return (x[0])
class GetY(ItemTransform): 
    def encodes(self, x): return (x[1])

# the datablock
dblock = DataBlock(blocks=(TabularBlock2D, RegressionBlock),
                   splitter = IndexSplitter(range(800,1000)),
                   get_x=GetX,
                   get_y=GetY)

# use DataBlock.dataloaders to load in from the inputs
dls = dblock.dataloaders(inputs, bs=64, shuffle=False)
# check that split is correct
print("Train dataset contains {} items".format(len(dls.train_ds)))
print("Valid dataset contains {} items".format(len(dls.valid_ds)))

Train dataset contains 800 items
Valid dataset contains 200 items

# examine one of the validation inputs
dls.valid_ds[0]

(tensor([[0.8814, 0.2982, 0.8461, 0.6074, 0.0805, 0.1704, 0.0955, 0.6987, 0.7238, 0.5987, 0.5830, 0.7893, 0.3019, 0.3885, 0.2340, 0.5088, 0.2791, 0.7467, 0.6164, 0.8925],
[0.7628, 0.9055, 0.4871, 0.7319, 0.0510, 0.2391, 0.5522, 0.9547, 0.2206, 0.0313, 0.6588, 0.3297, 0.3175, 0.4832, 0.5657, 0.0681, 0.2055, 0.5100, 0.7208, 0.3052],
[0.8075, 0.5873, 0.1406, 0.9819, 0.2516, 0.9287, 0.8977, 0.9098, 0.5278, 0.4741, 0.5980, 0.2682, 0.7586, 0.5345, 0.8782, 0.8165, 0.4874, 0.5375, 0.0353, 0.7902],
[0.6679, 0.3096, 0.0680, 0.2973, 0.1165, 0.4832, 0.1359, 0.1992, 0.0218, 0.6936, 0.1179, 0.3832, 0.6801, 0.8971, 0.7871, 0.5653, 0.0204, 0.9969, 0.3654, 0.0925],
[0.9425, 0.4224, 0.8633, 0.0713, 0.9213, 0.5986, 0.9772, 0.5364, 0.4143, 0.2246, 0.6872, 0.4356, 0.4121, 0.7838, 0.5322, 0.1515, 0.3203, 0.4054, 0.6664, 0.7768],
[0.5515, 0.2842, 0.0591, 0.7507, 0.7954, 0.7621, 0.4061, 0.1374, 0.4491, 0.5887, 0.6555, 0.8538, 0.2488, 0.4635, 0.6574, 0.0325, 0.3264, 0.6948, 0.8964, 0.9067],
[0.5223, 0.4029, 0.1333, 0.0926, 0.6846, 0.1382, 0.3463, 0.7041, 0.7195, 0.9345, 0.1667, 0.2565, 0.2773, 0.3554, 0.0667, 0.8071, 0.8925, 0.6219, 0.4593, 0.9727],
[0.9022, 0.3241, 0.3927, 0.7969, 0.5479, 0.8064, 0.7236, 0.5620, 0.8221, 0.7087, 0.0236, 0.2229, 0.3904, 0.6523, 0.1350, 0.7267, 0.4532, 0.3103, 0.5598, 0.7523],
[0.7611, 0.8707, 0.1314, 0.5322, 0.8318, 0.9018, 0.5010, 0.2912, 0.1227, 0.4612, 0.7476, 0.0301, 0.4316, 0.1354, 0.3103, 0.9634, 0.3012, 0.6153, 0.8235, 0.5585],
[0.6097, 0.4236, 0.3908, 0.5015, 0.1280, 0.4301, 0.2199, 0.3319, 0.4754, 0.4327, 0.4817, 0.6626, 0.5959, 0.1911, 0.5242, 0.5041, 0.0374, 0.4596, 0.4198, 0.9932],
[0.4461, 0.8136, 0.9931, 0.7760, 0.3853, 0.4103, 0.2662, 0.4255, 0.4281, 0.6994, 0.4063, 0.6971, 0.9991, 0.1164, 0.1600, 0.5969, 0.9353, 0.8585, 0.6029, 0.1449],
[0.7838, 0.1344, 0.9414, 0.6440, 0.7564, 0.2963, 0.8247, 0.0493, 0.9270, 0.5521, 0.0108, 0.1293, 0.5627, 0.8001, 0.6279, 0.2140, 0.7542, 0.8480, 0.5232, 0.2249],
[0.6167, 0.0687, 0.2118, 0.9620, 0.7675, 0.5626, 0.9194, 0.2716, 0.2356, 0.8705, 0.7965, 0.6119, 0.7054, 0.8326, 0.3923, 0.5869, 0.1178, 0.3776, 0.3401, 0.7546],
[0.5193, 0.5521, 0.6790, 0.8169, 0.0828, 0.0464, 0.3280, 0.0790, 0.9480, 0.2166, 0.3292, 0.8126, 0.6500, 0.0216, 0.6666, 0.8074, 0.2746, 0.5925, 0.5727, 0.5911],
[0.8080, 0.6560, 0.6805, 0.5358, 0.8087, 0.1312, 0.2438, 0.4797, 0.5729, 0.7188, 0.1660, 0.0095, 0.5791, 0.9520, 0.0607, 0.3403, 0.5945, 0.4176, 0.9376, 0.1637]]),
tensor([0.8210]))

Does this implementation seem alright? Or is there a much better way to do this?