Modifying a 3 band model for additional bands

In lesson 6 of the 2020 course Jeremy is asked about using pre-trained models on images with 4 bands, his answer is “I don’t know if there’s a tutorial but we can certainly make sure somebody on the forum has shown how to do it. it’s super straightforward, it should be pretty much automatic”.

I work in the GIS space so this got me really interested as many of the problems I’m working on would benefit from the addition of more bands, but unfortunately I can’t find much more on this topic. What I would like to do is use a pre-trained 3 band model and modify it to work with any number of bands. Ideally I would like to see an example on how to do this for u-net segmentation and for image classification. I was hoping someone could point me in the direction of an end to end example?

This is what I have found so far,

This article conceptually talks about the process but does not share the code:
transfer learning from rgb to multi band imagery

This article explains how to build a u-net type network from scratch which works with 4 band images however it does not use a pre-trained model.
Creating a Very Simple U-Net Model with PyTorch for Semantic Segmentation of Satellite Images

Lastly this article gets pretty close but also does not use 3 band pre-trained models.
How to implement augmentations for Multispectral Satellite Images Segmentation using Fastai-v2 and Albumentations

Hi Nick,

Please see the n_in parameter to cnn_learner and unet_learner. It specifies the number of channels (bands) that the pretrained model expects.

The docs are paltry. You will probably need to study the code to understand exactly what n_in is doing.

HTH, :slightly_smiling_face:

2 Likes

Hi Malcolm,

Thanks for the reply, the ‘n_in’ parameter is definitely part of the puzzle but I still need some way to modify a pretrained 3 band model to handle 4 band (or more) images.

At this point I’m starting to think I should just side step the issue by splitting my imagery up into groups of 3 bands and training a model on each group then combining the results. It’s not very elegant but should be better than only using 3 bands.

I would however love a better solution…

The n_in parameter does exactly what you are asking for (if I understand your question). It works for all the standard pretrained models provided by fastai.

If you need to expand the number of channels of your own pretrained model, the fastai code shows you how to do it.

HTH, :slightly_smiling_face:

Hi Malcolm,

Thanks for that, but I’m not having much luck, I was following along with this tutorial and indeed they do use n_in = 4 when defining the learner, however they also use ‘pretrained=False’ when I change that to True I get this error

RuntimeError                              Traceback (most recent call last)
<ipython-input-13-5a8507582839> in <module>
     7     return torch.nn.functional.cross_entropy(pred, targ.squeeze(1).type(torch.long))
     8 
----> 9 learn = unet_learner(dl, resnet18, n_in=4, n_out=2, pretrained=True, loss_func=loss_fn, metrics=acc_metric)

~/fastai/fastai/vision/learner.py in unet_learner(dls, arch, normalize, n_out, pretrained, config, loss_func, opt_func, lr, splitter, cbs, metrics, path, model_dir, wd, wd_bn_bias, train_bn, moms, **kwargs)
   217     n_out = ifnone(n_out, get_c(dls))
   218     assert n_out, "`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`"
--> 219     img_size = dls.one_batch()[0].shape[-2:]
   220     assert img_size, "image size could not be inferred from data"
   221     model = create_unet_model(arch, n_out, img_size, pretrained=pretrained, **kwargs)

~/fastai/fastai/data/load.py in one_batch(self)
   135     def one_batch(self):
   136         if self.n is not None and len(self)==0: raise ValueError(f'This DataLoader does not contain any batches')
--> 137         with self.fake_l.no_multiproc(): res = first(self)
   138         if hasattr(self, 'it'): delattr(self, 'it')
   139         return res

~/anaconda3/envs/fastaf/lib/python3.8/site-packages/fastcore/basics.py in first(x, f, negate, **kwargs)
   545     x = iter(x)
   546     if f: x = filter_ex(x, f=f, negate=negate, gen=True, **kwargs)
--> 547     return next(x, None)
   548 
   549 # Cell

~/fastai/fastai/data/load.py in __iter__(self)
   101         for b in _loaders[self.fake_l.num_workers==0](self.fake_l):
   102             if self.device is not None: b = to_device(b, self.device)
