I’m trying to merge two datasets, and an error left me very confused. I will reproduce this error in a simple way as below:
# define two datasets
import torch.utils.data as data
class test_data1(data.Dataset):
def __init__(self):
super(test_data1, self).__init__()
def __getitem__(self, index):
inputs = {}
inputs['img'] = 1
inputs['mask'] = 2
inputs['filename'] = 3
inputs['task'] = 4
return inputs
def __len__(self):
return 50
class test_data2(data.Dataset):
def __init__(self):
super(test_data2, self).__init__()
def __getitem__(self, index):
inputs = {}
inputs['depth'] = 5
inputs['apple'] = 6
inputs['banana'] = 7
inputs['task'] = 8
return inputs
def __len__(self):
return 50
# call ConcatDataset to combine this two datasets
all_train_dataset = torch.utils.data.ConcatDataset([datasets.test_data1(), datasets.test_data2()])
train_loader = DataLoader(
all_train_dataset, self.opt.batch_size, True,
num_workers=self.opt.num_workers, pin_memory=True, drop_last=True)
for batch_idx, inputs in enumerate(self.train_loader):
print(inputs['task'])
And the following error occurred:
How can i solve this error?