Refactor libvec to replace custom scorer types with Lucene types (#108917)

This commit refactors libvec to replace custom scorer types with Lucene types.

The initial implementation created separate types to model the vector scorer with an adapter between them and the Lucene types. This was done to avoid a dependency on Lucene from the native module. This is no longer an issue, since the code is now separated from the native module already, and in fact already depends on Lucene. This PR drops the custom types infavour of the Lucene ones. This will help future refactoring, and avoid bugs by reusing the existing and know model in this area.

I also took the liberty of reflowing the code to match that of the recent change in Lucene to support off-heap scoring - this code is now very similar to that, and will become even more clean and streamlined in the lucene_snapshot branch. This refactoring is not directly dependent on the next version of Lucene, so it done in main.
This commit is contained in:
Chris Hegarty 2024-05-23 15:28:32 +01:00 committed by GitHub
parent 3578dca668
commit ff37f1f767
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 443 additions and 606 deletions

View file

@ -8,16 +8,19 @@
package org.elasticsearch.benchmark.vector; package org.elasticsearch.benchmark.vector;
import org.apache.lucene.codecs.lucene99.OffHeapQuantizedByteVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput; import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizedRandomVectorScorer;
import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity; import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity;
import org.elasticsearch.common.logging.LogConfigurator; import org.elasticsearch.common.logging.LogConfigurator;
import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.IOUtils;
import org.elasticsearch.vec.VectorScorer;
import org.elasticsearch.vec.VectorScorerFactory; import org.elasticsearch.vec.VectorScorerFactory;
import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.BenchmarkMode;
@ -71,10 +74,10 @@ public class VectorScorerBenchmark {
float vec2Offset; float vec2Offset;
float scoreCorrectionConstant; float scoreCorrectionConstant;
ScalarQuantizedVectorSimilarity luceneDotScorer; RandomVectorScorer luceneDotScorer;
ScalarQuantizedVectorSimilarity luceneSqrScorer; RandomVectorScorer luceneSqrScorer;
VectorScorer nativeDotScorer; RandomVectorScorer nativeDotScorer;
VectorScorer nativeSqrScorer; RandomVectorScorer nativeSqrScorer;
@Setup @Setup
public void setup() throws IOException { public void setup() throws IOException {
@ -107,14 +110,22 @@ public class VectorScorerBenchmark {
out.writeInt(Float.floatToIntBits(vec2Offset)); out.writeInt(Float.floatToIntBits(vec2Offset));
} }
in = dir.openInput("vector.data", IOContext.DEFAULT); in = dir.openInput("vector.data", IOContext.DEFAULT);
var values = vectorValues(dims, 2, in);
luceneDotScorer = ScalarQuantizedVectorSimilarity.fromVectorSimilarity( luceneDotScorer = new ScalarQuantizedRandomVectorScorer(
VectorSimilarityFunction.DOT_PRODUCT, ScalarQuantizedVectorSimilarity.fromVectorSimilarity(VectorSimilarityFunction.DOT_PRODUCT, scoreCorrectionConstant),
scoreCorrectionConstant values.copy(),
vec1,
vec1Offset
); );
luceneSqrScorer = ScalarQuantizedVectorSimilarity.fromVectorSimilarity(VectorSimilarityFunction.EUCLIDEAN, scoreCorrectionConstant); luceneSqrScorer = new ScalarQuantizedRandomVectorScorer(
nativeDotScorer = factory.getInt7ScalarQuantizedVectorScorer(dims, size, scoreCorrectionConstant, DOT_PRODUCT, in).get(); ScalarQuantizedVectorSimilarity.fromVectorSimilarity(VectorSimilarityFunction.EUCLIDEAN, scoreCorrectionConstant),
nativeSqrScorer = factory.getInt7ScalarQuantizedVectorScorer(dims, size, scoreCorrectionConstant, EUCLIDEAN, in).get(); values.copy(),
vec1,
vec1Offset
);
nativeDotScorer = factory.getInt7ScalarQuantizedVectorScorer(DOT_PRODUCT, in, values, scoreCorrectionConstant).get().scorer(0);
nativeSqrScorer = factory.getInt7ScalarQuantizedVectorScorer(EUCLIDEAN, in, values, scoreCorrectionConstant).get().scorer(0);
// sanity // sanity
var f1 = dotProductLucene(); var f1 = dotProductLucene();
@ -144,13 +155,13 @@ public class VectorScorerBenchmark {
} }
@Benchmark @Benchmark
public float dotProductLucene() { public float dotProductLucene() throws IOException {
return luceneDotScorer.score(vec1, vec1Offset, vec2, vec2Offset); return luceneDotScorer.score(1);
} }
@Benchmark @Benchmark
public float dotProductNative() throws IOException { public float dotProductNative() throws IOException {
return nativeDotScorer.score(0, 1); return nativeDotScorer.score(1);
} }
@Benchmark @Benchmark
@ -166,13 +177,13 @@ public class VectorScorerBenchmark {
// -- square distance // -- square distance
@Benchmark @Benchmark
public float squareDistanceLucene() { public float squareDistanceLucene() throws IOException {
return luceneSqrScorer.score(vec1, vec1Offset, vec2, vec2Offset); return luceneSqrScorer.score(1);
} }
@Benchmark @Benchmark
public float squareDistanceNative() throws IOException { public float squareDistanceNative() throws IOException {
return nativeSqrScorer.score(0, 1); return nativeSqrScorer.score(1);
} }
@Benchmark @Benchmark
@ -186,6 +197,10 @@ public class VectorScorerBenchmark {
return 1 / (1f + adjustedDistance); return 1 / (1f + adjustedDistance);
} }
RandomAccessQuantizedByteVectorValues vectorValues(int dims, int size, IndexInput in) throws IOException {
return new OffHeapQuantizedByteVectorValues.DenseOffHeapVectorValues(dims, size, in.slice("values", 0, in.length()));
}
// Unsigned int7 byte vectors have values in the range of 0 to 127 (inclusive). // Unsigned int7 byte vectors have values in the range of 0 to 127 (inclusive).
static final byte MIN_INT7_VALUE = 0; static final byte MIN_INT7_VALUE = 0;
static final byte MAX_INT7_VALUE = 127; static final byte MAX_INT7_VALUE = 127;

View file

@ -1,27 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.vec;
import java.io.IOException;
/** A scorer of vectors. */
public interface VectorScorer {
/** Computes the score of the vectors at the given ordinals. */
float score(int firstOrd, int secondOrd) throws IOException;
/** The per-vector dimension size. */
int dims();
/** The maximum ordinal of vector this scorer can score. */
int maxOrd();
VectorScorer copy();
}

View file

@ -9,6 +9,8 @@
package org.elasticsearch.vec; package org.elasticsearch.vec;
import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import java.util.Optional; import java.util.Optional;
@ -23,20 +25,18 @@ public interface VectorScorerFactory {
* Returns an optional containing an int7 scalar quantized vector scorer for * Returns an optional containing an int7 scalar quantized vector scorer for
* the given parameters, or an empty optional if a scorer is not supported. * the given parameters, or an empty optional if a scorer is not supported.
* *
* @param dims the vector dimensions
* @param maxOrd the ordinal of the largest vector accessible
* @param scoreCorrectionConstant the score correction constant
* @param similarityType the similarity type * @param similarityType the similarity type
* @param indexInput the index input containing the vector data; * @param input the index input containing the vector data;
* offset of the first vector is 0, * offset of the first vector is 0,
* the length must be (maxOrd + Float#BYTES) * dims * the length must be (maxOrd + Float#BYTES) * dims
* @return an optional containing the vector scorer, or empty * @param values the random access vector values
* @param scoreCorrectionConstant the score correction constant
* @return an optional containing the vector scorer supplier, or empty
*/ */
Optional<VectorScorer> getInt7ScalarQuantizedVectorScorer( Optional<RandomVectorScorerSupplier> getInt7ScalarQuantizedVectorScorer(
int dims,
int maxOrd,
float scoreCorrectionConstant,
VectorSimilarityType similarityType, VectorSimilarityType similarityType,
IndexInput indexInput IndexInput input,
RandomAccessQuantizedByteVectorValues values,
float scoreCorrectionConstant
); );
} }

View file

@ -9,6 +9,8 @@
package org.elasticsearch.vec; package org.elasticsearch.vec;
import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import java.util.Optional; import java.util.Optional;
@ -17,12 +19,11 @@ class VectorScorerFactoryImpl implements VectorScorerFactory {
static final VectorScorerFactoryImpl INSTANCE = null; static final VectorScorerFactoryImpl INSTANCE = null;
@Override @Override
public Optional<VectorScorer> getInt7ScalarQuantizedVectorScorer( public Optional<RandomVectorScorerSupplier> getInt7ScalarQuantizedVectorScorer(
int dims,
int maxOrd,
float scoreCorrectionConstant,
VectorSimilarityType similarityType, VectorSimilarityType similarityType,
IndexInput input IndexInput input,
RandomAccessQuantizedByteVectorValues values,
float scoreCorrectionConstant
) { ) {
throw new UnsupportedOperationException("should not reach here"); throw new UnsupportedOperationException("should not reach here");
} }

View file

@ -1,46 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.vec;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import java.io.IOException;
/** An adapter between VectorScorer and RandomVectorScorerSupplier. */
public final class VectorScorerSupplierAdapter implements RandomVectorScorerSupplier {
private final VectorScorer scorer;
public VectorScorerSupplierAdapter(VectorScorer scorer) {
this.scorer = scorer;
}
@Override
public RandomVectorScorer scorer(int ord) throws IOException {
return new RandomVectorScorer() {
final int firstOrd = ord;
@Override
public float score(int otherOrd) throws IOException {
return scorer.score(firstOrd, otherOrd);
}
@Override
public int maxOrd() {
return scorer.maxOrd();
}
};
}
@Override
public RandomVectorScorerSupplier copy() throws IOException {
return new VectorScorerSupplierAdapter(scorer.copy());
}
}

View file

@ -9,11 +9,13 @@
package org.elasticsearch.vec; package org.elasticsearch.vec;
import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import org.elasticsearch.nativeaccess.NativeAccess; import org.elasticsearch.nativeaccess.NativeAccess;
import org.elasticsearch.vec.internal.IndexInputUtils; import org.elasticsearch.vec.internal.IndexInputUtils;
import org.elasticsearch.vec.internal.Int7DotProduct; import org.elasticsearch.vec.internal.Int7SQVectorScorerSupplier.DotProductSupplier;
import org.elasticsearch.vec.internal.Int7Euclidean; import org.elasticsearch.vec.internal.Int7SQVectorScorerSupplier.EuclideanSupplier;
import org.elasticsearch.vec.internal.Int7MaximumInnerProduct; import org.elasticsearch.vec.internal.Int7SQVectorScorerSupplier.MaxInnerProductSupplier;
import java.util.Optional; import java.util.Optional;
@ -28,21 +30,27 @@ class VectorScorerFactoryImpl implements VectorScorerFactory {
} }
@Override @Override
public Optional<VectorScorer> getInt7ScalarQuantizedVectorScorer( public Optional<RandomVectorScorerSupplier> getInt7ScalarQuantizedVectorScorer(
int dims,
int maxOrd,
float scoreCorrectionConstant,
VectorSimilarityType similarityType, VectorSimilarityType similarityType,
IndexInput input IndexInput input,
RandomAccessQuantizedByteVectorValues values,
float scoreCorrectionConstant
) { ) {
input = IndexInputUtils.unwrapAndCheckInputOrNull(input); input = IndexInputUtils.unwrapAndCheckInputOrNull(input);
if (input == null) { if (input == null) {
return Optional.empty(); // the input type is not MemorySegment based return Optional.empty(); // the input type is not MemorySegment based
} }
return Optional.of(switch (similarityType) { checkInvariants(values.size(), values.dimension(), input);
case COSINE, DOT_PRODUCT -> new Int7DotProduct(dims, maxOrd, scoreCorrectionConstant, input); return switch (similarityType) {
case EUCLIDEAN -> new Int7Euclidean(dims, maxOrd, scoreCorrectionConstant, input); case COSINE, DOT_PRODUCT -> Optional.of(new DotProductSupplier(input, values, scoreCorrectionConstant));
case MAXIMUM_INNER_PRODUCT -> new Int7MaximumInnerProduct(dims, maxOrd, scoreCorrectionConstant, input); case EUCLIDEAN -> Optional.of(new EuclideanSupplier(input, values, scoreCorrectionConstant));
}); case MAXIMUM_INNER_PRODUCT -> Optional.of(new MaxInnerProductSupplier(input, values, scoreCorrectionConstant));
};
}
static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) {
if (input.length() < (long) vectorByteLength * maxOrd) {
throw new IllegalArgumentException("input length is less than expected vector data");
}
} }
} }

