From 1324ee0115bb99e0ec59864e8b45f905160f0b31 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Tue, 13 May 2025 18:47:59 -0400 Subject: [PATCH] Reapply "Adds new unexposed and experimental IVF format (#127528)" (#128005) (#128051) This reverts commit 8a17a5ed5f68073803853fbda81d6a74e8b7e695. reapplying ivf format, but with a fix. --- .../benchmark/vector/OSQScorerBenchmark.java | 2 +- .../ES91OSQVectorsScorer.java | 2 +- .../elasticsearch/simdvec/ESVectorUtil.java | 42 + .../DefaultESVectorUtilSupport.java | 12 + .../DefaultESVectorizationProvider.java | 1 + .../vectorization/ESVectorUtilSupport.java | 3 + .../ESVectorizationProvider.java | 1 + .../ESVectorizationProvider.java | 1 + .../MemorySegmentES91OSQVectorsScorer.java | 1 + .../PanamaESVectorUtilSupport.java | 43 + .../PanamaESVectorizationProvider.java | 1 + .../simdvec/ESVectorUtilTests.java | 16 + .../ES91OSQVectorScorerTests.java | 1 + server/src/main/java/module-info.java | 3 +- .../vectors/DefaultIVFVectorsReader.java | 420 ++++++++++ .../vectors/DefaultIVFVectorsWriter.java | 736 ++++++++++++++++++ .../index/codec/vectors/IVFVectorsFormat.java | 110 +++ .../index/codec/vectors/IVFVectorsReader.java | 354 +++++++++ .../index/codec/vectors/IVFVectorsWriter.java | 486 ++++++++++++ .../index/codec/vectors/NeighborQueue.java | 159 ++++ .../org.apache.lucene.codecs.KnnVectorsFormat | 1 + .../codec/vectors/IVFVectorsFormatTests.java | 65 ++ .../codec/vectors/NeighborQueueTests.java | 119 +++ 23 files changed, 2576 insertions(+), 3 deletions(-) rename libs/simdvec/src/main/java/org/elasticsearch/simdvec/{internal/vectorization => }/ES91OSQVectorsScorer.java (99%) create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/NeighborQueue.java create mode 100644 server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java create mode 100644 server/src/test/java/org/elasticsearch/index/codec/vectors/NeighborQueueTests.java diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java index 486919f10bcf..85ca13e6e875 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java @@ -17,7 +17,7 @@ import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; import org.elasticsearch.common.logging.LogConfigurator; -import org.elasticsearch.simdvec.internal.vectorization.ES91OSQVectorsScorer; +import org.elasticsearch.simdvec.ES91OSQVectorsScorer; import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorsScorer.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91OSQVectorsScorer.java similarity index 99% rename from libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorsScorer.java rename to libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91OSQVectorsScorer.java index 839e5f29a114..be55c48dbe44 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorsScorer.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91OSQVectorsScorer.java @@ -6,7 +6,7 @@ * your election, the "Elastic License 2.0", the "GNU Affero General Public * License v3.0 only", or the "Server Side Public License, v 1". */ -package org.elasticsearch.simdvec.internal.vectorization; +package org.elasticsearch.simdvec; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.IndexInput; diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java index 41bf6ff58d14..50b8e18c3d22 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java @@ -9,11 +9,13 @@ package org.elasticsearch.simdvec; +import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.BitUtil; import org.apache.lucene.util.Constants; import org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport; import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider; +import java.io.IOException; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodType; @@ -41,6 +43,10 @@ public class ESVectorUtil { private static final ESVectorUtilSupport IMPL = ESVectorizationProvider.getInstance().getVectorUtilSupport(); + public static ES91OSQVectorsScorer getES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException { + return ESVectorizationProvider.getInstance().newES91OSQVectorsScorer(input, dimension); + } + public static long ipByteBinByte(byte[] q, byte[] d) { if (q.length != d.length * B_QUERY) { throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + B_QUERY + " x " + d.length); @@ -211,4 +217,40 @@ public class ESVectorUtil { assert stats.length == 6; IMPL.centerAndCalculateOSQStatsDp(target, centroid, centered, stats); } + + /** + * Calculates the difference between two vectors and stores the result in a third vector. + * @param v1 the first vector + * @param v2 the second vector + * @param result the result vector, must be the same length as the input vectors + */ + public static void subtract(float[] v1, float[] v2, float[] result) { + if (v1.length != v2.length) { + throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + v2.length); + } + if (result.length != v1.length) { + throw new IllegalArgumentException("vector dimensions differ: " + result.length + "!=" + v1.length); + } + for (int i = 0; i < v1.length; i++) { + result[i] = v1[i] - v2[i]; + } + } + + /** + * calculates the spill-over score for a vector and a centroid, given its residual with + * its actually nearest centroid + * @param v1 the vector + * @param centroid the centroid + * @param originalResidual the residual with the actually nearest centroid + * @return the spill-over score (soar) + */ + public static float soarResidual(float[] v1, float[] centroid, float[] originalResidual) { + if (v1.length != centroid.length) { + throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + centroid.length); + } + if (originalResidual.length != v1.length) { + throw new IllegalArgumentException("vector dimensions differ: " + originalResidual.length + "!=" + v1.length); + } + return IMPL.soarResidual(v1, centroid, originalResidual); + } } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java index ce8fce7e68b7..846472876f37 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java @@ -138,6 +138,18 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport { stats[5] = centroidDot; } + @Override + public float soarResidual(float[] v1, float[] centroid, float[] originalResidual) { + assert v1.length == centroid.length; + assert v1.length == originalResidual.length; + float proj = 0; + for (int i = 0; i < v1.length; i++) { + float djk = v1[i] - centroid[i]; + proj = fma(djk, originalResidual[i], proj); + } + return proj; + } + public static int ipByteBitImpl(byte[] q, byte[] d) { return ipByteBitImpl(q, d, 0); } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java index e8ff6f83f217..51a78d3cd6c3 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java @@ -10,6 +10,7 @@ package org.elasticsearch.simdvec.internal.vectorization; import org.apache.lucene.store.IndexInput; +import org.elasticsearch.simdvec.ES91OSQVectorsScorer; import java.io.IOException; diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java index b2615c55e64e..8aa50e8c4280 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java @@ -28,4 +28,7 @@ public interface ESVectorUtilSupport { void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats); void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats); + + float soarResidual(float[] v1, float[] centroid, float[] originalResidual); + } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java index 7f4e62f156a3..8c040484c7c0 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java @@ -10,6 +10,7 @@ package org.elasticsearch.simdvec.internal.vectorization; import org.apache.lucene.store.IndexInput; +import org.elasticsearch.simdvec.ES91OSQVectorsScorer; import java.io.IOException; import java.util.Objects; diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java index af5df094659e..ea4180b59565 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java @@ -13,6 +13,7 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.Constants; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; +import org.elasticsearch.simdvec.ES91OSQVectorsScorer; import java.io.IOException; import java.util.Locale; diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java index 0bf3b8da22d2..46daa074c5e5 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java @@ -20,6 +20,7 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; +import org.elasticsearch.simdvec.ES91OSQVectorsScorer; import java.io.IOException; import java.lang.foreign.MemorySegment; diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java index cf856c5322f0..1d8f59f85567 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java @@ -367,6 +367,49 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport { return (1f - lambda) * xe * xe / norm2 + lambda * e; } + @Override + public float soarResidual(float[] v1, float[] centroid, float[] originalResidual) { + assert v1.length == centroid.length; + assert v1.length == originalResidual.length; + float proj = 0; + int i = 0; + if (v1.length > 2 * FLOAT_SPECIES.length()) { + FloatVector projVec1 = FloatVector.zero(FLOAT_SPECIES); + FloatVector projVec2 = FloatVector.zero(FLOAT_SPECIES); + int unrolledLimit = FLOAT_SPECIES.loopBound(v1.length) - FLOAT_SPECIES.length(); + for (; i < unrolledLimit; i += 2 * FLOAT_SPECIES.length()) { + // one + FloatVector v1Vec0 = FloatVector.fromArray(FLOAT_SPECIES, v1, i); + FloatVector centroidVec0 = FloatVector.fromArray(FLOAT_SPECIES, centroid, i); + FloatVector originalResidualVec0 = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i); + FloatVector djkVec0 = v1Vec0.sub(centroidVec0); + projVec1 = fma(djkVec0, originalResidualVec0, projVec1); + + // two + FloatVector v1Vec1 = FloatVector.fromArray(FLOAT_SPECIES, v1, i + FLOAT_SPECIES.length()); + FloatVector centroidVec1 = FloatVector.fromArray(FLOAT_SPECIES, centroid, i + FLOAT_SPECIES.length()); + FloatVector originalResidualVec1 = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i + FLOAT_SPECIES.length()); + FloatVector djkVec1 = v1Vec1.sub(centroidVec1); + projVec2 = fma(djkVec1, originalResidualVec1, projVec2); + } + // vector tail + for (; i < FLOAT_SPECIES.loopBound(v1.length); i += FLOAT_SPECIES.length()) { + FloatVector v1Vec = FloatVector.fromArray(FLOAT_SPECIES, v1, i); + FloatVector centroidVec = FloatVector.fromArray(FLOAT_SPECIES, centroid, i); + FloatVector originalResidualVec = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i); + FloatVector djkVec = v1Vec.sub(centroidVec); + projVec1 = fma(djkVec, originalResidualVec, projVec1); + } + proj += projVec1.add(projVec2).reduceLanes(ADD); + } + // tail + for (; i < v1.length; i++) { + float djk = v1[i] - centroid[i]; + proj = fma(djk, originalResidual[i], proj); + } + return proj; + } + private static final VectorSpecies BYTE_SPECIES_128 = ByteVector.SPECIES_128; private static final VectorSpecies BYTE_SPECIES_256 = ByteVector.SPECIES_256; diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java index c409be4fb37d..5ff8c19c90a5 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java @@ -11,6 +11,7 @@ package org.elasticsearch.simdvec.internal.vectorization; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.MemorySegmentAccessInput; +import org.elasticsearch.simdvec.ES91OSQVectorsScorer; import java.io.IOException; import java.lang.foreign.MemorySegment; diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java index 0c99fad2d3d5..abd4e3b0be04 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java @@ -268,6 +268,22 @@ public class ESVectorUtilTests extends BaseVectorizationTests { } } + public void testSoarOverspillScore() { + int size = random().nextInt(128, 512); + float deltaEps = 1e-5f * size; + var vector = new float[size]; + var centroid = new float[size]; + var preResidual = new float[size]; + for (int i = 0; i < size; ++i) { + vector[i] = random().nextFloat(); + centroid[i] = random().nextFloat(); + preResidual[i] = random().nextFloat(); + } + var expected = defaultedProvider.getVectorUtilSupport().soarResidual(vector, centroid, preResidual); + var result = defOrPanamaProvider.getVectorUtilSupport().soarResidual(vector, centroid, preResidual); + assertEquals(expected, result, deltaEps); + } + void testIpByteBinImpl(ToLongBiFunction ipByteBinFunc) { int iterations = atLeast(50); for (int i = 0; i < iterations; i++) { diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java index 53b14ae4910c..5544c0686fa5 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java @@ -16,6 +16,7 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; +import org.elasticsearch.simdvec.ES91OSQVectorsScorer; import static org.hamcrest.Matchers.lessThan; diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index 7edefe9fae58..8da4f403c29b 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -454,7 +454,8 @@ module org.elasticsearch.server { org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat, - org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; + org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat, + org.elasticsearch.index.codec.vectors.IVFVectorsFormat; provides org.apache.lucene.codecs.Codec with diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java new file mode 100644 index 000000000000..e09cf474d09e --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java @@ -0,0 +1,420 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors; + +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.NeighborQueue; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; +import org.elasticsearch.simdvec.ES91OSQVectorsScorer; +import org.elasticsearch.simdvec.ESVectorUtil; + +import java.io.IOException; +import java.util.function.IntPredicate; + +import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.QUERY_BITS; +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize; +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.transposeHalfByte; +import static org.elasticsearch.simdvec.ES91OSQVectorsScorer.BULK_SIZE; + +/** + * Default implementation of {@link IVFVectorsReader}. It scores the posting lists centroids using + * brute force and then scores the top ones using the posting list. + */ +public class DefaultIVFVectorsReader extends IVFVectorsReader { + private static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1); + + public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException { + super(state, rawVectorsReader); + } + + @Override + CentroidQueryScorer getCentroidScorer( + FieldInfo fieldInfo, + int numCentroids, + IndexInput centroids, + float[] targetQuery, + IndexInput clusters + ) throws IOException { + FieldEntry fieldEntry = fields.get(fieldInfo.number); + float[] globalCentroid = fieldEntry.globalCentroid(); + float globalCentroidDp = fieldEntry.globalCentroidDp(); + OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + byte[] quantized = new byte[targetQuery.length]; + float[] targetScratch = ArrayUtil.copyArray(targetQuery); + OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize( + targetScratch, + quantized, + (byte) 4, + globalCentroid + ); + return new CentroidQueryScorer() { + int currentCentroid = -1; + private final byte[] quantizedCentroid = new byte[fieldInfo.getVectorDimension()]; + private final float[] centroid = new float[fieldInfo.getVectorDimension()]; + private final float[] centroidCorrectiveValues = new float[3]; + private int quantizedCentroidComponentSum; + private final long centroidByteSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES; + + @Override + public int size() { + return numCentroids; + } + + @Override + public float[] centroid(int centroidOrdinal) throws IOException { + readQuantizedAndRawCentroid(centroidOrdinal); + return centroid; + } + + private void readQuantizedAndRawCentroid(int centroidOrdinal) throws IOException { + if (centroidOrdinal == currentCentroid) { + return; + } + centroids.seek(centroidOrdinal * centroidByteSize); + quantizedCentroidComponentSum = readQuantizedValue(centroids, quantizedCentroid, centroidCorrectiveValues); + centroids.seek(numCentroids * centroidByteSize + (long) Float.BYTES * quantizedCentroid.length * centroidOrdinal); + centroids.readFloats(centroid, 0, centroid.length); + currentCentroid = centroidOrdinal; + } + + @Override + public float score(int centroidOrdinal) throws IOException { + readQuantizedAndRawCentroid(centroidOrdinal); + return int4QuantizedScore( + quantized, + queryParams, + fieldInfo.getVectorDimension(), + quantizedCentroid, + centroidCorrectiveValues, + quantizedCentroidComponentSum, + globalCentroidDp, + fieldInfo.getVectorSimilarityFunction() + ); + } + }; + } + + @Override + protected FloatVectorValues getCentroids(IndexInput indexInput, int numCentroids, FieldInfo info) { + FieldEntry entry = fields.get(info.number); + if (entry == null) { + return null; + } + return new OffHeapCentroidFloatVectorValues(numCentroids, indexInput, info.getVectorDimension()); + } + + @Override + NeighborQueue scorePostingLists(FieldInfo fieldInfo, KnnCollector knnCollector, CentroidQueryScorer centroidQueryScorer, int nProbe) + throws IOException { + NeighborQueue neighborQueue = new NeighborQueue(centroidQueryScorer.size(), true); + // TODO Off heap scoring for quantized centroids? + for (int centroid = 0; centroid < centroidQueryScorer.size(); centroid++) { + neighborQueue.add(centroid, centroidQueryScorer.score(centroid)); + } + return neighborQueue; + } + + @Override + PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput indexInput, float[] target, IntPredicate needsScoring) + throws IOException { + FieldEntry entry = fields.get(fieldInfo.number); + return new MemorySegmentPostingsVisitor(target, indexInput, entry, fieldInfo, needsScoring); + } + + // TODO can we do this in off-heap blocks? + static float int4QuantizedScore( + byte[] quantizedQuery, + OptimizedScalarQuantizer.QuantizationResult queryCorrections, + int dims, + byte[] binaryCode, + float[] targetCorrections, + int targetComponentSum, + float centroidDp, + VectorSimilarityFunction similarityFunction + ) { + float qcDist = VectorUtil.int4DotProduct(quantizedQuery, binaryCode); + float ax = targetCorrections[0]; + // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary + float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE; + float ay = queryCorrections.lowerInterval(); + float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE; + float y1 = queryCorrections.quantizedComponentSum(); + float score = ax * ay * dims + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist; + if (similarityFunction == EUCLIDEAN) { + score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2 * score; + return Math.max(1 / (1f + score), 0); + } else { + // For cosine and max inner product, we need to apply the additional correction, which is + // assumed to be the non-centered dot-product between the vector and the centroid + score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp; + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + return VectorUtil.scaleMaxInnerProductScore(score); + } + return Math.max((1f + score) / 2f, 0); + } + } + + static class OffHeapCentroidFloatVectorValues extends FloatVectorValues { + private final int numCentroids; + private final IndexInput input; + private final int dimension; + private final float[] centroid; + private final long centroidByteSize; + private int ord = -1; + + OffHeapCentroidFloatVectorValues(int numCentroids, IndexInput input, int dimension) { + this.numCentroids = numCentroids; + this.input = input; + this.dimension = dimension; + this.centroid = new float[dimension]; + this.centroidByteSize = dimension + 3 * Float.BYTES + Short.BYTES; + } + + @Override + public float[] vectorValue(int ord) throws IOException { + if (ord < 0 || ord >= numCentroids) { + throw new IllegalArgumentException("ord must be in [0, " + numCentroids + "]"); + } + if (ord == this.ord) { + return centroid; + } + readQuantizedCentroid(ord); + return centroid; + } + + private void readQuantizedCentroid(int centroidOrdinal) throws IOException { + if (centroidOrdinal == ord) { + return; + } + input.seek(numCentroids * centroidByteSize + (long) Float.BYTES * dimension * centroidOrdinal); + input.readFloats(centroid, 0, centroid.length); + ord = centroidOrdinal; + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return numCentroids; + } + + @Override + public FloatVectorValues copy() throws IOException { + return new OffHeapCentroidFloatVectorValues(numCentroids, input.clone(), dimension); + } + } + + private static class MemorySegmentPostingsVisitor implements PostingVisitor { + final long quantizedByteLength; + final IndexInput indexInput; + final float[] target; + final FieldEntry entry; + final FieldInfo fieldInfo; + final IntPredicate needsScoring; + private final ES91OSQVectorsScorer osqVectorsScorer; + final float[] scores = new float[BULK_SIZE]; + final float[] correctionsLower = new float[BULK_SIZE]; + final float[] correctionsUpper = new float[BULK_SIZE]; + final int[] correctionsSum = new int[BULK_SIZE]; + final float[] correctionsAdd = new float[BULK_SIZE]; + + int[] docIdsScratch = new int[0]; + int vectors; + boolean quantized = false; + float centroidDp; + float[] centroid; + long slicePos; + OptimizedScalarQuantizer.QuantizationResult queryCorrections; + DocIdsWriter docIdsWriter = new DocIdsWriter(); + + final float[] scratch; + final byte[] quantizationScratch; + final byte[] quantizedQueryScratch; + final OptimizedScalarQuantizer quantizer; + final float[] correctiveValues = new float[3]; + final long quantizedVectorByteSize; + + MemorySegmentPostingsVisitor( + float[] target, + IndexInput indexInput, + FieldEntry entry, + FieldInfo fieldInfo, + IntPredicate needsScoring + ) throws IOException { + this.target = target; + this.indexInput = indexInput; + this.entry = entry; + this.fieldInfo = fieldInfo; + this.needsScoring = needsScoring; + + scratch = new float[target.length]; + quantizationScratch = new byte[target.length]; + final int discretizedDimensions = discretize(fieldInfo.getVectorDimension(), 64); + quantizedQueryScratch = new byte[QUERY_BITS * discretizedDimensions / 8]; + quantizedByteLength = discretizedDimensions / 8 + (Float.BYTES * 3) + Short.BYTES; + quantizedVectorByteSize = (discretizedDimensions / 8); + quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + osqVectorsScorer = ESVectorUtil.getES91OSQVectorsScorer(indexInput, fieldInfo.getVectorDimension()); + } + + @Override + public int resetPostingsScorer(int centroidOrdinal, float[] centroid) throws IOException { + quantized = false; + indexInput.seek(entry.postingListOffsets()[centroidOrdinal]); + vectors = indexInput.readVInt(); + centroidDp = Float.intBitsToFloat(indexInput.readInt()); + this.centroid = centroid; + // read the doc ids + docIdsScratch = vectors > docIdsScratch.length ? new int[vectors] : docIdsScratch; + docIdsWriter.readInts(indexInput, vectors, docIdsScratch); + slicePos = indexInput.getFilePointer(); + return vectors; + } + + void scoreIndividually(int offset) throws IOException { + // score individually, first the quantized byte chunk + for (int j = 0; j < BULK_SIZE; j++) { + int doc = docIdsScratch[j + offset]; + if (doc != -1) { + indexInput.seek(slicePos + (offset * quantizedByteLength) + (j * quantizedVectorByteSize)); + float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch); + scores[j] = qcDist; + } + } + // read in all corrections + indexInput.seek(slicePos + (offset * quantizedByteLength) + (BULK_SIZE * quantizedVectorByteSize)); + indexInput.readFloats(correctionsLower, 0, BULK_SIZE); + indexInput.readFloats(correctionsUpper, 0, BULK_SIZE); + for (int j = 0; j < BULK_SIZE; j++) { + correctionsSum[j] = Short.toUnsignedInt(indexInput.readShort()); + } + indexInput.readFloats(correctionsAdd, 0, BULK_SIZE); + // Now apply corrections + for (int j = 0; j < BULK_SIZE; j++) { + int doc = docIdsScratch[offset + j]; + if (doc != -1) { + scores[j] = osqVectorsScorer.score( + queryCorrections, + fieldInfo.getVectorSimilarityFunction(), + centroidDp, + correctionsLower[j], + correctionsUpper[j], + correctionsSum[j], + correctionsAdd[j], + scores[j] + ); + } + } + } + + @Override + public int visit(KnnCollector knnCollector) throws IOException { + // block processing + int scoredDocs = 0; + int limit = vectors - BULK_SIZE + 1; + int i = 0; + for (; i < limit; i += BULK_SIZE) { + int docsToScore = BULK_SIZE; + for (int j = 0; j < BULK_SIZE; j++) { + int doc = docIdsScratch[i + j]; + if (needsScoring.test(doc) == false) { + docIdsScratch[i + j] = -1; + docsToScore--; + } + } + if (docsToScore == 0) { + continue; + } + quantizeQueryIfNecessary(); + indexInput.seek(slicePos + i * quantizedByteLength); + if (docsToScore < BULK_SIZE / 2) { + scoreIndividually(i); + } else { + osqVectorsScorer.scoreBulk( + quantizedQueryScratch, + queryCorrections, + fieldInfo.getVectorSimilarityFunction(), + centroidDp, + scores + ); + } + for (int j = 0; j < BULK_SIZE; j++) { + int doc = docIdsScratch[i + j]; + if (doc != -1) { + scoredDocs++; + knnCollector.collect(doc, scores[j]); + } + } + } + // process tail + for (; i < vectors; i++) { + int doc = docIdsScratch[i]; + if (needsScoring.test(doc)) { + quantizeQueryIfNecessary(); + indexInput.seek(slicePos + i * quantizedByteLength); + float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch); + indexInput.readFloats(correctiveValues, 0, 3); + final int quantizedComponentSum = Short.toUnsignedInt(indexInput.readShort()); + float score = osqVectorsScorer.score( + queryCorrections, + fieldInfo.getVectorSimilarityFunction(), + centroidDp, + correctiveValues[0], + correctiveValues[1], + quantizedComponentSum, + correctiveValues[2], + qcDist + ); + scoredDocs++; + knnCollector.collect(doc, score); + } + } + if (scoredDocs > 0) { + knnCollector.incVisitedCount(scoredDocs); + } + return scoredDocs; + } + + private void quantizeQueryIfNecessary() { + if (quantized == false) { + System.arraycopy(target, 0, scratch, 0, target.length); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + VectorUtil.l2normalize(scratch); + } + queryCorrections = quantizer.scalarQuantize(scratch, quantizationScratch, (byte) 4, centroid); + transposeHalfByte(quantizationScratch, quantizedQueryScratch); + quantized = true; + } + } + } + + static int readQuantizedValue(IndexInput indexInput, byte[] binaryValue, float[] corrections) throws IOException { + assert corrections.length == 3; + indexInput.readBytes(binaryValue, 0, binaryValue.length); + corrections[0] = Float.intBitsToFloat(indexInput.readInt()); + corrections[1] = Float.intBitsToFloat(indexInput.readInt()); + corrections[2] = Float.intBitsToFloat(indexInput.readInt()); + return Short.toUnsignedInt(indexInput.readShort()); + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java new file mode 100644 index 000000000000..1c431b01e611 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java @@ -0,0 +1,736 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors; + +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.IntArrayList; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.InfoStream; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; +import org.elasticsearch.simdvec.ES91OSQVectorsScorer; +import org.elasticsearch.simdvec.ESVectorUtil; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.List; + +import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.INDEX_BITS; +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize; +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.packAsBinary; +import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.IVF_VECTOR_COMPONENT; + +/** + * Default implementation of {@link IVFVectorsWriter}. It uses {@link KMeans} algorithm to + * partition the vector space, and then stores the centroids an posting list in a sequential + * fashion. + */ +public class DefaultIVFVectorsWriter extends IVFVectorsWriter { + + static final float SOAR_LAMBDA = 1.0f; + // What percentage of the centroids do we do a second check on for SOAR assignment + static final float EXT_SOAR_LIMIT_CHECK_RATIO = 0.10f; + + private final int vectorPerCluster; + + public DefaultIVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVectorDelegate, int vectorPerCluster) throws IOException { + super(state, rawVectorDelegate); + this.vectorPerCluster = vectorPerCluster; + } + + @Override + CentroidAssignmentScorer calculateAndWriteCentroids( + FieldInfo fieldInfo, + FloatVectorValues floatVectorValues, + IndexOutput centroidOutput, + float[] globalCentroid + ) throws IOException { + if (floatVectorValues.size() == 0) { + return CentroidAssignmentScorer.EMPTY; + } + // calculate the centroids + int maxNumClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1; + int desiredClusters = (int) Math.max(Math.sqrt(floatVectorValues.size()), maxNumClusters); + final KMeans.Results kMeans = KMeans.cluster( + floatVectorValues, + desiredClusters, + false, + 42L, + KMeans.KmeansInitializationMethod.PLUS_PLUS, + null, + fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE, + 1, + 15, + desiredClusters * 256 + ); + float[][] centroids = kMeans.centroids(); + // write them + writeCentroids(centroids, fieldInfo, globalCentroid, centroidOutput); + return new OnHeapCentroidAssignmentScorer(centroids); + } + + @Override + long[] buildAndWritePostingsLists( + FieldInfo fieldInfo, + InfoStream infoStream, + CentroidAssignmentScorer randomCentroidScorer, + FloatVectorValues floatVectorValues, + IndexOutput postingsOutput + ) throws IOException { + IntArrayList[] clusters = new IntArrayList[randomCentroidScorer.size()]; + for (int i = 0; i < randomCentroidScorer.size(); i++) { + clusters[i] = new IntArrayList(floatVectorValues.size() / randomCentroidScorer.size() / 4); + } + assignCentroids(randomCentroidScorer, floatVectorValues, clusters); + if (infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { + printClusterQualityStatistics(clusters, infoStream); + } + // write the posting lists + final long[] offsets = new long[randomCentroidScorer.size()]; + OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + BinarizedFloatVectorValues binarizedByteVectorValues = new BinarizedFloatVectorValues(floatVectorValues, quantizer); + DocIdsWriter docIdsWriter = new DocIdsWriter(); + for (int i = 0; i < randomCentroidScorer.size(); i++) { + float[] centroid = randomCentroidScorer.centroid(i); + binarizedByteVectorValues.centroid = centroid; + // TODO sort by distance to the centroid + IntArrayList cluster = clusters[i]; + // TODO align??? + offsets[i] = postingsOutput.getFilePointer(); + int size = cluster.size(); + postingsOutput.writeVInt(size); + postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); + // TODO we might want to consider putting the docIds in a separate file + // to aid with only having to fetch vectors from slower storage when they are required + // keeping them in the same file indicates we pull the entire file into cache + docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster.get(j)), cluster.size(), postingsOutput); + writePostingList(cluster, postingsOutput, binarizedByteVectorValues); + } + return offsets; + } + + private void writePostingList(IntArrayList cluster, IndexOutput postingsOutput, BinarizedFloatVectorValues binarizedByteVectorValues) + throws IOException { + int limit = cluster.size() - ES91OSQVectorsScorer.BULK_SIZE + 1; + int cidx = 0; + OptimizedScalarQuantizer.QuantizationResult[] corrections = + new OptimizedScalarQuantizer.QuantizationResult[ES91OSQVectorsScorer.BULK_SIZE]; + // Write vectors in bulks of ES91OSQVectorsScorer.BULK_SIZE. + for (; cidx < limit; cidx += ES91OSQVectorsScorer.BULK_SIZE) { + for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { + int ord = cluster.get(cidx + j); + byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord); + // write vector + postingsOutput.writeBytes(binaryValue, 0, binaryValue.length); + corrections[j] = binarizedByteVectorValues.getCorrectiveTerms(ord); + } + // write corrections + for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { + postingsOutput.writeInt(Float.floatToIntBits(corrections[j].lowerInterval())); + } + for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { + postingsOutput.writeInt(Float.floatToIntBits(corrections[j].upperInterval())); + } + for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { + int targetComponentSum = corrections[j].quantizedComponentSum(); + assert targetComponentSum >= 0 && targetComponentSum <= 0xffff; + postingsOutput.writeShort((short) targetComponentSum); + } + for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { + postingsOutput.writeInt(Float.floatToIntBits(corrections[j].additionalCorrection())); + } + } + // write tail + for (; cidx < cluster.size(); cidx++) { + int ord = cluster.get(cidx); + // write vector + byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord); + OptimizedScalarQuantizer.QuantizationResult correction = binarizedByteVectorValues.getCorrectiveTerms(ord); + writeQuantizedValue(postingsOutput, binaryValue, correction); + binarizedByteVectorValues.getCorrectiveTerms(ord); + postingsOutput.writeBytes(binaryValue, 0, binaryValue.length); + postingsOutput.writeInt(Float.floatToIntBits(correction.lowerInterval())); + postingsOutput.writeInt(Float.floatToIntBits(correction.upperInterval())); + postingsOutput.writeInt(Float.floatToIntBits(correction.additionalCorrection())); + assert correction.quantizedComponentSum() >= 0 && correction.quantizedComponentSum() <= 0xffff; + postingsOutput.writeShort((short) correction.quantizedComponentSum()); + } + } + + @Override + CentroidAssignmentScorer createCentroidScorer( + IndexInput centroidsInput, + int numCentroids, + FieldInfo fieldInfo, + float[] globalCentroid + ) { + return new OffHeapCentroidAssignmentScorer(centroidsInput, numCentroids, fieldInfo); + } + + static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] globalCentroid, IndexOutput centroidOutput) + throws IOException { + final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + byte[] quantizedScratch = new byte[fieldInfo.getVectorDimension()]; + float[] centroidScratch = new float[fieldInfo.getVectorDimension()]; + // TODO do we want to store these distances as well for future use? + float[] distances = new float[centroids.length]; + for (int i = 0; i < centroids.length; i++) { + distances[i] = VectorUtil.squareDistance(centroids[i], globalCentroid); + } + // sort the centroids by distance to globalCentroid, nearest (smallest distance), to furthest + // (largest) + for (int i = 0; i < centroids.length; i++) { + for (int j = i + 1; j < centroids.length; j++) { + if (distances[i] > distances[j]) { + float[] tmp = centroids[i]; + centroids[i] = centroids[j]; + centroids[j] = tmp; + float tmpDistance = distances[i]; + distances[i] = distances[j]; + distances[j] = tmpDistance; + } + } + } + for (float[] centroid : centroids) { + System.arraycopy(centroid, 0, centroidScratch, 0, centroid.length); + OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize( + centroidScratch, + quantizedScratch, + (byte) 4, + globalCentroid + ); + writeQuantizedValue(centroidOutput, quantizedScratch, result); + } + final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (float[] centroid : centroids) { + buffer.asFloatBuffer().put(centroid); + centroidOutput.writeBytes(buffer.array(), buffer.array().length); + } + } + + static float[][] gatherInitCentroids( + List centroidList, + List segmentCentroids, + int desiredClusters, + FieldInfo fieldInfo, + MergeState mergeState + ) throws IOException { + if (centroidList.size() == 0) { + return null; + } + long startTime = System.nanoTime(); + // sort centroid list by floatvector size + FloatVectorValues baseSegment = centroidList.get(0); + for (var l : centroidList) { + if (l.size() > baseSegment.size()) { + baseSegment = l; + } + } + float[] scratch = new float[fieldInfo.getVectorDimension()]; + float minimumDistance = Float.MAX_VALUE; + for (int j = 0; j < baseSegment.size(); j++) { + System.arraycopy(baseSegment.vectorValue(j), 0, scratch, 0, baseSegment.dimension()); + for (int k = j + 1; k < baseSegment.size(); k++) { + float d = VectorUtil.squareDistance(scratch, baseSegment.vectorValue(k)); + if (d < minimumDistance) { + minimumDistance = d; + } + } + } + if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { + mergeState.infoStream.message( + IVF_VECTOR_COMPONENT, + "Agglomerative cluster min distance: " + minimumDistance + " From biggest segment: " + baseSegment.size() + ); + } + int[] labels = new int[segmentCentroids.size()]; + // loop over segments + int clusterIdx = 0; + // keep track of all inter-centroid distances, + // using less than centroid * centroid space (e.g. not keeping track of duplicates) + for (int i = 0; i < segmentCentroids.size(); i++) { + if (labels[i] == 0) { + clusterIdx += 1; + labels[i] = clusterIdx; + } + SegmentCentroid segmentCentroid = segmentCentroids.get(i); + System.arraycopy( + centroidList.get(segmentCentroid.segment()).vectorValue(segmentCentroid.centroid), + 0, + scratch, + 0, + baseSegment.dimension() + ); + for (int j = i + 1; j < segmentCentroids.size(); j++) { + float d = VectorUtil.squareDistance( + scratch, + centroidList.get(segmentCentroids.get(j).segment()).vectorValue(segmentCentroids.get(j).centroid()) + ); + if (d < minimumDistance / 2) { + if (labels[j] == 0) { + labels[j] = labels[i]; + } else { + for (int k = 0; k < labels.length; k++) { + if (labels[k] == labels[j]) { + labels[k] = labels[i]; + } + } + } + } + } + } + float[][] initCentroids = new float[clusterIdx][fieldInfo.getVectorDimension()]; + int[] sum = new int[clusterIdx]; + for (int i = 0; i < segmentCentroids.size(); i++) { + SegmentCentroid segmentCentroid = segmentCentroids.get(i); + int label = labels[i]; + FloatVectorValues segment = centroidList.get(segmentCentroid.segment()); + float[] vector = segment.vectorValue(segmentCentroid.centroid); + for (int j = 0; j < vector.length; j++) { + initCentroids[label - 1][j] += (vector[j] * segmentCentroid.centroidSize); + } + sum[label - 1] += segmentCentroid.centroidSize; + } + for (int i = 0; i < initCentroids.length; i++) { + if (sum[i] == 0 || sum[i] == 1) { + continue; + } + for (int j = 0; j < initCentroids[i].length; j++) { + initCentroids[i][j] /= sum[i]; + } + } + if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { + mergeState.infoStream.message( + IVF_VECTOR_COMPONENT, + "Agglomerative cluster time ms: " + ((System.nanoTime() - startTime) / 1000000.0) + ); + mergeState.infoStream.message( + IVF_VECTOR_COMPONENT, + "Gathered initCentroids:" + initCentroids.length + " for desired: " + desiredClusters + ); + } + return initCentroids; + } + + record SegmentCentroid(int segment, int centroid, int centroidSize) {} + + /** + * Calculate the centroids for the given field and write them to the given + * temporary centroid output. + * When merging, we first bootstrap the KMeans algorithm with the centroids contained in the merging segments. + * To prevent centroids that are too similar from having an outsized impact, all centroids that are closer than + * the largest segments intra-cluster distance are merged into a single centroid. + * The resulting centroids are then used to initialize the KMeans algorithm. + * + * @param fieldInfo merging field info + * @param floatVectorValues the float vector values to merge + * @param temporaryCentroidOutput the temporary centroid output + * @param mergeState the merge state + * @param globalCentroid the global centroid, calculated by this method and used to quantize the centroids + * @return the number of centroids written + * @throws IOException if an I/O error occurs + */ + @Override + protected int calculateAndWriteCentroids( + FieldInfo fieldInfo, + FloatVectorValues floatVectorValues, + IndexOutput temporaryCentroidOutput, + MergeState mergeState, + float[] globalCentroid + ) throws IOException { + if (floatVectorValues.size() == 0) { + return 0; + } + int maxNumClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1; + int desiredClusters = (int) Math.max(Math.sqrt(floatVectorValues.size()), maxNumClusters); + // init centroids from merge state + List centroidList = new ArrayList<>(); + List segmentCentroids = new ArrayList<>(desiredClusters); + + int segmentIdx = 0; + for (var reader : mergeState.knnVectorsReaders) { + IVFVectorsReader ivfVectorsReader = IVFVectorsFormat.getIVFReader(reader, fieldInfo.name); + if (ivfVectorsReader == null) { + continue; + } + + FloatVectorValues centroid = ivfVectorsReader.getCentroids(fieldInfo); + if (centroid == null) { + continue; + } + centroidList.add(centroid); + for (int i = 0; i < centroid.size(); i++) { + int size = ivfVectorsReader.centroidSize(fieldInfo.name, i); + if (size == 0) { + continue; + } + segmentCentroids.add(new SegmentCentroid(segmentIdx, i, size)); + } + segmentIdx++; + } + + float[][] initCentroids = gatherInitCentroids(centroidList, segmentCentroids, desiredClusters, fieldInfo, mergeState); + + // FIXME: run a custom version of KMeans that is just better... + long nanoTime = System.nanoTime(); + final KMeans.Results kMeans = KMeans.cluster( + floatVectorValues, + desiredClusters, + false, + 42L, + KMeans.KmeansInitializationMethod.PLUS_PLUS, + initCentroids, + fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE, + 1, + 5, + desiredClusters * 64 + ); + if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { + mergeState.infoStream.message(IVF_VECTOR_COMPONENT, "KMeans time ms: " + ((System.nanoTime() - nanoTime) / 1000000.0)); + } + float[][] centroids = kMeans.centroids(); + + // write them + // calculate the global centroid from all the centroids: + for (float[] centroid : centroids) { + for (int j = 0; j < centroid.length; j++) { + globalCentroid[j] += centroid[j]; + } + } + for (int j = 0; j < globalCentroid.length; j++) { + globalCentroid[j] /= centroids.length; + } + writeCentroids(centroids, fieldInfo, globalCentroid, temporaryCentroidOutput); + return centroids.length; + } + + @Override + long[] buildAndWritePostingsLists( + FieldInfo fieldInfo, + CentroidAssignmentScorer centroidAssignmentScorer, + FloatVectorValues floatVectorValues, + IndexOutput postingsOutput, + MergeState mergeState + ) throws IOException { + IntArrayList[] clusters = new IntArrayList[centroidAssignmentScorer.size()]; + for (int i = 0; i < centroidAssignmentScorer.size(); i++) { + clusters[i] = new IntArrayList(floatVectorValues.size() / centroidAssignmentScorer.size() / 4); + } + long nanoTime = System.nanoTime(); + // Can we do a pre-filter by finding the nearest centroids to the original vector centroids? + // We need to be careful on vecOrd vs. doc as we need random access to the raw vector for posting list writing + assignCentroids(centroidAssignmentScorer, floatVectorValues, clusters); + if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { + mergeState.infoStream.message(IVF_VECTOR_COMPONENT, "assignCentroids time ms: " + ((System.nanoTime() - nanoTime) / 1000000.0)); + } + + if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { + printClusterQualityStatistics(clusters, mergeState.infoStream); + } + // write the posting lists + final long[] offsets = new long[centroidAssignmentScorer.size()]; + OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + BinarizedFloatVectorValues binarizedByteVectorValues = new BinarizedFloatVectorValues(floatVectorValues, quantizer); + DocIdsWriter docIdsWriter = new DocIdsWriter(); + for (int i = 0; i < centroidAssignmentScorer.size(); i++) { + float[] centroid = centroidAssignmentScorer.centroid(i); + binarizedByteVectorValues.centroid = centroid; + // TODO: sort by distance to the centroid + IntArrayList cluster = clusters[i]; + // TODO align??? + offsets[i] = postingsOutput.getFilePointer(); + int size = cluster.size(); + postingsOutput.writeVInt(size); + postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); + // TODO we might want to consider putting the docIds in a separate file + // to aid with only having to fetch vectors from slower storage when they are required + // keeping them in the same file indicates we pull the entire file into cache + docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster.get(j)), size, postingsOutput); + writePostingList(cluster, postingsOutput, binarizedByteVectorValues); + } + return offsets; + } + + private static void printClusterQualityStatistics(IntArrayList[] clusters, InfoStream infoStream) { + float min = Float.MAX_VALUE; + float max = Float.MIN_VALUE; + float mean = 0; + float m2 = 0; + // iteratively compute the variance & mean + int count = 0; + for (IntArrayList cluster : clusters) { + count += 1; + if (cluster == null) { + continue; + } + float delta = cluster.size() - mean; + mean += delta / count; + m2 += delta * (cluster.size() - mean); + min = Math.min(min, cluster.size()); + max = Math.max(max, cluster.size()); + } + float variance = m2 / (clusters.length - 1); + infoStream.message( + IVF_VECTOR_COMPONENT, + "Centroid count: " + + clusters.length + + " min: " + + min + + " max: " + + max + + " mean: " + + mean + + " stdDev: " + + Math.sqrt(variance) + + " variance: " + + variance + ); + } + + static void assignCentroids(CentroidAssignmentScorer scorer, FloatVectorValues vectors, IntArrayList[] clusters) throws IOException { + int numCentroids = scorer.size(); + // we at most will look at the EXT_SOAR_LIMIT_CHECK_RATIO nearest centroids if possible + int soarToCheck = (int) (numCentroids * EXT_SOAR_LIMIT_CHECK_RATIO); + int soarClusterCheckCount = Math.min(numCentroids - 1, soarToCheck); + NeighborQueue neighborsToCheck = new NeighborQueue(soarClusterCheckCount + 1, true); + OrdScoreIterator ordScoreIterator = new OrdScoreIterator(soarClusterCheckCount + 1); + float[] scratch = new float[vectors.dimension()]; + for (int docID = 0; docID < vectors.size(); docID++) { + float[] vector = vectors.vectorValue(docID); + scorer.setScoringVector(vector); + int bestCentroid = 0; + float bestScore = Float.MAX_VALUE; + if (numCentroids > 1) { + for (short c = 0; c < numCentroids; c++) { + float squareDist = scorer.score(c); + neighborsToCheck.insertWithOverflow(c, squareDist); + } + // pop the best + int sz = neighborsToCheck.size(); + int best = neighborsToCheck.consumeNodesAndScoresMin(ordScoreIterator.ords, ordScoreIterator.scores); + // Set the size to the number of neighbors we actually found + ordScoreIterator.setSize(sz); + bestScore = ordScoreIterator.getScore(best); + bestCentroid = ordScoreIterator.getOrd(best); + } + clusters[bestCentroid].add(docID); + if (soarClusterCheckCount > 0) { + assignCentroidSOAR( + ordScoreIterator, + docID, + bestCentroid, + scorer.centroid(bestCentroid), + bestScore, + scratch, + scorer, + vector, + clusters + ); + } + neighborsToCheck.clear(); + } + } + + static void assignCentroidSOAR( + OrdScoreIterator centroidsToCheck, + int vecOrd, + int bestCentroidId, + float[] bestCentroid, + float bestScore, + float[] scratch, + CentroidAssignmentScorer scorer, + float[] vector, + IntArrayList[] clusters + ) throws IOException { + ESVectorUtil.subtract(vector, bestCentroid, scratch); + int bestSecondaryCentroid = -1; + float minDist = Float.MAX_VALUE; + for (int i = 0; i < centroidsToCheck.size(); i++) { + float score = centroidsToCheck.getScore(i); + int centroidOrdinal = centroidsToCheck.getOrd(i); + if (centroidOrdinal == bestCentroidId) { + continue; + } + float proj = ESVectorUtil.soarResidual(vector, scorer.centroid(centroidOrdinal), scratch); + score += SOAR_LAMBDA * proj * proj / bestScore; + if (score < minDist) { + bestSecondaryCentroid = centroidOrdinal; + minDist = score; + } + } + if (bestSecondaryCentroid != -1) { + clusters[bestSecondaryCentroid].add(vecOrd); + } + } + + static class OrdScoreIterator { + private final int[] ords; + private final float[] scores; + private int idx = 0; + + OrdScoreIterator(int size) { + this.ords = new int[size]; + this.scores = new float[size]; + } + + int setSize(int size) { + if (size > ords.length) { + throw new IllegalArgumentException("size must be <= " + ords.length); + } + this.idx = size; + return size; + } + + int getOrd(int idx) { + return ords[idx]; + } + + float getScore(int idx) { + return scores[idx]; + } + + int size() { + return idx; + } + } + + // TODO unify with OSQ format + static class BinarizedFloatVectorValues { + private OptimizedScalarQuantizer.QuantizationResult corrections; + private final byte[] binarized; + private final byte[] initQuantized; + private float[] centroid; + private final FloatVectorValues values; + private final OptimizedScalarQuantizer quantizer; + + private int lastOrd = -1; + + BinarizedFloatVectorValues(FloatVectorValues delegate, OptimizedScalarQuantizer quantizer) { + this.values = delegate; + this.quantizer = quantizer; + this.binarized = new byte[discretize(delegate.dimension(), 64) / 8]; + this.initQuantized = new byte[delegate.dimension()]; + } + + public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) { + if (ord != lastOrd) { + throw new IllegalStateException( + "attempt to retrieve corrective terms for different ord " + ord + " than the quantization was done for: " + lastOrd + ); + } + return corrections; + } + + public byte[] vectorValue(int ord) throws IOException { + if (ord != lastOrd) { + binarize(ord); + lastOrd = ord; + } + return binarized; + } + + private void binarize(int ord) throws IOException { + corrections = quantizer.scalarQuantize(values.vectorValue(ord), initQuantized, INDEX_BITS, centroid); + packAsBinary(initQuantized, binarized); + } + } + + static class OffHeapCentroidAssignmentScorer implements CentroidAssignmentScorer { + private final IndexInput centroidsInput; + private final int numCentroids; + private final int dimension; + private final float[] scratch; + private float[] q; + private final long rawCentroidOffset; + private int currOrd = -1; + + OffHeapCentroidAssignmentScorer(IndexInput centroidsInput, int numCentroids, FieldInfo info) { + this.centroidsInput = centroidsInput; + this.numCentroids = numCentroids; + this.dimension = info.getVectorDimension(); + this.scratch = new float[dimension]; + this.rawCentroidOffset = (dimension + 3 * Float.BYTES + Short.BYTES) * numCentroids; + } + + @Override + public int size() { + return numCentroids; + } + + @Override + public float[] centroid(int centroidOrdinal) throws IOException { + if (centroidOrdinal == currOrd) { + return scratch; + } + centroidsInput.seek(rawCentroidOffset + (long) centroidOrdinal * dimension * Float.BYTES); + centroidsInput.readFloats(scratch, 0, dimension); + this.currOrd = centroidOrdinal; + return scratch; + } + + @Override + public void setScoringVector(float[] vector) { + q = vector; + } + + @Override + public float score(int centroidOrdinal) throws IOException { + return VectorUtil.squareDistance(centroid(centroidOrdinal), q); + } + } + + // TODO throw away rawCentroids + static class OnHeapCentroidAssignmentScorer implements CentroidAssignmentScorer { + private final float[][] centroids; + private float[] q; + + OnHeapCentroidAssignmentScorer(float[][] centroids) { + this.centroids = centroids; + } + + @Override + public int size() { + return centroids.length; + } + + @Override + public void setScoringVector(float[] vector) { + q = vector; + } + + @Override + public float[] centroid(int centroidOrdinal) throws IOException { + return centroids[centroidOrdinal]; + } + + @Override + public float score(int centroidOrdinal) throws IOException { + return VectorUtil.squareDistance(centroid(centroidOrdinal), q); + } + } + + static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections) + throws IOException { + indexOutput.writeBytes(binaryValue, binaryValue.length); + indexOutput.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + indexOutput.writeInt(Float.floatToIntBits(corrections.upperInterval())); + indexOutput.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff; + indexOutput.writeShort((short) corrections.quantizedComponentSum()); + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java new file mode 100644 index 000000000000..f124e978116e --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java @@ -0,0 +1,110 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; + +import java.io.IOException; + +/** + * Codec format for Inverted File Vector indexes. This index expects to break the dimensional space + * into clusters and assign each vector to a cluster generating a posting list of vectors. Clusters + * are represented by centroids. + * The vector quantization format used here is a per-vector optimized scalar quantization. Also see {@link + * OptimizedScalarQuantizer}. Some of key features are: + * + * The format is stored in three files: + * + *

.cenivf (centroid data) file

+ *

Which stores the raw and quantized centroid vectors. + * + *

.clivf (cluster data) file

+ * + *

Stores the quantized vectors for each cluster, inline and stored in blocks. Additionally, the docIds of + * each vector is stored. + * + *

.mivf (centroid metadata) file

+ * + *

Stores metadata including the number of centroids and their offsets in the clivf file

+ * + */ +public class IVFVectorsFormat extends KnnVectorsFormat { + + public static final String IVF_VECTOR_COMPONENT = "IVF"; + public static final String NAME = "IVFVectorsFormat"; + // centroid ordinals -> centroid values, offsets + public static final String CENTROID_EXTENSION = "cenivf"; + // offsets contained in cen_ivf, [vector ordinals, actually just docIds](long varint), quantized + // vectors (OSQ bit) + public static final String CLUSTER_EXTENSION = "clivf"; + static final String IVF_META_EXTENSION = "mivf"; + + public static final int VERSION_START = 0; + public static final int VERSION_CURRENT = VERSION_START; + + private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); + + private static final int DEFAULT_VECTORS_PER_CLUSTER = 1000; + + private final int vectorPerCluster; + + public IVFVectorsFormat(int vectorPerCluster) { + super(NAME); + if (vectorPerCluster <= 0) { + throw new IllegalArgumentException("vectorPerCluster must be > 0"); + } + this.vectorPerCluster = vectorPerCluster; + } + + /** Constructs a format using the given graph construction parameters and scalar quantization. */ + public IVFVectorsFormat() { + this(DEFAULT_VECTORS_PER_CLUSTER); + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new DefaultIVFVectorsWriter(state, rawVectorFormat.fieldsWriter(state), vectorPerCluster); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new DefaultIVFVectorsReader(state, rawVectorFormat.fieldsReader(state)); + } + + @Override + public int getMaxDimensions(String fieldName) { + return 1024; + } + + @Override + public String toString() { + return "IVFVectorFormat"; + } + + static IVFVectorsReader getIVFReader(KnnVectorsReader vectorsReader, String fieldName) { + if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) { + vectorsReader = candidateReader.getFieldReader(fieldName); + } + if (vectorsReader instanceof IVFVectorsReader reader) { + return reader; + } + return null; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java new file mode 100644 index 000000000000..4319c7c47c82 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java @@ -0,0 +1,354 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors; + +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.CorruptIndexException; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.IntObjectHashMap; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.DataInput; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.FixedBitSet; +import org.apache.lucene.util.hnsw.NeighborQueue; +import org.elasticsearch.core.IOUtils; + +import java.io.IOException; +import java.util.function.IntPredicate; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS; + +/** + * Reader for IVF vectors. This reader is used to read the IVF vectors from the index. + */ +public abstract class IVFVectorsReader extends KnnVectorsReader { + + private final IndexInput ivfCentroids, ivfClusters; + private final SegmentReadState state; + private final FieldInfos fieldInfos; + protected final IntObjectHashMap fields; + private final FlatVectorsReader rawVectorsReader; + + protected IVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException { + this.state = state; + this.fieldInfos = state.fieldInfos; + this.rawVectorsReader = rawVectorsReader; + this.fields = new IntObjectHashMap<>(); + String meta = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, IVFVectorsFormat.IVF_META_EXTENSION); + + int versionMeta = -1; + boolean success = false; + try (ChecksumIndexInput ivfMeta = state.directory.openChecksumInput(meta)) { + Throwable priorE = null; + try { + versionMeta = CodecUtil.checkIndexHeader( + ivfMeta, + IVFVectorsFormat.NAME, + IVFVectorsFormat.VERSION_START, + IVFVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + readFields(ivfMeta); + } catch (Throwable exception) { + priorE = exception; + } finally { + CodecUtil.checkFooter(ivfMeta, priorE); + } + ivfCentroids = openDataInput(state, versionMeta, IVFVectorsFormat.CENTROID_EXTENSION, IVFVectorsFormat.NAME, state.context); + ivfClusters = openDataInput(state, versionMeta, IVFVectorsFormat.CLUSTER_EXTENSION, IVFVectorsFormat.NAME, state.context); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + abstract CentroidQueryScorer getCentroidScorer( + FieldInfo fieldInfo, + int numCentroids, + IndexInput centroids, + float[] target, + IndexInput clusters + ) throws IOException; + + protected abstract FloatVectorValues getCentroids(IndexInput indexInput, int numCentroids, FieldInfo info) throws IOException; + + public FloatVectorValues getCentroids(FieldInfo fieldInfo) throws IOException { + FieldEntry entry = fields.get(fieldInfo.number); + if (entry == null) { + return null; + } + return getCentroids(entry.centroidSlice(ivfCentroids), entry.postingListOffsets.length, fieldInfo); + } + + int centroidSize(String fieldName, int centroidOrdinal) throws IOException { + FieldInfo fieldInfo = state.fieldInfos.fieldInfo(fieldName); + FieldEntry entry = fields.get(fieldInfo.number); + ivfClusters.seek(entry.postingListOffsets[centroidOrdinal]); + return ivfClusters.readVInt(); + } + + private static IndexInput openDataInput( + SegmentReadState state, + int versionMeta, + String fileExtension, + String codecName, + IOContext context + ) throws IOException { + final String fileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension); + final IndexInput in = state.directory.openInput(fileName, context); + boolean success = false; + try { + final int versionVectorData = CodecUtil.checkIndexHeader( + in, + codecName, + IVFVectorsFormat.VERSION_START, + IVFVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + if (versionMeta != versionVectorData) { + throw new CorruptIndexException( + "Format versions mismatch: meta=" + versionMeta + ", " + codecName + "=" + versionVectorData, + in + ); + } + CodecUtil.retrieveChecksum(in); + success = true; + return in; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(in); + } + } + } + + private void readFields(ChecksumIndexInput meta) throws IOException { + for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { + final FieldInfo info = fieldInfos.fieldInfo(fieldNumber); + if (info == null) { + throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); + } + fields.put(info.number, readField(meta, info)); + } + } + + private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException { + final VectorEncoding vectorEncoding = readVectorEncoding(input); + final VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); + final long centroidOffset = input.readLong(); + final long centroidLength = input.readLong(); + final int numPostingLists = input.readVInt(); + final long[] postingListOffsets = new long[numPostingLists]; + for (int i = 0; i < numPostingLists; i++) { + postingListOffsets[i] = input.readLong(); + } + final float[] globalCentroid = new float[info.getVectorDimension()]; + float globalCentroidDp = 0; + if (numPostingLists > 0) { + input.readFloats(globalCentroid, 0, globalCentroid.length); + globalCentroidDp = Float.intBitsToFloat(input.readInt()); + } + if (similarityFunction != info.getVectorSimilarityFunction()) { + throw new IllegalStateException( + "Inconsistent vector similarity function for field=\"" + + info.name + + "\"; " + + similarityFunction + + " != " + + info.getVectorSimilarityFunction() + ); + } + return new FieldEntry( + similarityFunction, + vectorEncoding, + centroidOffset, + centroidLength, + postingListOffsets, + globalCentroid, + globalCentroidDp + ); + } + + private static VectorSimilarityFunction readSimilarityFunction(DataInput input) throws IOException { + final int i = input.readInt(); + if (i < 0 || i >= SIMILARITY_FUNCTIONS.size()) { + throw new IllegalArgumentException("invalid distance function: " + i); + } + return SIMILARITY_FUNCTIONS.get(i); + } + + private static VectorEncoding readVectorEncoding(DataInput input) throws IOException { + final int encodingId = input.readInt(); + if (encodingId < 0 || encodingId >= VectorEncoding.values().length) { + throw new CorruptIndexException("Invalid vector encoding id: " + encodingId, input); + } + return VectorEncoding.values()[encodingId]; + } + + @Override + public final void checkIntegrity() throws IOException { + rawVectorsReader.checkIntegrity(); + CodecUtil.checksumEntireFile(ivfCentroids); + CodecUtil.checksumEntireFile(ivfClusters); + } + + @Override + public final FloatVectorValues getFloatVectorValues(String field) throws IOException { + return rawVectorsReader.getFloatVectorValues(field); + } + + @Override + public final ByteVectorValues getByteVectorValues(String field) throws IOException { + return rawVectorsReader.getByteVectorValues(field); + } + + protected float[] getGlobalCentroid(FieldInfo info) { + if (info == null || info.getVectorEncoding().equals(VectorEncoding.BYTE)) { + return null; + } + FieldEntry entry = fields.get(info.number); + if (entry == null) { + return null; + } + return entry.globalCentroid(); + } + + @Override + public final void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + final FieldInfo fieldInfo = state.fieldInfos.fieldInfo(field); + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32) == false) { + rawVectorsReader.search(field, target, knnCollector, acceptDocs); + return; + } + // TODO add new ivf search strategy + int nProbe = 10; + float percentFiltered = 1f; + if (acceptDocs instanceof BitSet bitSet) { + percentFiltered = Math.max(0f, Math.min(1f, (float) bitSet.approximateCardinality() / bitSet.length())); + } + int numVectors = rawVectorsReader.getFloatVectorValues(field).size(); + BitSet visitedDocs = new FixedBitSet(state.segmentInfo.maxDoc() + 1); + IntPredicate needsScoring = docId -> { + if (acceptDocs != null && acceptDocs.get(docId) == false) { + return false; + } + return visitedDocs.getAndSet(docId) == false; + }; + + FieldEntry entry = fields.get(fieldInfo.number); + CentroidQueryScorer centroidQueryScorer = getCentroidScorer( + fieldInfo, + entry.postingListOffsets.length, + entry.centroidSlice(ivfCentroids), + target, + ivfClusters + ); + final NeighborQueue centroidQueue = scorePostingLists(fieldInfo, knnCollector, centroidQueryScorer, nProbe); + PostingVisitor scorer = getPostingVisitor(fieldInfo, ivfClusters, target, needsScoring); + int centroidsVisited = 0; + long expectedDocs = 0; + long actualDocs = 0; + // initially we visit only the "centroids to search" + while (centroidQueue.size() > 0 && centroidsVisited < nProbe) { + ++centroidsVisited; + // todo do we actually need to know the score??? + int centroidOrdinal = centroidQueue.pop(); + // todo do we need direct access to the raw centroid??? + expectedDocs += scorer.resetPostingsScorer(centroidOrdinal, centroidQueryScorer.centroid(centroidOrdinal)); + actualDocs += scorer.visit(knnCollector); + } + if (acceptDocs != null) { + float unfilteredRatioVisited = (float) expectedDocs / numVectors; + int filteredVectors = (int) Math.ceil(numVectors * percentFiltered); + float expectedScored = Math.min(2 * filteredVectors * unfilteredRatioVisited, expectedDocs / 2f); + while (centroidQueue.size() > 0 && (actualDocs < expectedScored || actualDocs < knnCollector.k())) { + int centroidOrdinal = centroidQueue.pop(); + scorer.resetPostingsScorer(centroidOrdinal, centroidQueryScorer.centroid(centroidOrdinal)); + actualDocs += scorer.visit(knnCollector); + } + } + } + + @Override + public final void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + final FieldInfo fieldInfo = state.fieldInfos.fieldInfo(field); + final ByteVectorValues values = rawVectorsReader.getByteVectorValues(field); + for (int i = 0; i < values.size(); i++) { + final float score = fieldInfo.getVectorSimilarityFunction().compare(target, values.vectorValue(i)); + knnCollector.collect(values.ordToDoc(i), score); + if (knnCollector.earlyTerminated()) { + return; + } + } + } + + abstract NeighborQueue scorePostingLists( + FieldInfo fieldInfo, + KnnCollector knnCollector, + CentroidQueryScorer centroidQueryScorer, + int nProbe + ) throws IOException; + + @Override + public void close() throws IOException { + IOUtils.close(rawVectorsReader, ivfCentroids, ivfClusters); + } + + protected record FieldEntry( + VectorSimilarityFunction similarityFunction, + VectorEncoding vectorEncoding, + long centroidOffset, + long centroidLength, + long[] postingListOffsets, + float[] globalCentroid, + float globalCentroidDp + ) { + IndexInput centroidSlice(IndexInput centroidFile) throws IOException { + return centroidFile.slice("centroids", centroidOffset, centroidLength); + } + } + + abstract PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput postingsLists, float[] target, IntPredicate needsScoring) + throws IOException; + + interface CentroidQueryScorer { + int size(); + + float[] centroid(int centroidOrdinal) throws IOException; + + float score(int centroidOrdinal) throws IOException; + } + + interface PostingVisitor { + // TODO maybe we can not specifically pass the centroid... + + /** returns the number of documents in the posting list */ + int resetPostingsScorer(int centroidOrdinal, float[] centroid) throws IOException; + + /** returns the number of scored documents */ + int visit(KnnCollector collector) throws IOException; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java new file mode 100644 index 000000000000..4e9c4ee47e3f --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java @@ -0,0 +1,486 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors; + +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnFieldVectorsWriter; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.InfoStream; +import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.core.IOUtils; +import org.elasticsearch.core.SuppressForbidden; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.List; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +/** + * Base class for IVF vectors writer. + */ +public abstract class IVFVectorsWriter extends KnnVectorsWriter { + + private final List fieldWriters = new ArrayList<>(); + private final IndexOutput ivfCentroids, ivfClusters; + private final IndexOutput ivfMeta; + private final FlatVectorsWriter rawVectorDelegate; + private final SegmentWriteState segmentWriteState; + + protected IVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVectorDelegate) throws IOException { + this.segmentWriteState = state; + this.rawVectorDelegate = rawVectorDelegate; + final String metaFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + IVFVectorsFormat.IVF_META_EXTENSION + ); + + final String ivfCentroidsFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + IVFVectorsFormat.CENTROID_EXTENSION + ); + final String ivfClustersFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + IVFVectorsFormat.CLUSTER_EXTENSION + ); + boolean success = false; + try { + ivfMeta = state.directory.createOutput(metaFileName, state.context); + CodecUtil.writeIndexHeader( + ivfMeta, + IVFVectorsFormat.NAME, + IVFVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + ivfCentroids = state.directory.createOutput(ivfCentroidsFileName, state.context); + CodecUtil.writeIndexHeader( + ivfCentroids, + IVFVectorsFormat.NAME, + IVFVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + ivfClusters = state.directory.createOutput(ivfClustersFileName, state.context); + CodecUtil.writeIndexHeader( + ivfClusters, + IVFVectorsFormat.NAME, + IVFVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + @Override + public final KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + if (fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE) { + throw new IllegalArgumentException("IVF does not support cosine similarity"); + } + final FlatFieldVectorsWriter rawVectorDelegate = this.rawVectorDelegate.addField(fieldInfo); + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + @SuppressWarnings("unchecked") + final FlatFieldVectorsWriter floatWriter = (FlatFieldVectorsWriter) rawVectorDelegate; + fieldWriters.add(new FieldWriter(fieldInfo, floatWriter)); + } + return rawVectorDelegate; + } + + protected abstract int calculateAndWriteCentroids( + FieldInfo fieldInfo, + FloatVectorValues floatVectorValues, + IndexOutput temporaryCentroidOutput, + MergeState mergeState, + float[] globalCentroid + ) throws IOException; + + abstract long[] buildAndWritePostingsLists( + FieldInfo fieldInfo, + CentroidAssignmentScorer scorer, + FloatVectorValues floatVectorValues, + IndexOutput postingsOutput, + MergeState mergeState + ) throws IOException; + + abstract CentroidAssignmentScorer calculateAndWriteCentroids( + FieldInfo fieldInfo, + FloatVectorValues floatVectorValues, + IndexOutput centroidOutput, + float[] globalCentroid + ) throws IOException; + + abstract long[] buildAndWritePostingsLists( + FieldInfo fieldInfo, + InfoStream infoStream, + CentroidAssignmentScorer scorer, + FloatVectorValues floatVectorValues, + IndexOutput postingsOutput + ) throws IOException; + + abstract CentroidAssignmentScorer createCentroidScorer( + IndexInput centroidsInput, + int numCentroids, + FieldInfo fieldInfo, + float[] globalCentroid + ) throws IOException; + + @Override + public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + rawVectorDelegate.flush(maxDoc, sortMap); + for (FieldWriter fieldWriter : fieldWriters) { + float[] globalCentroid = new float[fieldWriter.fieldInfo.getVectorDimension()]; + // calculate global centroid + for (var vector : fieldWriter.delegate.getVectors()) { + for (int i = 0; i < globalCentroid.length; i++) { + globalCentroid[i] += vector[i]; + } + } + for (int i = 0; i < globalCentroid.length; i++) { + globalCentroid[i] /= fieldWriter.delegate.getVectors().size(); + } + // build a float vector values with random access + final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldWriter.fieldInfo, fieldWriter.delegate, maxDoc); + // build centroids + long centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES); + final CentroidAssignmentScorer centroidAssignmentScorer = calculateAndWriteCentroids( + fieldWriter.fieldInfo, + floatVectorValues, + ivfCentroids, + globalCentroid + ); + long centroidLength = ivfCentroids.getFilePointer() - centroidOffset; + final long[] offsets = buildAndWritePostingsLists( + fieldWriter.fieldInfo, + segmentWriteState.infoStream, + centroidAssignmentScorer, + floatVectorValues, + ivfClusters + ); + writeMeta(fieldWriter.fieldInfo, centroidOffset, centroidLength, offsets, globalCentroid); + } + } + + private static FloatVectorValues getFloatVectorValues( + FieldInfo fieldInfo, + FlatFieldVectorsWriter fieldVectorsWriter, + int maxDoc + ) throws IOException { + List vectors = fieldVectorsWriter.getVectors(); + if (vectors.size() == maxDoc) { + return FloatVectorValues.fromFloats(vectors, fieldInfo.getVectorDimension()); + } + final DocIdSetIterator iterator = fieldVectorsWriter.getDocsWithFieldSet().iterator(); + final int[] docIds = new int[vectors.size()]; + for (int i = 0; i < docIds.length; i++) { + docIds[i] = iterator.nextDoc(); + } + assert iterator.nextDoc() == NO_MORE_DOCS; + return new FloatVectorValues() { + @Override + public float[] vectorValue(int ord) { + return vectors.get(ord); + } + + @Override + public FloatVectorValues copy() { + return this; + } + + @Override + public int dimension() { + return fieldInfo.getVectorDimension(); + } + + @Override + public int size() { + return vectors.size(); + } + + @Override + public int ordToDoc(int ord) { + return docIds[ord]; + } + }; + } + + static IVFVectorsReader getIVFReader(KnnVectorsReader vectorsReader, String fieldName) { + if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) { + vectorsReader = candidateReader.getFieldReader(fieldName); + } + if (vectorsReader instanceof IVFVectorsReader reader) { + return reader; + } + return null; + } + + @Override + @SuppressForbidden(reason = "require usage of Lucene's IOUtils#deleteFilesIgnoringExceptions(...)") + public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + final int numVectors; + String tempRawVectorsFileName = null; + boolean success = false; + // build a float vector values with random access. In order to do that we dump the vectors to + // a temporary file + // and write the docID follow by the vector + try (IndexOutput out = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "ivf_", IOContext.DEFAULT)) { + tempRawVectorsFileName = out.getName(); + // TODO do this better, we shouldn't have to write to a temp file, we should be able to + // to just from the merged vector values, the tricky part is the random access. + numVectors = writeFloatVectorValues(fieldInfo, out, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)); + CodecUtil.writeFooter(out); + success = true; + } finally { + if (success == false && tempRawVectorsFileName != null) { + org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName); + } + } + try (IndexInput in = mergeState.segmentInfo.dir.openInput(tempRawVectorsFileName, IOContext.DEFAULT)) { + float[] calculatedGlobalCentroid = new float[fieldInfo.getVectorDimension()]; + final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, in, numVectors); + success = false; + CentroidAssignmentScorer centroidAssignmentScorer; + long centroidOffset; + long centroidLength; + String centroidTempName = null; + int numCentroids; + IndexOutput centroidTemp = null; + try { + centroidTemp = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "civf_", IOContext.DEFAULT); + centroidTempName = centroidTemp.getName(); + numCentroids = calculateAndWriteCentroids( + fieldInfo, + floatVectorValues, + centroidTemp, + mergeState, + calculatedGlobalCentroid + ); + success = true; + } finally { + if (success == false && centroidTempName != null) { + IOUtils.closeWhileHandlingException(centroidTemp); + org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, centroidTempName); + } + } + try { + if (numCentroids == 0) { + centroidOffset = ivfCentroids.getFilePointer(); + writeMeta(fieldInfo, centroidOffset, 0, new long[0], null); + CodecUtil.writeFooter(centroidTemp); + IOUtils.close(centroidTemp); + return; + } + CodecUtil.writeFooter(centroidTemp); + IOUtils.close(centroidTemp); + centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES); + try (IndexInput centroidInput = mergeState.segmentInfo.dir.openInput(centroidTempName, IOContext.DEFAULT)) { + ivfCentroids.copyBytes(centroidInput, centroidInput.length() - CodecUtil.footerLength()); + centroidLength = ivfCentroids.getFilePointer() - centroidOffset; + centroidAssignmentScorer = createCentroidScorer(centroidInput, numCentroids, fieldInfo, calculatedGlobalCentroid); + assert centroidAssignmentScorer.size() == numCentroids; + // build a float vector values with random access + // build centroids + final long[] offsets = buildAndWritePostingsLists( + fieldInfo, + centroidAssignmentScorer, + floatVectorValues, + ivfClusters, + mergeState + ); + assert offsets.length == centroidAssignmentScorer.size(); + writeMeta(fieldInfo, centroidOffset, centroidLength, offsets, calculatedGlobalCentroid); + } + } finally { + org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions( + mergeState.segmentInfo.dir, + tempRawVectorsFileName, + centroidTempName + ); + } + } finally { + org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName); + } + } + } + + private static FloatVectorValues getFloatVectorValues(FieldInfo fieldInfo, IndexInput randomAccessInput, int numVectors) { + if (numVectors == 0) { + return FloatVectorValues.fromFloats(List.of(), fieldInfo.getVectorDimension()); + } + final long length = (long) Float.BYTES * fieldInfo.getVectorDimension() + Integer.BYTES; + final float[] vector = new float[fieldInfo.getVectorDimension()]; + return new FloatVectorValues() { + @Override + public float[] vectorValue(int ord) throws IOException { + randomAccessInput.seek(ord * length + Integer.BYTES); + randomAccessInput.readFloats(vector, 0, vector.length); + return vector; + } + + @Override + public FloatVectorValues copy() { + return this; + } + + @Override + public int dimension() { + return fieldInfo.getVectorDimension(); + } + + @Override + public int size() { + return numVectors; + } + + @Override + public int ordToDoc(int ord) { + try { + randomAccessInput.seek(ord * length); + return randomAccessInput.readInt(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + }; + } + + private static int writeFloatVectorValues(FieldInfo fieldInfo, IndexOutput out, FloatVectorValues floatVectorValues) + throws IOException { + int numVectors = 0; + final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + final KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); + for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) { + numVectors++; + float[] vector = floatVectorValues.vectorValue(iterator.index()); + out.writeInt(iterator.docID()); + buffer.asFloatBuffer().put(vector); + out.writeBytes(buffer.array(), buffer.array().length); + } + return numVectors; + } + + private void writeMeta(FieldInfo field, long centroidOffset, long centroidLength, long[] offsets, float[] globalCentroid) + throws IOException { + ivfMeta.writeInt(field.number); + ivfMeta.writeInt(field.getVectorEncoding().ordinal()); + ivfMeta.writeInt(distFuncToOrd(field.getVectorSimilarityFunction())); + ivfMeta.writeLong(centroidOffset); + ivfMeta.writeLong(centroidLength); + ivfMeta.writeVInt(offsets.length); + for (long offset : offsets) { + ivfMeta.writeLong(offset); + } + if (offsets.length > 0) { + final ByteBuffer buffer = ByteBuffer.allocate(globalCentroid.length * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + buffer.asFloatBuffer().put(globalCentroid); + ivfMeta.writeBytes(buffer.array(), buffer.array().length); + ivfMeta.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(globalCentroid, globalCentroid))); + } + } + + private static int distFuncToOrd(VectorSimilarityFunction func) { + for (int i = 0; i < SIMILARITY_FUNCTIONS.size(); i++) { + if (SIMILARITY_FUNCTIONS.get(i).equals(func)) { + return (byte) i; + } + } + throw new IllegalArgumentException("invalid distance function: " + func); + } + + @Override + public final void finish() throws IOException { + rawVectorDelegate.finish(); + if (ivfMeta != null) { + // write end of fields marker + ivfMeta.writeInt(-1); + CodecUtil.writeFooter(ivfMeta); + } + if (ivfCentroids != null) { + CodecUtil.writeFooter(ivfCentroids); + } + if (ivfClusters != null) { + CodecUtil.writeFooter(ivfClusters); + } + } + + @Override + public final void close() throws IOException { + IOUtils.close(rawVectorDelegate, ivfMeta, ivfCentroids, ivfClusters); + } + + @Override + public final long ramBytesUsed() { + return rawVectorDelegate.ramBytesUsed(); + } + + private record FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter delegate) {} + + interface CentroidAssignmentScorer { + CentroidAssignmentScorer EMPTY = new CentroidAssignmentScorer() { + @Override + public int size() { + return 0; + } + + @Override + public float[] centroid(int centroidOrdinal) { + throw new IllegalStateException("No centroids"); + } + + @Override + public float score(int centroidOrdinal) { + throw new IllegalStateException("No centroids"); + } + + @Override + public void setScoringVector(float[] vector) { + throw new IllegalStateException("No centroids"); + } + }; + + int size(); + + float[] centroid(int centroidOrdinal) throws IOException; + + void setScoringVector(float[] vector); + + float score(int centroidOrdinal) throws IOException; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/NeighborQueue.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/NeighborQueue.java new file mode 100644 index 000000000000..f27e85d46cdd --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/NeighborQueue.java @@ -0,0 +1,159 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ + +package org.elasticsearch.index.codec.vectors; + +import org.apache.lucene.util.LongHeap; +import org.apache.lucene.util.NumericUtils; + +/** + * Copied from and modified from Apache Lucene. + */ +class NeighborQueue { + + private enum Order { + MIN_HEAP { + @Override + long apply(long v) { + return v; + } + }, + MAX_HEAP { + @Override + long apply(long v) { + // This cannot be just `-v` since Long.MIN_VALUE doesn't have a positive counterpart. It + // needs a function that returns MAX_VALUE for MIN_VALUE and vice-versa. + return -1 - v; + } + }; + + abstract long apply(long v); + } + + private final LongHeap heap; + private final Order order; + + NeighborQueue(int initialSize, boolean maxHeap) { + this.heap = new LongHeap(initialSize); + this.order = maxHeap ? Order.MAX_HEAP : Order.MIN_HEAP; + } + + /** + * @return the number of elements in the heap + */ + public int size() { + return heap.size(); + } + + /** + * Adds a new graph arc, extending the storage as needed. + * + * @param newNode the neighbor node id + * @param newScore the score of the neighbor, relative to some other node + */ + public void add(int newNode, float newScore) { + heap.push(encode(newNode, newScore)); + } + + /** + * If the heap is not full (size is less than the initialSize provided to the constructor), adds a + * new node-and-score element. If the heap is full, compares the score against the current top + * score, and replaces the top element if newScore is better than (greater than unless the heap is + * reversed), the current top score. + * + * @param newNode the neighbor node id + * @param newScore the score of the neighbor, relative to some other node + */ + public boolean insertWithOverflow(int newNode, float newScore) { + return heap.insertWithOverflow(encode(newNode, newScore)); + } + + /** + * Encodes the node ID and its similarity score as long, preserving the Lucene tie-breaking rule + * that when two scores are equal, the smaller node ID must win. + * @param node the node ID + * @param score the node score + * @return the encoded score, node ID + */ + private long encode(int node, float score) { + return order.apply((((long) NumericUtils.floatToSortableInt(score)) << 32) | (0xFFFFFFFFL & ~node)); + } + + /** Returns the top element's node id. */ + int topNode() { + return decodeNodeId(heap.top()); + } + + /** + * Returns the top element's node score. For the min heap this is the minimum score. For the max + * heap this is the maximum score. + */ + float topScore() { + return decodeScore(heap.top()); + } + + private float decodeScore(long heapValue) { + return NumericUtils.sortableIntToFloat((int) (order.apply(heapValue) >> 32)); + } + + private int decodeNodeId(long heapValue) { + return (int) ~(order.apply(heapValue)); + } + + /** Removes the top element and returns its node id. */ + public int pop() { + return decodeNodeId(heap.pop()); + } + + public void consumeNodes(int[] dest) { + if (dest.length < size()) { + throw new IllegalArgumentException("Destination array is too small. Expected at least " + size() + " elements."); + } + for (int i = 0; i < size(); i++) { + dest[i] = decodeNodeId(heap.get(i + 1)); + } + } + + public int consumeNodesAndScoresMin(int[] dest, float[] scores) { + if (dest.length < size() || scores.length < size()) { + throw new IllegalArgumentException("Destination array is too small. Expected at least " + size() + " elements."); + } + float bestScore = Float.POSITIVE_INFINITY; + int bestIdx = 0; + for (int i = 0; i < size(); i++) { + long heapValue = heap.get(i + 1); + scores[i] = decodeScore(heapValue); + dest[i] = decodeNodeId(heapValue); + if (scores[i] < bestScore) { + bestScore = scores[i]; + bestIdx = i; + } + } + return bestIdx; + } + + public void clear() { + heap.clear(); + } + + @Override + public String toString() { + return "Neighbors[" + heap.size() + "]"; + } +} diff --git a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index cef8d0998081..14e68029abc3 100644 --- a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -7,3 +7,4 @@ org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat +org.elasticsearch.index.codec.vectors.IVFVectorsFormat diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java new file mode 100644 index 000000000000..c822d71a358f --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java @@ -0,0 +1,65 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.index.codec.vectors; + +import com.carrotsearch.randomizedtesting.generators.RandomPicks; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.tests.util.TestUtil; +import org.elasticsearch.common.logging.LogConfigurator; +import org.junit.Before; + +import java.util.List; + +public class IVFVectorsFormatTests extends BaseKnnVectorsFormatTestCase { + + static { + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + KnnVectorsFormat format; + + @Before + @Override + public void setUp() throws Exception { + format = new IVFVectorsFormat(random().nextInt(10, 1000)); + super.setUp(); + } + + @Override + protected VectorSimilarityFunction randomSimilarity() { + return RandomPicks.randomFrom( + random(), + List.of( + VectorSimilarityFunction.DOT_PRODUCT, + VectorSimilarityFunction.EUCLIDEAN, + VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT + ) + ); + } + + @Override + protected VectorEncoding randomVectorEncoding() { + return VectorEncoding.FLOAT32; + } + + @Override + public void testSearchWithVisitedLimit() { + // ivf doesn't enforce visitation limit + } + + @Override + protected Codec getCodec() { + return TestUtil.alwaysKnnVectorsFormat(format); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/NeighborQueueTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/NeighborQueueTests.java new file mode 100644 index 000000000000..7238f58d746d --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/NeighborQueueTests.java @@ -0,0 +1,119 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ + +package org.elasticsearch.index.codec.vectors; + +import org.elasticsearch.test.ESTestCase; + +/** + * copied and modified from Lucene + */ +public class NeighborQueueTests extends ESTestCase { + public void testNeighborsProduct() { + // make sure we have the sign correct + NeighborQueue nn = new NeighborQueue(2, false); + assertTrue(nn.insertWithOverflow(2, 0.5f)); + assertTrue(nn.insertWithOverflow(1, 0.2f)); + assertTrue(nn.insertWithOverflow(3, 1f)); + assertEquals(0.5f, nn.topScore(), 0); + nn.pop(); + assertEquals(1f, nn.topScore(), 0); + nn.pop(); + } + + public void testNeighborsMaxHeap() { + NeighborQueue nn = new NeighborQueue(2, true); + assertTrue(nn.insertWithOverflow(2, 2)); + assertTrue(nn.insertWithOverflow(1, 1)); + assertFalse(nn.insertWithOverflow(3, 3)); + assertEquals(2f, nn.topScore(), 0); + nn.pop(); + assertEquals(1f, nn.topScore(), 0); + } + + public void testTopMaxHeap() { + NeighborQueue nn = new NeighborQueue(2, true); + nn.add(1, 2); + nn.add(2, 1); + // lower scores are better; highest score on top + assertEquals(2, nn.topScore(), 0); + assertEquals(1, nn.topNode()); + } + + public void testTopMinHeap() { + NeighborQueue nn = new NeighborQueue(2, false); + nn.add(1, 0.5f); + nn.add(2, -0.5f); + // higher scores are better; lowest score on top + assertEquals(-0.5f, nn.topScore(), 0); + assertEquals(2, nn.topNode()); + } + + public void testClear() { + NeighborQueue nn = new NeighborQueue(2, false); + nn.add(1, 1.1f); + nn.add(2, -2.2f); + nn.clear(); + + assertEquals(0, nn.size()); + } + + public void testMaxSizeQueue() { + NeighborQueue nn = new NeighborQueue(2, false); + nn.add(1, 1); + nn.add(2, 2); + assertEquals(2, nn.size()); + assertEquals(1, nn.topNode()); + + // insertWithOverflow does not extend the queue + nn.insertWithOverflow(3, 3); + assertEquals(2, nn.size()); + assertEquals(2, nn.topNode()); + + // add does extend the queue beyond maxSize + nn.add(4, 1); + assertEquals(3, nn.size()); + } + + public void testUnboundedQueue() { + NeighborQueue nn = new NeighborQueue(1, true); + float maxScore = -2; + int maxNode = -1; + for (int i = 0; i < 256; i++) { + // initial size is 32 + float score = random().nextFloat(); + if (score > maxScore) { + maxScore = score; + maxNode = i; + } + nn.add(i, score); + } + assertEquals(maxScore, nn.topScore(), 0); + assertEquals(maxNode, nn.topNode()); + } + + public void testInvalidArguments() { + expectThrows(IllegalArgumentException.class, () -> new NeighborQueue(0, false)); + } + + public void testToString() { + assertEquals("Neighbors[0]", new NeighborQueue(2, false).toString()); + } + +}