Expected input batch_size (40) to match target batch_size (3600)

I’m a little confused by this error when I call fit_one_cycle on my TextLearner.

I suspect that it is something to do with my loss function as swapping out CrossEntropyLossFlat() for nn.CrossEntropyLoss() means that I get this error instead:

ValueError: Expected input batch_size (40) to match target batch_size (90).

Where am I going wrong? Do I need a new loss function?

Error

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-16-c1ed9d459fcc> in <module>
----> 1 learn.fit_one_cycle(5, 5e-3)

~/.local/lib/python3.8/site-packages/fastai/callback/schedule.py in fit_one_cycle(self, n_epoch, lr_max, div, div_final, pct_start, wd, moms, cbs, reset_opt)
    110     scheds = {'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final),
    111               'mom': combined_cos(pct_start, *(self.moms if moms is None else moms))}
--> 112     self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd)
    113 
    114 # Cell

~/.local/lib/python3.8/site-packages/fastai/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
    204             self.opt.set_hypers(lr=self.lr if lr is None else lr)
    205             self.n_epoch = n_epoch
--> 206             self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
    207 
    208     def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None

~/.local/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    153 
    154     def _with_events(self, f, event_type, ex, final=noop):
--> 155         try:       self(f'before_{event_type}')       ;f()
    156         except ex: self(f'after_cancel_{event_type}')
    157         finally:   self(f'after_{event_type}')        ;final()

~/.local/lib/python3.8/site-packages/fastai/learner.py in _do_fit(self)
    195         for epoch in range(self.n_epoch):
    196             self.epoch=epoch
--> 197             self._with_events(self._do_epoch, 'epoch', CancelEpochException)
    198 
    199     def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False):

~/.local/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    153 
    154     def _with_events(self, f, event_type, ex, final=noop):
--> 155         try:       self(f'before_{event_type}')       ;f()
    156         except ex: self(f'after_cancel_{event_type}')
    157         finally:   self(f'after_{event_type}')        ;final()

~/.local/lib/python3.8/site-packages/fastai/learner.py in _do_epoch(self)
    189 
    190     def _do_epoch(self):
--> 191         self._do_epoch_train()
    192         self._do_epoch_validate()
    193 

~/.local/lib/python3.8/site-packages/fastai/learner.py in _do_epoch_train(self)
    181     def _do_epoch_train(self):
    182         self.dl = self.dls.train
--> 183         self._with_events(self.all_batches, 'train', CancelTrainException)
    184 
    185     def _do_epoch_validate(self, ds_idx=1, dl=None):

~/.local/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    153 
    154     def _with_events(self, f, event_type, ex, final=noop):
--> 155         try:       self(f'before_{event_type}')       ;f()
    156         except ex: self(f'after_cancel_{event_type}')
    157         finally:   self(f'after_{event_type}')        ;final()

~/.local/lib/python3.8/site-packages/fastai/learner.py in all_batches(self)
    159     def all_batches(self):
    160         self.n_iter = len(self.dl)
--> 161         for o in enumerate(self.dl): self.one_batch(*o)
    162 
    163     def _do_one_batch(self):

~/.local/lib/python3.8/site-packages/fastai/learner.py in one_batch(self, i, b)
    177         self.iter = i
    178         self._split(b)
--> 179         self._with_events(self._do_one_batch, 'batch', CancelBatchException)
    180 
    181     def _do_epoch_train(self):

~/.local/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    153 
    154     def _with_events(self, f, event_type, ex, final=noop):
--> 155         try:       self(f'before_{event_type}')       ;f()
    156         except ex: self(f'after_cancel_{event_type}')
    157         finally:   self(f'after_{event_type}')        ;final()

~/.local/lib/python3.8/site-packages/fastai/learner.py in _do_one_batch(self)
    164         self.pred = self.model(*self.xb)
    165         self('after_pred')
--> 166         if len(self.yb): self.loss = self.loss_func(self.pred, *self.yb)
    167         self('after_loss')
    168         if not self.training or not len(self.yb): return

~/.local/lib/python3.8/site-packages/fastai/losses.py in __call__(self, inp, targ, **kwargs)
     31         if targ.dtype in [torch.int8, torch.int16, torch.int32]: targ = targ.long()
     32         if self.flatten: inp = inp.view(-1,inp.shape[-1]) if self.is_2d else inp.view(-1)
---> 33         return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs)
     34 
     35 # Cell

~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

~/.local/lib/python3.8/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
    959 
    960     def forward(self, input: Tensor, target: Tensor) -> Tensor:
--> 961         return F.cross_entropy(input, target, weight=self.weight,
    962                                ignore_index=self.ignore_index, reduction=self.reduction)
    963 

~/.local/lib/python3.8/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2460         tens_ops = (input, target)
   2461         if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
