Elegant min/max per channel

Is there an elegant way to get stats (such as min or max) per channel of a tensor?

Let’s say I have this tensor:

a = torch.randn(2, 3, 4)
a
tensor([[[ 0.0064,  0.7634,  0.2181, -2.4037],
         [-0.0605,  0.2597, -0.4989,  1.0030],
         [ 1.0533,  1.0601,  1.4312, -0.6003]],

        [[ 0.1680, -1.4486, -0.3730, -0.6980],
         [ 0.3079,  0.7013,  0.9557, -0.4858],
         [ 0.1248, -2.0350,  1.2599, -1.4085]]])

What I would like is the following result:

torch.min(a[0,:,:]), torch.min(a[1,:,:])
(tensor(-2.4037), tensor(-2.0350))

but without having to specify each channel explicitly.

You could do this possibly:

torch.min(a[:1,:,:], dim=2)

Thanks! This doesn’t yield the same result though:

torch.return_types.min(
values=tensor([[-2.4037, -0.4989, -0.6003]]),
indices=tensor([[3, 2, 3]]))

This yields the correct values, but it’s quite ugly:

torch.min(torch.min(a[:,:,:], dim=1).values, dim=1).values
tensor([-2.4037, -2.0350])
1 Like

Looking on other forums, it seems likely that torch has no way to do this:

Map function along arbitrary dimension with pytorch? numpy?

Apply function along dimension of tensor?

1 Like

In case someone is curious, this is a somewhat clunky general solution:

def map_along(f, x, dim):
    mapped = []
    for channel in range(x.shape[dim]):
        subx = torch.index_select(x, dim, torch.tensor(channel))
        mapped.append(f(subx))
    return(torch.tensor(mapped))
map_along(torch.min, a, 0)
tensor([-2.4037, -2.0350])
1 Like

How about flatten the last two dimensions, and max(dim=1)?

a.flatten(1).max(dim=1).values

1 Like

I was going to say the same:

a.view(a.shape[0],-1).min(1).values

There you go @msp, let me know if this works for you

1 Like

Thanks guys!

@Pomo

With this I don’t get the right result:

tensor([1.4312, 1.2599])

@kshitijpatil09

This works:

tensor([-2.4037, -2.0350])

Interesting, I will have to study .view a bit more.

Because one says min and the other max.

1 Like

It’s just PyTorch’s way of saying reshape :wink:

1 Like

🤦 right :joy_cat: :+1: