Fix iterating for best centroid when algorithm is neighbour aware and decrease SAMPLES_PER_CLUSTER_DEFAULT (#130069)

* KMeansIntermediate shares assigments
This commit is contained in:
Ignacio Vera 2025-06-27 13:28:12 +02:00 committed by GitHub
parent 7c213baf4d
commit ce74df5c0c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 34 additions and 62 deletions

View file

@ -10,7 +10,6 @@
package org.elasticsearch.index.codec.vectors.cluster; package org.elasticsearch.index.codec.vectors.cluster;
import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.util.VectorUtil;
import java.io.IOException; import java.io.IOException;
@ -21,7 +20,7 @@ public class HierarchicalKMeans {
static final int MAXK = 128; static final int MAXK = 128;
static final int MAX_ITERATIONS_DEFAULT = 6; static final int MAX_ITERATIONS_DEFAULT = 6;
static final int SAMPLES_PER_CLUSTER_DEFAULT = 256; static final int SAMPLES_PER_CLUSTER_DEFAULT = 64;
static final float DEFAULT_SOAR_LAMBDA = 1.0f; static final float DEFAULT_SOAR_LAMBDA = 1.0f;
final int dimension; final int dimension;
@ -67,8 +66,7 @@ public class HierarchicalKMeans {
// partition the space // partition the space
KMeansIntermediate kMeansIntermediate = clusterAndSplit(vectors, targetSize); KMeansIntermediate kMeansIntermediate = clusterAndSplit(vectors, targetSize);
if (kMeansIntermediate.centroids().length > 1 && kMeansIntermediate.centroids().length < vectors.size()) { if (kMeansIntermediate.centroids().length > 1 && kMeansIntermediate.centroids().length < vectors.size()) {
float f = Math.min((float) samplesPerCluster / targetSize, 1.0f); int localSampleSize = Math.min(kMeansIntermediate.centroids().length * samplesPerCluster, vectors.size());
int localSampleSize = (int) (f * vectors.size());
KMeansLocal kMeansLocal = new KMeansLocal(localSampleSize, maxIterations, clustersPerNeighborhood, DEFAULT_SOAR_LAMBDA); KMeansLocal kMeansLocal = new KMeansLocal(localSampleSize, maxIterations, clustersPerNeighborhood, DEFAULT_SOAR_LAMBDA);
kMeansLocal.cluster(vectors, kMeansIntermediate, true); kMeansLocal.cluster(vectors, kMeansIntermediate, true);
} }
@ -86,42 +84,16 @@ public class HierarchicalKMeans {
// TODO: instead of creating a sub-cluster assignments reuse the parent array each time // TODO: instead of creating a sub-cluster assignments reuse the parent array each time
int[] assignments = new int[vectors.size()]; int[] assignments = new int[vectors.size()];
KMeansLocal kmeans = new KMeansLocal(m, maxIterations); KMeansLocal kmeans = new KMeansLocal(m, maxIterations);
float[][] centroids = KMeansLocal.pickInitialCentroids(vectors, k); float[][] centroids = KMeansLocal.pickInitialCentroids(vectors, k);
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids); KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, vectors::ordToDoc);
kmeans.cluster(vectors, kMeansIntermediate); kmeans.cluster(vectors, kMeansIntermediate);
// TODO: consider adding cluster size counts to the kmeans algo // TODO: consider adding cluster size counts to the kmeans algo
// handle assignment here so we can track distance and cluster size // handle assignment here so we can track distance and cluster size
int[] centroidVectorCount = new int[centroids.length]; int[] centroidVectorCount = new int[centroids.length];
float[][] nextCentroids = new float[centroids.length][dimension]; for (int assigment : assignments) {
for (int i = 0; i < vectors.size(); i++) { centroidVectorCount[assigment]++;
float smallest = Float.MAX_VALUE;
int centroidIdx = -1;
float[] vector = vectors.vectorValue(i);
for (int j = 0; j < centroids.length; j++) {
float[] centroid = centroids[j];
float d = VectorUtil.squareDistance(vector, centroid);
if (d < smallest) {
smallest = d;
centroidIdx = j;
}
}
centroidVectorCount[centroidIdx]++;
for (int j = 0; j < dimension; j++) {
nextCentroids[centroidIdx][j] += vector[j];
}
assignments[i] = centroidIdx;
}
// update centroids based on assignments of all vectors
for (int i = 0; i < centroids.length; i++) {
if (centroidVectorCount[i] > 0) {
for (int j = 0; j < dimension; j++) {
centroids[i][j] = nextCentroids[i][j] / centroidVectorCount[i];
}
}
} }
int effectiveK = 0; int effectiveK = 0;
@ -131,8 +103,6 @@ public class HierarchicalKMeans {
} }
} }
kMeansIntermediate = new KMeansIntermediate(centroids, assignments, vectors::ordToDoc);
if (effectiveK == 1) { if (effectiveK == 1) {
return kMeansIntermediate; return kMeansIntermediate;
} }