View file

@ -1,154 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.vec.internal;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity;
import org.elasticsearch.nativeaccess.NativeAccess;
import org.elasticsearch.nativeaccess.VectorSimilarityFunctions;
import org.elasticsearch.vec.VectorScorer;
import java.io.IOException;
import java.lang.foreign.MemorySegment;
import java.lang.invoke.MethodHandle;
abstract sealed class AbstractInt7ScalarQuantizedVectorScorer implements VectorScorer permits Int7DotProduct, Int7Euclidean,
Int7MaximumInnerProduct {
static final VectorSimilarityFunctions DISTANCE_FUNCS = NativeAccess.instance()
.getVectorSimilarityFunctions()
.orElseThrow(AssertionError::new);
protected final int dims;
protected final int maxOrd;
protected final float scoreCorrectionConstant;
protected final IndexInput input;
protected final MemorySegment segment;
protected final MemorySegment[] segments;
protected final long offset;
protected final int chunkSizePower;
protected final long chunkSizeMask;
private final ScalarQuantizedVectorSimilarity fallbackScorer;
protected AbstractInt7ScalarQuantizedVectorScorer(
int dims,
int maxOrd,
float scoreCorrectionConstant,
IndexInput input,
ScalarQuantizedVectorSimilarity fallbackScorer
) {
this.dims = dims;
this.maxOrd = maxOrd;
this.scoreCorrectionConstant = scoreCorrectionConstant;
this.input = input;
this.fallbackScorer = fallbackScorer;
this.segments = IndexInputUtils.segmentArray(input);
if (segments.length == 1) {
segment = segments[0];
offset = 0L;
} else {
segment = null;
offset = IndexInputUtils.offset(input);
}
this.chunkSizePower = IndexInputUtils.chunkSizePower(input);
this.chunkSizeMask = IndexInputUtils.chunkSizeMask(input);
}
@Override
public final int dims() {
return dims;
}
@Override
public final int maxOrd() {
return maxOrd;
}
protected final void checkOrdinal(int ord) {
if (ord < 0 || ord > maxOrd) {
throw new IllegalArgumentException("illegal ordinal: " + ord);
}
}
protected final float fallbackScore(long firstByteOffset, long secondByteOffset) throws IOException {
input.seek(firstByteOffset);
byte[] a = new byte[dims];
input.readBytes(a, 0, a.length);
float aOffsetValue = Float.intBitsToFloat(input.readInt());
input.seek(secondByteOffset);
byte[] b = new byte[dims];
input.readBytes(b, 0, a.length);
float bOffsetValue = Float.intBitsToFloat(input.readInt());
return fallbackScorer.score(a, aOffsetValue, b, bOffsetValue);
}
protected final MemorySegment segmentSlice(long pos, int length) {
if (segment != null) {
// single
if (checkIndex(pos, segment.byteSize() + 1)) {
return segment.asSlice(pos, length);
}
} else {
// multi
pos = pos + this.offset;
final int si = (int) (pos >> chunkSizePower);
final MemorySegment seg = segments[si];
long offset = pos & chunkSizeMask;
if (checkIndex(offset + length, seg.byteSize() + 1)) {
return seg.asSlice(offset, length);
}
}
return null;
}
static boolean checkIndex(long index, long length) {
return index >= 0 && index < length;
}
static final MethodHandle DOT_PRODUCT_7U = DISTANCE_FUNCS.dotProductHandle7u();
static final MethodHandle SQUARE_DISTANCE_7U = DISTANCE_FUNCS.squareDistanceHandle7u();
static int dotProduct7u(MemorySegment a, MemorySegment b, int length) {
// assert assertSegments(a, b, length);
try {
return (int) DOT_PRODUCT_7U.invokeExact(a, b, length);
} catch (Throwable e) {
if (e instanceof Error err) {
throw err;
} else if (e instanceof RuntimeException re) {
throw re;
} else {
throw new RuntimeException(e);
}
}
}
static int squareDistance7u(MemorySegment a, MemorySegment b, int length) {
// assert assertSegments(a, b, length);
try {
return (int) SQUARE_DISTANCE_7U.invokeExact(a, b, length);
} catch (Throwable e) {
if (e instanceof Error err) {
throw err;
} else if (e instanceof RuntimeException re) {
throw re;
} else {
throw new RuntimeException(e);
}
}
}
static boolean assertSegments(MemorySegment a, MemorySegment b, int length) {
return a.isNative() && a.byteSize() >= length && b.isNative() && b.byteSize() >= length;
}
}

