 # Torch equivalent of np.where

Hi,

I am working with 4-dimensional tensors, and I am using for loop to check the conditional flow, I wanted to know how can I implement the below code efficiently

``````## object_temp Shape -> (24, 7, 7, 30)
## Object Shape -> (24, 7, 7)
## target Shape -> ( 24, 7, 7, 30)

for i in range(object.shape):
for j in range(object.shape):
for k in range(object.shape):
if ((target[i,j,k,4] >0) | (target[i,j,k,9] >0)):
object[i,j,k] = 1
object_temp[i,j,k,:] = 1
``````

Make a 3D mask from the conditional (all non-looping tensor operations), and then `object[mask] = 1`.

Hope this is right, because I did not test it! I am creating a mask only. I am new to PyTorch so my knowledge is limited, can you point to the functions which I shall use.
I am creating a mask for conditions:

((target[i,j,k,4] >0) | (target[i,j,k,9] >0))

Hi,
I figure out how to create a mask for 3 dimensions, but how can I use it to extend to 4 dimension.

Below is my code:

``````## target Shape -> ( 24, 7, 7, 30)
## temp Shape -> ( 24, 7, 7)

### Wanted object_temp
##object_temp Shape -> ( 24, 7, 7, 30)
``````

temp = ((target[…,4] > 0 ) | (target[…,9] > 0 ))
object[temp] = 1

How can I extend this idea to create object_temp variable?

``````mask = ((target[...,4] > 0 ) | (target[...,9] > 0 ))
``````

But you should verify that this gives the correct answer. It might be wrong! Hint: a good way to learn PyTorch is to experiment directly in a Jupyter notebook. The PyTorch tutorials are a good resource, too. @Pomo, the implementation is giving error.

I think it needs a 4-dimensional tensor of true and False value only.

The solution to the above problem is torch function expand_as() combined with unsqueeze.

Is there something I do not understand about your question? The following code does not throw an error and gives the same result as your code using loops:

``````import torch

## object_temp Shape -> (24, 7, 7, 30)
## Object Shape -> (24, 7, 7)
## target Shape -> ( 24, 7, 7, 30)

object_temp = torch.randint(10, (24,7,7,30))
Object = torch.randint(10,(24, 7, 7))
target = torch.randint(2, ( 24, 7, 7, 30))

mask = ((target[...,4] > 0 ) | (target[...,9] > 0 ))
I did change “object” to “Object” because “object” is a reserved word in Python. 