Hello everyone, after several re-reads of the chapter 4 I finally jumped to the further research and I’ve modified the notebook to take care of the full mnist dataset.
I split the “training” images into a training and validation set (leaving aside the “testing” to test with images never seen by the model) but once I get to the fitting part, my accuracy is very bad (stabilize at 20%).
I’m suspecting to have missed something in creating the dataloaders but I’m not sure in how to proceed to verify that.
This is how I split and created the array of images.
imgs = []
timgs = [] # training images
vimgs = [] # validation images
for i in range(10):
sorted_imgs = (path/'training'/str(i)).ls().sorted()
shuffle_imgs = copy.deepcopy(sorted_imgs)
random.shuffle(shuffle_imgs)
vn = int(len(sorted_imgs) / 100 * 20)
imgs.append(sorted_imgs)
vimgs.append(shuffle_imgs[:vn])
timgs.append(shuffle_imgs[vn:])
And this is how I created the tensors for the dataset.
train_tensors = []
validation_tensors = []
train_stack = []
validation_stack = []
for i in range(10):
train_tensors.append([tensor(Image.open(o)) for o in timgs[i]])
validation_tensors.append([tensor(Image.open(o)) for o in vimgs[i]])
train_stack.append(torch.stack(train_tensors[i]).float()/255)
validation_stack.append(torch.stack(validation_tensors[i]).float()/255)
Next is basically vanilla code from the chapter. I wanted to display a confusion matrix to see where things were going wrong (I don’t know why, but that 20% accuracy is suspicious) but it throws an error when I try to to create the ClassificationInterpretation.
learn = Learner(dls, simple_net, opt_func=SGD,
loss_func=mnist_loss, metrics=batch_accuracy)
learn.fit(40, 0.1)
interp = ClassificationInterpretation.from_learner(learn)
Any pointers?