--> 103             yield self.after_batch(b)
   104         self.after_iter()
   105         if hasattr(self, 'it'): del(self.it)

~/anaconda3/envs/fastaf/lib/python3.8/site-packages/fastcore/transform.py in __call__(self, o)
   196         self.fs.append(t)
   197 
--> 198     def __call__(self, o): return compose_tfms(o, tfms=self.fs, split_idx=self.split_idx)
   199     def __repr__(self): return f"Pipeline: {' -> '.join([f.name for f in self.fs if f.name != 'noop'])}"
   200     def __getitem__(self,i): return self.fs[i]

~/anaconda3/envs/fastaf/lib/python3.8/site-packages/fastcore/transform.py in compose_tfms(x, tfms, is_enc, reverse, **kwargs)
   148     for f in tfms:
   149         if not is_enc: f = f.decode
--> 150         x = f(x, **kwargs)
   151     return x
   152 

~/anaconda3/envs/fastaf/lib/python3.8/site-packages/fastcore/transform.py in __call__(self, x, **kwargs)
    71     @property
    72     def name(self): return getattr(self, '_name', _get_name(self))
---> 73     def __call__(self, x, **kwargs): return self._call('encodes', x, **kwargs)
    74     def decode  (self, x, **kwargs): return self._call('decodes', x, **kwargs)
    75     def __repr__(self): return f'{self.name}:\nencodes: {self.encodes}decodes: {self.decodes}'

~/anaconda3/envs/fastaf/lib/python3.8/site-packages/fastcore/transform.py in _call(self, fn, x, split_idx, **kwargs)
    81     def _call(self, fn, x, split_idx=None, **kwargs):
    82         if split_idx!=self.split_idx and self.split_idx is not None: return x
---> 83         return self._do_call(getattr(self, fn), x, **kwargs)
    84 
    85     def _do_call(self, f, x, **kwargs):

~/anaconda3/envs/fastaf/lib/python3.8/site-packages/fastcore/transform.py in _do_call(self, f, x, **kwargs)
    88             ret = f.returns(x) if hasattr(f,'returns') else None
    89             return retain_type(f(x, **kwargs), x, ret)
---> 90         res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)
    91         return retain_type(res, x)
    92 

~/anaconda3/envs/fastaf/lib/python3.8/site-packages/fastcore/transform.py in <genexpr>(.0)
    88             ret = f.returns(x) if hasattr(f,'returns') else None
    89             return retain_type(f(x, **kwargs), x, ret)
---> 90         res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)
    91         return retain_type(res, x)
    92 

~/anaconda3/envs/fastaf/lib/python3.8/site-packages/fastcore/transform.py in _do_call(self, f, x, **kwargs)
    87             if f is None: return x
    88             ret = f.returns(x) if hasattr(f,'returns') else None
---> 89             return retain_type(f(x, **kwargs), x, ret)
    90         res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)
    91         return retain_type(res, x)

~/anaconda3/envs/fastaf/lib/python3.8/site-packages/fastcore/dispatch.py in __call__(self, *args, **kwargs)
   116         elif self.inst is not None: f = MethodType(f, self.inst)
   117         elif self.owner is not None: f = MethodType(f, self.owner)
--> 118         return f(*args, **kwargs)
   119 
   120     def __get__(self, inst, owner):

~/fastai/fastai/data/transforms.py in encodes(self, x)
   357             self.mean,self.std = x.mean(self.axes, keepdim=True),x.std(self.axes, keepdim=True)+1e-7
   358 
--> 359     def encodes(self, x:TensorImage): return (x-self.mean) / self.std
   360     def decodes(self, x:TensorImage):
   361         f = to_cpu if x.device.type=='cpu' else noop

~/fastai/fastai/torch_core.py in __torch_function__(self, func, types, args, kwargs)
   323         convert=False
   324         if _torch_handled(args, self._opt, func): convert,types = type(self),(torch.Tensor,)
