Torch equivalent of np.where

Hi,

I am working with 4-dimensional tensors, and I am using for loop to check the conditional flow, I wanted to know how can I implement the below code efficiently

## object_temp Shape -> (24, 7, 7, 30)
## Object Shape -> (24, 7, 7)
## target Shape -> ( 24, 7, 7, 30)

        for i in range(object.shape[0]):
          for j in range(object.shape[1]):
            for k in range(object.shape[2]):
              if ((target[i,j,k,4] >0) | (target[i,j,k,9] >0)):
                object[i,j,k] = 1
                object_temp[i,j,k,:] = 1

Make a 3D mask from the conditional (all non-looping tensor operations), and then object[mask] = 1.

Hope this is right, because I did not test it! :slightly_smiling_face:

I am creating a mask only. I am new to PyTorch so my knowledge is limited, can you point to the functions which I shall use.
I am creating a mask for conditions:

((target[i,j,k,4] >0) | (target[i,j,k,9] >0))

Hi,
I figure out how to create a mask for 3 dimensions, but how can I use it to extend to 4 dimension.

Below is my code:

## target Shape -> ( 24, 7, 7, 30)
## temp Shape -> ( 24, 7, 7)

### Wanted object_temp
##object_temp Shape -> ( 24, 7, 7, 30)

temp = ((target[…,4] > 0 ) | (target[…,9] > 0 ))
object[temp] = 1

How can I extend this idea to create object_temp variable?

How about…

mask = ((target[...,4] > 0 ) | (target[...,9] > 0 ))
Object[mask] = 1
object_temp[mask,:] = 1

But you should verify that this gives the correct answer. It might be wrong! :upside_down_face:

Hint: a good way to learn PyTorch is to experiment directly in a Jupyter notebook. The PyTorch tutorials are a good resource, too. :slightly_smiling_face:

@Pomo, the implementation is giving error.

I think it needs a 4-dimensional tensor of true and False value only.

The solution to the above problem is torch function expand_as() combined with unsqueeze.

Is there something I do not understand about your question? The following code does not throw an error and gives the same result as your code using loops:

import torch

## object_temp Shape -> (24, 7, 7, 30)
## Object Shape -> (24, 7, 7)
## target Shape -> ( 24, 7, 7, 30)

object_temp = torch.randint(10, (24,7,7,30))
Object = torch.randint(10,(24, 7, 7))
target = torch.randint(2, ( 24, 7, 7, 30))

mask = ((target[...,4] > 0 ) | (target[...,9] > 0 ))
Object[mask] = 1
object_temp[mask,:] = 1

I did change “object” to “Object” because “object” is a reserved word in Python. :slightly_smiling_face: