Howdy!
I’m trying to get a 1-layer image recognition network to identify types of sodas and energy drinks.
I’m getting a size mismatch between the input matrix, and my hidden layer. I want the input layer to be [64x150528], but It’s telling me that it’s [43008 x 224]. (See bottom for runtime error)
The problem seems to be that m1 (the bad matrix) starts out as [64 x 3 x 224 x 224], and then multiplies from left to right like so: [(64 x 3 x 224) x 224] = [43008 x 224]. But that’s a mistake, that includes including the batch size in the multiplication (that’s what the 64 is), instead of multiplying the image components (RGB and image resolution). So what I want is for it to multiply from right to left: [64 x (3 x 224 x 224)] = [64 x 150528]. Or else resize the data to fit what I want.
Does anyone know how to resize data inside an ImageDataBunch? Or otherwise fix this?
Thanks in advance.
Here’s my code:
from fastai.basics import *
from fastai.vision import *
path = Path('data/drinks')
[PosixPath('data/drinks/cleaned.csv')]
data = ImageDataBunch.from_csv(path,ds_tfms = get_transforms(do_flip = False) ,csv_labels = 'cleaned.csv', size = 224)
class Soda_Logistic(nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(150528,5, bias = True)
def forward(self, xb): return self.lin(xb)
model = Soda_Logistic().cuda()
model
Soda_Logistic(
(lin): Linear(in_features=150528, out_features=5, bias=True)
)
x,y = next(iter(data.train_dl))
x.shape, y.shape
(torch.Size([64, 3, 224, 224]), torch.Size([64]))
lr = 2e-2
loss_func=nn.CrossEntropyLoss()
def Update(x,y,lr):
wd = 1e-5
y_hat=model(x)
w2 = 0.
for p in model.parameters(): w2 += (p**2).sum()
loss = loss_func(y_hat,y) + w2*wd
loss.backward()
with torch.no_grad():
for p in model.parameters():
p.sub_(lr * p.grad)
p.grad.zero_()
return loss.item()
losses = [Update(x,y,lr) for x,y in data.train_dl]
And Here’s my output (specifically the runtime error):
RuntimeError: size mismatch, m1: [43008 x 224], m2: [150528 x 5] at /opt/conda/conda-bld/pytorch_1579022060824/work/aten/src/THC/generic/THCTensorMathBlas.cu:290