This is an interesting direction! A new case in TensorRange
may not be a good fit because a TensorRange
only applies to one dimension.
For now, we can start thinking about adding a subscript
that takes a boolean tensor. All subscript
s require also a getter, so what’s unclear to me is whether scalars under false
should be replaced with zero by default:
// Option 1
public extension Tensor where Scalar: Numeric {
subscript(mask: Tensor<Bool>) -> Tensor {
get {
return Tensor(0).broadcast(like: self).replacing(with: self, where: mask)
}
set {
return replacing(with: newValue, where: mask)
}
}
}
Or, we could make it take a default scalar that specifies the value under false
.
// Option 2, take a default scalar
public extension Tensor where Scalar: AdditiveArithmetic {
subscript(mask: Tensor<Bool>, otherwise scalarOnFalse: Scalar = .zero) -> Tensor {
get {
return Tensor(scalarOnFalse).broadcast(like: self).replacing(with: self, where: mask)
}
set {
return replacing(with: newValue, where: mask)
}
}
}
// Option 3, take a non-default tensor, achieving `replacing(with:where:)`'s full functionality.
public extension Tensor where Scalar: AdditiveArithmetic {
subscript(mask: Tensor<Bool>, otherwise scalarsOnFalse: Tensor) -> Tensor {
get {
return scalarOnFalse.replacing(with: self, where: mask)
}
set {
return replacing(with: newValue, where: mask)
}
}
}
I personally prefer option 2, as option 3 could be harder to use since it takes two tensors.