From ce74df5c0c06c5cfbe79812b71b4d7f4b4e425b6 Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Fri, 27 Jun 2025 13:28:12 +0200 Subject: [PATCH] Fix iterating for best centroid when algorithm is neighbour aware and decrease SAMPLES_PER_CLUSTER_DEFAULT (#130069) * KMeansIntermediate shares assigments --- .../vectors/cluster/HierarchicalKMeans.java | 40 ++------------ .../vectors/cluster/KMeansIntermediate.java | 4 -- .../codec/vectors/cluster/KMeansLocal.java | 52 +++++++++++-------- 3 files changed, 34 insertions(+), 62 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java index 6d50e5c473d0..6f7705bfcc1a 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java @@ -10,7 +10,6 @@ package org.elasticsearch.index.codec.vectors.cluster; import org.apache.lucene.index.FloatVectorValues; -import org.apache.lucene.util.VectorUtil; import java.io.IOException; @@ -21,7 +20,7 @@ public class HierarchicalKMeans { static final int MAXK = 128; static final int MAX_ITERATIONS_DEFAULT = 6; - static final int SAMPLES_PER_CLUSTER_DEFAULT = 256; + static final int SAMPLES_PER_CLUSTER_DEFAULT = 64; static final float DEFAULT_SOAR_LAMBDA = 1.0f; final int dimension; @@ -67,8 +66,7 @@ public class HierarchicalKMeans { // partition the space KMeansIntermediate kMeansIntermediate = clusterAndSplit(vectors, targetSize); if (kMeansIntermediate.centroids().length > 1 && kMeansIntermediate.centroids().length < vectors.size()) { - float f = Math.min((float) samplesPerCluster / targetSize, 1.0f); - int localSampleSize = (int) (f * vectors.size()); + int localSampleSize = Math.min(kMeansIntermediate.centroids().length * samplesPerCluster, vectors.size()); KMeansLocal kMeansLocal = new KMeansLocal(localSampleSize, maxIterations, clustersPerNeighborhood, DEFAULT_SOAR_LAMBDA); kMeansLocal.cluster(vectors, kMeansIntermediate, true); } @@ -86,42 +84,16 @@ public class HierarchicalKMeans { // TODO: instead of creating a sub-cluster assignments reuse the parent array each time int[] assignments = new int[vectors.size()]; - KMeansLocal kmeans = new KMeansLocal(m, maxIterations); float[][] centroids = KMeansLocal.pickInitialCentroids(vectors, k); - KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids); + KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, vectors::ordToDoc); kmeans.cluster(vectors, kMeansIntermediate); // TODO: consider adding cluster size counts to the kmeans algo // handle assignment here so we can track distance and cluster size int[] centroidVectorCount = new int[centroids.length]; - float[][] nextCentroids = new float[centroids.length][dimension]; - for (int i = 0; i < vectors.size(); i++) { - float smallest = Float.MAX_VALUE; - int centroidIdx = -1; - float[] vector = vectors.vectorValue(i); - for (int j = 0; j < centroids.length; j++) { - float[] centroid = centroids[j]; - float d = VectorUtil.squareDistance(vector, centroid); - if (d < smallest) { - smallest = d; - centroidIdx = j; - } - } - centroidVectorCount[centroidIdx]++; - for (int j = 0; j < dimension; j++) { - nextCentroids[centroidIdx][j] += vector[j]; - } - assignments[i] = centroidIdx; - } - - // update centroids based on assignments of all vectors - for (int i = 0; i < centroids.length; i++) { - if (centroidVectorCount[i] > 0) { - for (int j = 0; j < dimension; j++) { - centroids[i][j] = nextCentroids[i][j] / centroidVectorCount[i]; - } - } + for (int assigment : assignments) { + centroidVectorCount[assigment]++; } int effectiveK = 0; @@ -131,8 +103,6 @@ public class HierarchicalKMeans { } } - kMeansIntermediate = new KMeansIntermediate(centroids, assignments, vectors::ordToDoc); - if (effectiveK == 1) { return kMeansIntermediate; } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansIntermediate.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansIntermediate.java index 75caa5c7d328..e44112610812 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansIntermediate.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansIntermediate.java @@ -31,10 +31,6 @@ class KMeansIntermediate extends KMeansResult { this(new float[0][0], new int[0], i -> i, new int[0]); } - KMeansIntermediate(float[][] centroids) { - this(centroids, new int[0], i -> i, new int[0]); - } - KMeansIntermediate(float[][] centroids, int[] assignments) { this(centroids, assignments, i -> i, new int[0]); } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java index b1303b7124b2..892830bdcb34 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java @@ -87,17 +87,17 @@ class KMeansLocal { for (int i = 0; i < sampleSize; i++) { float[] vector = vectors.vectorValue(i); - int[] neighborOffsets = null; - int centroidIdx = -1; + final int assignment = assignments[i]; + final int bestCentroidOffset; if (neighborhoods != null) { - neighborOffsets = neighborhoods.get(assignments[i]); - centroidIdx = assignments[i]; + bestCentroidOffset = getBestCentroidFromNeighbours(centroids, vector, assignment, neighborhoods.get(assignment)); + } else { + bestCentroidOffset = getBestCentroid(centroids, vector); } - int bestCentroidOffset = getBestCentroidOffset(centroids, vector, centroidIdx, neighborOffsets); - if (assignments[i] != bestCentroidOffset) { + if (assignment != bestCentroidOffset) { + assignments[i] = bestCentroidOffset; changed = true; } - assignments[i] = bestCentroidOffset; centroidCounts[bestCentroidOffset]++; for (int d = 0; d < dim; d++) { nextCentroids[bestCentroidOffset][d] += vector[d]; @@ -116,23 +116,28 @@ class KMeansLocal { return changed; } - int getBestCentroidOffset(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) { + int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) { int bestCentroidOffset = centroidIdx; - float minDsq; - if (centroidIdx > 0 && centroidIdx < centroids.length) { - minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]); - } else { - minDsq = Float.MAX_VALUE; + assert centroidIdx >= 0 && centroidIdx < centroids.length; + float minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]); + for (int offset : centroidOffsets) { + float dsq = VectorUtil.squareDistance(vector, centroids[offset]); + if (dsq < minDsq) { + minDsq = dsq; + bestCentroidOffset = offset; + } } + return bestCentroidOffset; + } - int k = 0; - for (int j = 0; j < centroids.length; j++) { - if (centroidOffsets == null || j == centroidOffsets[k]) { - float dsq = VectorUtil.squareDistance(vector, centroids[j]); - if (dsq < minDsq) { - minDsq = dsq; - bestCentroidOffset = j; - } + int getBestCentroid(float[][] centroids, float[] vector) { + int bestCentroidOffset = 0; + float minDsq = Float.MAX_VALUE; + for (int i = 0; i < centroids.length; i++) { + float dsq = VectorUtil.squareDistance(vector, centroids[i]); + if (dsq < minDsq) { + minDsq = dsq; + bestCentroidOffset = i; } } return bestCentroidOffset; @@ -271,7 +276,8 @@ class KMeansLocal { return; } - int[] assignments = new int[n]; + int[] assignments = kMeansIntermediate.assignments(); + assert assignments.length == n; float[][] nextCentroids = new float[centroids.length][vectors.dimension()]; for (int i = 0; i < maxIterations; i++) { if (stepLloyd(vectors, centroids, nextCentroids, assignments, sampleSize, neighborhoods) == false) { @@ -291,7 +297,7 @@ class KMeansLocal { * @param maxIterations the max iterations to shift centroids */ public static void cluster(FloatVectorValues vectors, float[][] centroids, int sampleSize, int maxIterations) throws IOException { - KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids); + KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, new int[vectors.size()], vectors::ordToDoc); KMeansLocal kMeans = new KMeansLocal(sampleSize, maxIterations); kMeans.cluster(vectors, kMeansIntermediate); }