I have a weird use-case that i cant seem to get working. My Data consists of “Images” that are grouped into “Studies”. Each Study either consists of a “Frontal” x-ray Image or a “Frontal” x-ray and “Lateral” x-ray Image. In the case where 2 Images are available, i want the forward pass to consist of torch.max(pred_frontal, pred_lateral) basically. Where there is only 1 Image available, there should only be the prediction of the frontal image. For Example:
def forward(self, x1, x2=None): ftrs_frontal = self.encoder(x1) pred_frontal = self.head(ftrs_frontal) if x2 != None: ftrs_lateral = self.encoder(x2) pred_lateral = self.head(ftrs_lateral) return torch.max(pred_frontal, pred_lateral) return pred_frontal
Now the Problem is, the forward function gets batches as Parameters and I dont want these two cases batched together. What i need is two separate batches for the inputs, where they dont mix. What I’ve been doing is first training on one Study group and then on the other, which isnt ideal.
def get_only_lateral_studies_data_loader(df_path): df = pd.read_csv(df_path) train_df = df.loc[(df['valid'] == False) & (df['Lateral'] != 'black.jpg')] valid_df = df.loc[(df['valid'] == True) & (df['Lateral'] != 'black.jpg')] train_df.reset_index(inplace=True) valid_df.reset_index(inplace=True) train_tl= TfmdLists(range(len(train_df)), StudyTransform(train_df)) valid_tl= TfmdLists(range(len(valid_df)), StudyTransform(valid_df)) dls = DataLoaders.from_dsets(train_tl, valid_tl, after_item=[ToTensor], after_batch=[IntToFloatTensor, Normalize.from_stats(*imagenet_stats), *aug_transforms()]) dls = dls.cuda() return dls
def get_only_frontal_studies_data_loader(df_path): df = pd.read_csv(df_path) df = df.loc[df['Lateral'] == 'black.jpg'] df[target_label] = df[target_label].astype(bool) return ImageDataLoaders.from_df(df=df, path=path, fn_col='Frontal', valid_col='valid', label_col=target_label, batch_tfms=aug_transforms())
Is there a way of combining these dataloaders and getting one batch from each dataloader, which i can feed into the network, or am I thinking about this wrong?