mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 09:28:55 -04:00
Use the SIMD optimized SQ vector scorer at search time (#109109)
This commit extends the custom SIMD optimized SQ vector scorer to include search time scoring. When run on JDK22+ vector scoring with be done with the custom scorer. The implementation uses the JDK 22+ on-heap ALLOW_HEAP_ACCESS Linker.Option so that the native code can access the query vector directly.
This commit is contained in:
parent
b35239c952
commit
f71aba1fdd
14 changed files with 538 additions and 64 deletions
|
@ -80,6 +80,9 @@ public class VectorScorerBenchmark {
|
|||
RandomVectorScorer nativeDotScorer;
|
||||
RandomVectorScorer nativeSqrScorer;
|
||||
|
||||
RandomVectorScorer luceneDotScorerQuery;
|
||||
RandomVectorScorer nativeDotScorerQuery;
|
||||
|
||||
@Setup
|
||||
public void setup() throws IOException {
|
||||
var optionalVectorScorerFactory = VectorScorerFactory.instance();
|
||||
|
@ -116,8 +119,16 @@ public class VectorScorerBenchmark {
|
|||
values = vectorValues(dims, 2, in, VectorSimilarityFunction.EUCLIDEAN);
|
||||
luceneSqrScorer = luceneScoreSupplier(values, VectorSimilarityFunction.EUCLIDEAN).scorer(0);
|
||||
|
||||
nativeDotScorer = factory.getInt7ScalarQuantizedVectorScorer(DOT_PRODUCT, in, values, scoreCorrectionConstant).get().scorer(0);
|
||||
nativeSqrScorer = factory.getInt7ScalarQuantizedVectorScorer(EUCLIDEAN, in, values, scoreCorrectionConstant).get().scorer(0);
|
||||
nativeDotScorer = factory.getInt7SQVectorScorerSupplier(DOT_PRODUCT, in, values, scoreCorrectionConstant).get().scorer(0);
|
||||
nativeSqrScorer = factory.getInt7SQVectorScorerSupplier(EUCLIDEAN, in, values, scoreCorrectionConstant).get().scorer(0);
|
||||
|
||||
// setup for getInt7SQVectorScorer / query vector scoring
|
||||
float[] queryVec = new float[dims];
|
||||
for (int i = 0; i < dims; i++) {
|
||||
queryVec[i] = ThreadLocalRandom.current().nextFloat();
|
||||
}
|
||||
luceneDotScorerQuery = luceneScorer(values, VectorSimilarityFunction.DOT_PRODUCT, queryVec);
|
||||
nativeDotScorerQuery = factory.getInt7SQVectorScorer(VectorSimilarityFunction.DOT_PRODUCT, values, queryVec).get();
|
||||
|
||||
// sanity
|
||||
var f1 = dotProductLucene();
|
||||
|
@ -139,6 +150,12 @@ public class VectorScorerBenchmark {
|
|||
if (f1 != f3) {
|
||||
throw new AssertionError("lucene[" + f1 + "] != " + "scalar[" + f3 + "]");
|
||||
}
|
||||
|
||||
var q1 = dotProductLuceneQuery();
|
||||
var q2 = dotProductNativeQuery();
|
||||
if (q1 != q2) {
|
||||
throw new AssertionError("query: lucene[" + q1 + "] != " + "native[" + q2 + "]");
|
||||
}
|
||||
}
|
||||
|
||||
@TearDown
|
||||
|
@ -166,6 +183,16 @@ public class VectorScorerBenchmark {
|
|||
return (1 + adjustedDistance) / 2;
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public float dotProductLuceneQuery() throws IOException {
|
||||
return luceneDotScorerQuery.score(1);
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public float dotProductNativeQuery() throws IOException {
|
||||
return nativeDotScorerQuery.score(1);
|
||||
}
|
||||
|
||||
// -- square distance
|
||||
|
||||
@Benchmark
|
||||
|
@ -200,6 +227,11 @@ public class VectorScorerBenchmark {
|
|||
return new Lucene99ScalarQuantizedVectorScorer(null).getRandomVectorScorerSupplier(sim, values);
|
||||
}
|
||||
|
||||
RandomVectorScorer luceneScorer(RandomAccessQuantizedByteVectorValues values, VectorSimilarityFunction sim, float[] queryVec)
|
||||
throws IOException {
|
||||
return new Lucene99ScalarQuantizedVectorScorer(null).getRandomVectorScorer(sim, values, queryVec);
|
||||
}
|
||||
|
||||
// Unsigned int7 byte vectors have values in the range of 0 to 127 (inclusive).
|
||||
static final byte MIN_INT7_VALUE = 0;
|
||||
static final byte MAX_INT7_VALUE = 127;
|
||||
|
|
|
@ -50,8 +50,16 @@ public final class JdkVectorLibrary implements VectorLibrary {
|
|||
|
||||
private static final class JdkVectorSimilarityFunctions implements VectorSimilarityFunctions {
|
||||
|
||||
static final MethodHandle dot7u$mh = downcallHandle("dot7u", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT));
|
||||
static final MethodHandle sqr7u$mh = downcallHandle("sqr7u", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT));
|
||||
static final MethodHandle dot7u$mh = downcallHandle(
|
||||
"dot7u",
|
||||
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
|
||||
LinkerHelperUtil.critical()
|
||||
);
|
||||
static final MethodHandle sqr7u$mh = downcallHandle(
|
||||
"sqr7u",
|
||||
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
|
||||
LinkerHelperUtil.critical()
|
||||
);
|
||||
|
||||
/**
|
||||
* Computes the dot product of given unsigned int7 byte vectors.
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0 and the Server Side Public License, v 1; you may not use this file except
|
||||
* in compliance with, at your election, the Elastic License 2.0 or the Server
|
||||
* Side Public License, v 1.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.nativeaccess.jdk;
|
||||
|
||||
import java.lang.foreign.Linker;
|
||||
|
||||
public class LinkerHelperUtil {
|
||||
|
||||
static final Linker.Option[] NONE = new Linker.Option[0];
|
||||
|
||||
/** Returns an empty linker option array, since critical is only available since Java 22. */
|
||||
static Linker.Option[] critical() {
|
||||
return NONE;
|
||||
}
|
||||
|
||||
private LinkerHelperUtil() {}
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0 and the Server Side Public License, v 1; you may not use this file except
|
||||
* in compliance with, at your election, the Elastic License 2.0 or the Server
|
||||
* Side Public License, v 1.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.nativeaccess.jdk;
|
||||
|
||||
import java.lang.foreign.Linker;
|
||||
|
||||
public class LinkerHelperUtil {
|
||||
|
||||
static final Linker.Option[] ALLOW_HEAP_ACCESS = new Linker.Option[] { Linker.Option.critical(true) };
|
||||
|
||||
/** Returns a linker option used to mark a foreign function as critical. */
|
||||
static Linker.Option[] critical() {
|
||||
return ALLOW_HEAP_ACCESS;
|
||||
}
|
||||
|
||||
private LinkerHelperUtil() {}
|
||||
}
|
|
@ -68,15 +68,35 @@ public class JDKVectorLibraryTests extends VectorSimilarityFunctionsTests {
|
|||
for (int i = 0; i < loopTimes; i++) {
|
||||
int first = randomInt(numVecs - 1);
|
||||
int second = randomInt(numVecs - 1);
|
||||
// dot product
|
||||
int implDot = dotProduct7u(segment.asSlice((long) first * dims, dims), segment.asSlice((long) second * dims, dims), dims);
|
||||
int otherDot = dotProductScalar(values[first], values[second]);
|
||||
assertEquals(otherDot, implDot);
|
||||
var nativeSeg1 = segment.asSlice((long) first * dims, dims);
|
||||
var nativeSeg2 = segment.asSlice((long) second * dims, dims);
|
||||
|
||||
int implSqr = squareDistance7u(segment.asSlice((long) first * dims, dims), segment.asSlice((long) second * dims, dims), dims);
|
||||
int otherSqr = squareDistanceScalar(values[first], values[second]);
|
||||
assertEquals(otherSqr, implSqr);
|
||||
// dot product
|
||||
int expected = dotProductScalar(values[first], values[second]);
|
||||
assertEquals(expected, dotProduct7u(nativeSeg1, nativeSeg2, dims));
|
||||
if (testWithHeapSegments()) {
|
||||
var heapSeg1 = MemorySegment.ofArray(values[first]);
|
||||
var heapSeg2 = MemorySegment.ofArray(values[second]);
|
||||
assertEquals(expected, dotProduct7u(heapSeg1, heapSeg2, dims));
|
||||
assertEquals(expected, dotProduct7u(nativeSeg1, heapSeg2, dims));
|
||||
assertEquals(expected, dotProduct7u(heapSeg1, nativeSeg2, dims));
|
||||
}
|
||||
|
||||
// square distance
|
||||
expected = squareDistanceScalar(values[first], values[second]);
|
||||
assertEquals(expected, squareDistance7u(nativeSeg1, nativeSeg2, dims));
|
||||
if (testWithHeapSegments()) {
|
||||
var heapSeg1 = MemorySegment.ofArray(values[first]);
|
||||
var heapSeg2 = MemorySegment.ofArray(values[second]);
|
||||
assertEquals(expected, squareDistance7u(heapSeg1, heapSeg2, dims));
|
||||
assertEquals(expected, squareDistance7u(nativeSeg1, heapSeg2, dims));
|
||||
assertEquals(expected, squareDistance7u(heapSeg1, nativeSeg2, dims));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static boolean testWithHeapSegments() {
|
||||
return Runtime.version().feature() >= 22;
|
||||
}
|
||||
|
||||
public void testIllegalDims() {
|
||||
|
|
|
@ -8,7 +8,9 @@
|
|||
|
||||
package org.elasticsearch.vec;
|
||||
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
|
||||
|
||||
|
@ -22,8 +24,8 @@ public interface VectorScorerFactory {
|
|||
}
|
||||
|
||||
/**
|
||||
* Returns an optional containing an int7 scalar quantized vector scorer for
|
||||
* the given parameters, or an empty optional if a scorer is not supported.
|
||||
* Returns an optional containing an int7 scalar quantized vector score supplier
|
||||
* for the given parameters, or an empty optional if a scorer is not supported.
|
||||
*
|
||||
* @param similarityType the similarity type
|
||||
* @param input the index input containing the vector data;
|
||||
|
@ -33,10 +35,25 @@ public interface VectorScorerFactory {
|
|||
* @param scoreCorrectionConstant the score correction constant
|
||||
* @return an optional containing the vector scorer supplier, or empty
|
||||
*/
|
||||
Optional<RandomVectorScorerSupplier> getInt7ScalarQuantizedVectorScorer(
|
||||
Optional<RandomVectorScorerSupplier> getInt7SQVectorScorerSupplier(
|
||||
VectorSimilarityType similarityType,
|
||||
IndexInput input,
|
||||
RandomAccessQuantizedByteVectorValues values,
|
||||
float scoreCorrectionConstant
|
||||
);
|
||||
|
||||
/**
|
||||
* Returns an optional containing an int7 scalar quantized vector scorer for
|
||||
* the given parameters, or an empty optional if a scorer is not supported.
|
||||
*
|
||||
* @param sim the similarity type
|
||||
* @param values the random access vector values
|
||||
* @param queryVector the query vector
|
||||
* @return an optional containing the vector scorer, or empty
|
||||
*/
|
||||
Optional<RandomVectorScorer> getInt7SQVectorScorer(
|
||||
VectorSimilarityFunction sim,
|
||||
RandomAccessQuantizedByteVectorValues values,
|
||||
float[] queryVector
|
||||
);
|
||||
}
|
||||
|
|
|
@ -8,18 +8,20 @@
|
|||
|
||||
package org.elasticsearch.vec;
|
||||
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
|
||||
|
||||
import java.util.Optional;
|
||||
|
||||
class VectorScorerFactoryImpl implements VectorScorerFactory {
|
||||
final class VectorScorerFactoryImpl implements VectorScorerFactory {
|
||||
|
||||
static final VectorScorerFactoryImpl INSTANCE = null;
|
||||
|
||||
@Override
|
||||
public Optional<RandomVectorScorerSupplier> getInt7ScalarQuantizedVectorScorer(
|
||||
public Optional<RandomVectorScorerSupplier> getInt7SQVectorScorerSupplier(
|
||||
VectorSimilarityType similarityType,
|
||||
IndexInput input,
|
||||
RandomAccessQuantizedByteVectorValues values,
|
||||
|
@ -27,4 +29,13 @@ class VectorScorerFactoryImpl implements VectorScorerFactory {
|
|||
) {
|
||||
throw new UnsupportedOperationException("should not reach here");
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<RandomVectorScorer> getInt7SQVectorScorer(
|
||||
VectorSimilarityFunction sim,
|
||||
RandomAccessQuantizedByteVectorValues values,
|
||||
float[] queryVector
|
||||
) {
|
||||
throw new UnsupportedOperationException("should not reach here");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,19 +8,22 @@
|
|||
|
||||
package org.elasticsearch.vec;
|
||||
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.FilterIndexInput;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.MemorySegmentAccessInput;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
|
||||
import org.elasticsearch.nativeaccess.NativeAccess;
|
||||
import org.elasticsearch.vec.internal.Int7SQVectorScorer;
|
||||
import org.elasticsearch.vec.internal.Int7SQVectorScorerSupplier.DotProductSupplier;
|
||||
import org.elasticsearch.vec.internal.Int7SQVectorScorerSupplier.EuclideanSupplier;
|
||||
import org.elasticsearch.vec.internal.Int7SQVectorScorerSupplier.MaxInnerProductSupplier;
|
||||
|
||||
import java.util.Optional;
|
||||
|
||||
class VectorScorerFactoryImpl implements VectorScorerFactory {
|
||||
final class VectorScorerFactoryImpl implements VectorScorerFactory {
|
||||
|
||||
static final VectorScorerFactoryImpl INSTANCE;
|
||||
|
||||
|
@ -31,7 +34,7 @@ class VectorScorerFactoryImpl implements VectorScorerFactory {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Optional<RandomVectorScorerSupplier> getInt7ScalarQuantizedVectorScorer(
|
||||
public Optional<RandomVectorScorerSupplier> getInt7SQVectorScorerSupplier(
|
||||
VectorSimilarityType similarityType,
|
||||
IndexInput input,
|
||||
RandomAccessQuantizedByteVectorValues values,
|
||||
|
@ -50,6 +53,15 @@ class VectorScorerFactoryImpl implements VectorScorerFactory {
|
|||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<RandomVectorScorer> getInt7SQVectorScorer(
|
||||
VectorSimilarityFunction sim,
|
||||
RandomAccessQuantizedByteVectorValues values,
|
||||
float[] queryVector
|
||||
) {
|
||||
return Int7SQVectorScorer.create(sim, values, queryVector);
|
||||
}
|
||||
|
||||
static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) {
|
||||
if (input.length() < (long) vectorByteLength * maxOrd) {
|
||||
throw new IllegalArgumentException("input length is less than expected vector data");
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0 and the Server Side Public License, v 1; you may not use this file except
|
||||
* in compliance with, at your election, the Elastic License 2.0 or the Server
|
||||
* Side Public License, v 1.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.vec.internal;
|
||||
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
|
||||
|
||||
import java.util.Optional;
|
||||
|
||||
public final class Int7SQVectorScorer {
|
||||
|
||||
// Unconditionally returns an empty optional on <= JDK 21, since the scorer is only supported on JDK 22+
|
||||
public static Optional<RandomVectorScorer> create(
|
||||
VectorSimilarityFunction sim,
|
||||
RandomAccessQuantizedByteVectorValues values,
|
||||
float[] queryVector
|
||||
) {
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
private Int7SQVectorScorer() {}
|
||||
}
|
|
@ -24,7 +24,6 @@ public class Similarities {
|
|||
static final MethodHandle SQUARE_DISTANCE_7U = DISTANCE_FUNCS.squareDistanceHandle7u();
|
||||
|
||||
static int dotProduct7u(MemorySegment a, MemorySegment b, int length) {
|
||||
assert assertSegments(a, b, length);
|
||||
try {
|
||||
return (int) DOT_PRODUCT_7U.invokeExact(a, b, length);
|
||||
} catch (Throwable e) {
|
||||
|
@ -39,7 +38,6 @@ public class Similarities {
|
|||
}
|
||||
|
||||
static int squareDistance7u(MemorySegment a, MemorySegment b, int length) {
|
||||
assert assertSegments(a, b, length);
|
||||
try {
|
||||
return (int) SQUARE_DISTANCE_7U.invokeExact(a, b, length);
|
||||
} catch (Throwable e) {
|
||||
|
@ -52,8 +50,4 @@ public class Similarities {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
static boolean assertSegments(MemorySegment a, MemorySegment b, int length) {
|
||||
return a.isNative() && a.byteSize() >= length && b.isNative() && b.byteSize() >= length;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,162 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0 and the Server Side Public License, v 1; you may not use this file except
|
||||
* in compliance with, at your election, the Elastic License 2.0 or the Server
|
||||
* Side Public License, v 1.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.vec.internal;
|
||||
|
||||
import org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.FilterIndexInput;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.MemorySegmentAccessInput;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
|
||||
import org.apache.lucene.util.quantization.ScalarQuantizer;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.lang.foreign.MemorySegment;
|
||||
import java.util.Optional;
|
||||
|
||||
import static org.elasticsearch.vec.internal.Similarities.dotProduct7u;
|
||||
import static org.elasticsearch.vec.internal.Similarities.squareDistance7u;
|
||||
|
||||
public abstract sealed class Int7SQVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer {
|
||||
|
||||
final int vectorByteSize;
|
||||
final MemorySegmentAccessInput input;
|
||||
final MemorySegment query;
|
||||
final float scoreCorrectionConstant;
|
||||
final float queryCorrection;
|
||||
byte[] scratch;
|
||||
|
||||
/** Return an optional whose value, if present, is the scorer. Otherwise, an empty optional is returned. */
|
||||
public static Optional<RandomVectorScorer> create(
|
||||
VectorSimilarityFunction sim,
|
||||
RandomAccessQuantizedByteVectorValues values,
|
||||
float[] queryVector
|
||||
) {
|
||||
checkDimensions(queryVector.length, values.dimension());
|
||||
var input = values.getSlice();
|
||||
if (input == null) {
|
||||
return Optional.empty();
|
||||
}
|
||||
input = FilterIndexInput.unwrapOnlyTest(input);
|
||||
if (input instanceof MemorySegmentAccessInput == false) {
|
||||
return Optional.empty();
|
||||
}
|
||||
MemorySegmentAccessInput msInput = (MemorySegmentAccessInput) input;
|
||||
checkInvariants(values.size(), values.dimension(), input);
|
||||
|
||||
ScalarQuantizer scalarQuantizer = values.getScalarQuantizer();
|
||||
// TODO assert scalarQuantizer.getBits() == 7 or 8 ?
|
||||
byte[] quantizedQuery = new byte[queryVector.length];
|
||||
float queryCorrection = ScalarQuantizedVectorScorer.quantizeQuery(queryVector, quantizedQuery, sim, scalarQuantizer);
|
||||
return switch (sim) {
|
||||
case COSINE, DOT_PRODUCT -> Optional.of(new DotProductScorer(msInput, values, quantizedQuery, queryCorrection));
|
||||
case EUCLIDEAN -> Optional.of(new EuclideanScorer(msInput, values, quantizedQuery, queryCorrection));
|
||||
case MAXIMUM_INNER_PRODUCT -> Optional.of(new MaxInnerProductScorer(msInput, values, quantizedQuery, queryCorrection));
|
||||
};
|
||||
}
|
||||
|
||||
Int7SQVectorScorer(
|
||||
MemorySegmentAccessInput input,
|
||||
RandomAccessQuantizedByteVectorValues values,
|
||||
byte[] queryVector,
|
||||
float queryCorrection
|
||||
) {
|
||||
super(values);
|
||||
this.input = input;
|
||||
assert queryVector.length == values.getVectorByteLength();
|
||||
this.vectorByteSize = values.getVectorByteLength();
|
||||
this.query = MemorySegment.ofArray(queryVector);
|
||||
this.queryCorrection = queryCorrection;
|
||||
this.scoreCorrectionConstant = values.getScalarQuantizer().getConstantMultiplier();
|
||||
}
|
||||
|
||||
final MemorySegment getSegment(int ord) throws IOException {
|
||||
checkOrdinal(ord);
|
||||
long byteOffset = (long) ord * (vectorByteSize + Float.BYTES);
|
||||
MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteSize);
|
||||
if (seg == null) {
|
||||
if (scratch == null) {
|
||||
scratch = new byte[vectorByteSize];
|
||||
}
|
||||
input.readBytes(byteOffset, scratch, 0, vectorByteSize);
|
||||
seg = MemorySegment.ofArray(scratch);
|
||||
}
|
||||
return seg;
|
||||
}
|
||||
|
||||
static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) {
|
||||
if (input.length() < (long) vectorByteLength * maxOrd) {
|
||||
throw new IllegalArgumentException("input length is less than expected vector data");
|
||||
}
|
||||
}
|
||||
|
||||
final void checkOrdinal(int ord) {
|
||||
if (ord < 0 || ord >= maxOrd()) {
|
||||
throw new IllegalArgumentException("illegal ordinal: " + ord);
|
||||
}
|
||||
}
|
||||
|
||||
public static final class DotProductScorer extends Int7SQVectorScorer {
|
||||
public DotProductScorer(MemorySegmentAccessInput in, RandomAccessQuantizedByteVectorValues values, byte[] query, float correction) {
|
||||
super(in, values, query, correction);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score(int node) throws IOException {
|
||||
checkOrdinal(node);
|
||||
int dotProduct = dotProduct7u(query, getSegment(node), vectorByteSize);
|
||||
assert dotProduct >= 0;
|
||||
long byteOffset = (long) node * (vectorByteSize + Float.BYTES);
|
||||
float nodeCorrection = Float.intBitsToFloat(input.readInt(byteOffset + vectorByteSize));
|
||||
float adjustedDistance = dotProduct * scoreCorrectionConstant + queryCorrection + nodeCorrection;
|
||||
return Math.max((1 + adjustedDistance) / 2, 0f);
|
||||
}
|
||||
}
|
||||
|
||||
public static final class EuclideanScorer extends Int7SQVectorScorer {
|
||||
public EuclideanScorer(MemorySegmentAccessInput in, RandomAccessQuantizedByteVectorValues values, byte[] query, float correction) {
|
||||
super(in, values, query, correction);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score(int node) throws IOException {
|
||||
checkOrdinal(node);
|
||||
int sqDist = squareDistance7u(query, getSegment(node), vectorByteSize);
|
||||
float adjustedDistance = sqDist * scoreCorrectionConstant;
|
||||
return 1 / (1f + adjustedDistance);
|
||||
}
|
||||
}
|
||||
|
||||
public static final class MaxInnerProductScorer extends Int7SQVectorScorer {
|
||||
public MaxInnerProductScorer(MemorySegmentAccessInput in, RandomAccessQuantizedByteVectorValues values, byte[] query, float corr) {
|
||||
super(in, values, query, corr);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score(int node) throws IOException {
|
||||
checkOrdinal(node);
|
||||
int dotProduct = dotProduct7u(query, getSegment(node), vectorByteSize);
|
||||
assert dotProduct >= 0;
|
||||
long byteOffset = (long) node * (vectorByteSize + Float.BYTES);
|
||||
float nodeCorrection = Float.intBitsToFloat(input.readInt(byteOffset + vectorByteSize));
|
||||
float adjustedDistance = dotProduct * scoreCorrectionConstant + queryCorrection + nodeCorrection;
|
||||
if (adjustedDistance < 0) {
|
||||
return 1 / (1 + -1 * adjustedDistance);
|
||||
}
|
||||
return adjustedDistance + 1;
|
||||
}
|
||||
}
|
||||
|
||||
static void checkDimensions(int queryLen, int fieldLen) {
|
||||
if (queryLen != fieldLen) {
|
||||
throw new IllegalArgumentException("vector query dimension: " + queryLen + " differs from field dimension: " + fieldLen);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -37,6 +37,7 @@ import java.util.function.Function;
|
|||
import java.util.function.Predicate;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer.quantizeQuery;
|
||||
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
|
||||
import static org.elasticsearch.vec.VectorSimilarityType.COSINE;
|
||||
import static org.elasticsearch.vec.VectorSimilarityType.DOT_PRODUCT;
|
||||
|
@ -70,32 +71,42 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
|
|||
void testSimpleImpl(long maxChunkSize) throws IOException {
|
||||
assumeTrue(notSupportedMsg(), supported());
|
||||
var factory = AbstractVectorTestCase.factory.get();
|
||||
var scalarQuantizer = new ScalarQuantizer(0.1f, 0.9f, (byte) 7);
|
||||
|
||||
try (Directory dir = new MMapDirectory(createTempDir("testSimpleImpl"), maxChunkSize)) {
|
||||
for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) {
|
||||
for (int dims : List.of(31, 32, 33)) {
|
||||
// dimensions that cross the scalar / native boundary (stride)
|
||||
byte[] vec1 = new byte[dims];
|
||||
byte[] vec2 = new byte[dims];
|
||||
String fileName = "testSimpleImpl" + "-" + dims;
|
||||
float[] query1 = new float[dims];
|
||||
float[] query2 = new float[dims];
|
||||
float vec1Correction, vec2Correction;
|
||||
String fileName = "testSimpleImpl-" + sim + "-" + dims + ".vex";
|
||||
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
|
||||
for (int i = 0; i < dims; i++) {
|
||||
vec1[i] = (byte) i;
|
||||
vec2[i] = (byte) (dims - i);
|
||||
query1[i] = (float) i;
|
||||
query2[i] = (float) (dims - i);
|
||||
}
|
||||
var oneFactor = floatToByteArray(1f);
|
||||
byte[] bytes = concat(vec1, oneFactor, vec2, oneFactor);
|
||||
vec1Correction = quantizeQuery(query1, vec1, VectorSimilarityType.of(sim), scalarQuantizer);
|
||||
vec2Correction = quantizeQuery(query2, vec2, VectorSimilarityType.of(sim), scalarQuantizer);
|
||||
byte[] bytes = concat(vec1, floatToByteArray(vec1Correction), vec2, floatToByteArray(vec2Correction));
|
||||
out.writeBytes(bytes, 0, bytes.length);
|
||||
}
|
||||
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
|
||||
for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) {
|
||||
var values = vectorValues(dims, 2, in, VectorSimilarityType.of(sim));
|
||||
float scc = values.getScalarQuantizer().getConstantMultiplier();
|
||||
float expected = luceneScore(sim, vec1, vec2, scc, 1, 1);
|
||||
float expected = luceneScore(sim, vec1, vec2, scc, vec1Correction, vec2Correction);
|
||||
|
||||
var luceneSupplier = luceneScoreSupplier(values, VectorSimilarityType.of(sim)).scorer(0);
|
||||
assertThat(luceneSupplier.score(1), equalTo(expected));
|
||||
var supplier = factory.getInt7ScalarQuantizedVectorScorer(sim, in, values, scc).get();
|
||||
var supplier = factory.getInt7SQVectorScorerSupplier(sim, in, values, scc).get();
|
||||
assertThat(supplier.scorer(0).score(1), equalTo(expected));
|
||||
|
||||
if (Runtime.version().feature() >= 22) {
|
||||
var qScorer = factory.getInt7SQVectorScorer(VectorSimilarityType.of(sim), values, query1).get();
|
||||
assertThat(qScorer.score(1), equalTo(expected));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -121,23 +132,23 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
|
|||
// dot product
|
||||
float expected = 0f;
|
||||
assertThat(luceneScore(DOT_PRODUCT, vec1, vec2, 1, -5, -5), equalTo(expected));
|
||||
var supplier = factory.getInt7ScalarQuantizedVectorScorer(DOT_PRODUCT, in, values, 1).get();
|
||||
var supplier = factory.getInt7SQVectorScorerSupplier(DOT_PRODUCT, in, values, 1).get();
|
||||
assertThat(supplier.scorer(0).score(1), equalTo(expected));
|
||||
assertThat(supplier.scorer(0).score(1), greaterThanOrEqualTo(0f));
|
||||
// max inner product
|
||||
expected = luceneScore(MAXIMUM_INNER_PRODUCT, vec1, vec2, 1, -5, -5);
|
||||
supplier = factory.getInt7ScalarQuantizedVectorScorer(MAXIMUM_INNER_PRODUCT, in, values, 1).get();
|
||||
supplier = factory.getInt7SQVectorScorerSupplier(MAXIMUM_INNER_PRODUCT, in, values, 1).get();
|
||||
assertThat(supplier.scorer(0).score(1), greaterThanOrEqualTo(0f));
|
||||
assertThat(supplier.scorer(0).score(1), equalTo(expected));
|
||||
// cosine
|
||||
expected = 0f;
|
||||
assertThat(luceneScore(COSINE, vec1, vec2, 1, -5, -5), equalTo(expected));
|
||||
supplier = factory.getInt7ScalarQuantizedVectorScorer(COSINE, in, values, 1).get();
|
||||
supplier = factory.getInt7SQVectorScorerSupplier(COSINE, in, values, 1).get();
|
||||
assertThat(supplier.scorer(0).score(1), equalTo(expected));
|
||||
assertThat(supplier.scorer(0).score(1), greaterThanOrEqualTo(0f));
|
||||
// euclidean
|
||||
expected = luceneScore(EUCLIDEAN, vec1, vec2, 1, -5, -5);
|
||||
supplier = factory.getInt7ScalarQuantizedVectorScorer(EUCLIDEAN, in, values, 1).get();
|
||||
supplier = factory.getInt7SQVectorScorerSupplier(EUCLIDEAN, in, values, 1).get();
|
||||
assertThat(supplier.scorer(0).score(1), equalTo(expected));
|
||||
assertThat(supplier.scorer(0).score(1), greaterThanOrEqualTo(0f));
|
||||
}
|
||||
|
@ -146,27 +157,27 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
|
|||
|
||||
public void testRandom() throws IOException {
|
||||
assumeTrue(notSupportedMsg(), supported());
|
||||
testRandom(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_RANDOM_INT7_FUNC);
|
||||
testRandomSupplier(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_RANDOM_INT7_FUNC);
|
||||
}
|
||||
|
||||
public void testRandomMaxChunkSizeSmall() throws IOException {
|
||||
assumeTrue(notSupportedMsg(), supported());
|
||||
long maxChunkSize = randomLongBetween(32, 128);
|
||||
logger.info("maxChunkSize=" + maxChunkSize);
|
||||
testRandom(maxChunkSize, BYTE_ARRAY_RANDOM_INT7_FUNC);
|
||||
testRandomSupplier(maxChunkSize, BYTE_ARRAY_RANDOM_INT7_FUNC);
|
||||
}
|
||||
|
||||
public void testRandomMax() throws IOException {
|
||||
assumeTrue(notSupportedMsg(), supported());
|
||||
testRandom(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_MAX_INT7_FUNC);
|
||||
testRandomSupplier(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_MAX_INT7_FUNC);
|
||||
}
|
||||
|
||||
public void testRandomMin() throws IOException {
|
||||
assumeTrue(notSupportedMsg(), supported());
|
||||
testRandom(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_MIN_INT7_FUNC);
|
||||
testRandomSupplier(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_MIN_INT7_FUNC);
|
||||
}
|
||||
|
||||
void testRandom(long maxChunkSize, Function<Integer, byte[]> byteArraySupplier) throws IOException {
|
||||
void testRandomSupplier(long maxChunkSize, Function<Integer, byte[]> byteArraySupplier) throws IOException {
|
||||
var factory = AbstractVectorTestCase.factory.get();
|
||||
|
||||
try (Directory dir = new MMapDirectory(createTempDir("testRandom"), maxChunkSize)) {
|
||||
|
@ -195,7 +206,7 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
|
|||
for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) {
|
||||
var values = vectorValues(dims, size, in, VectorSimilarityType.of(sim));
|
||||
float expected = luceneScore(sim, vectors[idx0], vectors[idx1], correction, offsets[idx0], offsets[idx1]);
|
||||
var supplier = factory.getInt7ScalarQuantizedVectorScorer(sim, in, values, correction).get();
|
||||
var supplier = factory.getInt7SQVectorScorerSupplier(sim, in, values, correction).get();
|
||||
assertThat(supplier.scorer(idx0).score(idx1), equalTo(expected));
|
||||
}
|
||||
}
|
||||
|
@ -203,6 +214,61 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testRandomScorer() throws IOException {
|
||||
testRandomScorerImpl(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, VectorScorerFactoryTests.FLOAT_ARRAY_RANDOM_FUNC);
|
||||
}
|
||||
|
||||
public void testRandomScorerMax() throws IOException {
|
||||
testRandomScorerImpl(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, VectorScorerFactoryTests.FLOAT_ARRAY_MAX_FUNC);
|
||||
}
|
||||
|
||||
public void testRandomScorerChunkSizeSmall() throws IOException {
|
||||
assumeTrue(notSupportedMsg(), supported());
|
||||
long maxChunkSize = randomLongBetween(32, 128);
|
||||
logger.info("maxChunkSize=" + maxChunkSize);
|
||||
testRandomScorerImpl(maxChunkSize, FLOAT_ARRAY_RANDOM_FUNC);
|
||||
}
|
||||
|
||||
void testRandomScorerImpl(long maxChunkSize, Function<Integer, float[]> floatArraySupplier) throws IOException {
|
||||
assumeTrue("scorer only supported on JDK 22+", Runtime.version().feature() >= 22);
|
||||
var factory = AbstractVectorTestCase.factory.get();
|
||||
var scalarQuantizer = new ScalarQuantizer(0.1f, 0.9f, (byte) 7);
|
||||
|
||||
try (Directory dir = new MMapDirectory(createTempDir("testRandom"), maxChunkSize)) {
|
||||
for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) {
|
||||
final int dims = randomIntBetween(1, 4096);
|
||||
final int size = randomIntBetween(2, 100);
|
||||
final float[][] vectors = new float[size][];
|
||||
final byte[][] qVectors = new byte[size][];
|
||||
final float[] corrections = new float[size];
|
||||
|
||||
String fileName = "testRandom-" + sim + "-" + dims + ".vex";
|
||||
logger.info("Testing " + fileName);
|
||||
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
vectors[i] = floatArraySupplier.apply(dims);
|
||||
qVectors[i] = new byte[dims];
|
||||
corrections[i] = quantizeQuery(vectors[i], qVectors[i], VectorSimilarityType.of(sim), scalarQuantizer);
|
||||
out.writeBytes(qVectors[i], 0, dims);
|
||||
out.writeBytes(floatToByteArray(corrections[i]), 0, 4);
|
||||
}
|
||||
}
|
||||
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
|
||||
for (int times = 0; times < TIMES; times++) {
|
||||
int idx0 = randomIntBetween(0, size - 1);
|
||||
int idx1 = randomIntBetween(0, size - 1);
|
||||
var values = vectorValues(dims, size, in, VectorSimilarityType.of(sim));
|
||||
var correction = scalarQuantizer.getConstantMultiplier();
|
||||
|
||||
var expected = luceneScore(sim, qVectors[idx0], qVectors[idx1], correction, corrections[idx0], corrections[idx1]);
|
||||
var scorer = factory.getInt7SQVectorScorer(VectorSimilarityType.of(sim), values, vectors[idx0]).get();
|
||||
assertThat(scorer.score(idx1), equalTo(expected));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testRandomSlice() throws IOException {
|
||||
assumeTrue(notSupportedMsg(), supported());
|
||||
testRandomSliceImpl(30, 64, 1, BYTE_ARRAY_RANDOM_INT7_FUNC);
|
||||
|
@ -243,7 +309,7 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
|
|||
for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) {
|
||||
var values = vectorValues(dims, size, in, VectorSimilarityType.of(sim));
|
||||
float expected = luceneScore(sim, vectors[idx0], vectors[idx1], correction, offsets[idx0], offsets[idx1]);
|
||||
var supplier = factory.getInt7ScalarQuantizedVectorScorer(sim, in, values, correction).get();
|
||||
var supplier = factory.getInt7SQVectorScorerSupplier(sim, in, values, correction).get();
|
||||
assertThat(supplier.scorer(idx0).score(idx1), equalTo(expected));
|
||||
}
|
||||
}
|
||||
|
@ -281,7 +347,7 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
|
|||
for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) {
|
||||
var values = vectorValues(dims, size, in, VectorSimilarityType.of(sim));
|
||||
float expected = luceneScore(sim, vector(idx0, dims), vector(idx1, dims), correction, off0, off1);
|
||||
var supplier = factory.getInt7ScalarQuantizedVectorScorer(sim, in, values, correction).get();
|
||||
var supplier = factory.getInt7SQVectorScorerSupplier(sim, in, values, correction).get();
|
||||
assertThat(supplier.scorer(idx0).score(idx1), equalTo(expected));
|
||||
}
|
||||
}
|
||||
|
@ -319,7 +385,7 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
|
|||
|
||||
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
|
||||
var values = vectorValues(dims, 4, in, VectorSimilarityType.of(sim));
|
||||
var scoreSupplier = factory.getInt7ScalarQuantizedVectorScorer(sim, in, values, 1f).get();
|
||||
var scoreSupplier = factory.getInt7SQVectorScorerSupplier(sim, in, values, 1f).get();
|
||||
var tasks = List.<Callable<Optional<Throwable>>>of(
|
||||
new ScoreCallable(scoreSupplier.copy().scorer(0), 1, expectedScore1),
|
||||
new ScoreCallable(scoreSupplier.copy().scorer(2), 3, expectedScore2)
|
||||
|
@ -382,6 +448,20 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
|
|||
return ba;
|
||||
}
|
||||
|
||||
static Function<Integer, float[]> FLOAT_ARRAY_RANDOM_FUNC = size -> {
|
||||
float[] fa = new float[size];
|
||||
for (int i = 0; i < size; i++) {
|
||||
fa[i] = randomFloat();
|
||||
}
|
||||
return fa;
|
||||
};
|
||||
|
||||
static Function<Integer, float[]> FLOAT_ARRAY_MAX_FUNC = size -> {
|
||||
float[] fa = new float[size];
|
||||
Arrays.fill(fa, Float.MAX_VALUE);
|
||||
return fa;
|
||||
};
|
||||
|
||||
static Function<Integer, byte[]> BYTE_ARRAY_RANDOM_INT7_FUNC = size -> {
|
||||
byte[] ba = new byte[size];
|
||||
randomBytesBetween(ba, MIN_INT7_VALUE, MAX_INT7_VALUE);
|
||||
|
|
|
@ -39,7 +39,6 @@ import org.elasticsearch.vec.VectorScorerFactory;
|
|||
import org.elasticsearch.vec.VectorSimilarityType;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Optional;
|
||||
|
||||
public class ES814ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
|
||||
|
||||
|
@ -198,19 +197,19 @@ public class ES814ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
|
|||
static final class ESFlatVectorsScorer implements FlatVectorsScorer {
|
||||
|
||||
final FlatVectorsScorer delegate;
|
||||
final VectorScorerFactory factory;
|
||||
|
||||
ESFlatVectorsScorer(FlatVectorsScorer delegte) {
|
||||
this.delegate = delegte;
|
||||
factory = VectorScorerFactory.instance().orElse(null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(VectorSimilarityFunction sim, RandomAccessVectorValues values)
|
||||
throws IOException {
|
||||
if (values instanceof RandomAccessQuantizedByteVectorValues qValues && values.getSlice() != null) {
|
||||
Optional<VectorScorerFactory> factory = VectorScorerFactory.instance();
|
||||
if (factory.isPresent()) {
|
||||
var scorer = factory.get()
|
||||
.getInt7ScalarQuantizedVectorScorer(
|
||||
if (factory != null) {
|
||||
var scorer = factory.getInt7SQVectorScorerSupplier(
|
||||
VectorSimilarityType.of(sim),
|
||||
values.getSlice(),
|
||||
qValues,
|
||||
|
@ -227,6 +226,14 @@ public class ES814ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
|
|||
@Override
|
||||
public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction sim, RandomAccessVectorValues values, float[] query)
|
||||
throws IOException {
|
||||
if (values instanceof RandomAccessQuantizedByteVectorValues qValues && values.getSlice() != null) {
|
||||
if (factory != null) {
|
||||
var scorer = factory.getInt7SQVectorScorer(sim, qValues, query);
|
||||
if (scorer.isPresent()) {
|
||||
return scorer.get();
|
||||
}
|
||||
}
|
||||
}
|
||||
return delegate.getRandomVectorScorer(sim, values, query);
|
||||
}
|
||||
|
||||
|
|
|
@ -12,12 +12,14 @@ import org.apache.lucene.codecs.Codec;
|
|||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.KnnFloatVectorField;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.StoredFields;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.store.MMapDirectory;
|
||||
|
@ -32,6 +34,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
|||
public class ES814HnswScalarQuantizedVectorsFormatTests extends BaseKnnVectorsFormatTestCase {
|
||||
|
||||
static {
|
||||
LogConfigurator.loadLog4jPlugins();
|
||||
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
|
||||
}
|
||||
|
||||
|
@ -117,4 +120,57 @@ public class ES814HnswScalarQuantizedVectorsFormatTests extends BaseKnnVectorsFo
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testSingleVectorPerSegmentCosine() throws Exception {
|
||||
testSingleVectorPerSegment(VectorSimilarityFunction.COSINE);
|
||||
}
|
||||
|
||||
public void testSingleVectorPerSegmentDot() throws Exception {
|
||||
testSingleVectorPerSegment(VectorSimilarityFunction.DOT_PRODUCT);
|
||||
}
|
||||
|
||||
public void testSingleVectorPerSegmentEuclidean() throws Exception {
|
||||
testSingleVectorPerSegment(VectorSimilarityFunction.EUCLIDEAN);
|
||||
}
|
||||
|
||||
public void testSingleVectorPerSegmentMIP() throws Exception {
|
||||
testSingleVectorPerSegment(VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT);
|
||||
}
|
||||
|
||||
private void testSingleVectorPerSegment(VectorSimilarityFunction sim) throws Exception {
|
||||
var codec = getCodec();
|
||||
try (Directory dir = new MMapDirectory(createTempDir().resolve("dir1"))) {
|
||||
try (IndexWriter writer = new IndexWriter(dir, newIndexWriterConfig().setCodec(codec))) {
|
||||
Document doc2 = new Document();
|
||||
doc2.add(new KnnFloatVectorField("field", new float[] { 0.8f, 0.6f }, sim));
|
||||
doc2.add(newTextField("id", "A", Field.Store.YES));
|
||||
writer.addDocument(doc2);
|
||||
writer.commit();
|
||||
|
||||
Document doc1 = new Document();
|
||||
doc1.add(new KnnFloatVectorField("field", new float[] { 0.6f, 0.8f }, sim));
|
||||
doc1.add(newTextField("id", "B", Field.Store.YES));
|
||||
writer.addDocument(doc1);
|
||||
writer.commit();
|
||||
|
||||
Document doc3 = new Document();
|
||||
doc3.add(new KnnFloatVectorField("field", new float[] { -0.6f, -0.8f }, sim));
|
||||
doc3.add(newTextField("id", "C", Field.Store.YES));
|
||||
writer.addDocument(doc3);
|
||||
writer.commit();
|
||||
|
||||
writer.forceMerge(1);
|
||||
}
|
||||
try (DirectoryReader reader = DirectoryReader.open(dir)) {
|
||||
LeafReader leafReader = getOnlyLeafReader(reader);
|
||||
StoredFields storedFields = reader.storedFields();
|
||||
float[] queryVector = new float[] { 0.6f, 0.8f };
|
||||
var hits = leafReader.searchNearestVectors("field", queryVector, 3, null, 100);
|
||||
assertEquals(hits.scoreDocs.length, 3);
|
||||
assertEquals("B", storedFields.document(hits.scoreDocs[0].doc).get("id"));
|
||||
assertEquals("A", storedFields.document(hits.scoreDocs[1].doc).get("id"));
|
||||
assertEquals("C", storedFields.document(hits.scoreDocs[2].doc).get("id"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue