Hi, I’m using fastai to fine tune a distilbert model. And the error occurs when calculating the loss. Could you give me some advice on how to deal with it?
My code:
tokenizer = AutoTokenizer.from_pretrained("cl-tohoku/bert-base-japanese-whole-word-masking")
distil_model = AutoModel.from_pretrained("bandainamco-mirai/distilbert-base-japanese")
class DistilBertClassifier(torch.nn.Module):
def __init__(self):
super(DistilBertClassifier, self).__init__()
self.distil_bert = distil_model
self.linear = nn.Linear(768, NUM_CLASSES)
nn.init.normal_(self.linear.weight, std=0.02)
nn.init.normal_(self.linear.bias, 0)
def forward(self, input_ids):
vec, _ = self.distil_bert(input_ids)
vec = vec[:,0,:]
out = self.linear(vec)
return out
model = DistilBertClassifier()
learn = Learner(dls,
HF_BaseModelWrapper(model),
opt_func=partial(Adam, decouple_wd=True),
loss_func=CrossEntropyLossFlat(),
metrics=[accuracy, F1Score(average='macro')],
cbs=[HF_BaseModelCallback],
splitter=hf_splitter)
The error:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
~/tweets/test.py in <module>
<a href='file:///home/zchelllo/tweets/test.py?line=112'>113</a> learn.freeze()
<a href='file:///home/zchelllo/tweets/test.py?line=113'>114</a> # learn.lr_find(suggestions=True)
---> <a href='file:///home/zchelllo/tweets/test.py?line=114'>115</a> learn.fit_one_cycle(EPOCH)
~/anaconda3/lib/python3.8/site-packages/fastcore/logargs.py in _f(*args, **kwargs)
54 init_args.update(log)
55 setattr(inst, 'init_args', init_args)
---> 56 return inst if to_return else f(*args, **kwargs)
57 return _f
~/anaconda3/lib/python3.8/site-packages/fastai/callback/schedule.py in fit_one_cycle(self, n_epoch, lr_max, div, div_final, pct_start, wd, moms, cbs, reset_opt)
111 scheds = {'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final),
112 'mom': combined_cos(pct_start, *(self.moms if moms is None else moms))}
--> 113 self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd)
114
115 # Cell
~/anaconda3/lib/python3.8/site-packages/fastcore/logargs.py in _f(*args, **kwargs)
54 init_args.update(log)
55 setattr(inst, 'init_args', init_args)
---> 56 return inst if to_return else f(*args, **kwargs)
57 return _f
~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
205 self.opt.set_hypers(lr=self.lr if lr is None else lr)
206 self.n_epoch,self.loss = n_epoch,tensor(0.)
--> 207 self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
208
209 def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None
~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
153
154 def _with_events(self, f, event_type, ex, final=noop):
--> 155 try: self(f'before_{event_type}') ;f()
156 except ex: self(f'after_cancel_{event_type}')
157 finally: self(f'after_{event_type}') ;final()
~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in _do_fit(self)
195 for epoch in range(self.n_epoch):
196 self.epoch=epoch
--> 197 self._with_events(self._do_epoch, 'epoch', CancelEpochException)
198
199 @log_args(but='cbs')
~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
153
154 def _with_events(self, f, event_type, ex, final=noop):
--> 155 try: self(f'before_{event_type}') ;f()
156 except ex: self(f'after_cancel_{event_type}')
157 finally: self(f'after_{event_type}') ;final()
~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in _do_epoch(self)
189
190 def _do_epoch(self):
--> 191 self._do_epoch_train()
192 self._do_epoch_validate()
193
~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in _do_epoch_train(self)
181 def _do_epoch_train(self):
182 self.dl = self.dls.train
--> 183 self._with_events(self.all_batches, 'train', CancelTrainException)
184
185 def _do_epoch_validate(self, ds_idx=1, dl=None):
~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
153
154 def _with_events(self, f, event_type, ex, final=noop):
--> 155 try: self(f'before_{event_type}') ;f()
156 except ex: self(f'after_cancel_{event_type}')
157 finally: self(f'after_{event_type}') ;final()
~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in all_batches(self)
159 def all_batches(self):
160 self.n_iter = len(self.dl)
--> 161 for o in enumerate(self.dl): self.one_batch(*o)
162
163 def _do_one_batch(self):
~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in one_batch(self, i, b)
177 self.iter = i
178 self._split(b)
--> 179 self._with_events(self._do_one_batch, 'batch', CancelBatchException)
180
181 def _do_epoch_train(self):
~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
153
154 def _with_events(self, f, event_type, ex, final=noop):
--> 155 try: self(f'before_{event_type}') ;f()
156 except ex: self(f'after_cancel_{event_type}')
157 finally: self(f'after_{event_type}') ;final()
~/anaconda3/lib/python3.8/site-packages/fastai/learner.py in _do_one_batch(self)
164 self.pred = self.model(*self.xb)
165 self('after_pred')
--> 166 if len(self.yb): self.loss = self.loss_func(self.pred, *self.yb)
167 self('after_loss')
168 if not self.training or not len(self.yb): return
~/anaconda3/lib/python3.8/site-packages/fastai/layers.py in __call__(self, inp, targ, **kwargs)
295 if targ.dtype in [torch.int8, torch.int16, torch.int32]: targ = targ.long()
296 if self.flatten: inp = inp.view(-1,inp.shape[-1]) if self.is_2d else inp.view(-1)
--> 297 return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs)
298
299 # Cell
~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
945
946 def forward(self, input: Tensor, target: Tensor) -> Tensor:
--> 947 return F.cross_entropy(input, target, weight=self.weight,
948 ignore_index=self.ignore_index, reduction=self.reduction)
949
~/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
2420 if size_average is not None or reduce is not None:
2421 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2422 return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
2423
2424
~/anaconda3/lib/python3.8/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
2213
2214 if input.size(0) != target.size(0):
-> 2215 raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
2216 .format(input.size(0), target.size(0)))
2217 if dim == 2:
ValueError: Expected input batch_size (1) to match target batch_size (64).