Use the SIMD optimized SQ vector scorer at search time (#109109)

This commit extends the custom SIMD optimized SQ vector scorer to include search time scoring.

When run on JDK22+ vector scoring with be done with the custom scorer. The implementation uses the JDK 22+ on-heap ALLOW_HEAP_ACCESS Linker.Option so that the native code can access the query vector directly.
This commit is contained in:
Chris Hegarty 2024-05-29 16:32:06 +01:00 committed by GitHub
parent b35239c952
commit f71aba1fdd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 538 additions and 64 deletions

View file

@ -80,6 +80,9 @@ public class VectorScorerBenchmark {
RandomVectorScorer nativeDotScorer;
RandomVectorScorer nativeSqrScorer;
RandomVectorScorer luceneDotScorerQuery;
RandomVectorScorer nativeDotScorerQuery;
@Setup
public void setup() throws IOException {
var optionalVectorScorerFactory = VectorScorerFactory.instance();
@ -116,8 +119,16 @@ public class VectorScorerBenchmark {
values = vectorValues(dims, 2, in, VectorSimilarityFunction.EUCLIDEAN);
luceneSqrScorer = luceneScoreSupplier(values, VectorSimilarityFunction.EUCLIDEAN).scorer(0);
nativeDotScorer = factory.getInt7ScalarQuantizedVectorScorer(DOT_PRODUCT, in, values, scoreCorrectionConstant).get().scorer(0);
nativeSqrScorer = factory.getInt7ScalarQuantizedVectorScorer(EUCLIDEAN, in, values, scoreCorrectionConstant).get().scorer(0);
nativeDotScorer = factory.getInt7SQVectorScorerSupplier(DOT_PRODUCT, in, values, scoreCorrectionConstant).get().scorer(0);
nativeSqrScorer = factory.getInt7SQVectorScorerSupplier(EUCLIDEAN, in, values, scoreCorrectionConstant).get().scorer(0);
// setup for getInt7SQVectorScorer / query vector scoring
float[] queryVec = new float[dims];
for (int i = 0; i < dims; i++) {
queryVec[i] = ThreadLocalRandom.current().nextFloat();
}
luceneDotScorerQuery = luceneScorer(values, VectorSimilarityFunction.DOT_PRODUCT, queryVec);
nativeDotScorerQuery = factory.getInt7SQVectorScorer(VectorSimilarityFunction.DOT_PRODUCT, values, queryVec).get();
// sanity
var f1 = dotProductLucene();
@ -139,6 +150,12 @@ public class VectorScorerBenchmark {
if (f1 != f3) {
throw new AssertionError("lucene[" + f1 + "] != " + "scalar[" + f3 + "]");
}
var q1 = dotProductLuceneQuery();
var q2 = dotProductNativeQuery();
if (q1 != q2) {
throw new AssertionError("query: lucene[" + q1 + "] != " + "native[" + q2 + "]");
}
}
@TearDown
@ -166,6 +183,16 @@ public class VectorScorerBenchmark {
return (1 + adjustedDistance) / 2;
}
@Benchmark
public float dotProductLuceneQuery() throws IOException {
return luceneDotScorerQuery.score(1);
}
@Benchmark
public float dotProductNativeQuery() throws IOException {
return nativeDotScorerQuery.score(1);
}
// -- square distance
@Benchmark
@ -200,6 +227,11 @@ public class VectorScorerBenchmark {
return new Lucene99ScalarQuantizedVectorScorer(null).getRandomVectorScorerSupplier(sim, values);
}
RandomVectorScorer luceneScorer(RandomAccessQuantizedByteVectorValues values, VectorSimilarityFunction sim, float[] queryVec)
throws IOException {
return new Lucene99ScalarQuantizedVectorScorer(null).getRandomVectorScorer(sim, values, queryVec);
}
// Unsigned int7 byte vectors have values in the range of 0 to 127 (inclusive).
static final byte MIN_INT7_VALUE = 0;
static final byte MAX_INT7_VALUE = 127;

View file

@ -50,8 +50,16 @@ public final class JdkVectorLibrary implements VectorLibrary {
private static final class JdkVectorSimilarityFunctions implements VectorSimilarityFunctions {
static final MethodHandle dot7u$mh = downcallHandle("dot7u", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT));
static final MethodHandle sqr7u$mh = downcallHandle("sqr7u", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT));
static final MethodHandle dot7u$mh = downcallHandle(
"dot7u",
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
LinkerHelperUtil.critical()
);
static final MethodHandle sqr7u$mh = downcallHandle(
"sqr7u",
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
LinkerHelperUtil.critical()
);
/**
* Computes the dot product of given unsigned int7 byte vectors.

View file

@ -0,0 +1,23 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.nativeaccess.jdk;
import java.lang.foreign.Linker;
public class LinkerHelperUtil {
static final Linker.Option[] NONE = new Linker.Option[0];
/** Returns an empty linker option array, since critical is only available since Java 22. */
static Linker.Option[] critical() {
return NONE;
}
private LinkerHelperUtil() {}
}

View file

@ -0,0 +1,23 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.nativeaccess.jdk;
import java.lang.foreign.Linker;
public class LinkerHelperUtil {
static final Linker.Option[] ALLOW_HEAP_ACCESS = new Linker.Option[] { Linker.Option.critical(true) };
/** Returns a linker option used to mark a foreign function as critical. */
static Linker.Option[] critical() {
return ALLOW_HEAP_ACCESS;
}
private LinkerHelperUtil() {}
}

View file

@ -68,17 +68,37 @@ public class JDKVectorLibraryTests extends VectorSimilarityFunctionsTests {
for (int i = 0; i < loopTimes; i++) {
int first = randomInt(numVecs - 1);
int second = randomInt(numVecs - 1);
// dot product
int implDot = dotProduct7u(segment.asSlice((long) first * dims, dims), segment.asSlice((long) second * dims, dims), dims);
int otherDot = dotProductScalar(values[first], values[second]);
assertEquals(otherDot, implDot);
var nativeSeg1 = segment.asSlice((long) first * dims, dims);
var nativeSeg2 = segment.asSlice((long) second * dims, dims);
int implSqr = squareDistance7u(segment.asSlice((long) first * dims, dims), segment.asSlice((long) second * dims, dims), dims);
int otherSqr = squareDistanceScalar(values[first], values[second]);
assertEquals(otherSqr, implSqr);
// dot product
int expected = dotProductScalar(values[first], values[second]);
assertEquals(expected, dotProduct7u(nativeSeg1, nativeSeg2, dims));
if (testWithHeapSegments()) {
var heapSeg1 = MemorySegment.ofArray(values[first]);
var heapSeg2 = MemorySegment.ofArray(values[second]);
assertEquals(expected, dotProduct7u(heapSeg1, heapSeg2, dims));
assertEquals(expected, dotProduct7u(nativeSeg1, heapSeg2, dims));
assertEquals(expected, dotProduct7u(heapSeg1, nativeSeg2, dims));
}
// square distance
expected = squareDistanceScalar(values[first], values[second]);
assertEquals(expected, squareDistance7u(nativeSeg1, nativeSeg2, dims));
if (testWithHeapSegments()) {
var heapSeg1 = MemorySegment.ofArray(values[first]);
var heapSeg2 = MemorySegment.ofArray(values[second]);
assertEquals(expected, squareDistance7u(heapSeg1, heapSeg2, dims));
assertEquals(expected, squareDistance7u(nativeSeg1, heapSeg2, dims));
assertEquals(expected, squareDistance7u(heapSeg1, nativeSeg2, dims));
}
}
}
static boolean testWithHeapSegments() {
return Runtime.version().feature() >= 22;
}
public void testIllegalDims() {
assumeTrue(notSupportedMsg(), supported());
var segment = arena.allocate((long) size * 3);

View file

@ -8,7 +8,9 @@
package org.elasticsearch.vec;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
@ -22,8 +24,8 @@ public interface VectorScorerFactory {
}
/**
* Returns an optional containing an int7 scalar quantized vector scorer for
* the given parameters, or an empty optional if a scorer is not supported.
* Returns an optional containing an int7 scalar quantized vector score supplier
* for the given parameters, or an empty optional if a scorer is not supported.
*
* @param similarityType the similarity type
* @param input the index input containing the vector data;
@ -33,10 +35,25 @@ public interface VectorScorerFactory {
* @param scoreCorrectionConstant the score correction constant
* @return an optional containing the vector scorer supplier, or empty
*/
Optional<RandomVectorScorerSupplier> getInt7ScalarQuantizedVectorScorer(
Optional<RandomVectorScorerSupplier> getInt7SQVectorScorerSupplier(
VectorSimilarityType similarityType,
IndexInput input,
RandomAccessQuantizedByteVectorValues values,
float scoreCorrectionConstant
);
/**
* Returns an optional containing an int7 scalar quantized vector scorer for
* the given parameters, or an empty optional if a scorer is not supported.
*
* @param sim the similarity type
* @param values the random access vector values
* @param queryVector the query vector
* @return an optional containing the vector scorer, or empty
*/
Optional<RandomVectorScorer> getInt7SQVectorScorer(
VectorSimilarityFunction sim,
RandomAccessQuantizedByteVectorValues values,
float[] queryVector
);
}

View file

@ -8,18 +8,20 @@
package org.elasticsearch.vec;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import java.util.Optional;
class VectorScorerFactoryImpl implements VectorScorerFactory {
final class VectorScorerFactoryImpl implements VectorScorerFactory {
static final VectorScorerFactoryImpl INSTANCE = null;
@Override
public Optional<RandomVectorScorerSupplier> getInt7ScalarQuantizedVectorScorer(
public Optional<RandomVectorScorerSupplier> getInt7SQVectorScorerSupplier(
VectorSimilarityType similarityType,
IndexInput input,
RandomAccessQuantizedByteVectorValues values,
@ -27,4 +29,13 @@ class VectorScorerFactoryImpl implements VectorScorerFactory {
) {
throw new UnsupportedOperationException("should not reach here");
}
@Override
public Optional<RandomVectorScorer> getInt7SQVectorScorer(
VectorSimilarityFunction sim,
RandomAccessQuantizedByteVectorValues values,
float[] queryVector
) {
throw new UnsupportedOperationException("should not reach here");
}
}

View file

@ -8,19 +8,22 @@
package org.elasticsearch.vec;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.FilterIndexInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.MemorySegmentAccessInput;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import org.elasticsearch.nativeaccess.NativeAccess;
import org.elasticsearch.vec.internal.Int7SQVectorScorer;
import org.elasticsearch.vec.internal.Int7SQVectorScorerSupplier.DotProductSupplier;
import org.elasticsearch.vec.internal.Int7SQVectorScorerSupplier.EuclideanSupplier;
import org.elasticsearch.vec.internal.Int7SQVectorScorerSupplier.MaxInnerProductSupplier;
import java.util.Optional;
class VectorScorerFactoryImpl implements VectorScorerFactory {
final class VectorScorerFactoryImpl implements VectorScorerFactory {
static final VectorScorerFactoryImpl INSTANCE;
@ -31,7 +34,7 @@ class VectorScorerFactoryImpl implements VectorScorerFactory {
}
@Override
public Optional<RandomVectorScorerSupplier> getInt7ScalarQuantizedVectorScorer(
public Optional<RandomVectorScorerSupplier> getInt7SQVectorScorerSupplier(
VectorSimilarityType similarityType,
IndexInput input,
RandomAccessQuantizedByteVectorValues values,
@ -50,6 +53,15 @@ class VectorScorerFactoryImpl implements VectorScorerFactory {
};
}
@Override
public Optional<RandomVectorScorer> getInt7SQVectorScorer(
VectorSimilarityFunction sim,
RandomAccessQuantizedByteVectorValues values,
float[] queryVector
) {
return Int7SQVectorScorer.create(sim, values, queryVector);
}
static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) {
if (input.length() < (long) vectorByteLength * maxOrd) {
throw new IllegalArgumentException("input length is less than expected vector data");

View file

@ -0,0 +1,29 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.vec.internal;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import java.util.Optional;
public final class Int7SQVectorScorer {
// Unconditionally returns an empty optional on <= JDK 21, since the scorer is only supported on JDK 22+
public static Optional<RandomVectorScorer> create(
VectorSimilarityFunction sim,
RandomAccessQuantizedByteVectorValues values,
float[] queryVector
) {
return Optional.empty();
}
private Int7SQVectorScorer() {}
}

View file

@ -24,7 +24,6 @@ public class Similarities {
static final MethodHandle SQUARE_DISTANCE_7U = DISTANCE_FUNCS.squareDistanceHandle7u();
static int dotProduct7u(MemorySegment a, MemorySegment b, int length) {
assert assertSegments(a, b, length);
try {
return (int) DOT_PRODUCT_7U.invokeExact(a, b, length);
} catch (Throwable e) {
@ -39,7 +38,6 @@ public class Similarities {
}
static int squareDistance7u(MemorySegment a, MemorySegment b, int length) {
assert assertSegments(a, b, length);
try {
return (int) SQUARE_DISTANCE_7U.invokeExact(a, b, length);
} catch (Throwable e) {
@ -52,8 +50,4 @@ public class Similarities {
}
}
}
static boolean assertSegments(MemorySegment a, MemorySegment b, int length) {
return a.isNative() && a.byteSize() >= length && b.isNative() && b.byteSize() >= length;
}
}

View file

@ -0,0 +1,162 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.vec.internal;
import org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.FilterIndexInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.MemorySegmentAccessInput;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizer;
import java.io.IOException;
import java.lang.foreign.MemorySegment;
import java.util.Optional;
import static org.elasticsearch.vec.internal.Similarities.dotProduct7u;
import static org.elasticsearch.vec.internal.Similarities.squareDistance7u;
public abstract sealed class Int7SQVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer {
final int vectorByteSize;
final MemorySegmentAccessInput input;
final MemorySegment query;
final float scoreCorrectionConstant;
final float queryCorrection;
byte[] scratch;
/** Return an optional whose value, if present, is the scorer. Otherwise, an empty optional is returned. */
public static Optional<RandomVectorScorer> create(
VectorSimilarityFunction sim,
RandomAccessQuantizedByteVectorValues values,
float[] queryVector
) {
checkDimensions(queryVector.length, values.dimension());
var input = values.getSlice();
if (input == null) {
return Optional.empty();
}
input = FilterIndexInput.unwrapOnlyTest(input);
if (input instanceof MemorySegmentAccessInput == false) {
return Optional.empty();
}
MemorySegmentAccessInput msInput = (MemorySegmentAccessInput) input;
checkInvariants(values.size(), values.dimension(), input);
ScalarQuantizer scalarQuantizer = values.getScalarQuantizer();
// TODO assert scalarQuantizer.getBits() == 7 or 8 ?
byte[] quantizedQuery = new byte[queryVector.length];
float queryCorrection = ScalarQuantizedVectorScorer.quantizeQuery(queryVector, quantizedQuery, sim, scalarQuantizer);
return switch (sim) {
case COSINE, DOT_PRODUCT -> Optional.of(new DotProductScorer(msInput, values, quantizedQuery, queryCorrection));
case EUCLIDEAN -> Optional.of(new EuclideanScorer(msInput, values, quantizedQuery, queryCorrection));
case MAXIMUM_INNER_PRODUCT -> Optional.of(new MaxInnerProductScorer(msInput, values, quantizedQuery, queryCorrection));
};
}
Int7SQVectorScorer(
MemorySegmentAccessInput input,
RandomAccessQuantizedByteVectorValues values,
byte[] queryVector,
float queryCorrection
) {
super(values);
this.input = input;
assert queryVector.length == values.getVectorByteLength();
this.vectorByteSize = values.getVectorByteLength();
this.query = MemorySegment.ofArray(queryVector);
this.queryCorrection = queryCorrection;
this.scoreCorrectionConstant = values.getScalarQuantizer().getConstantMultiplier();
}
final MemorySegment getSegment(int ord) throws IOException {
checkOrdinal(ord);
long byteOffset = (long) ord * (vectorByteSize + Float.BYTES);
MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteSize);
if (seg == null) {
if (scratch == null) {
scratch = new byte[vectorByteSize];
}
input.readBytes(byteOffset, scratch, 0, vectorByteSize);
seg = MemorySegment.ofArray(scratch);
}
return seg;
}
static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) {
if (input.length() < (long) vectorByteLength * maxOrd) {
throw new IllegalArgumentException("input length is less than expected vector data");
}
}
final void checkOrdinal(int ord) {
if (ord < 0 || ord >= maxOrd()) {
throw new IllegalArgumentException("illegal ordinal: " + ord);
}
}
public static final class DotProductScorer extends Int7SQVectorScorer {
public DotProductScorer(MemorySegmentAccessInput in, RandomAccessQuantizedByteVectorValues values, byte[] query, float correction) {
super(in, values, query, correction);
}
@Override
public float score(int node) throws IOException {
checkOrdinal(node);
int dotProduct = dotProduct7u(query, getSegment(node), vectorByteSize);
assert dotProduct >= 0;
long byteOffset = (long) node * (vectorByteSize + Float.BYTES);
float nodeCorrection = Float.intBitsToFloat(input.readInt(byteOffset + vectorByteSize));
float adjustedDistance = dotProduct * scoreCorrectionConstant + queryCorrection + nodeCorrection;
return Math.max((1 + adjustedDistance) / 2, 0f);
}
}
public static final class EuclideanScorer extends Int7SQVectorScorer {
public EuclideanScorer(MemorySegmentAccessInput in, RandomAccessQuantizedByteVectorValues values, byte[] query, float correction) {
super(in, values, query, correction);
}
@Override
public float score(int node) throws IOException {
checkOrdinal(node);
int sqDist = squareDistance7u(query, getSegment(node), vectorByteSize);
float adjustedDistance = sqDist * scoreCorrectionConstant;
return 1 / (1f + adjustedDistance);
}
}
public static final class MaxInnerProductScorer extends Int7SQVectorScorer {
public MaxInnerProductScorer(MemorySegmentAccessInput in, RandomAccessQuantizedByteVectorValues values, byte[] query, float corr) {
super(in, values, query, corr);
}
@Override
public float score(int node) throws IOException {
checkOrdinal(node);
int dotProduct = dotProduct7u(query, getSegment(node), vectorByteSize);
assert dotProduct >= 0;
long byteOffset = (long) node * (vectorByteSize + Float.BYTES);
float nodeCorrection = Float.intBitsToFloat(input.readInt(byteOffset + vectorByteSize));
float adjustedDistance = dotProduct * scoreCorrectionConstant + queryCorrection + nodeCorrection;
if (adjustedDistance < 0) {
return 1 / (1 + -1 * adjustedDistance);
}
return adjustedDistance + 1;
}
}
static void checkDimensions(int queryLen, int fieldLen) {
if (queryLen != fieldLen) {
throw new IllegalArgumentException("vector query dimension: " + queryLen + " differs from field dimension: " + fieldLen);
}
}
}

View file

@ -37,6 +37,7 @@ import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.IntStream;
import static org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer.quantizeQuery;
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty;
import static org.elasticsearch.vec.VectorSimilarityType.COSINE;
import static org.elasticsearch.vec.VectorSimilarityType.DOT_PRODUCT;
@ -70,32 +71,42 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
void testSimpleImpl(long maxChunkSize) throws IOException {
assumeTrue(notSupportedMsg(), supported());
var factory = AbstractVectorTestCase.factory.get();
var scalarQuantizer = new ScalarQuantizer(0.1f, 0.9f, (byte) 7);
try (Directory dir = new MMapDirectory(createTempDir("testSimpleImpl"), maxChunkSize)) {
for (int dims : List.of(31, 32, 33)) {
// dimensions that cross the scalar / native boundary (stride)
byte[] vec1 = new byte[dims];
byte[] vec2 = new byte[dims];
String fileName = "testSimpleImpl" + "-" + dims;
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
for (int i = 0; i < dims; i++) {
vec1[i] = (byte) i;
vec2[i] = (byte) (dims - i);
for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) {
for (int dims : List.of(31, 32, 33)) {
// dimensions that cross the scalar / native boundary (stride)
byte[] vec1 = new byte[dims];
byte[] vec2 = new byte[dims];
float[] query1 = new float[dims];
float[] query2 = new float[dims];
float vec1Correction, vec2Correction;
String fileName = "testSimpleImpl-" + sim + "-" + dims + ".vex";
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
for (int i = 0; i < dims; i++) {
query1[i] = (float) i;
query2[i] = (float) (dims - i);
}
vec1Correction = quantizeQuery(query1, vec1, VectorSimilarityType.of(sim), scalarQuantizer);
vec2Correction = quantizeQuery(query2, vec2, VectorSimilarityType.of(sim), scalarQuantizer);
byte[] bytes = concat(vec1, floatToByteArray(vec1Correction), vec2, floatToByteArray(vec2Correction));
out.writeBytes(bytes, 0, bytes.length);
}
var oneFactor = floatToByteArray(1f);
byte[] bytes = concat(vec1, oneFactor, vec2, oneFactor);
out.writeBytes(bytes, 0, bytes.length);
}
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) {
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
var values = vectorValues(dims, 2, in, VectorSimilarityType.of(sim));
float scc = values.getScalarQuantizer().getConstantMultiplier();
float expected = luceneScore(sim, vec1, vec2, scc, 1, 1);
float expected = luceneScore(sim, vec1, vec2, scc, vec1Correction, vec2Correction);
var luceneSupplier = luceneScoreSupplier(values, VectorSimilarityType.of(sim)).scorer(0);
assertThat(luceneSupplier.score(1), equalTo(expected));
var supplier = factory.getInt7ScalarQuantizedVectorScorer(sim, in, values, scc).get();
var supplier = factory.getInt7SQVectorScorerSupplier(sim, in, values, scc).get();
assertThat(supplier.scorer(0).score(1), equalTo(expected));
if (Runtime.version().feature() >= 22) {
var qScorer = factory.getInt7SQVectorScorer(VectorSimilarityType.of(sim), values, query1).get();
assertThat(qScorer.score(1), equalTo(expected));
}
}
}
}
@ -121,23 +132,23 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
// dot product
float expected = 0f;
assertThat(luceneScore(DOT_PRODUCT, vec1, vec2, 1, -5, -5), equalTo(expected));
var supplier = factory.getInt7ScalarQuantizedVectorScorer(DOT_PRODUCT, in, values, 1).get();
var supplier = factory.getInt7SQVectorScorerSupplier(DOT_PRODUCT, in, values, 1).get();
assertThat(supplier.scorer(0).score(1), equalTo(expected));
assertThat(supplier.scorer(0).score(1), greaterThanOrEqualTo(0f));
// max inner product
expected = luceneScore(MAXIMUM_INNER_PRODUCT, vec1, vec2, 1, -5, -5);
supplier = factory.getInt7ScalarQuantizedVectorScorer(MAXIMUM_INNER_PRODUCT, in, values, 1).get();
supplier = factory.getInt7SQVectorScorerSupplier(MAXIMUM_INNER_PRODUCT, in, values, 1).get();
assertThat(supplier.scorer(0).score(1), greaterThanOrEqualTo(0f));
assertThat(supplier.scorer(0).score(1), equalTo(expected));
// cosine
expected = 0f;
assertThat(luceneScore(COSINE, vec1, vec2, 1, -5, -5), equalTo(expected));
supplier = factory.getInt7ScalarQuantizedVectorScorer(COSINE, in, values, 1).get();
supplier = factory.getInt7SQVectorScorerSupplier(COSINE, in, values, 1).get();
assertThat(supplier.scorer(0).score(1), equalTo(expected));
assertThat(supplier.scorer(0).score(1), greaterThanOrEqualTo(0f));
// euclidean
expected = luceneScore(EUCLIDEAN, vec1, vec2, 1, -5, -5);
supplier = factory.getInt7ScalarQuantizedVectorScorer(EUCLIDEAN, in, values, 1).get();
supplier = factory.getInt7SQVectorScorerSupplier(EUCLIDEAN, in, values, 1).get();
assertThat(supplier.scorer(0).score(1), equalTo(expected));
assertThat(supplier.scorer(0).score(1), greaterThanOrEqualTo(0f));
}
@ -146,27 +157,27 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
public void testRandom() throws IOException {
assumeTrue(notSupportedMsg(), supported());
testRandom(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_RANDOM_INT7_FUNC);
testRandomSupplier(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_RANDOM_INT7_FUNC);
}
public void testRandomMaxChunkSizeSmall() throws IOException {
assumeTrue(notSupportedMsg(), supported());
long maxChunkSize = randomLongBetween(32, 128);
logger.info("maxChunkSize=" + maxChunkSize);
testRandom(maxChunkSize, BYTE_ARRAY_RANDOM_INT7_FUNC);
testRandomSupplier(maxChunkSize, BYTE_ARRAY_RANDOM_INT7_FUNC);
}
public void testRandomMax() throws IOException {
assumeTrue(notSupportedMsg(), supported());
testRandom(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_MAX_INT7_FUNC);
testRandomSupplier(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_MAX_INT7_FUNC);
}
public void testRandomMin() throws IOException {
assumeTrue(notSupportedMsg(), supported());
testRandom(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_MIN_INT7_FUNC);
testRandomSupplier(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, BYTE_ARRAY_MIN_INT7_FUNC);
}
void testRandom(long maxChunkSize, Function<Integer, byte[]> byteArraySupplier) throws IOException {
void testRandomSupplier(long maxChunkSize, Function<Integer, byte[]> byteArraySupplier) throws IOException {
var factory = AbstractVectorTestCase.factory.get();
try (Directory dir = new MMapDirectory(createTempDir("testRandom"), maxChunkSize)) {
@ -195,7 +206,7 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) {
var values = vectorValues(dims, size, in, VectorSimilarityType.of(sim));
float expected = luceneScore(sim, vectors[idx0], vectors[idx1], correction, offsets[idx0], offsets[idx1]);
var supplier = factory.getInt7ScalarQuantizedVectorScorer(sim, in, values, correction).get();
var supplier = factory.getInt7SQVectorScorerSupplier(sim, in, values, correction).get();
assertThat(supplier.scorer(idx0).score(idx1), equalTo(expected));
}
}
@ -203,6 +214,61 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
}
}
public void testRandomScorer() throws IOException {
testRandomScorerImpl(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, VectorScorerFactoryTests.FLOAT_ARRAY_RANDOM_FUNC);
}
public void testRandomScorerMax() throws IOException {
testRandomScorerImpl(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, VectorScorerFactoryTests.FLOAT_ARRAY_MAX_FUNC);
}
public void testRandomScorerChunkSizeSmall() throws IOException {
assumeTrue(notSupportedMsg(), supported());
long maxChunkSize = randomLongBetween(32, 128);
logger.info("maxChunkSize=" + maxChunkSize);
testRandomScorerImpl(maxChunkSize, FLOAT_ARRAY_RANDOM_FUNC);
}
void testRandomScorerImpl(long maxChunkSize, Function<Integer, float[]> floatArraySupplier) throws IOException {
assumeTrue("scorer only supported on JDK 22+", Runtime.version().feature() >= 22);
var factory = AbstractVectorTestCase.factory.get();
var scalarQuantizer = new ScalarQuantizer(0.1f, 0.9f, (byte) 7);
try (Directory dir = new MMapDirectory(createTempDir("testRandom"), maxChunkSize)) {
for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) {
final int dims = randomIntBetween(1, 4096);
final int size = randomIntBetween(2, 100);
final float[][] vectors = new float[size][];
final byte[][] qVectors = new byte[size][];
final float[] corrections = new float[size];
String fileName = "testRandom-" + sim + "-" + dims + ".vex";
logger.info("Testing " + fileName);
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
for (int i = 0; i < size; i++) {
vectors[i] = floatArraySupplier.apply(dims);
qVectors[i] = new byte[dims];
corrections[i] = quantizeQuery(vectors[i], qVectors[i], VectorSimilarityType.of(sim), scalarQuantizer);
out.writeBytes(qVectors[i], 0, dims);
out.writeBytes(floatToByteArray(corrections[i]), 0, 4);
}
}
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
for (int times = 0; times < TIMES; times++) {
int idx0 = randomIntBetween(0, size - 1);
int idx1 = randomIntBetween(0, size - 1);
var values = vectorValues(dims, size, in, VectorSimilarityType.of(sim));
var correction = scalarQuantizer.getConstantMultiplier();
var expected = luceneScore(sim, qVectors[idx0], qVectors[idx1], correction, corrections[idx0], corrections[idx1]);
var scorer = factory.getInt7SQVectorScorer(VectorSimilarityType.of(sim), values, vectors[idx0]).get();
assertThat(scorer.score(idx1), equalTo(expected));
}
}
}
}
}
public void testRandomSlice() throws IOException {
assumeTrue(notSupportedMsg(), supported());
testRandomSliceImpl(30, 64, 1, BYTE_ARRAY_RANDOM_INT7_FUNC);
@ -243,7 +309,7 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) {
var values = vectorValues(dims, size, in, VectorSimilarityType.of(sim));
float expected = luceneScore(sim, vectors[idx0], vectors[idx1], correction, offsets[idx0], offsets[idx1]);
var supplier = factory.getInt7ScalarQuantizedVectorScorer(sim, in, values, correction).get();
var supplier = factory.getInt7SQVectorScorerSupplier(sim, in, values, correction).get();
assertThat(supplier.scorer(idx0).score(idx1), equalTo(expected));
}
}
@ -281,7 +347,7 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) {
var values = vectorValues(dims, size, in, VectorSimilarityType.of(sim));
float expected = luceneScore(sim, vector(idx0, dims), vector(idx1, dims), correction, off0, off1);
var supplier = factory.getInt7ScalarQuantizedVectorScorer(sim, in, values, correction).get();
var supplier = factory.getInt7SQVectorScorerSupplier(sim, in, values, correction).get();
assertThat(supplier.scorer(idx0).score(idx1), equalTo(expected));
}
}
@ -319,7 +385,7 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
var values = vectorValues(dims, 4, in, VectorSimilarityType.of(sim));
var scoreSupplier = factory.getInt7ScalarQuantizedVectorScorer(sim, in, values, 1f).get();
var scoreSupplier = factory.getInt7SQVectorScorerSupplier(sim, in, values, 1f).get();
var tasks = List.<Callable<Optional<Throwable>>>of(
new ScoreCallable(scoreSupplier.copy().scorer(0), 1, expectedScore1),
new ScoreCallable(scoreSupplier.copy().scorer(2), 3, expectedScore2)
@ -382,6 +448,20 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
return ba;
}
static Function<Integer, float[]> FLOAT_ARRAY_RANDOM_FUNC = size -> {
float[] fa = new float[size];
for (int i = 0; i < size; i++) {
fa[i] = randomFloat();
}
return fa;
};
static Function<Integer, float[]> FLOAT_ARRAY_MAX_FUNC = size -> {
float[] fa = new float[size];
Arrays.fill(fa, Float.MAX_VALUE);
return fa;
};
static Function<Integer, byte[]> BYTE_ARRAY_RANDOM_INT7_FUNC = size -> {
byte[] ba = new byte[size];
randomBytesBetween(ba, MIN_INT7_VALUE, MAX_INT7_VALUE);

View file

@ -39,7 +39,6 @@ import org.elasticsearch.vec.VectorScorerFactory;
import org.elasticsearch.vec.VectorSimilarityType;
import java.io.IOException;
import java.util.Optional;
public class ES814ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
@ -198,24 +197,24 @@ public class ES814ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
static final class ESFlatVectorsScorer implements FlatVectorsScorer {
final FlatVectorsScorer delegate;
final VectorScorerFactory factory;
ESFlatVectorsScorer(FlatVectorsScorer delegte) {
this.delegate = delegte;
factory = VectorScorerFactory.instance().orElse(null);
}
@Override
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(VectorSimilarityFunction sim, RandomAccessVectorValues values)
throws IOException {
if (values instanceof RandomAccessQuantizedByteVectorValues qValues && values.getSlice() != null) {
Optional<VectorScorerFactory> factory = VectorScorerFactory.instance();
if (factory.isPresent()) {
var scorer = factory.get()
.getInt7ScalarQuantizedVectorScorer(
VectorSimilarityType.of(sim),
values.getSlice(),
qValues,
qValues.getScalarQuantizer().getConstantMultiplier()
);
if (factory != null) {
var scorer = factory.getInt7SQVectorScorerSupplier(
VectorSimilarityType.of(sim),
values.getSlice(),
qValues,
qValues.getScalarQuantizer().getConstantMultiplier()
);
if (scorer.isPresent()) {
return scorer.get();
}
@ -227,6 +226,14 @@ public class ES814ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
@Override
public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction sim, RandomAccessVectorValues values, float[] query)
throws IOException {
if (values instanceof RandomAccessQuantizedByteVectorValues qValues && values.getSlice() != null) {
if (factory != null) {
var scorer = factory.getInt7SQVectorScorer(sim, qValues, query);
if (scorer.isPresent()) {
return scorer.get();
}
}
}
return delegate.getRandomVectorScorer(sim, values, query);
}

View file

@ -12,12 +12,14 @@ import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.StoredFields;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.MMapDirectory;
@ -32,6 +34,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
public class ES814HnswScalarQuantizedVectorsFormatTests extends BaseKnnVectorsFormatTestCase {
static {
LogConfigurator.loadLog4jPlugins();
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
}
@ -117,4 +120,57 @@ public class ES814HnswScalarQuantizedVectorsFormatTests extends BaseKnnVectorsFo
}
}
}
public void testSingleVectorPerSegmentCosine() throws Exception {
testSingleVectorPerSegment(VectorSimilarityFunction.COSINE);
}
public void testSingleVectorPerSegmentDot() throws Exception {
testSingleVectorPerSegment(VectorSimilarityFunction.DOT_PRODUCT);
}
public void testSingleVectorPerSegmentEuclidean() throws Exception {
testSingleVectorPerSegment(VectorSimilarityFunction.EUCLIDEAN);
}
public void testSingleVectorPerSegmentMIP() throws Exception {
testSingleVectorPerSegment(VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT);
}
private void testSingleVectorPerSegment(VectorSimilarityFunction sim) throws Exception {
var codec = getCodec();
try (Directory dir = new MMapDirectory(createTempDir().resolve("dir1"))) {
try (IndexWriter writer = new IndexWriter(dir, newIndexWriterConfig().setCodec(codec))) {
Document doc2 = new Document();
doc2.add(new KnnFloatVectorField("field", new float[] { 0.8f, 0.6f }, sim));
doc2.add(newTextField("id", "A", Field.Store.YES));
writer.addDocument(doc2);
writer.commit();
Document doc1 = new Document();
doc1.add(new KnnFloatVectorField("field", new float[] { 0.6f, 0.8f }, sim));
doc1.add(newTextField("id", "B", Field.Store.YES));
writer.addDocument(doc1);
writer.commit();
Document doc3 = new Document();
doc3.add(new KnnFloatVectorField("field", new float[] { -0.6f, -0.8f }, sim));
doc3.add(newTextField("id", "C", Field.Store.YES));
writer.addDocument(doc3);
writer.commit();
writer.forceMerge(1);
}
try (DirectoryReader reader = DirectoryReader.open(dir)) {
LeafReader leafReader = getOnlyLeafReader(reader);
StoredFields storedFields = reader.storedFields();
float[] queryVector = new float[] { 0.6f, 0.8f };
var hits = leafReader.searchNearestVectors("field", queryVector, 3, null, 100);
assertEquals(hits.scoreDocs.length, 3);
assertEquals("B", storedFields.document(hits.scoreDocs[0].doc).get("id"));
assertEquals("A", storedFields.document(hits.scoreDocs[1].doc).get("id"));
assertEquals("C", storedFields.document(hits.scoreDocs[2].doc).get("id"));
}
}
}
}