View file

@ -31,10 +31,6 @@ class KMeansIntermediate extends KMeansResult {
this(new float[0][0], new int[0], i -> i, new int[0]); this(new float[0][0], new int[0], i -> i, new int[0]);
} }
KMeansIntermediate(float[][] centroids) {
this(centroids, new int[0], i -> i, new int[0]);
}
KMeansIntermediate(float[][] centroids, int[] assignments) { KMeansIntermediate(float[][] centroids, int[] assignments) {
this(centroids, assignments, i -> i, new int[0]); this(centroids, assignments, i -> i, new int[0]);
} }

View file

@ -87,17 +87,17 @@ class KMeansLocal {
for (int i = 0; i < sampleSize; i++) { for (int i = 0; i < sampleSize; i++) {
float[] vector = vectors.vectorValue(i); float[] vector = vectors.vectorValue(i);
int[] neighborOffsets = null; final int assignment = assignments[i];
int centroidIdx = -1; final int bestCentroidOffset;
if (neighborhoods != null) { if (neighborhoods != null) {
neighborOffsets = neighborhoods.get(assignments[i]); bestCentroidOffset = getBestCentroidFromNeighbours(centroids, vector, assignment, neighborhoods.get(assignment));
centroidIdx = assignments[i]; } else {
bestCentroidOffset = getBestCentroid(centroids, vector);
} }
int bestCentroidOffset = getBestCentroidOffset(centroids, vector, centroidIdx, neighborOffsets); if (assignment != bestCentroidOffset) {
if (assignments[i] != bestCentroidOffset) { assignments[i] = bestCentroidOffset;
changed = true; changed = true;
} }
assignments[i] = bestCentroidOffset;
centroidCounts[bestCentroidOffset]++; centroidCounts[bestCentroidOffset]++;
for (int d = 0; d < dim; d++) { for (int d = 0; d < dim; d++) {
nextCentroids[bestCentroidOffset][d] += vector[d]; nextCentroids[bestCentroidOffset][d] += vector[d];
@ -116,23 +116,28 @@ class KMeansLocal {
return changed; return changed;
} }
int getBestCentroidOffset(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) { int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) {
int bestCentroidOffset = centroidIdx; int bestCentroidOffset = centroidIdx;
float minDsq; assert centroidIdx >= 0 && centroidIdx < centroids.length;
if (centroidIdx > 0 && centroidIdx < centroids.length) { float minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]);
minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]); for (int offset : centroidOffsets) {
} else { float dsq = VectorUtil.squareDistance(vector, centroids[offset]);
minDsq = Float.MAX_VALUE; if (dsq < minDsq) {
minDsq = dsq;
bestCentroidOffset = offset;
}
} }
return bestCentroidOffset;
}
int k = 0; int getBestCentroid(float[][] centroids, float[] vector) {
for (int j = 0; j < centroids.length; j++) { int bestCentroidOffset = 0;
if (centroidOffsets == null || j == centroidOffsets[k]) { float minDsq = Float.MAX_VALUE;
float dsq = VectorUtil.squareDistance(vector, centroids[j]); for (int i = 0; i < centroids.length; i++) {
if (dsq < minDsq) { float dsq = VectorUtil.squareDistance(vector, centroids[i]);
minDsq = dsq; if (dsq < minDsq) {
bestCentroidOffset = j; minDsq = dsq;
} bestCentroidOffset = i;
} }
} }
return bestCentroidOffset; return bestCentroidOffset;
@ -271,7 +276,8 @@ class KMeansLocal {
return; return;
} }
int[] assignments = new int[n]; int[] assignments = kMeansIntermediate.assignments();
assert assignments.length == n;
float[][] nextCentroids = new float[centroids.length][vectors.dimension()]; float[][] nextCentroids = new float[centroids.length][vectors.dimension()];
for (int i = 0; i < maxIterations; i++) { for (int i = 0; i < maxIterations; i++) {
if (stepLloyd(vectors, centroids, nextCentroids, assignments, sampleSize, neighborhoods) == false) { if (stepLloyd(vectors, centroids, nextCentroids, assignments, sampleSize, neighborhoods) == false) {
@ -291,7 +297,7 @@ class KMeansLocal {
* @param maxIterations the max iterations to shift centroids * @param maxIterations the max iterations to shift centroids
*/ */
public static void cluster(FloatVectorValues vectors, float[][] centroids, int sampleSize, int maxIterations) throws IOException { public static void cluster(FloatVectorValues vectors, float[][] centroids, int sampleSize, int maxIterations) throws IOException {
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids); KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, new int[vectors.size()], vectors::ordToDoc);
KMeansLocal kMeans = new KMeansLocal(sampleSize, maxIterations); KMeansLocal kMeans = new KMeansLocal(sampleSize, maxIterations);
kMeans.cluster(vectors, kMeansIntermediate); kMeans.cluster(vectors, kMeansIntermediate);
} }