Hi guys,
I’m trying to implement the training process for the complete MNIST set as suggested at the end of chapter 4
For a batch size of 5, after the a model calc i got a tensor [5, 10], so 5 images x 10 probabilities for each, and thats seems ok, but when trying to calculate the loss i got the error: The size of tensor a (5) must match the size of tensor b (10) at non-singleton dimension 1
a are the labels, and they are 5, because the batch size is 5, so i can’t understand how can I calculate the cross entropy loss between a 5x10 probabilities and only 5 labels, should i transform the 5 labels to 5x10 probabilities?
bs = 5
dl = DataLoader(dset, batch_size=bs, shuffle=True, random=True)
w1 = torch.randn((28*28, 30)).requires_grad_()
w2 = torch.randn((30, 10)).requires_grad_()
b1 = torch.randn(1).requires_grad_()
b2 = torch.randn(1).requires_grad_()
lr = 0.00001
def softmax(preds):
ex = torch.exp(preds)
return ex / torch.sum(ex, axis=0)
def model(x):
res = x@w1 + b1
res = res.max(tensor(0.0)) # ReLU
res = res@w2 + b2
res = softmax(res)
return res
def loss_func(pred, y): # cross entropy
return -torch.sum(y * torch.log(pred))
def forward():
global w1,w2
global b1,b2
global bs
for x,y in dl:
x = x.view((bs, 28*28))
pred = model(x)
print(x.shape)
print(y.shape)
print(pred.shape)
loss = loss_func(pred, y)
loss.backward()
w1.data -= (lr * w1.grad)
w2.data -= (lr * w2.grad)
b1.data -= (lr * b1.grad)
b2.data -= (lr * b2.grad)
w1.grad = None
w2.grad = None
b1.grad = None
b2.grad = None
forward()
and this is the output
torch.Size([5, 784])
torch.Size([5])
torch.Size([5, 10])
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-93-e9fb94dcb8e7> in <module>()
46 b2.grad = None
47
---> 48 forward()
1 frames
<ipython-input-93-e9fb94dcb8e7> in loss_func(pred, y)
18
19 def loss_func(pred, y):
---> 20 return -torch.sum(y * torch.log(pred))
21
22 def my_accuracy(xb, yb):
RuntimeError: The size of tensor a (5) must match the size of tensor b (10) at non-singleton dimension 1