Thank you Vincent for sharing! I have always wondered why we used the Gram matrix, I had some intuition that it has something to do with second order terms but could not come up with a better loss function. This is the insight I was looking for!
For the curious here is my Pytorch implementation (ported from Tensorflow):
def torch_moments(x):
if len(x.shape) == 3:
c, w, h = x.shape
elif len(x.shape) == 4:
x = x[0, :, :, :]
c, w, h = x.shape
elif len(x.shape) == 2:
w, h = x.shape
c = 1
n = w * h
x = x.permute(1, 2, 0)
flat = torch.reshape(x, (n, c))
mu = torch.mean(flat, dim=0, keepdim=True)
cov = torch.matmul(torch.transpose(flat - mu, 0, 1), flat - mu) / n
return mu, cov
def wdist(m1, m2):
mean_stl, cov_stl = torch_moments(m1)
eigvals, eigvects = torch.symeig(cov_stl, eigenvectors=True)
eigroot_mat = torch.diag(torch.sqrt(torch.max(eigvals, torch.tensor([0.]))))
torch.matmul(eigvects, eigroot_mat)
root_cov_stl = torch.matmul(torch.matmul(eigvects, eigroot_mat), torch.transpose(eigvects, 1, 0))
tr_cov_stl = torch.sum(torch.max(eigvals, torch.tensor([0.])))
mean_synth, cov_synth = torch_moments(m2)
tr_cov_synth = torch.sum(torch.max(torch.symeig(cov_synth, eigenvectors=True)[0], torch.tensor([0.])))
mean_diff_squared = torch.sum(torch.square(mean_stl-mean_synth))
cov_prod = torch.matmul(torch.matmul(root_cov_stl, cov_synth), root_cov_stl)
var_overlap = torch.sum(torch.sqrt(torch.max(torch.symeig(cov_prod, eigenvectors=True)[0], torch.tensor([0.1]))))
dist = mean_diff_squared + tr_cov_stl + tr_cov_synth - 2 * var_overlap
return dist