Meanshift Clustering in Pytorch

(Brendan Fortuner) #1

Here’s a thread for those interested in helping Jeremy implement Meanshift clustering in Pytorch. @jeremy perhaps you can elaborate on specific deliverables for this project?

(Jeremy Howard) #2

Thanks for creating this thread. I know that @Matthew and @bckenstler are interested too, FYI.

Next step is to learn about locality sensitive hashing, and/or spill trees, and try to implement (batch-wise!) in Pytorch. LSH is a pretty simple algorithm to implement, so maybe see if you can start there? The only problem may be if it turns out that some of the bit-twiddling ops aren’t in Pytorch - if that turns out to be the case, spill trees may be a better option.

Once you have a working GPU-accelerated approximate nearest neighbor (ANN) you can make sure that both that, and the mean-shift code, can handle a wide range of datasets. E.g. what if the dimensionality of the dataset is really big? (We should probably dynamically choose the batch size to fit in GPU RAM).

Then combine the two pieces together so we have the ANN used as step in the clustering - and make sure it still works well. So for that, it would be could to find (or come up with) some good test cases to measure the accuracy and speed of the algorithm, so we can compare.

Then we should think about the little niceties, like:

  • Auto-select the kernel bandwidth (optionally)
  • Loop until the algorithm is stable, rather than a fixed # times (maybe make the “stable” measurement a parameter too)

(nima) #3

Wonder if it makes sense to take an existing numpy based implementation (e.g. and then apply pytorch where appropriate.

A quick look at the file makes it seem relatively straight forward. Though, not sure how to store the actual hash tables with pytorch.

Scikit-learn also has an implmenetation that we could take a look at.

(David Gutman) #4

There was another implementation I came across too:

(Jeremy Howard) #5

@davecg @nima they would both be interesting to try - in fact you could start by just using the numpy implementation inside the current pytorch clustering loop, and see what the performance looks like (for a variety of # rows and # cols). If it turns out that the approx nearest neighbors timing isn’t significant, then there wouldn’t even be a need to port it to pytorch!

(Kent) #6

A question for usage of the Pytorch implementation: can it be only used with Pytorch, or can it be mixed with other frameworks? For example, if one builds a model in Keras with Tensorflow backend, and wants to call this Pytorch implementation of Meanshift Clustering, will it cause any collisions between two libraries (e.g. Tensorflow and Pytorch competing for the same GPU resource)?

(Jeremy Howard) #7

Potentially that would cause challenges with GPU memory, although if you were careful to clear the GPU RAM at appropriate times that would be OK.

Combining CPU libs with Pytorch, OTOH, should be entirely straightforward.

(Kent) #8

I see, that would mean to avoid such potential conflicts, it is best to implement it in Keras if the model is written in Keras. Are there major challenges doing this in Keras?

(Jeremy Howard) #9

Keras would be tough. You need something lower level IMHO. i.e. Pytorch, TF, or Theano. I’d suggest Pytorch.

(Suresh ) #10

I tried working on this, but did not get very far on the GPU optimization. However, here is my concise version of the code from in case anyone else is attempting to solve this.

class PyTorchLSHash(object):
def init(self, hash_size, input_dim, num_hashtables=1):

    self.uniform_planes = [np.random.randn(hash_size, input_dim)
                           for _ in range(num_hashtables)]
    self.hash_tables = [dict() for i in range(num_hashtables)]
def _hash(self, planes, input_point):
    input_point = np.array(input_point)  # for faster dot product
    projections =, input_point)
    return "".join(['1' if i > 0 else '0' for i in projections])
def index(self, input_point):
    value = tuple(input_point)
    for i, table in enumerate(self.hash_tables):
        table.setdefault(self._hash(self.uniform_planes[i], input_point),[]).append(value)
def query(self, query_point, num_results=None):
    candidates = set()
    for i, table in enumerate(self.hash_tables):
        binary_hash = self._hash(self.uniform_planes[i], query_point)
        candidates.update(table.get(binary_hash, []))
    d_func = PyTorchLSHash.torch_euclidean_dist_square
    query_point = torch.LongTensor(query_point)
    candidates = [(ix, d_func(query_point, torch.LongTensor(ix)))
                  for ix in candidates]
    candidates.sort(key=lambda x: x[1])
    return candidates[:num_results] if num_results else candidates
def torch_euclidean_dist_square(x, y):
    diff = x - y
    return, diff)

