Hi! I’ve totally rebuilt my MNIST model with Cross Entropy Loss, however I’ve experience something really strange… I’m getting an 89 percent accuracy after the first epoch.
(Also I can’t use predict, I’m getting an
AttributeError: 'list' object has no attribute 'decode_batch'
error, I have no idea what that is…
Here’s my code
from fastai.vision.all import *
from fastbook import *
path = untar_data(URLs.MNIST)
path.ls()
(#2) [Path('/storage/data/mnist_png/training'),Path('/storage/data/mnist_png/testing')]
path_train = path/'training'
path_valid = path/'testing'
train_x = get_image_files(path_train).sorted()
train_x = [(tensor(Image.open(element)).float()) / 255 for element in train_x ]
train_x = torch.stack(train_x).view(-1,28*28)
train_y = [int(element.parent.name) for element in get_image_files(path_train).sorted()]
train_y = tensor(train_y)
dl = DataLoader(list(zip(train_x,train_y)),batch_size=256,shuffle=True)
valid_x = get_image_files(path_valid).sorted()
valid_x = [(tensor(Image.open(element)).float()) / 255 for element in valid_x ]
valid_x = torch.stack(valid_x).view(-1,28*28)
valid_y = [int(element.parent.name) for element in get_image_files(path_valid).sorted()]
valid_y = tensor(valid_y)
dl_valid = DataLoader(list(zip(valid_x,valid_y)),batch_size=256,shuffle=True)
dls = DataLoaders(dl,dl_valid)
simple_net = nn.Sequential(
nn.Linear(28*28,50),
nn.ReLU(),
nn.Linear(50,10)
)
learner = Learner(dls,simple_net,opt_func=SGD,loss_func=nn.CrossEntropyLoss(),metrics=accuracy,lr=0.001)
learner.fit(10,0.1)
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.461113 | 0.376712 | 0.896700 | 00:01 |
1 | 0.337488 | 0.320396 | 0.906000 | 00:01 |
2 | 0.304718 | 0.276135 | 0.921400 | 00:01 |
3 | 0.275385 | 0.254352 | 0.927000 | 00:01 |
4 | 0.240319 | 0.235078 | 0.931800 | 00:01 |
5 | 0.225280 | 0.219778 | 0.936700 | 00:01 |
6 | 0.218840 | 0.209284 | 0.939000 | 00:01 |
7 | 0.201014 | 0.194134 | 0.942900 | 00:01 |
8 | 0.190363 | 0.185128 | 0.944700 | 00:01 |
9 | 0.184855 | 0.178727 | 0.948300 | 00:01 |