Fix iterating for best centroid when algorithm is neighbour aware and decrease SAMPLES_PER_CLUSTER_DEFAULT (#130069)

* KMeansIntermediate shares assigments
This commit is contained in:
Ignacio Vera 2025-06-27 13:28:12 +02:00 committed by GitHub
parent 7c213baf4d
commit ce74df5c0c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 34 additions and 62 deletions

View file

@ -10,7 +10,6 @@
package org.elasticsearch.index.codec.vectors.cluster;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.util.VectorUtil;
import java.io.IOException;
@ -21,7 +20,7 @@ public class HierarchicalKMeans {
static final int MAXK = 128;
static final int MAX_ITERATIONS_DEFAULT = 6;
static final int SAMPLES_PER_CLUSTER_DEFAULT = 256;
static final int SAMPLES_PER_CLUSTER_DEFAULT = 64;
static final float DEFAULT_SOAR_LAMBDA = 1.0f;
final int dimension;
@ -67,8 +66,7 @@ public class HierarchicalKMeans {
// partition the space
KMeansIntermediate kMeansIntermediate = clusterAndSplit(vectors, targetSize);
if (kMeansIntermediate.centroids().length > 1 && kMeansIntermediate.centroids().length < vectors.size()) {
float f = Math.min((float) samplesPerCluster / targetSize, 1.0f);
int localSampleSize = (int) (f * vectors.size());
int localSampleSize = Math.min(kMeansIntermediate.centroids().length * samplesPerCluster, vectors.size());
KMeansLocal kMeansLocal = new KMeansLocal(localSampleSize, maxIterations, clustersPerNeighborhood, DEFAULT_SOAR_LAMBDA);
kMeansLocal.cluster(vectors, kMeansIntermediate, true);
}
@ -86,42 +84,16 @@ public class HierarchicalKMeans {
// TODO: instead of creating a sub-cluster assignments reuse the parent array each time
int[] assignments = new int[vectors.size()];
KMeansLocal kmeans = new KMeansLocal(m, maxIterations);
float[][] centroids = KMeansLocal.pickInitialCentroids(vectors, k);
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids);
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, vectors::ordToDoc);
kmeans.cluster(vectors, kMeansIntermediate);
// TODO: consider adding cluster size counts to the kmeans algo
// handle assignment here so we can track distance and cluster size
int[] centroidVectorCount = new int[centroids.length];
float[][] nextCentroids = new float[centroids.length][dimension];
for (int i = 0; i < vectors.size(); i++) {
float smallest = Float.MAX_VALUE;
int centroidIdx = -1;
float[] vector = vectors.vectorValue(i);
for (int j = 0; j < centroids.length; j++) {
float[] centroid = centroids[j];
float d = VectorUtil.squareDistance(vector, centroid);
if (d < smallest) {
smallest = d;
centroidIdx = j;
}
}
centroidVectorCount[centroidIdx]++;
for (int j = 0; j < dimension; j++) {
nextCentroids[centroidIdx][j] += vector[j];
}
assignments[i] = centroidIdx;
}
// update centroids based on assignments of all vectors
for (int i = 0; i < centroids.length; i++) {
if (centroidVectorCount[i] > 0) {
for (int j = 0; j < dimension; j++) {
centroids[i][j] = nextCentroids[i][j] / centroidVectorCount[i];
}
}
for (int assigment : assignments) {
centroidVectorCount[assigment]++;
}
int effectiveK = 0;
@ -131,8 +103,6 @@ public class HierarchicalKMeans {
}
}
kMeansIntermediate = new KMeansIntermediate(centroids, assignments, vectors::ordToDoc);
if (effectiveK == 1) {
return kMeansIntermediate;
}

View file

@ -31,10 +31,6 @@ class KMeansIntermediate extends KMeansResult {
this(new float[0][0], new int[0], i -> i, new int[0]);
}
KMeansIntermediate(float[][] centroids) {
this(centroids, new int[0], i -> i, new int[0]);
}
KMeansIntermediate(float[][] centroids, int[] assignments) {
this(centroids, assignments, i -> i, new int[0]);
}

View file

@ -87,17 +87,17 @@ class KMeansLocal {
for (int i = 0; i < sampleSize; i++) {
float[] vector = vectors.vectorValue(i);
int[] neighborOffsets = null;
int centroidIdx = -1;
final int assignment = assignments[i];
final int bestCentroidOffset;
if (neighborhoods != null) {
neighborOffsets = neighborhoods.get(assignments[i]);
centroidIdx = assignments[i];
bestCentroidOffset = getBestCentroidFromNeighbours(centroids, vector, assignment, neighborhoods.get(assignment));
} else {
bestCentroidOffset = getBestCentroid(centroids, vector);
}
int bestCentroidOffset = getBestCentroidOffset(centroids, vector, centroidIdx, neighborOffsets);
if (assignments[i] != bestCentroidOffset) {
if (assignment != bestCentroidOffset) {
assignments[i] = bestCentroidOffset;
changed = true;
}
assignments[i] = bestCentroidOffset;
centroidCounts[bestCentroidOffset]++;
for (int d = 0; d < dim; d++) {
nextCentroids[bestCentroidOffset][d] += vector[d];
@ -116,23 +116,28 @@ class KMeansLocal {
return changed;
}
int getBestCentroidOffset(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) {
int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) {
int bestCentroidOffset = centroidIdx;
float minDsq;
if (centroidIdx > 0 && centroidIdx < centroids.length) {
minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]);
} else {
minDsq = Float.MAX_VALUE;
assert centroidIdx >= 0 && centroidIdx < centroids.length;
float minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]);
for (int offset : centroidOffsets) {
float dsq = VectorUtil.squareDistance(vector, centroids[offset]);
if (dsq < minDsq) {
minDsq = dsq;
bestCentroidOffset = offset;
}
}
return bestCentroidOffset;
}
int k = 0;
for (int j = 0; j < centroids.length; j++) {
if (centroidOffsets == null || j == centroidOffsets[k]) {
float dsq = VectorUtil.squareDistance(vector, centroids[j]);
if (dsq < minDsq) {
minDsq = dsq;
bestCentroidOffset = j;
}
int getBestCentroid(float[][] centroids, float[] vector) {
int bestCentroidOffset = 0;
float minDsq = Float.MAX_VALUE;
for (int i = 0; i < centroids.length; i++) {
float dsq = VectorUtil.squareDistance(vector, centroids[i]);
if (dsq < minDsq) {
minDsq = dsq;
bestCentroidOffset = i;
}
}
return bestCentroidOffset;
@ -271,7 +276,8 @@ class KMeansLocal {
return;
}
int[] assignments = new int[n];
int[] assignments = kMeansIntermediate.assignments();
assert assignments.length == n;
float[][] nextCentroids = new float[centroids.length][vectors.dimension()];
for (int i = 0; i < maxIterations; i++) {
if (stepLloyd(vectors, centroids, nextCentroids, assignments, sampleSize, neighborhoods) == false) {
@ -291,7 +297,7 @@ class KMeansLocal {
* @param maxIterations the max iterations to shift centroids
*/
public static void cluster(FloatVectorValues vectors, float[][] centroids, int sampleSize, int maxIterations) throws IOException {
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids);
KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, new int[vectors.size()], vectors::ordToDoc);
KMeansLocal kMeans = new KMeansLocal(sampleSize, maxIterations);
kMeans.cluster(vectors, kMeansIntermediate);
}