you can test the following using the code below

initial setup parameters

‘# hash_size = 6
hash_size = 2
input_vector_size = 8
’# num_samples = 10000000 # number of samples to insert in the array
num_samples = 5
’# lets generate random values between 1,30 for each vector.
a = np.random.randint(1,30,(num_samples,input_vector_size)).tolist()
simple known query item…
b = a[-1:] + np.ones((1,8), dtype=int) # slightly perturb the last element for search search
print(b, a[-1:])
query_item = b[0].tolist() # this query should result in a distance of input_vector_size

now insert items in index…

pylsh = PyTorchLSHash(hash_size, input_vector_size)
for x in a:

finally, query item

%time query_result_gpu = pylsh.query(query_item, 2)
query_result_gpu, query_item

(Suresh ) #11

Adding my notes for the above code. I don’t have the time to work on this today, but if someone wants to continue, here are my thoughts.

The hash function is not the bottle neck, but as a starting point, I was able to replace.

input_point = np.array(input_point) # for faster dot product
projections =, input_point)


        input_point = torch.DoubleTensor(input_point).view(self.input_dim,1)
        planes = torch.from_numpy(planes)
        projections = dot(planes,input_point)
        projections = projections.numpy()

Also, the core of the optimization lies in these few lines

    query_point = torch.LongTensor(query_point)
    candidates = [(ix, d_func(query_point, torch.LongTensor(ix)))
                  for ix in candidates]
def torch_euclidean_dist_square(x, y):
    diff = x - y
    return, diff)

For optimization, my thinking is that we vecorize the code first and then go to the GPU and do work with batches.
For the vectorization process, we should be able to go from candidates (set) -> candidates(array) -> diff = (query - candidates) (array) via broadcasting -> elementwise dot product on diff array. --> append candidates , dot products result.

Also, I dont think sorting is the bottle neck, but it would be nice if we can do that in the GPU.

(Jeremy Howard) #12

I just had a brief look into using sklearn.neighbors.KDTree for a sample problem in the notebook (this won’t scale to higher dimensions - just for doing some quick testing). It appears that it adds negligible performance impact to build the index each epoch, so seems this approach is likely to be effective. Using the returned indexes will take some effort - I think torch.gather() looks the most likely approach, or else just index_mask().

(Suresh ) #13

Am I correct in reading that the KDTree wont scale to larger dimensions or were you commenting on the code above?

(Suresh ) #14

I was able to vectorize the above to the lines below.

‘# using numpy alone.
’# diff_t = np.array(list(candidates)) - query_point
’# result = (diff_t * diff_t).sum(-1) #just a dot product.
’# using pytorch
input_dim = input_vector_size
diff_t = sub(torch.FloatTensor(list(candidates)),torch.FloatTensor(query_point).view(1,input_dim))
dp = (diff_t * diff_t).sum((1)) #dot product
result = dp.numpy().flatten().tolist() #convert to tensor to python

    candidates = list(zip(candidates,result))

this gave me 3x to 5x improvement over previous code. (5sec -> 1sec for 10M items). However, blindly adding cuda() alone causes memory error.

(Jeremy Howard) #15

Yeah you’ll want to do it in batches - see my pytorch meanshift code for a simple example of that.

(Jeremy Howard) #16

unsqueeze is an easier way to do this FYI.

(Jeremy Howard) #17


(Suresh ) #18

thanks for the unsqueeze tip.

I knew that I had to do batch processing for speedup, but mental math on memory says that it should fit in memory. 10Million / 10 * 8 dimensions is not that big for memory. I am wondering if including tensorflow eats the memory… my header includes are as follows:

import numpy as np
import importlib
import utils2; importlib.reload(utils2)
from utils2 import *

import torch_utils; importlib.reload(torch_utils)
from torch_utils import *

(Jeremy Howard) #19

If you use tensorflow at all it’ll eat all your memory. So it could well be that.

Nonetheless, we do need to ensure we can handle datasets that are bigger than RAM…

(Brendan Fortuner) #20

We had a hard time with this at the hackathon when looping to process video frames. Is there some tweak to tensorflow configurations we can make to avoid this? I’ve had a better experience with Theano so far more generally