Cuda error trying to train text segmentation

So I’m pretty new to deep learning and I’m trying to train a model to find text in images.
I found a mask-dataset(COCO_TS) that is based on about 15000 images of the coco-dataset and managed to create the dataloaders. The masks are saved as .png files, with black representing “no-text” and white “text”.

def label_func_x(fn): return path/"COCO_NEW"/f"{fn.stem}.jpg"
def label_func_y(fn): return path/"COCO_TS_labels"/f"{fn.stem}.png"
path = Path('/storage/data')

dblock = DataBlock(blocks = (ImageBlock,MaskBlock),
              get_items = get_image_files,
              get_x = label_func_x,
              get_y = label_func_y,
              splitter = RandomSplitter(),
              item_tfms=Resize(256, ResizeMethod.Pad, pad_mode='zeros'))

dls = dblock.dataloaders(path_coco_new, path=Path('/storage/data'), bs=8)

Up to this point everything works and when using show_batch() I can see the images with the masks overlaid.
But when I try to train the model with the following code:

learn = unet_learner(dls, resnet34, n_out = 2)
learn.fine_tune(1)

I get the following error for learn.fine_tune:

epoch	train_loss	valid_loss	time
0	0.000000	00:01
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-21-fb4fd0bdbed6> in <module>
     1 learn = unet_learner(dls, resnet34, n_out = 2)
     2 
----> 3 learn.fine_tune(1)

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastcore/logargs.py in _f(*args, **kwargs)
    54         init_args.update(log)
    55         setattr(inst, 'init_args', init_args)
---> 56         return inst if to_return else f(*args, **kwargs)
    57     return _f

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/callback/schedule.py in fine_tune(self, epochs, base_lr, freeze_epochs, lr_mult, pct_start, div, **kwargs)
   159     "Fine tune with `freeze` for `freeze_epochs` then with `unfreeze` from `epochs` using discriminative LR"
   160     self.freeze()
--> 161     self.fit_one_cycle(freeze_epochs, slice(base_lr), pct_start=0.99, **kwargs)
   162     base_lr /= 2
   163     self.unfreeze()

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastcore/logargs.py in _f(*args, **kwargs)
    54         init_args.update(log)
    55         setattr(inst, 'init_args', init_args)
---> 56         return inst if to_return else f(*args, **kwargs)
    57     return _f

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/callback/schedule.py in fit_one_cycle(self, n_epoch, lr_max, div, div_final, pct_start, wd, moms, cbs, reset_opt)
   111     scheds = {'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final),
   112               'mom': combined_cos(pct_start, *(self.moms if moms is None else moms))}
--> 113     self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd)
   114 
   115 # Cell

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastcore/logargs.py in _f(*args, **kwargs)
    54         init_args.update(log)
    55         setattr(inst, 'init_args', init_args)
---> 56         return inst if to_return else f(*args, **kwargs)
    57     return _f

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
   205             self.opt.set_hypers(lr=self.lr if lr is None else lr)
   206             self.n_epoch = n_epoch
--> 207             self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
   208 
   209     def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
   153 
   154     def _with_events(self, f, event_type, ex, final=noop):
--> 155         try:       self(f'before_{event_type}')       ;f()
   156         except ex: self(f'after_cancel_{event_type}')
   157         finally:   self(f'after_{event_type}')        ;final()

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py in _do_fit(self)
   195         for epoch in range(self.n_epoch):
   196             self.epoch=epoch
--> 197             self._with_events(self._do_epoch, 'epoch', CancelEpochException)
   198 
   199     @log_args(but='cbs')

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
   153 
   154     def _with_events(self, f, event_type, ex, final=noop):
--> 155         try:       self(f'before_{event_type}')       ;f()
   156         except ex: self(f'after_cancel_{event_type}')
   157         finally:   self(f'after_{event_type}')        ;final()

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py in _do_epoch(self)
   189 
   190     def _do_epoch(self):
--> 191         self._do_epoch_train()
   192         self._do_epoch_validate()
   193 

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py in _do_epoch_train(self)
   181     def _do_epoch_train(self):
   182         self.dl = self.dls.train
--> 183         self._with_events(self.all_batches, 'train', CancelTrainException)
   184 
   185     def _do_epoch_validate(self, ds_idx=1, dl=None):

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
   153 
   154     def _with_events(self, f, event_type, ex, final=noop):
--> 155         try:       self(f'before_{event_type}')       ;f()
   156         except ex: self(f'after_cancel_{event_type}')
   157         finally:   self(f'after_{event_type}')        ;final()

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py in all_batches(self)
   159     def all_batches(self):
   160         self.n_iter = len(self.dl)
--> 161         for o in enumerate(self.dl): self.one_batch(*o)
   162 
   163     def _do_one_batch(self):

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py in one_batch(self, i, b)
   177         self.iter = i
   178         self._split(b)
--> 179         self._with_events(self._do_one_batch, 'batch', CancelBatchException)
   180 
   181     def _do_epoch_train(self):

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
   153 
   154     def _with_events(self, f, event_type, ex, final=noop):
--> 155         try:       self(f'before_{event_type}')       ;f()
   156         except ex: self(f'after_cancel_{event_type}')
   157         finally:   self(f'after_{event_type}')        ;final()

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py in _do_one_batch(self)
   164         self.pred = self.model(*self.xb)
   165         self('after_pred')
--> 166         if len(self.yb): self.loss = self.loss_func(self.pred, *self.yb)
   167         self('after_loss')
   168         if not self.training or not len(self.yb): return

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/losses.py in __call__(self, inp, targ, **kwargs)
    32         if targ.dtype in [torch.int8, torch.int16, torch.int32]: targ = targ.long()
    33         if self.flatten: inp = inp.view(-1,inp.shape[-1]) if self.is_2d else inp.view(-1)
---> 34         return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs)
    35 
    36 # Cell

/opt/conda/envs/fastai/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   720             result = self._slow_forward(*input, **kwargs)
   721         else:
--> 722             result = self.forward(*input, **kwargs)
   723         for hook in itertools.chain(
   724                 _global_forward_hooks.values(),

/opt/conda/envs/fastai/lib/python3.8/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
   945 
   946     def forward(self, input: Tensor, target: Tensor) -> Tensor:
--> 947         return F.cross_entropy(input, target, weight=self.weight,
   948                                ignore_index=self.ignore_index, reduction=self.reduction)
   949 

/opt/conda/envs/fastai/lib/python3.8/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
  2420     if size_average is not None or reduce is not None:
  2421         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2422     return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
  2423 
  2424 

/opt/conda/envs/fastai/lib/python3.8/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
  2216                          .format(input.size(0), target.size(0)))
  2217     if dim == 2:
-> 2218         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
  2219     elif dim == 4:
  2220         ret = torch._C._nn.nll_loss2d(input, target, weight, _Reduction.get_enum(reduction), ignore_index)

RuntimeError: cuda runtime error (710) : device-side assert triggered at /opt/conda/conda-bld/pytorch_1595629395347/work/aten/src/THCUNN/generic/ClassNLLCriterion.cu:118

I’m not sure what the problem is and couldn’t find anything to get it to work.
Any help would be majorly appreciated!

If anyone has similar problems here is a quick update:
The problem was with the masks, as they were labeled from 0(black) for no text to 255(white) for text. Simply writing a quick script to change the maskes to 0 and 1 did the trick.