--> 325         res = super().__torch_function__(func, types, args=args, kwargs=kwargs)
   326         if convert: res = convert(res)
   327         if isinstance(res, TensorBase): res.set_meta(self, as_copy=True)

~/anaconda3/envs/fastaf/lib/python3.8/site-packages/torch/tensor.py in __torch_function__(cls, func, types, args, kwargs)
   993 
   994         with _C.DisableTorchFunction():
--> 995             ret = func(*args, **kwargs)
   996             return _convert(ret, cls)
   997 

RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 1

*edit added full error message

Hi Nick. I’m guessing that the line that generates the error is
learn = unet_learner(dl, resnet18, n_in=4, n_out=2, pretrained=False, loss_func=loss_fn, metrics=acc_metric)

The tutorial you linked unfortunately requires me to join the Kaggle competition and download the data locally. That’s a lot of work! I am willing to do so if necessary, but maybe a shortcut will be enough.

Could you post what this gives…
dl.one_batch()[0].shape

And the offending line and complete stack trace. Also please take a look at the link below. It gives many helpful debugging and posting tips.
https://forums.fast.ai/t/how-to-debug-your-code-and-ask-for-help-with-fastai-v2/64196

If you can do a bit of pre-debugging, it will save hassle for both of us by narrowing down the possibilities.

You might even have found a bug in fastai - wouldn’t that be exciting!

:slightly_smiling_face:

Thanks for the help Malcolm, sorry I should have specified, it was indeed this line that was throwing the error

learn = unet_learner(dl, resnet18, n_in=4, n_out=2, pretrained=False, loss_func=loss_fn, metrics=acc_metric)

But it only happens when setting

pretrained=True

Running the below code

dl.one_batch()[0].shape

returns

torch.Size([12, 4, 384, 384])

Which looks reasonable I think, batch size = 12, bands = 4 and image size = 384x384

While poking around I have also found a couple other problems,
Even running the tutorial in its default state appears to be somewhat problematic,
When running

learn.lr_find()

I’m getting the error

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-12-d81c6bd29d71> in <module>
----> 1 learn.lr_find()

~/fastai/fastai/callback/schedule.py in lr_find(self, start_lr, end_lr, num_it, stop_div, show_plot, suggestions)
    220     n_epoch = num_it//len(self.dls.train) + 1
    221     cb=LRFinder(start_lr=start_lr, end_lr=end_lr, num_it=num_it, stop_div=stop_div)
--> 222     with self.no_logging(): self.fit(n_epoch, cbs=cb)
    223     if show_plot: self.recorder.plot_lr_find()
    224     if suggestions:

~/fastai/fastai/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
    209             self.opt.set_hypers(lr=self.lr if lr is None else lr)
    210             self.n_epoch = n_epoch
--> 211             self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
    212 
    213     def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None

~/fastai/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    158 
    159     def _with_events(self, f, event_type, ex, final=noop):
--> 160         try: self(f'before_{event_type}');  f()
    161         except ex: self(f'after_cancel_{event_type}')
    162         self(f'after_{event_type}');  final()

~/fastai/fastai/learner.py in _do_fit(self)
    200         for epoch in range(self.n_epoch):
    201             self.epoch=epoch
--> 202             self._with_events(self._do_epoch, 'epoch', CancelEpochException)
    203 
    204     def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False):

~/fastai/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    158 
    159     def _with_events(self, f, event_type, ex, final=noop):
--> 160         try: self(f'before_{event_type}');  f()
    161         except ex: self(f'after_cancel_{event_type}')
    162         self(f'after_{event_type}');  final()

~/fastai/fastai/learner.py in _do_epoch(self)
    194 
    195     def _do_epoch(self):
--> 196         self._do_epoch_train()
    197         self._do_epoch_validate()
    198 

~/fastai/fastai/learner.py in _do_epoch_train(self)
    186     def _do_epoch_train(self):
    187         self.dl = self.dls.train
--> 188         self._with_events(self.all_batches, 'train', CancelTrainException)
    189 
    190     def _do_epoch_validate(self, ds_idx=1, dl=None):