View file

@ -1,62 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.vec.internal;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity;
import java.io.IOException;
import java.lang.foreign.MemorySegment;
// Scalar Quantized vectors are inherently byte sized, so dims is equal to the length in bytes.
public final class Int7DotProduct extends AbstractInt7ScalarQuantizedVectorScorer {
public Int7DotProduct(int dims, int maxOrd, float scoreCorrectionConstant, IndexInput input) {
super(
dims,
maxOrd,
scoreCorrectionConstant,
input,
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(VectorSimilarityFunction.DOT_PRODUCT, scoreCorrectionConstant)
);
}
@Override
public float score(int firstOrd, int secondOrd) throws IOException {
checkOrdinal(firstOrd);
checkOrdinal(secondOrd);
final int length = dims;
long firstByteOffset = (long) firstOrd * (length + Float.BYTES);
long secondByteOffset = (long) secondOrd * (length + Float.BYTES);
MemorySegment firstSeg = segmentSlice(firstByteOffset, length);
input.seek(firstByteOffset + length);
float firstOffset = Float.intBitsToFloat(input.readInt());
MemorySegment secondSeg = segmentSlice(secondByteOffset, length);
input.seek(secondByteOffset + length);
float secondOffset = Float.intBitsToFloat(input.readInt());
if (firstSeg != null && secondSeg != null) {
int dotProduct = dotProduct7u(firstSeg, secondSeg, length);
assert dotProduct >= 0;
float adjustedDistance = dotProduct * scoreCorrectionConstant + firstOffset + secondOffset;
return Math.max((1 + adjustedDistance) / 2, 0f);
} else {
return Math.max(fallbackScore(firstByteOffset, secondByteOffset), 0f);
}
}
@Override
public Int7DotProduct copy() {
return new Int7DotProduct(dims, maxOrd, scoreCorrectionConstant, input.clone());
}
}

View file

@ -1,56 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.vec.internal;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity;
import java.io.IOException;
import java.lang.foreign.MemorySegment;
// Scalar Quantized vectors are inherently bytes.
public final class Int7Euclidean extends AbstractInt7ScalarQuantizedVectorScorer {
public Int7Euclidean(int dims, int maxOrd, float scoreCorrectionConstant, IndexInput input) {
super(
dims,
maxOrd,
scoreCorrectionConstant,
input,
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(VectorSimilarityFunction.EUCLIDEAN, scoreCorrectionConstant)
);
}
@Override
public float score(int firstOrd, int secondOrd) throws IOException {
checkOrdinal(firstOrd);
checkOrdinal(secondOrd);
final int length = dims;
long firstByteOffset = (long) firstOrd * (length + Float.BYTES);
long secondByteOffset = (long) secondOrd * (length + Float.BYTES);
MemorySegment firstSeg = segmentSlice(firstByteOffset, length);
MemorySegment secondSeg = segmentSlice(secondByteOffset, length);
if (firstSeg != null && secondSeg != null) {
int squareDistance = squareDistance7u(firstSeg, secondSeg, length);
float adjustedDistance = squareDistance * scoreCorrectionConstant;
return 1 / (1f + adjustedDistance);
} else {
return fallbackScore(firstByteOffset, secondByteOffset);
}
}
@Override
public Int7Euclidean copy() {
return new Int7Euclidean(dims, maxOrd, scoreCorrectionConstant, input.clone());
}
}

View file

@ -1,72 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.vec.internal;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity;
import java.io.IOException;
import java.lang.foreign.MemorySegment;
// Scalar Quantized vectors are inherently bytes.
public final class Int7MaximumInnerProduct extends AbstractInt7ScalarQuantizedVectorScorer {
public Int7MaximumInnerProduct(int dims, int maxOrd, float scoreCorrectionConstant, IndexInput input) {
super(
dims,
maxOrd,
scoreCorrectionConstant,
input,
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT, scoreCorrectionConstant)
);
}
@Override
public float score(int firstOrd, int secondOrd) throws IOException {
checkOrdinal(firstOrd);
checkOrdinal(secondOrd);
final int length = dims;
long firstByteOffset = (long) firstOrd * (length + Float.BYTES);
long secondByteOffset = (long) secondOrd * (length + Float.BYTES);
MemorySegment firstSeg = segmentSlice(firstByteOffset, length);
input.seek(firstByteOffset + length);
float firstOffset = Float.intBitsToFloat(input.readInt());
MemorySegment secondSeg = segmentSlice(secondByteOffset, length);
input.seek(secondByteOffset + length);
float secondOffset = Float.intBitsToFloat(input.readInt());
if (firstSeg != null && secondSeg != null) {
int dotProduct = dotProduct7u(firstSeg, secondSeg, length);
float adjustedDistance = dotProduct * scoreCorrectionConstant + firstOffset + secondOffset;
return scaleMaxInnerProductScore(adjustedDistance);
} else {
return fallbackScore(firstByteOffset, secondByteOffset);
}
}
/**
* Returns a scaled score preventing negative scores for maximum-inner-product
* @param rawSimilarity the raw similarity between two vectors
*/
static float scaleMaxInnerProductScore(float rawSimilarity) {
if (rawSimilarity < 0) {
return 1 / (1 + -1 * rawSimilarity);
}
return rawSimilarity + 1;
}
@Override
public Int7MaximumInnerProduct copy() {
return new Int7MaximumInnerProduct(dims, maxOrd, scoreCorrectionConstant, input.clone());
}
}

