Hi,
I am trying to create my own data generator to train a UNet for image segmentation. The code is as follows:
df = pd.read_csv('/path/to/csv/data.csv')
X = list(df['input_img'])
y = list(df['mask_img'])
X_train, X_valid, y_train, y_valid = train_test_split(
X, y, test_size=0.33, random_state=42)
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, img):
img = img.transpose((2, 0, 1))
# return {'image': torch.from_numpy(img),
# }
return torch.from_numpy(img)
class NumbersDataset(Dataset):
def __init__(self, inputs, labels, transform=None):
classes = [0,1]
self.X = inputs
self.y = labels
self.transform = transform
self.c = 2
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
img_train = cv2.imread(self.X[idx])
img_mask = cv2.imread(self.y[idx])
img_train = cv2.resize(img_train, (427,240), interpolation = cv2.INTER_LANCZOS4)
img_mask = cv2.resize(img_mask, (427,240), interpolation = cv2.INTER_LANCZOS4)
img_mask = cv2.cvtColor(img_mask, cv2.COLOR_BGR2GRAY)
bin_mask = np.zeros_like(img_mask)
bin_mask[(img_mask)>0]=1
bin_mask = bin_mask.reshape(240, 427, 1)
if self.transform:
img_train = self.transform(img_train)
bin_mask = self.transform(bin_mask)
return img_train, bin_mask
if __name__ == '__main__':
dataset_train = NumbersDataset(X_train, y_train, transforms.Compose([ToTensor()]))
# dataset_train = NumbersDataset(X_train, y_train)
dataloader_train = DataLoader(dataset_train, batch_size=4, shuffle=True)
# dataset_valid = NumbersDataset(X_valid, y_valid)
dataset_valid = NumbersDataset(X_valid, y_valid, transforms.Compose([ToTensor()]))
dataloader_valid = DataLoader(dataset_valid, batch_size=4, shuffle=True)
datas = DataBunch.create(train_ds = dataloader_train, valid_ds = dataloader_valid)
# datas.show_batch()
datas.c = 1
learner = unet_learner(datas, models.resnet34)
The CSV
file contains the location of the required images in the first column and the respective masks in the second one. I get the following error:
TypeError: new() argument after * must be an iterable, not builtin_function_or_method
What should I do?