~/fastai/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    158 
    159     def _with_events(self, f, event_type, ex, final=noop):
--> 160         try: self(f'before_{event_type}');  f()
    161         except ex: self(f'after_cancel_{event_type}')
    162         self(f'after_{event_type}');  final()

~/fastai/fastai/learner.py in all_batches(self)
    164     def all_batches(self):
    165         self.n_iter = len(self.dl)
--> 166         for o in enumerate(self.dl): self.one_batch(*o)
    167 
    168     def _do_one_batch(self):

~/fastai/fastai/learner.py in one_batch(self, i, b)
    182         self.iter = i
    183         self._split(b)
--> 184         self._with_events(self._do_one_batch, 'batch', CancelBatchException)
    185 
    186     def _do_epoch_train(self):

~/fastai/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    158 
    159     def _with_events(self, f, event_type, ex, final=noop):
--> 160         try: self(f'before_{event_type}');  f()
    161         except ex: self(f'after_cancel_{event_type}')
    162         self(f'after_{event_type}');  final()

~/fastai/fastai/learner.py in _do_one_batch(self)
    170         self('after_pred')
    171         if len(self.yb):
--> 172             self.loss_grad = self.loss_func(self.pred, *self.yb)
    173             self.loss = self.loss_grad.clone()
    174         self('after_loss')

<ipython-input-11-1da375f2a6ef> in loss_fn(pred, targ)
      5 def loss_fn(pred, targ):
      6     targ[targ==255] = 1
----> 7     return torch.nn.functional.cross_entropy(pred, targ.squeeze(1).type(torch.long))
      8 
      9 learn = unet_learner(dl, resnet18, n_in=4, n_out=2, pretrained=False, loss_func=loss_fn, metrics=acc_metric)

~/anaconda3/envs/fastaf/lib/python3.8/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2460         tens_ops = (input, target)
   2461         if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
-> 2462             return handle_torch_function(
   2463                 cross_entropy, tens_ops, input, target, weight=weight,
   2464                 size_average=size_average, ignore_index=ignore_index, reduce=reduce,

~/anaconda3/envs/fastaf/lib/python3.8/site-packages/torch/overrides.py in handle_torch_function(public_api, relevant_args, *args, **kwargs)
   1064 
   1065     func_name = '{}.{}'.format(public_api.__module__, public_api.__name__)
-> 1066     raise TypeError("no implementation found for '{}' on types that implement "
   1067                     '__torch_function__: {}'
   1068                     .format(func_name, list(map(type, overloaded_args))))

TypeError: no implementation found for 'torch.nn.functional.cross_entropy' on types that implement __torch_function__: [<class 'fastai.torch_core.TensorImage'>, <class 'fastai.torch_core.TensorMask'>]

I’m getting this same error on my local machine and on Kaggle, considering that the code apparently worked with Fastai 2.0.13 and torch 1.6.0+cu101 (taken from the Kaggle print statements) I’m guessing something in torch or Fastai has changed. When searching around for the error above I found this thread, which appears to be the same issue, however apparently this has been fixed in the latest release which I have just installed, but I’m still getting the same error.

I have also tried modifying the script to use 3 bands, just to rule out some possible issues. I commented out the line below which appears to result in a dataloader with 3 bands.

map_filename(red_filename, str1='red', str2='nir'),

I also changed

n_in=4

to

n_in=3

However this returns the same error

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-13-d81c6bd29d71> in <module>
----> 1 learn.lr_find()

~/anaconda3/envs/fastaf/lib/python3.8/site-packages/fastai/callback/schedule.py in lr_find(self, start_lr, end_lr, num_it, stop_div, show_plot, suggestions)
   220     n_epoch = num_it//len(self.dls.train) + 1
   221     cb=LRFinder(start_lr=start_lr, end_lr=end_lr, num_it=num_it, stop_div=stop_div)
--> 222     with self.no_logging(): self.fit(n_epoch, cbs=cb)
   223     if show_plot: self.recorder.plot_lr_find()
   224     if suggestions:

~/anaconda3/envs/fastaf/lib/python3.8/site-packages/fastai/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
   209             self.opt.set_hypers(lr=self.lr if lr is None else lr)
   210             self.n_epoch = n_epoch
--> 211             self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
   212 
   213     def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None

~/anaconda3/envs/fastaf/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
   158 
   159     def _with_events(self, f, event_type, ex, final=noop):
--> 160         try: self(f'before_{event_type}');  f()
   161         except ex: self(f'after_cancel_{event_type}')
   162         self(f'after_{event_type}');  final()

~/anaconda3/envs/fastaf/lib/python3.8/site-packages/fastai/learner.py in _do_fit(self)
   200         for epoch in range(self.n_epoch):
   201             self.epoch=epoch
--> 202             self._with_events(self._do_epoch, 'epoch', CancelEpochException)
   203 
   204     def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False):