View file

@ -0,0 +1,221 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.vec.internal;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity;
import java.io.IOException;
import java.lang.foreign.MemorySegment;
import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
import static org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity.fromVectorSimilarity;
public abstract sealed class Int7SQVectorScorerSupplier implements RandomVectorScorerSupplier {
final int dims;
final int maxOrd;
final float scoreCorrectionConstant;
final IndexInput input;
final RandomAccessQuantizedByteVectorValues values; // to support ordToDoc/getAcceptOrds
final ScalarQuantizedVectorSimilarity fallbackScorer;
final MemorySegment segment;
final MemorySegment[] segments;
final long offset;
final int chunkSizePower;
final long chunkSizeMask;
protected Int7SQVectorScorerSupplier(
IndexInput input,
RandomAccessQuantizedByteVectorValues values,
float scoreCorrectionConstant,
ScalarQuantizedVectorSimilarity fallbackScorer
) {
this.input = input;
this.values = values;
this.dims = values.dimension();
this.maxOrd = values.size();
this.scoreCorrectionConstant = scoreCorrectionConstant;
this.fallbackScorer = fallbackScorer;
this.segments = IndexInputUtils.segmentArray(input);
if (segments.length == 1) {
segment = segments[0];
offset = 0L;
} else {
segment = null;
offset = IndexInputUtils.offset(input);
}
this.chunkSizePower = IndexInputUtils.chunkSizePower(input);
this.chunkSizeMask = IndexInputUtils.chunkSizeMask(input);
}
protected final void checkOrdinal(int ord) {
if (ord < 0 || ord > maxOrd) {
throw new IllegalArgumentException("illegal ordinal: " + ord);
}
}
final float scoreFromOrds(int firstOrd, int secondOrd) throws IOException {
checkOrdinal(firstOrd);
checkOrdinal(secondOrd);
final int length = dims;
long firstByteOffset = (long) firstOrd * (length + Float.BYTES);
long secondByteOffset = (long) secondOrd * (length + Float.BYTES);
MemorySegment firstSeg = segmentSlice(firstByteOffset, length);
if (firstSeg == null) {
return fallbackScore(firstByteOffset, secondByteOffset);
}
input.seek(firstByteOffset + length);
float firstOffset = Float.intBitsToFloat(input.readInt());
MemorySegment secondSeg = segmentSlice(secondByteOffset, length);
if (secondSeg == null) {
return fallbackScore(firstByteOffset, secondByteOffset);
}
input.seek(secondByteOffset + length);
float secondOffset = Float.intBitsToFloat(input.readInt());
return scoreFromSegments(firstSeg, firstOffset, secondSeg, secondOffset);
}
abstract float scoreFromSegments(MemorySegment a, float aOffset, MemorySegment b, float bOffset);
protected final float fallbackScore(long firstByteOffset, long secondByteOffset) throws IOException {
input.seek(firstByteOffset);
byte[] a = new byte[dims];
input.readBytes(a, 0, a.length);
float aOffsetValue = Float.intBitsToFloat(input.readInt());
input.seek(secondByteOffset);
byte[] b = new byte[dims];
input.readBytes(b, 0, a.length);
float bOffsetValue = Float.intBitsToFloat(input.readInt());
return fallbackScorer.score(a, aOffsetValue, b, bOffsetValue);
}
@Override
public RandomVectorScorer scorer(int ord) {
checkOrdinal(ord);
return new RandomVectorScorer.AbstractRandomVectorScorer<>(values) {
@Override
public float score(int node) throws IOException {
return scoreFromOrds(ord, node);
}
};
}
protected final MemorySegment segmentSlice(long pos, int length) {
if (segment != null) {
// single
if (checkIndex(pos, segment.byteSize() + 1)) {
return segment.asSlice(pos, length);
}
} else {
// multi
pos = pos + this.offset;
final int si = (int) (pos >> chunkSizePower);
final MemorySegment seg = segments[si];
long offset = pos & chunkSizeMask;
if (checkIndex(offset + length, seg.byteSize() + 1)) {
return seg.asSlice(offset, length);
}
}
return null;
}
public static final class EuclideanSupplier extends Int7SQVectorScorerSupplier {
public EuclideanSupplier(IndexInput input, RandomAccessQuantizedByteVectorValues values, float scoreCorrectionConstant) {
super(input, values, scoreCorrectionConstant, fromVectorSimilarity(EUCLIDEAN, scoreCorrectionConstant));
}
@Override
float scoreFromSegments(MemorySegment a, float aOffset, MemorySegment b, float bOffset) {
int squareDistance = Similarities.squareDistance7u(a, b, dims);
float adjustedDistance = squareDistance * scoreCorrectionConstant;
return 1 / (1f + adjustedDistance);
}
@Override
public EuclideanSupplier copy() {
return new EuclideanSupplier(input.clone(), values, scoreCorrectionConstant);
}
}
// This will be removed when we upgrade to 9.11, see https://github.com/apache/lucene/pull/13356
static final class DelegateDotScorer implements ScalarQuantizedVectorSimilarity {
final ScalarQuantizedVectorSimilarity delegate;
DelegateDotScorer(float scoreCorrectionConstant) {
delegate = fromVectorSimilarity(DOT_PRODUCT, scoreCorrectionConstant);
}
@Override
public float score(byte[] queryVector, float queryVectorOffset, byte[] storedVector, float vectorOffset) {
return Math.max(delegate.score(queryVector, queryVectorOffset, storedVector, vectorOffset), 0f);
}
}
public static final class DotProductSupplier extends Int7SQVectorScorerSupplier {
public DotProductSupplier(IndexInput input, RandomAccessQuantizedByteVectorValues values, float scoreCorrectionConstant) {
super(input, values, scoreCorrectionConstant, new DelegateDotScorer(scoreCorrectionConstant));
}
@Override
float scoreFromSegments(MemorySegment a, float aOffset, MemorySegment b, float bOffset) {
int dotProduct = Similarities.dotProduct7u(a, b, dims);
assert dotProduct >= 0;
float adjustedDistance = dotProduct * scoreCorrectionConstant + aOffset + bOffset;
return Math.max((1 + adjustedDistance) / 2, 0f);
}
@Override
public DotProductSupplier copy() {
return new DotProductSupplier(input.clone(), values, scoreCorrectionConstant);
}
}
public static final class MaxInnerProductSupplier extends Int7SQVectorScorerSupplier {
public MaxInnerProductSupplier(IndexInput input, RandomAccessQuantizedByteVectorValues values, float scoreCorrectionConstant) {
super(input, values, scoreCorrectionConstant, fromVectorSimilarity(MAXIMUM_INNER_PRODUCT, scoreCorrectionConstant));
}
@Override
float scoreFromSegments(MemorySegment a, float aOffset, MemorySegment b, float bOffset) {
int dotProduct = Similarities.dotProduct7u(a, b, dims);
assert dotProduct >= 0;
float adjustedDistance = dotProduct * scoreCorrectionConstant + aOffset + bOffset;
if (adjustedDistance < 0) {
return 1 / (1 + -1 * adjustedDistance);
}
return adjustedDistance + 1;
}
@Override
public MaxInnerProductSupplier copy() {
return new MaxInnerProductSupplier(input.clone(), values, scoreCorrectionConstant);
}
}
static boolean checkIndex(long index, long length) {
return index >= 0 && index < length;
}
}

