Show_results() fails for "listy" outputs

I’d like my neural network to output an ImageTuple as defined in the Siamese tutorial. This works fairly well, including the show and show_batch methods. However show_results fails for the following reason.

Learner.show_results() calls TfmDL.show_results(), so let’s start there:

def show_results(self, b, out, max_n=9, ctxs=None, show=True, **kwargs):
        x,y,its = self.show_batch(b, max_n=max_n, show=False)
        b_out = type(b)(b[:self.n_inp] + (tuple(out) if is_listy(out) else (out,)))
        x1,y1,outs = self.show_batch(b_out, max_n=max_n, show=False)
        res = (x,x1,None,None) if its is None else (x, y, its, outs.itemgot(slice(self.n_inp,None)))
        if not show: return res
        show_results(*res, ctxs=ctxs, max_n=max_n, **kwargs)

Note that my input is of the form
(TensorImage, ImageTuple(TensorImage, TensorImage, ...))
and output is
ImageTuple(TensorImage, TensorImage, ...).
Consequently, since is_listy(ImageTuple(x, y,...)) evaluates to True,
tuple(ImageTuple(x, y,...)) = (x, y, ...)
we have that b_out = (img, x, y, ...) (i.e. different structure than b) and the call to self.show_batch fails.

Unfortunately I don’t know how to fix this! Note that the issue is closely related to this one and may even share a similar resolution.

Phew, managed to get it to work with ImageTuple. The trick is to define a tensor class and a (reversible) transformation which is ordered after ToTensor so that everything that correctly distributes over tuples (such as cropping or resizing) is still allowed to do so. Once those operations are done, you construct a channel-wise concatenated tensor which your NN acts upon; call it ConcatTensor for example. Then you can type-dispatch show_batch and show_results to ConcatTensor and the reversible transformation will take care of everything. In code:

class ConcatTensor(TensorImage): 
    # Subclasses TensorImage for access to GPU transformations
    def show(self, **kwargs): return separate_tensor(self).show(**kwargs)

def separate_tensor(x:ConcatTensor, ch=0): 
    unbound = torch.unbind(x,ch)
    return ImageTuple(TensorImage(unbound[i:i+3]) for i in range(len(unbound)//3))

class ToTensorAlpha(ItemTransform):
    order = 6
    def encodes(self, xy): x,y = xy; return (x, ConcatTensor(torch.cat(y,0)))
    def decodes(self, xy): x,y = xy; return ImageTuple(x, separate_tensor_alpha(y))

FWIW cracking this was a really good intro to the mid-level API. Count me among the converted, the whole system’s really nice.