How can discrete mean-shift clustering be implemented in a backpropagation friendly way?

I am attempting to implement the Lasernet paper for LIDAR 3d detection, and most of it is straightforward except for the mean shift clustering part discussed in section 3.4. I’ll try to summarize but spending a minute to view the actual equations would probably be best.

To briefly summarize: The method makes bounding box predictions in the form of a per-class box prediction for each pixel in a dense 2d range image from the view of the sensor. The network can be used with just this, but the authors note that adding a mean shift clustering step can greatly improve performance. In order to do this, the range image must be projected into a 2d top-down xy grid of bins. Then the mean for each cell is computed by weighting itself with each of its 8 neighbors, and if its new mean falls into an adjacent cell, all of the points that fall within the cell are moved to the adjacent cell. After performing 3 iterations of this, the predictions in each cell are combined into one prediction, and then every pixel in the original range image corresponding to that cell is updated with the new combined prediction. Then the loss is computed based on the updated range image.

Implementing this kind of algorithm forward-pass only in vanilla Python or C++ would be pretty easy for me. But I’m unsure of how I’d represent this in torch or tf. For one, the xy grid naturally lends itself to a tensor, except that its innermost dimension (the corresponding points in the range image) is of variable size. Additionally, assigning range image points to bins basically involves converting a float position to an int index in the xy tensor, and seemingly can only be done one-by-one. How does this binning process not impede backpropagation?