View file

@ -0,0 +1,59 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.vec.internal;
import org.elasticsearch.nativeaccess.NativeAccess;
import org.elasticsearch.nativeaccess.VectorSimilarityFunctions;
import java.lang.foreign.MemorySegment;
import java.lang.invoke.MethodHandle;
public class Similarities {
static final VectorSimilarityFunctions DISTANCE_FUNCS = NativeAccess.instance()
.getVectorSimilarityFunctions()
.orElseThrow(AssertionError::new);
static final MethodHandle DOT_PRODUCT_7U = DISTANCE_FUNCS.dotProductHandle7u();
static final MethodHandle SQUARE_DISTANCE_7U = DISTANCE_FUNCS.squareDistanceHandle7u();
static int dotProduct7u(MemorySegment a, MemorySegment b, int length) {
assert assertSegments(a, b, length);
try {
return (int) DOT_PRODUCT_7U.invokeExact(a, b, length);
} catch (Throwable e) {
if (e instanceof Error err) {
throw err;
} else if (e instanceof RuntimeException re) {
throw re;
} else {
throw new RuntimeException(e);
}
}
}
static int squareDistance7u(MemorySegment a, MemorySegment b, int length) {
assert assertSegments(a, b, length);
try {
return (int) SQUARE_DISTANCE_7U.invokeExact(a, b, length);
} catch (Throwable e) {
if (e instanceof Error err) {
throw err;
} else if (e instanceof RuntimeException re) {
throw re;
} else {
throw new RuntimeException(e);
}
}
}
static boolean assertSegments(MemorySegment a, MemorySegment b, int length) {
return a.isNative() && a.byteSize() >= length && b.isNative() && b.byteSize() >= length;
}
}

View file

