Hi fellows,
Recently, I am trying to use fit_one_cycle to train a model with a custom PyTorch dataset and Dataloader. The custom dataset is used to load data with size (256,256,9) from hdf5 files.
Codes:
import time, torch, os, h5py
from torch.utils.data import Datasetclass hdf5_dataset(Dataset):
def __init__(self, path, data_type='train', transform=None): self.file_path = path self.data = None self.label = None self.data_type = data_type self.c = 17 with h5py.File(self.file_path+data_type+'_data.h5', 'r') as file: self.len = len(file) self.transform = transform def __len__(self): return self.len def __getitem__(self, idx): if self.data is None: self.data = h5py.File(self.file_path+self.data_type+'_data.h5', 'r') if self.label is None: self.label = h5py.File(self.file_path+self.data_type+'_label.h5', 'r') self.data_list = list(self.data.keys()) self.label_list = list(self.label.keys()) image = self.data.get(self.data_list[idx]).value label = self.label.get(self.label_list[idx]).value if self.transform: image = self.transform(image) return image, label
path = “path_to_hdf5_files”
batch_size=64
num_workers = 2trainset = hdf5_dataset(path, ‘train’)
validset = hdf5_dataset(path, ‘valid’)print(len(trainset), len(validset))
train_dl = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
valid_dl = DataLoader(validset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)from fastai.vision.data import DataLoaders
data = DataLoaders(train_dl, valid_dl)from fastai.vision.learner import cnn_learner, error_rate
import torch
from torchvision import modelslearner_original = cnn_learner(data, models.resnet34, metrics=error_rate, pretrained=True)
torch.cuda.set_device(0)
learner_original.model.cuda()learner_original.freeze()
learner_original.fit_one_cycle(5)learner_original.unfreeze()
learner_original.fit_one_cycle(5)
AssertionError:
n_out
is not defined, and could not be inferred from data, setdls.c
or passn_out
Question:
I assume fastai.vision.data.DataLoaders can wrap two torch.utils.data.DataLoader and use to build learner, but obviously, I was wrong. So if I want to build a custom dataloader that can load data from hdf5 or numpy files (not images) for fastai’s Learner, how should I do?
Many thanks!