-> 2462             return handle_torch_function(
   2463                 cross_entropy, tens_ops, input, target, weight=weight,
   2464                 size_average=size_average, ignore_index=ignore_index, reduce=reduce,

~/.local/lib/python3.8/site-packages/torch/overrides.py in handle_torch_function(public_api, relevant_args, *args, **kwargs)
   1058         # Use `public_api` instead of `implementation` so __torch_function__
   1059         # implementations can do equality/identity comparisons.
-> 1060         result = overloaded_arg.__torch_function__(public_api, types, args, kwargs)
   1061 
   1062         if result is not NotImplemented:

~/.local/lib/python3.8/site-packages/fastai/torch_core.py in __torch_function__(self, func, types, args, kwargs)
    317 #         if func.__name__[0]!='_': print(func, types, args, kwargs)
    318 #         with torch._C.DisableTorchFunction(): ret = _convert(func(*args, **(kwargs or {})), self.__class__)
--> 319         ret = super().__torch_function__(func, types, args=args, kwargs=kwargs)
    320         if isinstance(ret, TensorBase): ret.set_meta(self, as_copy=True)
    321         return ret

~/.local/lib/python3.8/site-packages/torch/tensor.py in __torch_function__(cls, func, types, args, kwargs)
    993 
    994         with _C.DisableTorchFunction():
--> 995             ret = func(*args, **kwargs)
    996             return _convert(ret, cls)
    997 

~/.local/lib/python3.8/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2466     if size_average is not None or reduce is not None:
   2467         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2468     return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
   2469 
   2470 

~/.local/lib/python3.8/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   2259 
   2260     if input.size(0) != target.size(0):
-> 2261         raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
   2262                          .format(input.size(0), target.size(0)))
   2263     if dim == 2:

ValueError: Expected input batch_size (40) to match target batch_size (3600).


Code

from fastai.text.all import *

path = Path('data/human')
files = get_files(path)

bs = 90
sl = 40

cut = int(len(files)*0.8)
splits = [list(range(cut)), list(range(cut,len(files)))]

 class CodeTokenizer():
    def __init__(self, split_char=' ', **kwargs): self.split_char=split_char
    def __call__(self, items): return (t.split(self.split_char) for t in items)

rules = [replace_rep, replace_wrep, spec_add_spaces, replace_all_caps, replace_maj, lowercase]
tfms = [Tokenizer(tok=CodeTokenizer()), Numericalize]

dsets = Datasets(files, [tfms], splits=splits, dl_type=LMDataLoader)

dls = dsets.dataloaders(bs=bs, seq_len=sl)

dls.show_batch(max_n=2)

model = AWD_LSTM(len(dls.vocab), emb_sz=400, n_hid=1152, n_layers=2)
learn = TextLearner(dls, model, loss_func=CrossEntropyLossFlat(), metrics=[accuracy])

learn.fit_one_cycle(5, 5e-3)

This is likely a shape error. The shape that CrossEntropyLoss expects is different then it actually gets. CrossEntropyLossFlat works because it flattens the input and target so the dimensions match.

2 Likes

Totally agree. Do you happen to have 90 classes?

From the provided code, it seems that 40 is the sequence length and 90 is the batch size…

1 Like

Yeah, I thought that CrossEntropyLossFlat() would fix things but apparently not.

For context, using nn.CrossEntropyLoss() gives me Expected input batch_size (40) to match target batch_size (90). instead.

could you call:

b  = dls.one_batch()

and check sizes on input and target? (the batch b contains both)`

x, y = b
b  = dls.one_batch()
x, y = b
x.size(), y.size()

Gives

(torch.Size([90, 40]), torch.Size([90, 40]))

Which I think is what I would expect: this is batch size x sequence length, corresponding to what is explained in https://github.com/fastai/fastbook/blob/master/10_nlp.ipynb ( Putting Our Texts into Batches for a Language Model)

If I write a custom loss_func to log, I can see the input and target values being passed in.

def loss_func(inp, tar):
    print(inp.shape) # torch.Size([40, 400]) i.e. sl, emb_sz
    print(tar.shape) # torch.Size([90, 40]) i.e. bs, sl

These are not what I expect. I predicted that my inp shape would be batch_size x emb_sz as I thought that the output of my model would be one 1-hot-encoded embedding representing the predicted word per sequence. I predicted that the tar shape would be same.

Are these predictions wrong, or do I need to fix something else in my notebook?

I’ve discovered that using the get_language_model fixes my shape problem.

# Does not work
model = AWD_LSTM(len(dls.vocab), emb_sz=400, n_hid=1152, n_layers=2)

# Works!
model = get_language_model(AWD_LSTM, len(dls.vocab))

Looking at the source for get_langauge_model, it looks like this is doing some things under the hood that I wasn’t aware of.

Primarily, it looks like there is a ‘decoder’ step that I didn’t know about. This seems to explain the shape issue. I’ll read up on this after lunch over at https://machinelearningmastery.com/encoder-decoder-recurrent-neural-network-models-neural-machine-translation/

I generally do this:

x,y = dls.one_batch()
out = model(x)
out.shape

to be sure that the model is doing his job.

1 Like

The AWD-LSTM is a recurrent layer that outputs to the size of the hidden dim (as the nn.LSTM) you need to convert this to your desired output size. The get_language_model does this by adding a LinearDecoder, it also adds some goodies as dropout.

2 Likes

Thanks! I think I am starting to get it now.