Improving meanshift

Hi all. I’m opening this topic to discuss meanshift, state of the art, reproducibility and possible improvements.

I did a quick search online and found some new interesting approaches:

I personally find meanshift++ a very interesting direction since the runtime grows linearly with dataset size!

4 Likes

I recommend having a go at implementing some ideas before reading these papers BTW.

3 Likes

I implemented the miniai meanshift using jax (just to learn), but it is worse than what we had:

PyTorch: 3.93 ms ± 185 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
JAX: 62.3 ms ± 1.74 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Matrix multiplication is 44 times faster in PyTorch than in JAX.

Correction:

PyTorch was 1.6 times faster than JAX for matrix multiplication.
JAX meanshift is 2.24 times faster than the batched version of PyTorch meanshift:

PyTorch: 3.8 ms ± 133 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
JAX: 1.69 ms ± 49.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

1 Like

GPU or CPU?

GPU. I will share the notebook but today was a crazy day where I live (Brasilia, Brazil’s capital).

Oh… interesting times for you I expect

1 Like

There was an error in the implementation. Here are the notebooks of the results above:
github.com/fredguth/jaxai

A day of infamy, I would say. It is so said to see what these 21 century fascists believe and all the damage they can do to our society(social network AI algorithms are much to blame, in my opinion. But not the only culprits, of course).

The positive side of this story is that they did not accomplish what they wanted and despite the damage in buildings and art pieces, nobody was killed.