From f71aba1fdde066dc5ae6a5d353b51b9f5049a9c0 Mon Sep 17 00:00:00 2001 From: Chris Hegarty <62058229+ChrisHegarty@users.noreply.github.com> Date: Wed, 29 May 2024 16:32:06 +0100 Subject: [PATCH] Use the SIMD optimized SQ vector scorer at search time (#109109) This commit extends the custom SIMD optimized SQ vector scorer to include search time scoring. When run on JDK22+ vector scoring with be done with the custom scorer. The implementation uses the JDK 22+ on-heap ALLOW_HEAP_ACCESS Linker.Option so that the native code can access the query vector directly. --- .../vector/VectorScorerBenchmark.java | 36 +++- .../nativeaccess/jdk/JdkVectorLibrary.java | 12 +- .../nativeaccess/jdk/LinkerHelperUtil.java | 23 +++ .../nativeaccess/jdk/LinkerHelperUtil.java | 23 +++ .../jdk/JDKVectorLibraryTests.java | 34 +++- .../vec/VectorScorerFactory.java | 23 ++- .../vec/VectorScorerFactoryImpl.java | 15 +- .../vec/VectorScorerFactoryImpl.java | 16 +- .../vec/internal/Int7SQVectorScorer.java | 29 ++++ .../vec/internal/Similarities.java | 6 - .../vec/internal/Int7SQVectorScorer.java | 162 ++++++++++++++++++ .../vec/VectorScorerFactoryTests.java | 140 +++++++++++---- .../ES814ScalarQuantizedVectorsFormat.java | 27 +-- ...HnswScalarQuantizedVectorsFormatTests.java | 56 ++++++ 14 files changed, 538 insertions(+), 64 deletions(-) create mode 100644 libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/LinkerHelperUtil.java create mode 100644 libs/native/src/main22/java/org/elasticsearch/nativeaccess/jdk/LinkerHelperUtil.java create mode 100644 libs/vec/src/main21/java/org/elasticsearch/vec/internal/Int7SQVectorScorer.java create mode 100644 libs/vec/src/main22/java/org/elasticsearch/vec/internal/Int7SQVectorScorer.java diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java index 8836e6449b20..68e0f2151d1d 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java @@ -80,6 +80,9 @@ public class VectorScorerBenchmark { RandomVectorScorer nativeDotScorer; RandomVectorScorer nativeSqrScorer; + RandomVectorScorer luceneDotScorerQuery; + RandomVectorScorer nativeDotScorerQuery; + @Setup public void setup() throws IOException { var optionalVectorScorerFactory = VectorScorerFactory.instance(); @@ -116,8 +119,16 @@ public class VectorScorerBenchmark { values = vectorValues(dims, 2, in, VectorSimilarityFunction.EUCLIDEAN); luceneSqrScorer = luceneScoreSupplier(values, VectorSimilarityFunction.EUCLIDEAN).scorer(0); - nativeDotScorer = factory.getInt7ScalarQuantizedVectorScorer(DOT_PRODUCT, in, values, scoreCorrectionConstant).get().scorer(0); - nativeSqrScorer = factory.getInt7ScalarQuantizedVectorScorer(EUCLIDEAN, in, values, scoreCorrectionConstant).get().scorer(0); + nativeDotScorer = factory.getInt7SQVectorScorerSupplier(DOT_PRODUCT, in, values, scoreCorrectionConstant).get().scorer(0); + nativeSqrScorer = factory.getInt7SQVectorScorerSupplier(EUCLIDEAN, in, values, scoreCorrectionConstant).get().scorer(0); + + // setup for getInt7SQVectorScorer / query vector scoring + float[] queryVec = new float[dims]; + for (int i = 0; i < dims; i++) { + queryVec[i] = ThreadLocalRandom.current().nextFloat(); + } + luceneDotScorerQuery = luceneScorer(values, VectorSimilarityFunction.DOT_PRODUCT, queryVec); + nativeDotScorerQuery = factory.getInt7SQVectorScorer(VectorSimilarityFunction.DOT_PRODUCT, values, queryVec).get(); // sanity var f1 = dotProductLucene(); @@ -139,6 +150,12 @@ public class VectorScorerBenchmark { if (f1 != f3) { throw new AssertionError("lucene[" + f1 + "] != " + "scalar[" + f3 + "]"); } + + var q1 = dotProductLuceneQuery(); + var q2 = dotProductNativeQuery(); + if (q1 != q2) { + throw new AssertionError("query: lucene[" + q1 + "] != " + "native[" + q2 + "]"); + } } @TearDown @@ -166,6 +183,16 @@ public class VectorScorerBenchmark { return (1 + adjustedDistance) / 2; } + @Benchmark + public float dotProductLuceneQuery() throws IOException { + return luceneDotScorerQuery.score(1); + } + + @Benchmark + public float dotProductNativeQuery() throws IOException { + return nativeDotScorerQuery.score(1); + } + // -- square distance @Benchmark @@ -200,6 +227,11 @@ public class VectorScorerBenchmark { return new Lucene99ScalarQuantizedVectorScorer(null).getRandomVectorScorerSupplier(sim, values); } + RandomVectorScorer luceneScorer(RandomAccessQuantizedByteVectorValues values, VectorSimilarityFunction sim, float[] queryVec) + throws IOException { + return new Lucene99ScalarQuantizedVectorScorer(null).getRandomVectorScorer(sim, values, queryVec); + } + // Unsigned int7 byte vectors have values in the range of 0 to 127 (inclusive). static final byte MIN_INT7_VALUE = 0; static final byte MAX_INT7_VALUE = 127; diff --git a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java index e49b1985d643..db2e7b85c30d 100644 --- a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java +++ b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java @@ -50,8 +50,16 @@ public final class JdkVectorLibrary implements VectorLibrary { private static final class JdkVectorSimilarityFunctions implements VectorSimilarityFunctions { - static final MethodHandle dot7u$mh = downcallHandle("dot7u", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT)); - static final MethodHandle sqr7u$mh = downcallHandle("sqr7u", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT)); + static final MethodHandle dot7u$mh = downcallHandle( + "dot7u", + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT), + LinkerHelperUtil.critical() + ); + static final MethodHandle sqr7u$mh = downcallHandle( + "sqr7u", + FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT), + LinkerHelperUtil.critical() + ); /** * Computes the dot product of given unsigned int7 byte vectors. diff --git a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/LinkerHelperUtil.java b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/LinkerHelperUtil.java new file mode 100644 index 000000000000..8befc4bec127 --- /dev/null +++ b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/LinkerHelperUtil.java @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.nativeaccess.jdk; + +import java.lang.foreign.Linker; + +public class LinkerHelperUtil { + + static final Linker.Option[] NONE = new Linker.Option[0]; + + /** Returns an empty linker option array, since critical is only available since Java 22. */ + static Linker.Option[] critical() { + return NONE; + } + + private LinkerHelperUtil() {} +} diff --git a/libs/native/src/main22/java/org/elasticsearch/nativeaccess/jdk/LinkerHelperUtil.java b/libs/native/src/main22/java/org/elasticsearch/nativeaccess/jdk/LinkerHelperUtil.java new file mode 100644 index 000000000000..6ca3aeaa301c --- /dev/null +++ b/libs/native/src/main22/java/org/elasticsearch/nativeaccess/jdk/LinkerHelperUtil.java @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.nativeaccess.jdk; + +import java.lang.foreign.Linker; + +public class LinkerHelperUtil { + + static final Linker.Option[] ALLOW_HEAP_ACCESS = new Linker.Option[] { Linker.Option.critical(true) }; + + /** Returns a linker option used to mark a foreign function as critical. */ + static Linker.Option[] critical() { + return ALLOW_HEAP_ACCESS; + } + + private LinkerHelperUtil() {} +} diff --git a/libs/native/src/test21/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryTests.java b/libs/native/src/test21/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryTests.java index b34e25868844..cb68dca14cc4 100644 --- a/libs/native/src/test21/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryTests.java +++ b/libs/native/src/test21/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryTests.java @@ -68,17 +68,37 @@ public class JDKVectorLibraryTests extends VectorSimilarityFunctionsTests { for (int i = 0; i < loopTimes; i++) { int first = randomInt(numVecs - 1); int second = randomInt(numVecs - 1); - // dot product - int implDot = dotProduct7u(segment.asSlice((long) first * dims, dims), segment.asSlice((long) second * dims, dims), dims); - int otherDot = dotProductScalar(values[first], values[second]); - assertEquals(otherDot, implDot); + var nativeSeg1 = segment.asSlice((long) first * dims, dims); + var nativeSeg2 = segment.asSlice((long) second * dims, dims); - int implSqr = squareDistance7u(segment.asSlice((long) first * dims, dims), segment.asSlice((long) second * dims, dims), dims); - int otherSqr = squareDistanceScalar(values[first], values[second]); - assertEquals(otherSqr, implSqr); + // dot product + int expected = dotProductScalar(values[first], values[second]); + assertEquals(expected, dotProduct7u(nativeSeg1, nativeSeg2, dims)); + if (testWithHeapSegments()) { + var heapSeg1 = MemorySegment.ofArray(values[first]); + var heapSeg2 = MemorySegment.ofArray(values[second]); + assertEquals(expected, dotProduct7u(heapSeg1, heapSeg2, dims)); + assertEquals(expected, dotProduct7u(nativeSeg1, heapSeg2, dims)); + assertEquals(expected, dotProduct7u(heapSeg1, nativeSeg2, dims)); + } + + // square distance + expected = squareDistanceScalar(values[first], values[second]); + assertEquals(expected, squareDistance7u(nativeSeg1, nativeSeg2, dims)); + if (testWithHeapSegments()) { + var heapSeg1 = MemorySegment.ofArray(values[first]); + var heapSeg2 = MemorySegment.ofArray(values[second]); + assertEquals(expected, squareDistance7u(heapSeg1, heapSeg2, dims)); + assertEquals(expected, squareDistance7u(nativeSeg1, heapSeg2, dims)); + assertEquals(expected, squareDistance7u(heapSeg1, nativeSeg2, dims)); + } } } + static boolean testWithHeapSegments() { + return Runtime.version().feature() >= 22; + } + public void testIllegalDims() { assumeTrue(notSupportedMsg(), supported()); var segment = arena.allocate((long) size * 3); diff --git a/libs/vec/src/main/java/org/elasticsearch/vec/VectorScorerFactory.java b/libs/vec/src/main/java/org/elasticsearch/vec/VectorScorerFactory.java index ad7f467da9d2..600557278681 100644 --- a/libs/vec/src/main/java/org/elasticsearch/vec/VectorScorerFactory.java +++ b/libs/vec/src/main/java/org/elasticsearch/vec/VectorScorerFactory.java @@ -8,7 +8,9 @@ package org.elasticsearch.vec; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; @@ -22,8 +24,8 @@ public interface VectorScorerFactory { } /** - * Returns an optional containing an int7 scalar quantized vector scorer for - * the given parameters, or an empty optional if a scorer is not supported. + * Returns an optional containing an int7 scalar quantized vector score supplier + * for the given parameters, or an empty optional if a scorer is not supported. * * @param similarityType the similarity type * @param input the index input containing the vector data; @@ -33,10 +35,25 @@ public interface VectorScorerFactory { * @param scoreCorrectionConstant the score correction constant * @return an optional containing the vector scorer supplier, or empty */ - Optional getInt7ScalarQuantizedVectorScorer( + Optional getInt7SQVectorScorerSupplier( VectorSimilarityType similarityType, IndexInput input, RandomAccessQuantizedByteVectorValues values, float scoreCorrectionConstant ); + + /** + * Returns an optional containing an int7 scalar quantized vector scorer for + * the given parameters, or an empty optional if a scorer is not supported. + * + * @param sim the similarity type + * @param values the random access vector values + * @param queryVector the query vector + * @return an optional containing the vector scorer, or empty + */ + Optional getInt7SQVectorScorer( + VectorSimilarityFunction sim, + RandomAccessQuantizedByteVectorValues values, + float[] queryVector + ); } diff --git a/libs/vec/src/main/java/org/elasticsearch/vec/VectorScorerFactoryImpl.java b/libs/vec/src/main/java/org/elasticsearch/vec/VectorScorerFactoryImpl.java index 0b8231770490..3d6d0db38718 100644 --- a/libs/vec/src/main/java/org/elasticsearch/vec/VectorScorerFactoryImpl.java +++ b/libs/vec/src/main/java/org/elasticsearch/vec/VectorScorerFactoryImpl.java @@ -8,18 +8,20 @@ package org.elasticsearch.vec; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; import java.util.Optional; -class VectorScorerFactoryImpl implements VectorScorerFactory { +final class VectorScorerFactoryImpl implements VectorScorerFactory { static final VectorScorerFactoryImpl INSTANCE = null; @Override - public Optional getInt7ScalarQuantizedVectorScorer( + public Optional getInt7SQVectorScorerSupplier( VectorSimilarityType similarityType, IndexInput input, RandomAccessQuantizedByteVectorValues values, @@ -27,4 +29,13 @@ class VectorScorerFactoryImpl implements VectorScorerFactory { ) { throw new UnsupportedOperationException("should not reach here"); } + + @Override + public Optional getInt7SQVectorScorer( + VectorSimilarityFunction sim, + RandomAccessQuantizedByteVectorValues values, + float[] queryVector + ) { + throw new UnsupportedOperationException("should not reach here"); + } } diff --git a/libs/vec/src/main21/java/org/elasticsearch/vec/VectorScorerFactoryImpl.java b/libs/vec/src/main21/java/org/elasticsearch/vec/VectorScorerFactoryImpl.java index b48bf84cb2e7..944e886041b8 100644 --- a/libs/vec/src/main21/java/org/elasticsearch/vec/VectorScorerFactoryImpl.java +++ b/libs/vec/src/main21/java/org/elasticsearch/vec/VectorScorerFactoryImpl.java @@ -8,19 +8,22 @@ package org.elasticsearch.vec; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.FilterIndexInput; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.MemorySegmentAccessInput; +import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; import org.elasticsearch.nativeaccess.NativeAccess; +import org.elasticsearch.vec.internal.Int7SQVectorScorer; import org.elasticsearch.vec.internal.Int7SQVectorScorerSupplier.DotProductSupplier; import org.elasticsearch.vec.internal.Int7SQVectorScorerSupplier.EuclideanSupplier; import org.elasticsearch.vec.internal.Int7SQVectorScorerSupplier.MaxInnerProductSupplier; import java.util.Optional; -class VectorScorerFactoryImpl implements VectorScorerFactory { +final class VectorScorerFactoryImpl implements VectorScorerFactory { static final VectorScorerFactoryImpl INSTANCE; @@ -31,7 +34,7 @@ class VectorScorerFactoryImpl implements VectorScorerFactory { } @Override - public Optional getInt7ScalarQuantizedVectorScorer( + public Optional getInt7SQVectorScorerSupplier( VectorSimilarityType similarityType, IndexInput input, RandomAccessQuantizedByteVectorValues values, @@ -50,6 +53,15 @@ class VectorScorerFactoryImpl implements VectorScorerFactory { }; } + @Override + public Optional getInt7SQVectorScorer( + VectorSimilarityFunction sim, + RandomAccessQuantizedByteVectorValues values, + float[] queryVector + ) { + return Int7SQVectorScorer.create(sim, values, queryVector); + } + static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) { if (input.length() < (long) vectorByteLength * maxOrd) { throw new IllegalArgumentException("input length is less than expected vector data"); diff --git a/libs/vec/src/main21/java/org/elasticsearch/vec/internal/Int7SQVectorScorer.java b/libs/vec/src/main21/java/org/elasticsearch/vec/internal/Int7SQVectorScorer.java new file mode 100644 index 000000000000..95bf79eb9660 --- /dev/null +++ b/libs/vec/src/main21/java/org/elasticsearch/vec/internal/Int7SQVectorScorer.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.vec.internal; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; + +import java.util.Optional; + +public final class Int7SQVectorScorer { + + // Unconditionally returns an empty optional on <= JDK 21, since the scorer is only supported on JDK 22+ + public static Optional create( + VectorSimilarityFunction sim, + RandomAccessQuantizedByteVectorValues values, + float[] queryVector + ) { + return Optional.empty(); + } + + private Int7SQVectorScorer() {} +} diff --git a/libs/vec/src/main21/java/org/elasticsearch/vec/internal/Similarities.java b/libs/vec/src/main21/java/org/elasticsearch/vec/internal/Similarities.java index d0333931ce22..8f78e5a385a1 100644 --- a/libs/vec/src/main21/java/org/elasticsearch/vec/internal/Similarities.java +++ b/libs/vec/src/main21/java/org/elasticsearch/vec/internal/Similarities.java @@ -24,7 +24,6 @@ public class Similarities { static final MethodHandle SQUARE_DISTANCE_7U = DISTANCE_FUNCS.squareDistanceHandle7u(); static int dotProduct7u(MemorySegment a, MemorySegment b, int length) { - assert assertSegments(a, b, length); try { return (int) DOT_PRODUCT_7U.invokeExact(a, b, length); } catch (Throwable e) { @@ -39,7 +38,6 @@ public class Similarities { } static int squareDistance7u(MemorySegment a, MemorySegment b, int length) { - assert assertSegments(a, b, length); try { return (int) SQUARE_DISTANCE_7U.invokeExact(a, b, length); } catch (Throwable e) { @@ -52,8 +50,4 @@ public class Similarities { } } } - - static boolean assertSegments(MemorySegment a, MemorySegment b, int length) { - return a.isNative() && a.byteSize() >= length && b.isNative() && b.byteSize() >= length; - } } diff --git a/libs/vec/src/main22/java/org/elasticsearch/vec/internal/Int7SQVectorScorer.java b/libs/vec/src/main22/java/org/elasticsearch/vec/internal/Int7SQVectorScorer.java new file mode 100644 index 000000000000..b835e7734a48 --- /dev/null +++ b/libs/vec/src/main22/java/org/elasticsearch/vec/internal/Int7SQVectorScorer.java @@ -0,0 +1,162 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.vec.internal; + +import org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.FilterIndexInput; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.MemorySegmentAccessInput; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; +import org.apache.lucene.util.quantization.ScalarQuantizer; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.util.Optional; + +import static org.elasticsearch.vec.internal.Similarities.dotProduct7u; +import static org.elasticsearch.vec.internal.Similarities.squareDistance7u; + +public abstract sealed class Int7SQVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { + + final int vectorByteSize; + final MemorySegmentAccessInput input; + final MemorySegment query; + final float scoreCorrectionConstant; + final float queryCorrection; + byte[] scratch; + + /** Return an optional whose value, if present, is the scorer. Otherwise, an empty optional is returned. */ + public static Optional create( + VectorSimilarityFunction sim, + RandomAccessQuantizedByteVectorValues values, + float[] queryVector + ) { + checkDimensions(queryVector.length, values.dimension()); + var input = values.getSlice(); + if (input == null) { + return Optional.empty(); + } + input = FilterIndexInput.unwrapOnlyTest(input); + if (input instanceof MemorySegmentAccessInput == false) { + return Optional.empty(); + } + MemorySegmentAccessInput msInput = (MemorySegmentAccessInput) input; + checkInvariants(values.size(), values.dimension(), input); + + ScalarQuantizer scalarQuantizer = values.getScalarQuantizer(); + // TODO assert scalarQuantizer.getBits() == 7 or 8 ? + byte[] quantizedQuery = new byte[queryVector.length]; + float queryCorrection = ScalarQuantizedVectorScorer.quantizeQuery(queryVector, quantizedQuery, sim, scalarQuantizer); + return switch (sim) { + case COSINE, DOT_PRODUCT -> Optional.of(new DotProductScorer(msInput, values, quantizedQuery, queryCorrection)); + case EUCLIDEAN -> Optional.of(new EuclideanScorer(msInput, values, quantizedQuery, queryCorrection)); + case MAXIMUM_INNER_PRODUCT -> Optional.of(new MaxInnerProductScorer(msInput, values, quantizedQuery, queryCorrection)); + }; + } + + Int7SQVectorScorer( + MemorySegmentAccessInput input, + RandomAccessQuantizedByteVectorValues values, + byte[] queryVector, + float queryCorrection + ) { + super(values); + this.input = input; + assert queryVector.length == values.getVectorByteLength(); + this.vectorByteSize = values.getVectorByteLength(); + this.query = MemorySegment.ofArray(queryVector); + this.queryCorrection = queryCorrection; + this.scoreCorrectionConstant = values.getScalarQuantizer().getConstantMultiplier(); + } + + final MemorySegment getSegment(int ord) throws IOException { + checkOrdinal(ord); + long byteOffset = (long) ord * (vectorByteSize + Float.BYTES); + MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteSize); + if (seg == null) { + if (scratch == null) { + scratch = new byte[vectorByteSize]; + } + input.readBytes(byteOffset, scratch, 0, vectorByteSize); + seg = MemorySegment.ofArray(scratch); + } + return seg; + } + + static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) { + if (input.length() < (long) vectorByteLength * maxOrd) { + throw new IllegalArgumentException("input length is less than expected vector data"); + } + } + + final void checkOrdinal(int ord) { + if (ord < 0 || ord >= maxOrd()) { + throw new IllegalArgumentException("illegal ordinal: " + ord); + } + } + + public static final class DotProductScorer extends Int7SQVectorScorer { + public DotProductScorer(MemorySegmentAccessInput in, RandomAccessQuantizedByteVectorValues values, byte[] query, float correction) { + super(in, values, query, correction); + } + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + int dotProduct = dotProduct7u(query, getSegment(node), vectorByteSize); + assert dotProduct >= 0; + long byteOffset = (long) node * (vectorByteSize + Float.BYTES); + float nodeCorrection = Float.intBitsToFloat(input.readInt(byteOffset + vectorByteSize)); + float adjustedDistance = dotProduct * scoreCorrectionConstant + queryCorrection + nodeCorrection; + return Math.max((1 + adjustedDistance) / 2, 0f); + } + } + + public static final class EuclideanScorer extends Int7SQVectorScorer { + public EuclideanScorer(MemorySegmentAccessInput in, RandomAccessQuantizedByteVectorValues values, byte[] query, float correction) { + super(in, values, query, correction); + } + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + int sqDist = squareDistance7u(query, getSegment(node), vectorByteSize); + float adjustedDistance = sqDist * scoreCorrectionConstant; + return 1 / (1f + adjustedDistance); + } + } + + public static final class MaxInnerProductScorer extends Int7SQVectorScorer { + public MaxInnerProductScorer(MemorySegmentAccessInput in, RandomAccessQuantizedByteVectorValues values, byte[] query, float corr) { + super(in, values, query, corr); + } + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + int dotProduct = dotProduct7u(query, getSegment(node), vectorByteSize); + assert dotProduct >= 0; + long byteOffset = (long) node * (vectorByteSize + Float.BYTES); + float nodeCorrection = Float.intBitsToFloat(input.readInt(byteOffset + vectorByteSize)); + float adjustedDistance = dotProduct * scoreCorrectionConstant + queryCorrection + nodeCorrection; + if (adjustedDistance < 0) { + return 1 / (1 + -1 * adjustedDistance); + } + return adjustedDistance + 1; + } + } + + static void checkDimensions(int queryLen, int fieldLen) { + if (queryLen != fieldLen) { + throw new IllegalArgumentException("vector query dimension: " + queryLen + " differs from field dimension: " + fieldLen); + } + } +} diff --git a/libs/vec/src/test/java/org/elasticsearch/vec/VectorScorerFactoryTests.java b/libs/vec/src/test/java/org/elasticsearch/vec/VectorScorerFactoryTests.java index 1d429032b3cc..742585722c8a 100644 --- a/libs/vec/src/test/java/org/elasticsearch/vec/VectorScorerFactoryTests.java +++ b/libs/vec/src/test/java/org/elasticsearch/vec/VectorScorerFactoryTests.java @@ -37,6 +37,7 @@ import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.IntStream; +import static org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer.quantizeQuery; import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty; import static org.elasticsearch.vec.VectorSimilarityType.COSINE; import static org.elasticsearch.vec.VectorSimilarityType.DOT_PRODUCT; @@ -70,32 +71,42 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase { void testSimpleImpl(long maxChunkSize) throws IOException { assumeTrue(notSupportedMsg(), supported()); var factory = AbstractVectorTestCase.factory.get(); + var scalarQuantizer = new ScalarQuantizer(0.1f, 0.9f, (byte) 7); try (Directory dir = new MMapDirectory(createTempDir("testSimpleImpl"), maxChunkSize)) { - for (int dims : List.of(31, 32, 33)) { - // dimensions that cross the scalar / native boundary (stride) - byte[] vec1 = new byte[dims]; - byte[] vec2 = new byte[dims]; - String fileName = "testSimpleImpl" + "-" + dims; - try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { - for (int i = 0; i < dims; i++) { - vec1[i] = (byte) i; - vec2[i] = (byte) (dims - i); + for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { + for (int dims : List.of(31, 32, 33)) { + // dimensions that cross the scalar / native boundary (stride) + byte[] vec1 = new byte[dims]; + byte[] vec2 = new byte[dims]; + float[] query1 = new float[dims]; + float[] query2 = new float[dims]; + float vec1Correction, vec2Correction; + String fileName = "testSimpleImpl-" + sim + "-" + dims + ".vex"; + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < dims; i++) { + query1[i] = (float) i; + query2[i] = (float) (dims - i); + } + vec1Correction = quantizeQuery(query1, vec1, VectorSimilarityType.of(sim), scalarQuantizer); + vec2Correction = quantizeQuery(query2, vec2, VectorSimilarityType.of(sim), scalarQuantizer); + byte[] bytes = concat(vec1, floatToByteArray(vec1Correction), vec2, floatToByteArray(vec2Correction)); + out.writeBytes(bytes, 0, bytes.length); } - var oneFactor = floatToByteArray(1f); - byte[] bytes = concat(vec1, oneFactor, vec2, oneFactor); - out.writeBytes(bytes, 0, bytes.length); - } - try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { - for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { var values = vectorValues(dims, 2, in, VectorSimilarityType.of(sim)); float scc = values.getScalarQuantizer().getConstantMultiplier(); - float expected = luceneScore(sim, vec1, vec2, scc, 1, 1); + float expected = luceneScore(sim, vec1, vec2, scc, vec1Correction, vec2Correction); var luceneSupplier = luceneScoreSupplier(values, VectorSimilarityType.of(sim)).scorer(0); assertThat(luceneSupplier.score(1), equalTo(expected)); - var supplier = factory.getInt7ScalarQuantizedVectorScorer(sim, in, values, scc).get(); + var supplier = factory.getInt7SQVectorScorerSupplier(sim, in, values, scc).get(); assertThat(supplier.scorer(0).score(1), equalTo(expected)); + + if (Runtime.version().feature() >= 22) { + var qScorer = factory.getInt7SQVectorScorer(VectorSimilarityType.of(sim), values, query1).get(); + assertThat(qScorer.score(1), equalTo(expected)); + } } } } @@ -121,23 +132,23 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase { // dot product float expected = 0f; assertThat(luceneScore(DOT_PRODUCT, vec1, vec2, 1, -5, -5), equalTo(expected)); - var supplier = factory.getInt7ScalarQuantizedVectorScorer(DOT_PRODUCT, in, values, 1).get(); + var supplier = factory.getInt7SQVectorScorerSupplier(DOT_PRODUCT, in, values, 1).get(); assertThat(supplier.scorer(0).score(1), equalTo(expected)); assertThat(supplier.scorer(0).score(1), greaterThanOrEqualTo(0f)); // max inner product expected = luceneScore(MAXIMUM_INNER_PRODUCT, vec1, vec2, 1, -5, -5); - supplier = factory.getInt7ScalarQuantizedVectorScorer(MAXIMUM_INNER_PRODUCT, in, values, 1).get(); + supplier = factory.getInt7SQVectorScorerSupplier(MAXIMUM_INNER_PRODUCT, in, values, 1).get(); assertThat(supplier.scorer(0).score(1), greaterThanOrEqualTo(0f)); assertThat(supplier.scorer(0).score(1), equalTo(expected)); // cosine expected = 0f; assertThat(luceneScore(COSINE, vec1, vec2, 1, -5, -5), equalTo(expected)); - supplier = factory.getInt7ScalarQuantizedVectorScorer(COSINE, in, values, 1).get(); + supplier = factory.getInt7SQVectorScorerSupplier(COSINE, in, values, 1).get(); assertThat(supplier.scorer(0).score(1), equalTo(expected)); assertThat(supplier.scorer(0).score(1), greaterThanOrEqualTo(0f)); // euclidean expected = luceneScore(EUCLIDEAN, vec1, vec2, 1, -5, -5); - supplier = factory.getInt7ScalarQuantizedVectorScorer(EUCLIDEAN, in, values, 1).get(); + supplier = factory.getInt7SQVectorScorerSupplier(EUCLIDEAN, in, values, 1).get(); assertThat(supplier.scorer(0).score(1), equalTo(expected)); assertThat(supplier.scorer(0).score(1), greaterThanOrEqualTo(0f)); } @@ -146,27 +157,27 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase { public void testRandom() throws IOException { assumeTrue(notSupportedMsg(), supported()); - testRandom(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_RANDOM_INT7_FUNC); + testRandomSupplier(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_RANDOM_INT7_FUNC); } public void testRandomMaxChunkSizeSmall() throws IOException { assumeTrue(notSupportedMsg(), supported()); long maxChunkSize = randomLongBetween(32, 128); logger.info("maxChunkSize=" + maxChunkSize); - testRandom(maxChunkSize, BYTE_ARRAY_RANDOM_INT7_FUNC); + testRandomSupplier(maxChunkSize, BYTE_ARRAY_RANDOM_INT7_FUNC); } public void testRandomMax() throws IOException { assumeTrue(notSupportedMsg(), supported()); - testRandom(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_MAX_INT7_FUNC); + testRandomSupplier(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_MAX_INT7_FUNC); } public void testRandomMin() throws IOException { assumeTrue(notSupportedMsg(), supported()); - testRandom(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_MIN_INT7_FUNC); + testRandomSupplier(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_MIN_INT7_FUNC); } - void testRandom(long maxChunkSize, Function byteArraySupplier) throws IOException { + void testRandomSupplier(long maxChunkSize, Function byteArraySupplier) throws IOException { var factory = AbstractVectorTestCase.factory.get(); try (Directory dir = new MMapDirectory(createTempDir("testRandom"), maxChunkSize)) { @@ -195,7 +206,7 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase { for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { var values = vectorValues(dims, size, in, VectorSimilarityType.of(sim)); float expected = luceneScore(sim, vectors[idx0], vectors[idx1], correction, offsets[idx0], offsets[idx1]); - var supplier = factory.getInt7ScalarQuantizedVectorScorer(sim, in, values, correction).get(); + var supplier = factory.getInt7SQVectorScorerSupplier(sim, in, values, correction).get(); assertThat(supplier.scorer(idx0).score(idx1), equalTo(expected)); } } @@ -203,6 +214,61 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase { } } + public void testRandomScorer() throws IOException { + testRandomScorerImpl(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, VectorScorerFactoryTests.FLOAT_ARRAY_RANDOM_FUNC); + } + + public void testRandomScorerMax() throws IOException { + testRandomScorerImpl(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, VectorScorerFactoryTests.FLOAT_ARRAY_MAX_FUNC); + } + + public void testRandomScorerChunkSizeSmall() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + long maxChunkSize = randomLongBetween(32, 128); + logger.info("maxChunkSize=" + maxChunkSize); + testRandomScorerImpl(maxChunkSize, FLOAT_ARRAY_RANDOM_FUNC); + } + + void testRandomScorerImpl(long maxChunkSize, Function floatArraySupplier) throws IOException { + assumeTrue("scorer only supported on JDK 22+", Runtime.version().feature() >= 22); + var factory = AbstractVectorTestCase.factory.get(); + var scalarQuantizer = new ScalarQuantizer(0.1f, 0.9f, (byte) 7); + + try (Directory dir = new MMapDirectory(createTempDir("testRandom"), maxChunkSize)) { + for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { + final int dims = randomIntBetween(1, 4096); + final int size = randomIntBetween(2, 100); + final float[][] vectors = new float[size][]; + final byte[][] qVectors = new byte[size][]; + final float[] corrections = new float[size]; + + String fileName = "testRandom-" + sim + "-" + dims + ".vex"; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + vectors[i] = floatArraySupplier.apply(dims); + qVectors[i] = new byte[dims]; + corrections[i] = quantizeQuery(vectors[i], qVectors[i], VectorSimilarityType.of(sim), scalarQuantizer); + out.writeBytes(qVectors[i], 0, dims); + out.writeBytes(floatToByteArray(corrections[i]), 0, 4); + } + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + int idx0 = randomIntBetween(0, size - 1); + int idx1 = randomIntBetween(0, size - 1); + var values = vectorValues(dims, size, in, VectorSimilarityType.of(sim)); + var correction = scalarQuantizer.getConstantMultiplier(); + + var expected = luceneScore(sim, qVectors[idx0], qVectors[idx1], correction, corrections[idx0], corrections[idx1]); + var scorer = factory.getInt7SQVectorScorer(VectorSimilarityType.of(sim), values, vectors[idx0]).get(); + assertThat(scorer.score(idx1), equalTo(expected)); + } + } + } + } + } + public void testRandomSlice() throws IOException { assumeTrue(notSupportedMsg(), supported()); testRandomSliceImpl(30, 64, 1, BYTE_ARRAY_RANDOM_INT7_FUNC); @@ -243,7 +309,7 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase { for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { var values = vectorValues(dims, size, in, VectorSimilarityType.of(sim)); float expected = luceneScore(sim, vectors[idx0], vectors[idx1], correction, offsets[idx0], offsets[idx1]); - var supplier = factory.getInt7ScalarQuantizedVectorScorer(sim, in, values, correction).get(); + var supplier = factory.getInt7SQVectorScorerSupplier(sim, in, values, correction).get(); assertThat(supplier.scorer(idx0).score(idx1), equalTo(expected)); } } @@ -281,7 +347,7 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase { for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { var values = vectorValues(dims, size, in, VectorSimilarityType.of(sim)); float expected = luceneScore(sim, vector(idx0, dims), vector(idx1, dims), correction, off0, off1); - var supplier = factory.getInt7ScalarQuantizedVectorScorer(sim, in, values, correction).get(); + var supplier = factory.getInt7SQVectorScorerSupplier(sim, in, values, correction).get(); assertThat(supplier.scorer(idx0).score(idx1), equalTo(expected)); } } @@ -319,7 +385,7 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase { try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { var values = vectorValues(dims, 4, in, VectorSimilarityType.of(sim)); - var scoreSupplier = factory.getInt7ScalarQuantizedVectorScorer(sim, in, values, 1f).get(); + var scoreSupplier = factory.getInt7SQVectorScorerSupplier(sim, in, values, 1f).get(); var tasks = List.>>of( new ScoreCallable(scoreSupplier.copy().scorer(0), 1, expectedScore1), new ScoreCallable(scoreSupplier.copy().scorer(2), 3, expectedScore2) @@ -382,6 +448,20 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase { return ba; } + static Function FLOAT_ARRAY_RANDOM_FUNC = size -> { + float[] fa = new float[size]; + for (int i = 0; i < size; i++) { + fa[i] = randomFloat(); + } + return fa; + }; + + static Function FLOAT_ARRAY_MAX_FUNC = size -> { + float[] fa = new float[size]; + Arrays.fill(fa, Float.MAX_VALUE); + return fa; + }; + static Function BYTE_ARRAY_RANDOM_INT7_FUNC = size -> { byte[] ba = new byte[size]; randomBytesBetween(ba, MIN_INT7_VALUE, MAX_INT7_VALUE); diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES814ScalarQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES814ScalarQuantizedVectorsFormat.java index 942d968b48c1..0d1c5efeb3e2 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES814ScalarQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES814ScalarQuantizedVectorsFormat.java @@ -39,7 +39,6 @@ import org.elasticsearch.vec.VectorScorerFactory; import org.elasticsearch.vec.VectorSimilarityType; import java.io.IOException; -import java.util.Optional; public class ES814ScalarQuantizedVectorsFormat extends FlatVectorsFormat { @@ -198,24 +197,24 @@ public class ES814ScalarQuantizedVectorsFormat extends FlatVectorsFormat { static final class ESFlatVectorsScorer implements FlatVectorsScorer { final FlatVectorsScorer delegate; + final VectorScorerFactory factory; ESFlatVectorsScorer(FlatVectorsScorer delegte) { this.delegate = delegte; + factory = VectorScorerFactory.instance().orElse(null); } @Override public RandomVectorScorerSupplier getRandomVectorScorerSupplier(VectorSimilarityFunction sim, RandomAccessVectorValues values) throws IOException { if (values instanceof RandomAccessQuantizedByteVectorValues qValues && values.getSlice() != null) { - Optional factory = VectorScorerFactory.instance(); - if (factory.isPresent()) { - var scorer = factory.get() - .getInt7ScalarQuantizedVectorScorer( - VectorSimilarityType.of(sim), - values.getSlice(), - qValues, - qValues.getScalarQuantizer().getConstantMultiplier() - ); + if (factory != null) { + var scorer = factory.getInt7SQVectorScorerSupplier( + VectorSimilarityType.of(sim), + values.getSlice(), + qValues, + qValues.getScalarQuantizer().getConstantMultiplier() + ); if (scorer.isPresent()) { return scorer.get(); } @@ -227,6 +226,14 @@ public class ES814ScalarQuantizedVectorsFormat extends FlatVectorsFormat { @Override public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction sim, RandomAccessVectorValues values, float[] query) throws IOException { + if (values instanceof RandomAccessQuantizedByteVectorValues qValues && values.getSlice() != null) { + if (factory != null) { + var scorer = factory.getInt7SQVectorScorer(sim, qValues, query); + if (scorer.isPresent()) { + return scorer.get(); + } + } + } return delegate.getRandomVectorScorer(sim, values, query); } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormatTests.java index 915c5f655f18..ca446a607f63 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormatTests.java @@ -12,12 +12,14 @@ import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99Codec; import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.StoredFields; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; import org.apache.lucene.store.MMapDirectory; @@ -32,6 +34,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; public class ES814HnswScalarQuantizedVectorsFormatTests extends BaseKnnVectorsFormatTestCase { static { + LogConfigurator.loadLog4jPlugins(); LogConfigurator.configureESLogging(); // native access requires logging to be initialized } @@ -117,4 +120,57 @@ public class ES814HnswScalarQuantizedVectorsFormatTests extends BaseKnnVectorsFo } } } + + public void testSingleVectorPerSegmentCosine() throws Exception { + testSingleVectorPerSegment(VectorSimilarityFunction.COSINE); + } + + public void testSingleVectorPerSegmentDot() throws Exception { + testSingleVectorPerSegment(VectorSimilarityFunction.DOT_PRODUCT); + } + + public void testSingleVectorPerSegmentEuclidean() throws Exception { + testSingleVectorPerSegment(VectorSimilarityFunction.EUCLIDEAN); + } + + public void testSingleVectorPerSegmentMIP() throws Exception { + testSingleVectorPerSegment(VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT); + } + + private void testSingleVectorPerSegment(VectorSimilarityFunction sim) throws Exception { + var codec = getCodec(); + try (Directory dir = new MMapDirectory(createTempDir().resolve("dir1"))) { + try (IndexWriter writer = new IndexWriter(dir, newIndexWriterConfig().setCodec(codec))) { + Document doc2 = new Document(); + doc2.add(new KnnFloatVectorField("field", new float[] { 0.8f, 0.6f }, sim)); + doc2.add(newTextField("id", "A", Field.Store.YES)); + writer.addDocument(doc2); + writer.commit(); + + Document doc1 = new Document(); + doc1.add(new KnnFloatVectorField("field", new float[] { 0.6f, 0.8f }, sim)); + doc1.add(newTextField("id", "B", Field.Store.YES)); + writer.addDocument(doc1); + writer.commit(); + + Document doc3 = new Document(); + doc3.add(new KnnFloatVectorField("field", new float[] { -0.6f, -0.8f }, sim)); + doc3.add(newTextField("id", "C", Field.Store.YES)); + writer.addDocument(doc3); + writer.commit(); + + writer.forceMerge(1); + } + try (DirectoryReader reader = DirectoryReader.open(dir)) { + LeafReader leafReader = getOnlyLeafReader(reader); + StoredFields storedFields = reader.storedFields(); + float[] queryVector = new float[] { 0.6f, 0.8f }; + var hits = leafReader.searchNearestVectors("field", queryVector, 3, null, 100); + assertEquals(hits.scoreDocs.length, 3); + assertEquals("B", storedFields.document(hits.scoreDocs[0].doc).get("id")); + assertEquals("A", storedFields.document(hits.scoreDocs[1].doc).get("id")); + assertEquals("C", storedFields.document(hits.scoreDocs[2].doc).get("id")); + } + } + } }