~/anaconda3/envs/fastaf/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
   158 
   159     def _with_events(self, f, event_type, ex, final=noop):
--> 160         try: self(f'before_{event_type}');  f()
   161         except ex: self(f'after_cancel_{event_type}')
   162         self(f'after_{event_type}');  final()

~/anaconda3/envs/fastaf/lib/python3.8/site-packages/fastai/learner.py in _do_epoch(self)
   194 
   195     def _do_epoch(self):
--> 196         self._do_epoch_train()
   197         self._do_epoch_validate()
   198 

~/anaconda3/envs/fastaf/lib/python3.8/site-packages/fastai/learner.py in _do_epoch_train(self)
   186     def _do_epoch_train(self):
   187         self.dl = self.dls.train
--> 188         self._with_events(self.all_batches, 'train', CancelTrainException)
   189 
   190     def _do_epoch_validate(self, ds_idx=1, dl=None):

~/anaconda3/envs/fastaf/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
   158 
   159     def _with_events(self, f, event_type, ex, final=noop):
--> 160         try: self(f'before_{event_type}');  f()
   161         except ex: self(f'after_cancel_{event_type}')
   162         self(f'after_{event_type}');  final()

~/anaconda3/envs/fastaf/lib/python3.8/site-packages/fastai/learner.py in all_batches(self)
   164     def all_batches(self):
   165         self.n_iter = len(self.dl)
--> 166         for o in enumerate(self.dl): self.one_batch(*o)
   167 
   168     def _do_one_batch(self):

~/anaconda3/envs/fastaf/lib/python3.8/site-packages/fastai/learner.py in one_batch(self, i, b)
   182         self.iter = i
   183         self._split(b)
--> 184         self._with_events(self._do_one_batch, 'batch', CancelBatchException)
   185 
   186     def _do_epoch_train(self):

~/anaconda3/envs/fastaf/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
   158 
   159     def _with_events(self, f, event_type, ex, final=noop):
--> 160         try: self(f'before_{event_type}');  f()
   161         except ex: self(f'after_cancel_{event_type}')
   162         self(f'after_{event_type}');  final()

~/anaconda3/envs/fastaf/lib/python3.8/site-packages/fastai/learner.py in _do_one_batch(self)
   170         self('after_pred')
   171         if len(self.yb):
--> 172             self.loss_grad = self.loss_func(self.pred, *self.yb)
   173             self.loss = self.loss_grad.clone()
   174         self('after_loss')

<ipython-input-12-9b2c6d2023e8> in loss_fn(pred, targ)
     5 def loss_fn(pred, targ):
     6     targ[targ==255] = 1
----> 7     return torch.nn.functional.cross_entropy(pred, targ.squeeze(1).type(torch.long))
     8 
     9 learn = unet_learner(dl, resnet18, n_in=3, n_out=2, pretrained=True, loss_func=loss_fn, metrics=acc_metric).to_fp16()

~/anaconda3/envs/fastaf/lib/python3.8/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
  2460         tens_ops = (input, target)
  2461         if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
