Using A Pair Of Images As The X-Values

Hey all. I’m looking to train a model that takes 2 images as an input, for categorization. I’ve already tried combining them into 1 image (by making an image twice as tall, with one on the bottom and one on top), but this has yielded unimpressive results.

I’ve also tried using a tuple of 2 ImageBlock’s when creating my datablock, but it looks like this hasn’t worked. My code is as follows:

def get_x(r): return r['Image1'], r['Image2']
def get_y(r): return r['Y']

dblock = DataBlock(blocks=((ImageBlock, ImageBlock), CategoryBlock),
                   get_x = get_x, get_y = get_y)
dsets = dblock.datasets(df)

dls = dblock.dataloaders(df, bs = 4)

#df has 3 columns: Image1, Image2, and Y. Image1 and Image2 are both paths to saved images, and Y is the category of that row

Running this leads to the message “Could not do one pass in your dataloader, there is something wrong in it”. Is there a mistake I’ve made, or should this problem be approached differently altogether?

You don’t have to pass the double input as a tuple. I think the following would work.

dblock = DataBlock(blocks=(ImageBlock, ImageBlock, CategoryBlock),
                   get_x = [get_x_1, get_x_2], 
                   get_y = get_y, 
                   n_inp = 2)
dsets = dblock.datasets(df)

If it still fails, use dblock.summary(df) for a more meaningful error message.

2 Likes

This has worked! The datablock has been formed successfully, but I get a new error when trying to train a learner. I assume I’m making another mistake. My updated code is:

def get_x_1(r): return r['Image1']
def get_x_2(r): return r['Image2']
def get_y(r): return r['Y']

dblock = DataBlock(blocks=(ImageBlock, ImageBlock, CategoryBlock),
                   get_x = [get_x_1, get_x_2], 
                   get_y = get_y, 
                   n_inp = 2)
dsets = dblock.datasets(df)

dls = dblock.dataloaders(df, bs = 2)

learn = cnn_learner(dls, resnet50, metrics = accuracy)
lr_min,lr_steep = learn.lr_find()

And the accompanying error is as follows:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-80-30710f09bb32> in <module>
      1 learn = cnn_learner(dls, resnet50, metrics = accuracy)
----> 2 lr_min,lr_steep = learn.lr_find()

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/callback/schedule.py in lr_find(self, start_lr, end_lr, num_it, stop_div, show_plot, suggestions)
    222     n_epoch = num_it//len(self.dls.train) + 1
    223     cb=LRFinder(start_lr=start_lr, end_lr=end_lr, num_it=num_it, stop_div=stop_div)
--> 224     with self.no_logging(): self.fit(n_epoch, cbs=cb)
    225     if show_plot: self.recorder.plot_lr_find()
    226     if suggestions:

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py in fit(self, n_epoch, lr, wd, cbs, reset_opt)
    203             self.opt.set_hypers(lr=self.lr if lr is None else lr)
    204             self.n_epoch = n_epoch
--> 205             self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
    206 
    207     def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    152 
    153     def _with_events(self, f, event_type, ex, final=noop):
--> 154         try:       self(f'before_{event_type}')       ;f()
    155         except ex: self(f'after_cancel_{event_type}')
    156         finally:   self(f'after_{event_type}')        ;final()

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py in _do_fit(self)
    194         for epoch in range(self.n_epoch):
    195             self.epoch=epoch
--> 196             self._with_events(self._do_epoch, 'epoch', CancelEpochException)
    197 
    198     def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False):

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    152 
    153     def _with_events(self, f, event_type, ex, final=noop):
--> 154         try:       self(f'before_{event_type}')       ;f()
    155         except ex: self(f'after_cancel_{event_type}')
    156         finally:   self(f'after_{event_type}')        ;final()

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py in _do_epoch(self)
    188 
    189     def _do_epoch(self):
--> 190         self._do_epoch_train()
    191         self._do_epoch_validate()
    192 

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py in _do_epoch_train(self)
    180     def _do_epoch_train(self):
    181         self.dl = self.dls.train
--> 182         self._with_events(self.all_batches, 'train', CancelTrainException)
    183 
    184     def _do_epoch_validate(self, ds_idx=1, dl=None):

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    152 
    153     def _with_events(self, f, event_type, ex, final=noop):
--> 154         try:       self(f'before_{event_type}')       ;f()
    155         except ex: self(f'after_cancel_{event_type}')
    156         finally:   self(f'after_{event_type}')        ;final()

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py in all_batches(self)
    158     def all_batches(self):
    159         self.n_iter = len(self.dl)
--> 160         for o in enumerate(self.dl): self.one_batch(*o)
    161 
    162     def _do_one_batch(self):

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py in one_batch(self, i, b)
    176         self.iter = i
    177         self._split(b)
--> 178         self._with_events(self._do_one_batch, 'batch', CancelBatchException)
    179 
    180     def _do_epoch_train(self):

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
    152 
    153     def _with_events(self, f, event_type, ex, final=noop):
--> 154         try:       self(f'before_{event_type}')       ;f()
    155         except ex: self(f'after_cancel_{event_type}')
    156         finally:   self(f'after_{event_type}')        ;final()

/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py in _do_one_batch(self)
    161 
    162     def _do_one_batch(self):
--> 163         self.pred = self.model(*self.xb)
    164         self('after_pred')
    165         if len(self.yb): self.loss = self.loss_func(self.pred, *self.yb)

/opt/conda/envs/fastai/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

TypeError: forward() takes 2 positional arguments but 3 were given

Would you happen to know the solution to this as well?

Yes. Now your model does get three inputs: self, Image1, Image2 but only expects self, Image1.

There are two possible solutions.

  1. Adapt the models forward function, to accept two inputs:
def forward(self, x1, x2):
    x = torch.cat([x1, x2], 1)
    # continue normally
    ....

2.Uusing a Callback to stack the inputs.

class ConcatImagesCallback(Callback):
    """
    Takes multiple images and concats them in the channel dim.
    Example:
        Having two Tensors of size (10, 3, 25, 25) would lead to a single Tensor of
        size (10, 6, 25, 25).
    """
    def before_batch(self):
        self.learn.xb = (torch.cat(self.learn.xb, dim=1), )

You’ll probably also have to adapt the input channels for the model as well.

5 Likes

Apologies - I know I’m asking a lot of follow-up questions. I’m not experienced enough with fastai to make these changes. Do you know if there’s a more in-depth tutorial on how to do this, or a way to find the full text of the forward function that a resnet50 normally uses?

Using ‘??forward’ on Jupyter doesn’t seem to show anything.

No problem :slight_smile:
The forward function is called when you pass a tensor to a model.
For example, if you have a tensor x and a model m, you would call m(x) to pass the tensor to the model.
The model is a python class. In a “normal” class, calling m(x)would actually execute m.__call__(x) but in PyTorch it executes m.forward(x).
If you want to inspect the forward function of any PyTorch module, you have to type module.forward??.
For resnet50, you need to type:

resnet = resnet50()
resnet.forward??
1 Like