Hi,
I am receiving the following error when trying to run get_preds.
probs, pred, idxs = learn.get_preds(dl=test_dl, with_decoded=True, act=activation_function)
Error:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Input In [48], in <cell line: 2>()
1 #Expecting 4d got 3d, need another dimension for the batch size
----> 2 probs, pred, idxs = learn.get_preds(dl=test_dl, with_decoded=True, act=activation_function)
File /usr/local/lib/python3.9/dist-packages/fastai/learner.py:300, in Learner.get_preds(self, ds_idx, dl, with_input, with_decoded, with_loss, act, inner, reorder, cbs, **kwargs)
298 if with_loss: ctx_mgrs.append(self.loss_not_reduced())
299 with ContextManagers(ctx_mgrs):
--> 300 self._do_epoch_validate(dl=dl)
301 if act is None: act = getcallable(self.loss_func, 'activation')
302 res = cb.all_tensors()
File /usr/local/lib/python3.9/dist-packages/fastai/learner.py:236, in Learner._do_epoch_validate(self, ds_idx, dl)
234 if dl is None: dl = self.dls[ds_idx]
235 self.dl = dl
--> 236 with torch.no_grad(): self._with_events(self.all_batches, 'validate', CancelValidException)
File /usr/local/lib/python3.9/dist-packages/fastai/learner.py:193, in Learner._with_events(self, f, event_type, ex, final)
192 def _with_events(self, f, event_type, ex, final=noop):
--> 193 try: self(f'before_{event_type}'); f()
194 except ex: self(f'after_cancel_{event_type}')
195 self(f'after_{event_type}'); final()
File /usr/local/lib/python3.9/dist-packages/fastai/learner.py:199, in Learner.all_batches(self)
197 def all_batches(self):
198 self.n_iter = len(self.dl)
--> 199 for o in enumerate(self.dl): self.one_batch(*o)
File /usr/local/lib/python3.9/dist-packages/fastai/learner.py:227, in Learner.one_batch(self, i, b)
225 b = self._set_device(b)
226 self._split(b)
--> 227 self._with_events(self._do_one_batch, 'batch', CancelBatchException)
File /usr/local/lib/python3.9/dist-packages/fastai/learner.py:193, in Learner._with_events(self, f, event_type, ex, final)
192 def _with_events(self, f, event_type, ex, final=noop):
--> 193 try: self(f'before_{event_type}'); f()
194 except ex: self(f'after_cancel_{event_type}')
195 self(f'after_{event_type}'); final()
File /usr/local/lib/python3.9/dist-packages/fastai/learner.py:205, in Learner._do_one_batch(self)
204 def _do_one_batch(self):
--> 205 self.pred = self.model(*self.xb)
206 self('after_pred')
207 if len(self.yb):
File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don't have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
File /usr/local/lib/python3.9/dist-packages/fastai/layers.py:408, in SequentialEx.forward(self, x)
406 for l in self.layers:
407 res.orig = x
--> 408 nres = l(res)
409 # We have to remove res.orig to avoid hanging refs and therefore memory leaks
410 res.orig, nres.orig = None, None
File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don't have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/container.py:139, in Sequential.forward(self, input)
137 def forward(self, input):
138 for module in self:
--> 139 input = module(input)
140 return input
File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don't have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/batchnorm.py:135, in _BatchNorm.forward(self, input)
134 def forward(self, input: Tensor) -> Tensor:
--> 135 self._check_input_dim(input)
137 # exponential_average_factor is set to self.momentum
138 # (when it is available) only so that it gets updated
139 # in ONNX graph when this node is exported to ONNX.
140 if self.momentum is None:
File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/batchnorm.py:407, in BatchNorm2d._check_input_dim(self, input)
405 def _check_input_dim(self, input):
406 if input.dim() != 4:
--> 407 raise ValueError("expected 4D input (got {}D input)".format(input.dim()))
ValueError: expected 4D input (got 3D input)
I have looked at the following and see the discrepancy.
Training dataloader:
batch = dl.one_batch()
for i in range(len(batch)):
print(batch[i].shape)
Output:
torch.Size([4, 8, 256, 256])
torch.Size([4, 256, 256])
After successfully training the unet learner.
I grab the test_dl from the learner.
test_dl = learn.dls.test_dl(test_files, bs=4)
test_batch = test_dl.one_batch()
for i in range(len(test_batch)):
print(test_batch[i].shape)
Output:
torch.Size([8, 256, 256])
torch.Size([8, 256, 256])
torch.Size([8, 256, 256])
torch.Size([8, 256, 256])
I have attempted to set create_batch in the test_dl like so:
def test_batch(xs):
stacked_xs = torch.stack(xs)
return stacked_xs
test_dl = learn.dls.test_dl(test_files[:6], bs=4, verbose=True, create_batch=test_batch)
But I end up with the same error message. I’m not sure why it’s not batching the test results up.
Here’s the code for the dataloader and learner.
class SmashTransform(ItemTransform):
def encodes(self, inp):
xs = inp[0]
x_smash = None
for idx, x in enumerate(xs):
if x_smash == None:
x_smash = x
else:
x_smash = torch.cat((x_smash, x))
if len(inp) == 2:
return x_smash, inp[1]
else:
return x_smash
def get_dataloader(file_set, months=MONTHS):
blocks = []
blocks.append(TransformBlock([partial(get_all_files_month, months=months)]))
blocks.append(TransformBlock([label_func, partial(open_tif, cls=TensorMask)]))
db = DataBlock(blocks=tuple(blocks),
splitter=RandomSplitter(valid_pct=0.2, seed=42),
n_inp=1,
item_tfms=[SmashTransform()]
#get_x=[partial(get_filenames, month=month) for month in months]
)
ds = db.datasets(source=file_set)
dl = db.dataloaders(source=file_set, bs=4)
return dl
def get_learner(dl, arch=resnet34, no_months=6):
return unet_learner(dl,
arch,
n_in=4 * no_months,
n_out=1,
loss_func=BCEWithLogitsLossFlat(),
#splitter=sat_splitter
metrics=[Recall(axis=1), F1Score(axis=1)]
)