@ -10,12 +10,15 @@ package org.elasticsearch.vec;
import com.carrotsearch.randomizedtesting.generators.RandomNumbers; import com.carrotsearch.randomizedtesting.generators.RandomNumbers;
import org.apache.lucene.codecs.lucene99.OffHeapQuantizedByteVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput; import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
@ -65,12 +68,12 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
assumeTrue(notSupportedMsg(), supported()); assumeTrue(notSupportedMsg(), supported());
var factory = AbstractVectorTestCase.factory.get(); var factory = AbstractVectorTestCase.factory.get();
try (Directory dir = new MMapDirectory(createTempDir(getTestName()), maxChunkSize)) { try (Directory dir = new MMapDirectory(createTempDir("testSimpleImpl"), maxChunkSize)) {
for (int dims : List.of(31, 32, 33)) { for (int dims : List.of(31, 32, 33)) {
// dimensions that cross the scalar / native boundary (stride) // dimensions that cross the scalar / native boundary (stride)
byte[] vec1 = new byte[dims]; byte[] vec1 = new byte[dims];
byte[] vec2 = new byte[dims]; byte[] vec2 = new byte[dims];
String fileName = getTestName() + "-" + dims; String fileName = "testSimpleImpl" + "-" + dims;
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
for (int i = 0; i < dims; i++) { for (int i = 0; i < dims; i++) {
vec1[i] = (byte) i; vec1[i] = (byte) i;
@ -81,26 +84,12 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
out.writeBytes(bytes, 0, bytes.length); out.writeBytes(bytes, 0, bytes.length);
} }
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
// dot product for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) {
float expected = luceneScore(DOT_PRODUCT, vec1, vec2, 1, 1, 1); var values = vectorValues(dims, 2, in, VectorSimilarityType.of(sim));
var scorer = factory.getInt7ScalarQuantizedVectorScorer(dims, 2, 1, DOT_PRODUCT, in).get(); float expected = luceneScore(sim, vec1, vec2, 1, 1, 1);
assertThat(scorer.score(0, 1), equalTo(expected)); var supplier = factory.getInt7ScalarQuantizedVectorScorer(sim, in, values, 1).get();
assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(0).score(1), equalTo(expected)); assertThat(supplier.scorer(0).score(1), equalTo(expected));
// max inner product }
expected = luceneScore(MAXIMUM_INNER_PRODUCT, vec1, vec2, 1, 1, 1);
scorer = factory.getInt7ScalarQuantizedVectorScorer(dims, 2, 1, MAXIMUM_INNER_PRODUCT, in).get();
assertThat(scorer.score(0, 1), equalTo(expected));
assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(0).score(1), equalTo(expected));
// cosine
expected = luceneScore(COSINE, vec1, vec2, 1, 1, 1);
scorer = factory.getInt7ScalarQuantizedVectorScorer(dims, 2, 1, COSINE, in).get();
assertThat(scorer.score(0, 1), equalTo(expected));
assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(0).score(1), equalTo(expected));
// euclidean
expected = luceneScore(EUCLIDEAN, vec1, vec2, 1, 1, 1);
scorer = factory.getInt7ScalarQuantizedVectorScorer(dims, 2, 1, EUCLIDEAN, in).get();
assertThat(scorer.score(0, 1), equalTo(expected));
assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(0).score(1), equalTo(expected));
} }
} }
} }
@ -110,43 +99,40 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
assumeTrue(notSupportedMsg(), supported()); assumeTrue(notSupportedMsg(), supported());
var factory = AbstractVectorTestCase.factory.get(); var factory = AbstractVectorTestCase.factory.get();
try (Directory dir = new MMapDirectory(createTempDir(getTestName()), MMapDirectory.DEFAULT_MAX_CHUNK_SIZE)) { try (Directory dir = new MMapDirectory(createTempDir("testNonNegativeDotProduct"), MMapDirectory.DEFAULT_MAX_CHUNK_SIZE)) {
// keep vecs `0` so dot product is `0` // keep vecs `0` so dot product is `0`
byte[] vec1 = new byte[32]; byte[] vec1 = new byte[32];
byte[] vec2 = new byte[32]; byte[] vec2 = new byte[32];
String fileName = getTestName() + "-32"; String fileName = "testNonNegativeDotProduct-32";
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
var negativeOffset = floatToByteArray(-5f); var negativeOffset = floatToByteArray(-5f);
byte[] bytes = concat(vec1, negativeOffset, vec2, negativeOffset); byte[] bytes = concat(vec1, negativeOffset, vec2, negativeOffset);
out.writeBytes(bytes, 0, bytes.length); out.writeBytes(bytes, 0, bytes.length);
} }
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
var values = vectorValues(32, 2, in, VectorSimilarityType.of(DOT_PRODUCT));
// dot product // dot product
float expected = 0f; // TODO fix in Lucene: https://github.com/apache/lucene/pull/13356 luceneScore(DOT_PRODUCT, vec1, vec2, float expected = 0f; // TODO fix in Lucene: https://github.com/apache/lucene/pull/13356 luceneScore(DOT_PRODUCT, vec1, vec2,
// 1, -5, -5); // 1, -5, -5);
var scorer = factory.getInt7ScalarQuantizedVectorScorer(32, 2, 1, DOT_PRODUCT, in).get(); var supplier = factory.getInt7ScalarQuantizedVectorScorer(DOT_PRODUCT, in, values, 1).get();
assertThat(scorer.score(0, 1), equalTo(expected)); assertThat(supplier.scorer(0).score(1), equalTo(expected));
assertThat(scorer.score(0, 1), greaterThanOrEqualTo(0f)); assertThat(supplier.scorer(0).score(1), greaterThanOrEqualTo(0f));
assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(0).score(1), equalTo(expected));
// max inner product // max inner product
expected = luceneScore(MAXIMUM_INNER_PRODUCT, vec1, vec2, 1, -5, -5); expected = luceneScore(MAXIMUM_INNER_PRODUCT, vec1, vec2, 1, -5, -5);
scorer = factory.getInt7ScalarQuantizedVectorScorer(32, 2, 1, MAXIMUM_INNER_PRODUCT, in).get(); supplier = factory.getInt7ScalarQuantizedVectorScorer(MAXIMUM_INNER_PRODUCT, in, values, 1).get();
assertThat(scorer.score(0, 1), greaterThanOrEqualTo(0f)); assertThat(supplier.scorer(0).score(1), greaterThanOrEqualTo(0f));
assertThat(scorer.score(0, 1), equalTo(expected)); assertThat(supplier.scorer(0).score(1), equalTo(expected));
assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(0).score(1), equalTo(expected));
// cosine // cosine
expected = 0f; // TODO fix in Lucene: https://github.com/apache/lucene/pull/13356 luceneScore(COSINE, vec1, vec2, 1, -5, expected = 0f; // TODO fix in Lucene: https://github.com/apache/lucene/pull/13356 luceneScore(COSINE, vec1, vec2, 1, -5,
// -5); // -5);
scorer = factory.getInt7ScalarQuantizedVectorScorer(32, 2, 1, COSINE, in).get(); supplier = factory.getInt7ScalarQuantizedVectorScorer(COSINE, in, values, 1).get();
assertThat(scorer.score(0, 1), equalTo(expected)); assertThat(supplier.scorer(0).score(1), equalTo(expected));
assertThat(scorer.score(0, 1), greaterThanOrEqualTo(0f)); assertThat(supplier.scorer(0).score(1), greaterThanOrEqualTo(0f));
assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(0).score(1), equalTo(expected));
// euclidean // euclidean
expected = luceneScore(EUCLIDEAN, vec1, vec2, 1, -5, -5); expected = luceneScore(EUCLIDEAN, vec1, vec2, 1, -5, -5);
scorer = factory.getInt7ScalarQuantizedVectorScorer(32, 2, 1, EUCLIDEAN, in).get(); supplier = factory.getInt7ScalarQuantizedVectorScorer(EUCLIDEAN, in, values, 1).get();
assertThat(scorer.score(0, 1), equalTo(expected)); assertThat(supplier.scorer(0).score(1), equalTo(expected));
assertThat(scorer.score(0, 1), greaterThanOrEqualTo(0f)); assertThat(supplier.scorer(0).score(1), greaterThanOrEqualTo(0f));
assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(0).score(1), equalTo(expected));
} }
} }
} }
@ -176,49 +162,35 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
void testRandom(long maxChunkSize, Function<Integer, byte[]> byteArraySupplier) throws IOException { void testRandom(long maxChunkSize, Function<Integer, byte[]> byteArraySupplier) throws IOException {
var factory = AbstractVectorTestCase.factory.get(); var factory = AbstractVectorTestCase.factory.get();
try (Directory dir = new MMapDirectory(createTempDir(getTestName()), maxChunkSize)) { try (Directory dir = new MMapDirectory(createTempDir("testRandom"), maxChunkSize)) {
for (int times = 0; times < TIMES; times++) { final int dims = randomIntBetween(1, 4096);
final int dims = randomIntBetween(1, 4096); final int size = randomIntBetween(2, 100);
final int size = randomIntBetween(2, 100); final float correction = randomFloat();
final float correction = randomFloat(); final byte[][] vectors = new byte[size][];
final byte[][] vectors = new byte[size][]; final float[] offsets = new float[size];
final float[] offsets = new float[size];
String fileName = getTestName() + "-" + times + "-" + dims; String fileName = "testRandom-" + dims;
logger.info("Testing " + fileName); logger.info("Testing " + fileName);
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
var vec = byteArraySupplier.apply(dims); var vec = byteArraySupplier.apply(dims);
var off = randomFloat(); var off = randomFloat();
out.writeBytes(vec, 0, vec.length); out.writeBytes(vec, 0, vec.length);
out.writeInt(Float.floatToIntBits(off)); out.writeInt(Float.floatToIntBits(off));
vectors[i] = vec; vectors[i] = vec;
offsets[i] = off; offsets[i] = off;
}
} }
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { }
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
for (int times = 0; times < TIMES; times++) {
int idx0 = randomIntBetween(0, size - 1); int idx0 = randomIntBetween(0, size - 1);
int idx1 = randomIntBetween(0, size - 1); // may be the same as idx0 - which is ok. int idx1 = randomIntBetween(0, size - 1); // may be the same as idx0 - which is ok.
// dot product for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) {
float expected = luceneScore(DOT_PRODUCT, vectors[idx0], vectors[idx1], correction, offsets[idx0], offsets[idx1]); var values = vectorValues(dims, size, in, VectorSimilarityType.of(sim));
var scorer = factory.getInt7ScalarQuantizedVectorScorer(dims, size, correction, DOT_PRODUCT, in).get(); float expected = luceneScore(sim, vectors[idx0], vectors[idx1], correction, offsets[idx0], offsets[idx1]);
assertThat(scorer.score(idx0, idx1), equalTo(expected)); var supplier = factory.getInt7ScalarQuantizedVectorScorer(sim, in, values, correction).get();
assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(idx0).score(idx1), equalTo(expected)); assertThat(supplier.scorer(idx0).score(idx1), equalTo(expected));
// max inner product }
expected = luceneScore(MAXIMUM_INNER_PRODUCT, vectors[idx0], vectors[idx1], correction, offsets[idx0], offsets[idx1]);
scorer = factory.getInt7ScalarQuantizedVectorScorer(dims, size, correction, MAXIMUM_INNER_PRODUCT, in).get();
assertThat(scorer.score(idx0, idx1), equalTo(expected));
assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(idx0).score(idx1), equalTo(expected));
// cosine
expected = luceneScore(COSINE, vectors[idx0], vectors[idx1], correction, offsets[idx0], offsets[idx1]);
scorer = factory.getInt7ScalarQuantizedVectorScorer(dims, size, correction, COSINE, in).get();
assertThat(scorer.score(idx0, idx1), equalTo(expected));
assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(idx0).score(idx1), equalTo(expected));
// euclidean
expected = luceneScore(EUCLIDEAN, vectors[idx0], vectors[idx1], correction, offsets[idx0], offsets[idx1]);
scorer = factory.getInt7ScalarQuantizedVectorScorer(dims, size, correction, EUCLIDEAN, in).get();
assertThat(scorer.score(idx0, idx1), equalTo(expected));
assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(idx0).score(idx1), equalTo(expected));
} }
} }
} }
@ -233,14 +205,14 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
throws IOException { throws IOException {
var factory = AbstractVectorTestCase.factory.get(); var factory = AbstractVectorTestCase.factory.get();
try (Directory dir = new MMapDirectory(createTempDir(getTestName()), maxChunkSize)) { try (Directory dir = new MMapDirectory(createTempDir("testRandomSliceImpl"), maxChunkSize)) {
for (int times = 0; times < TIMES; times++) { for (int times = 0; times < TIMES; times++) {
final int size = randomIntBetween(2, 100); final int size = randomIntBetween(2, 100);
final float correction = randomFloat(); final float correction = randomFloat();
final byte[][] vectors = new byte[size][]; final byte[][] vectors = new byte[size][];
final float[] offsets = new float[size]; final float[] offsets = new float[size];
String fileName = getTestName() + "-" + times + "-" + dims; String fileName = "testRandomSliceImpl-" + times + "-" + dims;
logger.info("Testing " + fileName); logger.info("Testing " + fileName);
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
byte[] ba = new byte[initialPadding]; byte[] ba = new byte[initialPadding];
@ -258,28 +230,16 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
var outter = dir.openInput(fileName, IOContext.DEFAULT); var outter = dir.openInput(fileName, IOContext.DEFAULT);
var in = outter.slice("slice", initialPadding, outter.length() - initialPadding) var in = outter.slice("slice", initialPadding, outter.length() - initialPadding)
) { ) {
int idx0 = randomIntBetween(0, size - 1); for (int itrs = 0; itrs < TIMES / 10; itrs++) {
int idx1 = randomIntBetween(0, size - 1); // may be the same as idx0 - which is ok. int idx0 = randomIntBetween(0, size - 1);
// dot product int idx1 = randomIntBetween(0, size - 1); // may be the same as idx0 - which is ok.
float expected = luceneScore(DOT_PRODUCT, vectors[idx0], vectors[idx1], correction, offsets[idx0], offsets[idx1]); for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) {
var scorer = factory.getInt7ScalarQuantizedVectorScorer(dims, size, correction, DOT_PRODUCT, in).get(); var values = vectorValues(dims, size, in, VectorSimilarityType.of(sim));
assertThat(scorer.score(idx0, idx1), equalTo(expected)); float expected = luceneScore(sim, vectors[idx0], vectors[idx1], correction, offsets[idx0], offsets[idx1]);
assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(idx0).score(idx1), equalTo(expected)); var supplier = factory.getInt7ScalarQuantizedVectorScorer(sim, in, values, correction).get();
// max inner product assertThat(supplier.scorer(idx0).score(idx1), equalTo(expected));
expected = luceneScore(MAXIMUM_INNER_PRODUCT, vectors[idx0], vectors[idx1], correction, offsets[idx0], offsets[idx1]); }
scorer = factory.getInt7ScalarQuantizedVectorScorer(dims, size, correction, MAXIMUM_INNER_PRODUCT, in).get(); }
assertThat(scorer.score(idx0, idx1), equalTo(expected));
assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(idx0).score(idx1), equalTo(expected));
// cosine
expected = luceneScore(COSINE, vectors[idx0], vectors[idx1], correction, offsets[idx0], offsets[idx1]);
scorer = factory.getInt7ScalarQuantizedVectorScorer(dims, size, correction, COSINE, in).get();
assertThat(scorer.score(idx0, idx1), equalTo(expected));
assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(idx0).score(idx1), equalTo(expected));
// euclidean
expected = luceneScore(EUCLIDEAN, vectors[idx0], vectors[idx1], correction, offsets[idx0], offsets[idx1]);
scorer = factory.getInt7ScalarQuantizedVectorScorer(dims, size, correction, EUCLIDEAN, in).get();
assertThat(scorer.score(idx0, idx1), equalTo(expected));
assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(idx0).score(idx1), equalTo(expected));
} }
} }
} }
@ -290,12 +250,12 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
public void testLarge() throws IOException { public void testLarge() throws IOException {
var factory = AbstractVectorTestCase.factory.get(); var factory = AbstractVectorTestCase.factory.get();
try (Directory dir = new MMapDirectory(createTempDir(getTestName()))) { try (Directory dir = new MMapDirectory(createTempDir("testLarge"))) {
final int dims = 8192; final int dims = 8192;
final int size = 262144; final int size = 262144;
final float correction = randomFloat(); final float correction = randomFloat();
String fileName = getTestName() + "-" + dims; String fileName = "testLarge-" + dims;
logger.info("Testing " + fileName); logger.info("Testing " + fileName);
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
@ -311,26 +271,12 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
int idx1 = size - 1; int idx1 = size - 1;
float off0 = (float) idx0; float off0 = (float) idx0;
float off1 = (float) idx1; float off1 = (float) idx1;
// dot product for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) {
float expected = luceneScore(DOT_PRODUCT, vector(idx0, dims), vector(idx1, dims), correction, off0, off1); var values = vectorValues(dims, size, in, VectorSimilarityType.of(sim));
var scorer = factory.getInt7ScalarQuantizedVectorScorer(dims, size, correction, DOT_PRODUCT, in).get(); float expected = luceneScore(sim, vector(idx0, dims), vector(idx1, dims), correction, off0, off1);
assertThat(scorer.score(idx0, idx1), equalTo(expected)); var supplier = factory.getInt7ScalarQuantizedVectorScorer(sim, in, values, correction).get();
assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(idx0).score(idx1), equalTo(expected)); assertThat(supplier.scorer(idx0).score(idx1), equalTo(expected));
// max inner product }
expected = luceneScore(MAXIMUM_INNER_PRODUCT, vector(idx0, dims), vector(idx1, dims), correction, off0, off1);
scorer = factory.getInt7ScalarQuantizedVectorScorer(dims, size, correction, MAXIMUM_INNER_PRODUCT, in).get();
assertThat(scorer.score(idx0, idx1), equalTo(expected));
assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(idx0).score(idx1), equalTo(expected));
// cosine
expected = luceneScore(COSINE, vector(idx0, dims), vector(idx1, dims), correction, off0, off1);
scorer = factory.getInt7ScalarQuantizedVectorScorer(dims, size, correction, COSINE, in).get();
assertThat(scorer.score(idx0, idx1), equalTo(expected));
assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(idx0).score(idx1), equalTo(expected));
// euclidean
expected = luceneScore(EUCLIDEAN, vector(idx0, dims), vector(idx1, dims), correction, off0, off1);
scorer = factory.getInt7ScalarQuantizedVectorScorer(dims, size, correction, EUCLIDEAN, in).get();
assertThat(scorer.score(idx0, idx1), equalTo(expected));
assertThat((new VectorScorerSupplierAdapter(scorer)).scorer(idx0).score(idx1), equalTo(expected));
} }
} }
} }
@ -355,7 +301,7 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
IntStream.range(0, dims).forEach(i -> vec1[i] = 1); IntStream.range(0, dims).forEach(i -> vec1[i] = 1);
IntStream.range(0, dims).forEach(i -> vec2[i] = 2); IntStream.range(0, dims).forEach(i -> vec2[i] = 2);
try (Directory dir = new MMapDirectory(createTempDir("testRace"), maxChunkSize)) { try (Directory dir = new MMapDirectory(createTempDir("testRace"), maxChunkSize)) {
String fileName = getTestName() + "-" + dims; String fileName = "testRace-" + dims;
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
var one = floatToByteArray(1f); var one = floatToByteArray(1f);
byte[] bytes = concat(vec1, one, vec1, one, vec2, one, vec2, one); byte[] bytes = concat(vec1, one, vec1, one, vec2, one, vec2, one);
@ -365,11 +311,11 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
var expectedScore2 = luceneScore(sim, vec2, vec2, 1, 1, 1); var expectedScore2 = luceneScore(sim, vec2, vec2, 1, 1, 1);
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
var scoreSupplier = factory.getInt7ScalarQuantizedVectorScorer(dims, 4, 1, sim, in).get(); var values = vectorValues(dims, 4, in, VectorSimilarityType.of(sim));
var scorer = new VectorScorerSupplierAdapter(scoreSupplier); var scoreSupplier = factory.getInt7ScalarQuantizedVectorScorer(sim, in, values, 1f).get();
var tasks = List.<Callable<Optional<Throwable>>>of( var tasks = List.<Callable<Optional<Throwable>>>of(
new ScoreCallable(scorer.copy().scorer(0), 1, expectedScore1), new ScoreCallable(scoreSupplier.copy().scorer(0), 1, expectedScore1),
new ScoreCallable(scorer.copy().scorer(2), 3, expectedScore2) new ScoreCallable(scoreSupplier.copy().scorer(2), 3, expectedScore2)
); );
var executor = Executors.newFixedThreadPool(2); var executor = Executors.newFixedThreadPool(2);
var results = executor.invokeAll(tasks); var results = executor.invokeAll(tasks);
@ -408,6 +354,10 @@ public class VectorScorerFactoryTests extends AbstractVectorTestCase {
} }
} }
RandomAccessQuantizedByteVectorValues vectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException {
return new OffHeapQuantizedByteVectorValues.DenseOffHeapVectorValues(dims, size, in.slice("values", 0, in.length()));
}
// creates the vector based on the given ordinal, which is reproducible given the ord and dims // creates the vector based on the given ordinal, which is reproducible given the ord and dims
static byte[] vector(int ord, int dims) { static byte[] vector(int ord, int dims) {
var random = new Random(Objects.hash(ord, dims)); var random = new Random(Objects.hash(ord, dims));

View file

@ -28,7 +28,7 @@ public class IndexInputUtilsTests extends ESTestCase {
public void testSingleSegment() throws IOException { public void testSingleSegment() throws IOException {
try (Directory dir = new MMapDirectory(createTempDir(getTestName()))) { try (Directory dir = new MMapDirectory(createTempDir(getTestName()))) {
for (int times = 0; times < TIMES; times++) { for (int times = 0; times < TIMES; times++) {
String fileName = getTestName() + times; String fileName = "testSingleSegment" + times;
int size = randomIntBetween(10, 127); int size = randomIntBetween(10, 127);
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
byte[] ba = new byte[size]; byte[] ba = new byte[size];
@ -80,7 +80,7 @@ public class IndexInputUtilsTests extends ESTestCase {
public void testMultiSegment() throws IOException { public void testMultiSegment() throws IOException {
try (Directory dir = new MMapDirectory(createTempDir(getTestName()), 32L)) { try (Directory dir = new MMapDirectory(createTempDir(getTestName()), 32L)) {
for (int times = 0; times < TIMES; times++) { for (int times = 0; times < TIMES; times++) {
String fileName = getTestName() + times; String fileName = "testMultiSegment" + times;
int size = randomIntBetween(65, 1511); int size = randomIntBetween(65, 1511);
int expectedNumSegs = size / 32 + 1; int expectedNumSegs = size / 32 + 1;
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {

View file

@ -48,12 +48,12 @@ import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.QuantizedVectorsReader; import org.apache.lucene.util.quantization.QuantizedVectorsReader;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizedRandomVectorScorerSupplier; import org.apache.lucene.util.quantization.ScalarQuantizedRandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.ScalarQuantizer; import org.apache.lucene.util.quantization.ScalarQuantizer;
import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.vec.VectorScorerFactory; import org.elasticsearch.vec.VectorScorerFactory;
import org.elasticsearch.vec.VectorScorerSupplierAdapter;
import org.elasticsearch.vec.VectorSimilarityType; import org.elasticsearch.vec.VectorSimilarityType;
import java.io.Closeable; import java.io.Closeable;
@ -425,19 +425,23 @@ public final class ES814ScalarQuantizedVectorsWriter extends FlatVectorsWriter {
success = true; success = true;
final IndexInput finalQuantizationDataInput = quantizationDataInput; final IndexInput finalQuantizationDataInput = quantizationDataInput;
final RandomAccessQuantizedByteVectorValues vectorValues = new OffHeapQuantizedByteVectorValues.DenseOffHeapVectorValues(
fieldInfo.getVectorDimension(),
docsWithField.cardinality(),
quantizationDataInput
);
// retrieve a scorer // retrieve a scorer
RandomVectorScorerSupplier scorerSupplier = null; RandomVectorScorerSupplier scorerSupplier = null;
Optional<VectorScorerFactory> factory = VectorScorerFactory.instance(); Optional<VectorScorerFactory> factory = VectorScorerFactory.instance();
if (factory.isPresent()) { if (factory.isPresent()) {
var scorer = factory.get() var scorer = factory.get()
.getInt7ScalarQuantizedVectorScorer( .getInt7ScalarQuantizedVectorScorer(
byteVectorValues.dimension(),
docsWithField.cardinality(),
mergedQuantizationState.getConstantMultiplier(),
VectorSimilarityType.of(fieldInfo.getVectorSimilarityFunction()), VectorSimilarityType.of(fieldInfo.getVectorSimilarityFunction()),
quantizationDataInput quantizationDataInput,
) vectorValues,
.map(VectorScorerSupplierAdapter::new); mergedQuantizationState.getConstantMultiplier()
);
if (scorer.isPresent()) { if (scorer.isPresent()) {
scorerSupplier = scorer.get(); scorerSupplier = scorer.get();
} }
@ -446,11 +450,7 @@ public final class ES814ScalarQuantizedVectorsWriter extends FlatVectorsWriter {
scorerSupplier = new ScalarQuantizedRandomVectorScorerSupplier( scorerSupplier = new ScalarQuantizedRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(), fieldInfo.getVectorSimilarityFunction(),
mergedQuantizationState, mergedQuantizationState,
new OffHeapQuantizedByteVectorValues.DenseOffHeapVectorValues( vectorValues
fieldInfo.getVectorDimension(),
docsWithField.cardinality(),
quantizationDataInput
)
); );
} }