Tensor.replacing(with:where:): replacing on false

This is an interesting direction! A new case in TensorRange may not be a good fit because a TensorRange only applies to one dimension.

For now, we can start thinking about adding a subscript that takes a boolean tensor. All subscripts require also a getter, so what’s unclear to me is whether scalars under false should be replaced with zero by default:

// Option 1
public extension Tensor where Scalar: Numeric {
    subscript(mask: Tensor<Bool>) -> Tensor {
        get {
            return Tensor(0).broadcast(like: self).replacing(with: self, where: mask)
        }
        set {
            return replacing(with: newValue, where: mask)
        }
    }
}

Or, we could make it take a default scalar that specifies the value under false.

// Option 2, take a default scalar
public extension Tensor where Scalar: AdditiveArithmetic {
    subscript(mask: Tensor<Bool>, otherwise scalarOnFalse: Scalar = .zero) -> Tensor {
        get {
            return Tensor(scalarOnFalse).broadcast(like: self).replacing(with: self, where: mask)
        }
        set {
            return replacing(with: newValue, where: mask)
        }
    }
}
// Option 3, take a non-default tensor, achieving `replacing(with:where:)`'s full functionality.
public extension Tensor where Scalar: AdditiveArithmetic {
    subscript(mask: Tensor<Bool>, otherwise scalarsOnFalse: Tensor) -> Tensor {
        get {
            return scalarOnFalse.replacing(with: self, where: mask)
        }
        set {
            return replacing(with: newValue, where: mask)
        }
    }
}

I personally prefer option 2, as option 3 could be harder to use since it takes two tensors.

1 Like