-> 2462             return handle_torch_function(
  2463                 cross_entropy, tens_ops, input, target, weight=weight,
  2464                 size_average=size_average, ignore_index=ignore_index, reduce=reduce,

~/anaconda3/envs/fastaf/lib/python3.8/site-packages/torch/overrides.py in handle_torch_function(public_api, relevant_args, *args, **kwargs)
  1064 
  1065     func_name = '{}.{}'.format(public_api.__module__, public_api.__name__)
-> 1066     raise TypeError("no implementation found for '{}' on types that implement "
  1067                     '__torch_function__: {}'
  1068                     .format(func_name, list(map(type, overloaded_args))))

TypeError: no implementation found for 'torch.nn.functional.cross_entropy' on types that implement __torch_function__: [<class 'fastai.torch_core.TensorImage'>, <class 'fastai.torch_core.TensorMask'>]

From my limited understanding I think this is pointing to the custom loss function as the issue, but I don’t understand it well enough to make any progress with it.

Also just linking in @cordmaur as this is his tutorial, so figured he may like to know whats going on.

Hi Nick. I certainly feel your frustration. I have many times been stuck on unfathomable errors, with apparently no way forward.

I suspect that more than one thing is going on. I made a very simple 4 band test case from your tutorial notebook, using random data. It runs perfectly on my local system, including lr_find() and fit_one_cycle(). I suggest that you paste the code into your notebook, split into cells, and see whether it runs.

import torch
from fastai.vision.all import *

tdata = [(torch.rand((4,50,60)),torch.randint(0,2,(50,60))) for i in range(100)]

dl1 = DataLoader(tdata,bs=3)  
dl1.n_inp=1
dl = DataLoaders(dl1, dl1)

xb,yb=dl.one_batch()
xb.shape,yb.shape

def acc_metric(input, target):
    target = target.squeeze(1)
    return (input.argmax(dim=1)==target).float().mean()

def loss_fn(pred, targ):
    targ[targ==255] = 1
    return torch.nn.functional.cross_entropy(pred, targ.squeeze(1).type(torch.long))

learn = unet_learner(dl, resnet18, n_in=4, n_out=2, pretrained=True, loss_func=loss_fn, metrics=acc_metric, normalize=False)

learn.model(xb).shape

learn.lr_find()

learn.fit_one_cycle(1)

If the code fails, probably you have some sort of obtuse fastai/PyTorch configuration error. The “no implementation found” error is one clue. What happens when you simply run
torch.nn.functional.cross_entropy in a notebook cell?

Also note that I added normalize=False to the unet_learner() parameters. With normalize=True, unet_learner() barfs with yet another error neither one of us has ever seen before. Don’t ask me why.

So I suggest seeing whether the simplified code above runs for you, and going forward from there. Unfortunately, diagnosing arcane Python module import errors is beyond my competence. If it turns out to be some sort of configuration problem, hopefully someone on the forums will be able to help you.

Goodnight! :slightly_smiling_face:

1 Like

Thanks Malcolm, I’m running all of this in a fresh conda environment so your sample code worked without a problem.

After playing around for a while I realised that your code is using the data type ‘torch.Tensor’ and the tutorial I was following was using ‘fastai.torch_core.TensorImage’. All I needed to do was change two lines in the tutorial to use ‘torch.tensor’ instead of the fastai version. The lines I changed were:

Within the ‘open_ms_tif’ function

# return TensorImage(ms_img) 
return torch.from_numpy(ms_img)

and from within the ‘SegmentationAlbumentationsTransform’ function

# return TensorImage(aug['image'].transpose(2,0,1)), TensorMask(aug['mask'])
return torch.from_numpy(aug['image'].transpose(2,0,1)), TensorMask(aug['mask'])

Then it worked perfectly with n_in=4 and pretrained=True :smiley:

So it appears that this info from Zachary on this thread is still correct.

The second is an issue with the newest pytorch, it won’t just readily accept types anymore like it used to before (so long as it was a tensor ).

Thanks for all the help Malcolm :+1:, this would have been very difficult without you.

1 Like

Hi Nick. I am glad to have helped you find the solution. Good luck with your project!

Malcolm :slightly_smiling_face:

1 Like