Reapply "Adds new unexposed and experimental IVF format (#127528)" (#128005) (#128051)

This reverts commit 8a17a5ed5f.

reapplying ivf format, but with a fix.
This commit is contained in:
Benjamin Trent 2025-05-13 18:47:59 -04:00 committed by GitHub
parent d07ec0cc44
commit 1324ee0115
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 2576 additions and 3 deletions

View file

@ -17,7 +17,7 @@ import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
import org.elasticsearch.common.logging.LogConfigurator;
import org.elasticsearch.simdvec.internal.vectorization.ES91OSQVectorsScorer;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;

View file

@ -6,7 +6,7 @@
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.simdvec.internal.vectorization;
package org.elasticsearch.simdvec;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.IndexInput;

View file

@ -9,11 +9,13 @@
package org.elasticsearch.simdvec;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.Constants;
import org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport;
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
import java.io.IOException;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
@ -41,6 +43,10 @@ public class ESVectorUtil {
private static final ESVectorUtilSupport IMPL = ESVectorizationProvider.getInstance().getVectorUtilSupport();
public static ES91OSQVectorsScorer getES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException {
return ESVectorizationProvider.getInstance().newES91OSQVectorsScorer(input, dimension);
}
public static long ipByteBinByte(byte[] q, byte[] d) {
if (q.length != d.length * B_QUERY) {
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + B_QUERY + " x " + d.length);
@ -211,4 +217,40 @@ public class ESVectorUtil {
assert stats.length == 6;
IMPL.centerAndCalculateOSQStatsDp(target, centroid, centered, stats);
}
/**
* Calculates the difference between two vectors and stores the result in a third vector.
* @param v1 the first vector
* @param v2 the second vector
* @param result the result vector, must be the same length as the input vectors
*/
public static void subtract(float[] v1, float[] v2, float[] result) {
if (v1.length != v2.length) {
throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + v2.length);
}
if (result.length != v1.length) {
throw new IllegalArgumentException("vector dimensions differ: " + result.length + "!=" + v1.length);
}
for (int i = 0; i < v1.length; i++) {
result[i] = v1[i] - v2[i];
}
}
/**
* calculates the spill-over score for a vector and a centroid, given its residual with
* its actually nearest centroid
* @param v1 the vector
* @param centroid the centroid
* @param originalResidual the residual with the actually nearest centroid
* @return the spill-over score (soar)
*/
public static float soarResidual(float[] v1, float[] centroid, float[] originalResidual) {
if (v1.length != centroid.length) {
throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + centroid.length);
}
if (originalResidual.length != v1.length) {
throw new IllegalArgumentException("vector dimensions differ: " + originalResidual.length + "!=" + v1.length);
}
return IMPL.soarResidual(v1, centroid, originalResidual);
}
}

View file

@ -138,6 +138,18 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
stats[5] = centroidDot;
}
@Override
public float soarResidual(float[] v1, float[] centroid, float[] originalResidual) {
assert v1.length == centroid.length;
assert v1.length == originalResidual.length;
float proj = 0;
for (int i = 0; i < v1.length; i++) {
float djk = v1[i] - centroid[i];
proj = fma(djk, originalResidual[i], proj);
}
return proj;
}
public static int ipByteBitImpl(byte[] q, byte[] d) {
return ipByteBitImpl(q, d, 0);
}

View file

@ -10,6 +10,7 @@
package org.elasticsearch.simdvec.internal.vectorization;
import org.apache.lucene.store.IndexInput;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
import java.io.IOException;

View file

@ -28,4 +28,7 @@ public interface ESVectorUtilSupport {
void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats);
void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats);
float soarResidual(float[] v1, float[] centroid, float[] originalResidual);
}

View file

@ -10,6 +10,7 @@
package org.elasticsearch.simdvec.internal.vectorization;
import org.apache.lucene.store.IndexInput;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
import java.io.IOException;
import java.util.Objects;

View file

@ -13,6 +13,7 @@ import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Constants;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
import java.io.IOException;
import java.util.Locale;

View file

@ -20,6 +20,7 @@ import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
import java.io.IOException;
import java.lang.foreign.MemorySegment;

View file

@ -367,6 +367,49 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
return (1f - lambda) * xe * xe / norm2 + lambda * e;
}
@Override
public float soarResidual(float[] v1, float[] centroid, float[] originalResidual) {
assert v1.length == centroid.length;
assert v1.length == originalResidual.length;
float proj = 0;
int i = 0;
if (v1.length > 2 * FLOAT_SPECIES.length()) {
FloatVector projVec1 = FloatVector.zero(FLOAT_SPECIES);
FloatVector projVec2 = FloatVector.zero(FLOAT_SPECIES);
int unrolledLimit = FLOAT_SPECIES.loopBound(v1.length) - FLOAT_SPECIES.length();
for (; i < unrolledLimit; i += 2 * FLOAT_SPECIES.length()) {
// one
FloatVector v1Vec0 = FloatVector.fromArray(FLOAT_SPECIES, v1, i);
FloatVector centroidVec0 = FloatVector.fromArray(FLOAT_SPECIES, centroid, i);
FloatVector originalResidualVec0 = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i);
FloatVector djkVec0 = v1Vec0.sub(centroidVec0);
projVec1 = fma(djkVec0, originalResidualVec0, projVec1);
// two
FloatVector v1Vec1 = FloatVector.fromArray(FLOAT_SPECIES, v1, i + FLOAT_SPECIES.length());
FloatVector centroidVec1 = FloatVector.fromArray(FLOAT_SPECIES, centroid, i + FLOAT_SPECIES.length());
FloatVector originalResidualVec1 = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i + FLOAT_SPECIES.length());
FloatVector djkVec1 = v1Vec1.sub(centroidVec1);
projVec2 = fma(djkVec1, originalResidualVec1, projVec2);
}
// vector tail
for (; i < FLOAT_SPECIES.loopBound(v1.length); i += FLOAT_SPECIES.length()) {
FloatVector v1Vec = FloatVector.fromArray(FLOAT_SPECIES, v1, i);
FloatVector centroidVec = FloatVector.fromArray(FLOAT_SPECIES, centroid, i);
FloatVector originalResidualVec = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i);
FloatVector djkVec = v1Vec.sub(centroidVec);
projVec1 = fma(djkVec, originalResidualVec, projVec1);
}
proj += projVec1.add(projVec2).reduceLanes(ADD);
}
// tail
for (; i < v1.length; i++) {
float djk = v1[i] - centroid[i];
proj = fma(djk, originalResidual[i], proj);
}
return proj;
}
private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
private static final VectorSpecies<Byte> BYTE_SPECIES_256 = ByteVector.SPECIES_256;

View file

@ -11,6 +11,7 @@ package org.elasticsearch.simdvec.internal.vectorization;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.MemorySegmentAccessInput;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
import java.io.IOException;
import java.lang.foreign.MemorySegment;

View file

@ -268,6 +268,22 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
}
}
public void testSoarOverspillScore() {
int size = random().nextInt(128, 512);
float deltaEps = 1e-5f * size;
var vector = new float[size];
var centroid = new float[size];
var preResidual = new float[size];
for (int i = 0; i < size; ++i) {
vector[i] = random().nextFloat();
centroid[i] = random().nextFloat();
preResidual[i] = random().nextFloat();
}
var expected = defaultedProvider.getVectorUtilSupport().soarResidual(vector, centroid, preResidual);
var result = defOrPanamaProvider.getVectorUtilSupport().soarResidual(vector, centroid, preResidual);
assertEquals(expected, result, deltaEps);
}
void testIpByteBinImpl(ToLongBiFunction<byte[], byte[]> ipByteBinFunc) {
int iterations = atLeast(50);
for (int i = 0; i < iterations; i++) {

View file

@ -16,6 +16,7 @@ import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
import static org.hamcrest.Matchers.lessThan;

View file

@ -454,7 +454,8 @@ module org.elasticsearch.server {
org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat,
org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat,
org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat,
org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat;
org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat,
org.elasticsearch.index.codec.vectors.IVFVectorsFormat;
provides org.apache.lucene.codecs.Codec
with

View file

@ -0,0 +1,420 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.index.codec.vectors;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
import org.elasticsearch.simdvec.ESVectorUtil;
import java.io.IOException;
import java.util.function.IntPredicate;
import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.QUERY_BITS;
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize;
import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.transposeHalfByte;
import static org.elasticsearch.simdvec.ES91OSQVectorsScorer.BULK_SIZE;
/**
* Default implementation of {@link IVFVectorsReader}. It scores the posting lists centroids using
* brute force and then scores the top ones using the posting list.
*/
public class DefaultIVFVectorsReader extends IVFVectorsReader {
private static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1);
public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException {
super(state, rawVectorsReader);
}
@Override
CentroidQueryScorer getCentroidScorer(
FieldInfo fieldInfo,
int numCentroids,
IndexInput centroids,
float[] targetQuery,
IndexInput clusters
) throws IOException {
FieldEntry fieldEntry = fields.get(fieldInfo.number);
float[] globalCentroid = fieldEntry.globalCentroid();
float globalCentroidDp = fieldEntry.globalCentroidDp();
OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
byte[] quantized = new byte[targetQuery.length];
float[] targetScratch = ArrayUtil.copyArray(targetQuery);
OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize(
targetScratch,
quantized,
(byte) 4,
globalCentroid
);
return new CentroidQueryScorer() {
int currentCentroid = -1;
private final byte[] quantizedCentroid = new byte[fieldInfo.getVectorDimension()];
private final float[] centroid = new float[fieldInfo.getVectorDimension()];
private final float[] centroidCorrectiveValues = new float[3];
private int quantizedCentroidComponentSum;
private final long centroidByteSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES;
@Override
public int size() {
return numCentroids;
}
@Override
public float[] centroid(int centroidOrdinal) throws IOException {
readQuantizedAndRawCentroid(centroidOrdinal);
return centroid;
}
private void readQuantizedAndRawCentroid(int centroidOrdinal) throws IOException {
if (centroidOrdinal == currentCentroid) {
return;
}
centroids.seek(centroidOrdinal * centroidByteSize);
quantizedCentroidComponentSum = readQuantizedValue(centroids, quantizedCentroid, centroidCorrectiveValues);
centroids.seek(numCentroids * centroidByteSize + (long) Float.BYTES * quantizedCentroid.length * centroidOrdinal);
centroids.readFloats(centroid, 0, centroid.length);
currentCentroid = centroidOrdinal;
}
@Override
public float score(int centroidOrdinal) throws IOException {
readQuantizedAndRawCentroid(centroidOrdinal);
return int4QuantizedScore(
quantized,
queryParams,
fieldInfo.getVectorDimension(),
quantizedCentroid,
centroidCorrectiveValues,
quantizedCentroidComponentSum,
globalCentroidDp,
fieldInfo.getVectorSimilarityFunction()
);
}
};
}
@Override
protected FloatVectorValues getCentroids(IndexInput indexInput, int numCentroids, FieldInfo info) {
FieldEntry entry = fields.get(info.number);
if (entry == null) {
return null;
}
return new OffHeapCentroidFloatVectorValues(numCentroids, indexInput, info.getVectorDimension());
}
@Override
NeighborQueue scorePostingLists(FieldInfo fieldInfo, KnnCollector knnCollector, CentroidQueryScorer centroidQueryScorer, int nProbe)
throws IOException {
NeighborQueue neighborQueue = new NeighborQueue(centroidQueryScorer.size(), true);
// TODO Off heap scoring for quantized centroids?
for (int centroid = 0; centroid < centroidQueryScorer.size(); centroid++) {
neighborQueue.add(centroid, centroidQueryScorer.score(centroid));
}
return neighborQueue;
}
@Override
PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput indexInput, float[] target, IntPredicate needsScoring)
throws IOException {
FieldEntry entry = fields.get(fieldInfo.number);
return new MemorySegmentPostingsVisitor(target, indexInput, entry, fieldInfo, needsScoring);
}
// TODO can we do this in off-heap blocks?
static float int4QuantizedScore(
byte[] quantizedQuery,
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
int dims,
byte[] binaryCode,
float[] targetCorrections,
int targetComponentSum,
float centroidDp,
VectorSimilarityFunction similarityFunction
) {
float qcDist = VectorUtil.int4DotProduct(quantizedQuery, binaryCode);
float ax = targetCorrections[0];
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE;
float ay = queryCorrections.lowerInterval();
float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
float y1 = queryCorrections.quantizedComponentSum();
float score = ax * ay * dims + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
if (similarityFunction == EUCLIDEAN) {
score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2 * score;
return Math.max(1 / (1f + score), 0);
} else {
// For cosine and max inner product, we need to apply the additional correction, which is
// assumed to be the non-centered dot-product between the vector and the centroid
score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp;
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
return VectorUtil.scaleMaxInnerProductScore(score);
}
return Math.max((1f + score) / 2f, 0);
}
}
static class OffHeapCentroidFloatVectorValues extends FloatVectorValues {
private final int numCentroids;
private final IndexInput input;
private final int dimension;
private final float[] centroid;
private final long centroidByteSize;
private int ord = -1;
OffHeapCentroidFloatVectorValues(int numCentroids, IndexInput input, int dimension) {
this.numCentroids = numCentroids;
this.input = input;
this.dimension = dimension;
this.centroid = new float[dimension];
this.centroidByteSize = dimension + 3 * Float.BYTES + Short.BYTES;
}
@Override
public float[] vectorValue(int ord) throws IOException {
if (ord < 0 || ord >= numCentroids) {
throw new IllegalArgumentException("ord must be in [0, " + numCentroids + "]");
}
if (ord == this.ord) {
return centroid;
}
readQuantizedCentroid(ord);
return centroid;
}
private void readQuantizedCentroid(int centroidOrdinal) throws IOException {
if (centroidOrdinal == ord) {
return;
}
input.seek(numCentroids * centroidByteSize + (long) Float.BYTES * dimension * centroidOrdinal);
input.readFloats(centroid, 0, centroid.length);
ord = centroidOrdinal;
}
@Override
public int dimension() {
return dimension;
}
@Override
public int size() {
return numCentroids;
}
@Override
public FloatVectorValues copy() throws IOException {
return new OffHeapCentroidFloatVectorValues(numCentroids, input.clone(), dimension);
}
}
private static class MemorySegmentPostingsVisitor implements PostingVisitor {
final long quantizedByteLength;
final IndexInput indexInput;
final float[] target;
final FieldEntry entry;
final FieldInfo fieldInfo;
final IntPredicate needsScoring;
private final ES91OSQVectorsScorer osqVectorsScorer;
final float[] scores = new float[BULK_SIZE];
final float[] correctionsLower = new float[BULK_SIZE];
final float[] correctionsUpper = new float[BULK_SIZE];
final int[] correctionsSum = new int[BULK_SIZE];
final float[] correctionsAdd = new float[BULK_SIZE];
int[] docIdsScratch = new int[0];
int vectors;
boolean quantized = false;
float centroidDp;
float[] centroid;
long slicePos;
OptimizedScalarQuantizer.QuantizationResult queryCorrections;
DocIdsWriter docIdsWriter = new DocIdsWriter();
final float[] scratch;
final byte[] quantizationScratch;
final byte[] quantizedQueryScratch;
final OptimizedScalarQuantizer quantizer;
final float[] correctiveValues = new float[3];
final long quantizedVectorByteSize;
MemorySegmentPostingsVisitor(
float[] target,
IndexInput indexInput,
FieldEntry entry,
FieldInfo fieldInfo,
IntPredicate needsScoring
) throws IOException {
this.target = target;
this.indexInput = indexInput;
this.entry = entry;
this.fieldInfo = fieldInfo;
this.needsScoring = needsScoring;
scratch = new float[target.length];
quantizationScratch = new byte[target.length];
final int discretizedDimensions = discretize(fieldInfo.getVectorDimension(), 64);
quantizedQueryScratch = new byte[QUERY_BITS * discretizedDimensions / 8];
quantizedByteLength = discretizedDimensions / 8 + (Float.BYTES * 3) + Short.BYTES;
quantizedVectorByteSize = (discretizedDimensions / 8);
quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
osqVectorsScorer = ESVectorUtil.getES91OSQVectorsScorer(indexInput, fieldInfo.getVectorDimension());
}
@Override
public int resetPostingsScorer(int centroidOrdinal, float[] centroid) throws IOException {
quantized = false;
indexInput.seek(entry.postingListOffsets()[centroidOrdinal]);
vectors = indexInput.readVInt();
centroidDp = Float.intBitsToFloat(indexInput.readInt());
this.centroid = centroid;
// read the doc ids
docIdsScratch = vectors > docIdsScratch.length ? new int[vectors] : docIdsScratch;
docIdsWriter.readInts(indexInput, vectors, docIdsScratch);
slicePos = indexInput.getFilePointer();
return vectors;
}
void scoreIndividually(int offset) throws IOException {
// score individually, first the quantized byte chunk
for (int j = 0; j < BULK_SIZE; j++) {
int doc = docIdsScratch[j + offset];
if (doc != -1) {
indexInput.seek(slicePos + (offset * quantizedByteLength) + (j * quantizedVectorByteSize));
float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch);
scores[j] = qcDist;
}
}
// read in all corrections
indexInput.seek(slicePos + (offset * quantizedByteLength) + (BULK_SIZE * quantizedVectorByteSize));
indexInput.readFloats(correctionsLower, 0, BULK_SIZE);
indexInput.readFloats(correctionsUpper, 0, BULK_SIZE);
for (int j = 0; j < BULK_SIZE; j++) {
correctionsSum[j] = Short.toUnsignedInt(indexInput.readShort());
}
indexInput.readFloats(correctionsAdd, 0, BULK_SIZE);
// Now apply corrections
for (int j = 0; j < BULK_SIZE; j++) {
int doc = docIdsScratch[offset + j];
if (doc != -1) {
scores[j] = osqVectorsScorer.score(
queryCorrections,
fieldInfo.getVectorSimilarityFunction(),
centroidDp,
correctionsLower[j],
correctionsUpper[j],
correctionsSum[j],
correctionsAdd[j],
scores[j]
);
}
}
}
@Override
public int visit(KnnCollector knnCollector) throws IOException {
// block processing
int scoredDocs = 0;
int limit = vectors - BULK_SIZE + 1;
int i = 0;
for (; i < limit; i += BULK_SIZE) {
int docsToScore = BULK_SIZE;
for (int j = 0; j < BULK_SIZE; j++) {
int doc = docIdsScratch[i + j];
if (needsScoring.test(doc) == false) {
docIdsScratch[i + j] = -1;
docsToScore--;
}
}
if (docsToScore == 0) {
continue;
}
quantizeQueryIfNecessary();
indexInput.seek(slicePos + i * quantizedByteLength);
if (docsToScore < BULK_SIZE / 2) {
scoreIndividually(i);
} else {
osqVectorsScorer.scoreBulk(
quantizedQueryScratch,
queryCorrections,
fieldInfo.getVectorSimilarityFunction(),
centroidDp,
scores
);
}
for (int j = 0; j < BULK_SIZE; j++) {
int doc = docIdsScratch[i + j];
if (doc != -1) {
scoredDocs++;
knnCollector.collect(doc, scores[j]);
}
}
}
// process tail
for (; i < vectors; i++) {
int doc = docIdsScratch[i];
if (needsScoring.test(doc)) {
quantizeQueryIfNecessary();
indexInput.seek(slicePos + i * quantizedByteLength);
float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch);
indexInput.readFloats(correctiveValues, 0, 3);
final int quantizedComponentSum = Short.toUnsignedInt(indexInput.readShort());
float score = osqVectorsScorer.score(
queryCorrections,
fieldInfo.getVectorSimilarityFunction(),
centroidDp,
correctiveValues[0],
correctiveValues[1],
quantizedComponentSum,
correctiveValues[2],
qcDist
);
scoredDocs++;
knnCollector.collect(doc, score);
}
}
if (scoredDocs > 0) {
knnCollector.incVisitedCount(scoredDocs);
}
return scoredDocs;
}
private void quantizeQueryIfNecessary() {
if (quantized == false) {
System.arraycopy(target, 0, scratch, 0, target.length);
if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
VectorUtil.l2normalize(scratch);
}
queryCorrections = quantizer.scalarQuantize(scratch, quantizationScratch, (byte) 4, centroid);
transposeHalfByte(quantizationScratch, quantizedQueryScratch);
quantized = true;
}
}
}
static int readQuantizedValue(IndexInput indexInput, byte[] binaryValue, float[] corrections) throws IOException {
assert corrections.length == 3;
indexInput.readBytes(binaryValue, 0, binaryValue.length);
corrections[0] = Float.intBitsToFloat(indexInput.readInt());
corrections[1] = Float.intBitsToFloat(indexInput.readInt());
corrections[2] = Float.intBitsToFloat(indexInput.readInt());
return Short.toUnsignedInt(indexInput.readShort());
}
}

View file

@ -0,0 +1,736 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.index.codec.vectors;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntArrayList;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
import org.elasticsearch.simdvec.ESVectorUtil;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.INDEX_BITS;
import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize;
import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.packAsBinary;
import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.IVF_VECTOR_COMPONENT;
/**
* Default implementation of {@link IVFVectorsWriter}. It uses {@link KMeans} algorithm to
* partition the vector space, and then stores the centroids an posting list in a sequential
* fashion.
*/
public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
static final float SOAR_LAMBDA = 1.0f;
// What percentage of the centroids do we do a second check on for SOAR assignment
static final float EXT_SOAR_LIMIT_CHECK_RATIO = 0.10f;
private final int vectorPerCluster;
public DefaultIVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVectorDelegate, int vectorPerCluster) throws IOException {
super(state, rawVectorDelegate);
this.vectorPerCluster = vectorPerCluster;
}
@Override
CentroidAssignmentScorer calculateAndWriteCentroids(
FieldInfo fieldInfo,
FloatVectorValues floatVectorValues,
IndexOutput centroidOutput,
float[] globalCentroid
) throws IOException {
if (floatVectorValues.size() == 0) {
return CentroidAssignmentScorer.EMPTY;
}
// calculate the centroids
int maxNumClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1;
int desiredClusters = (int) Math.max(Math.sqrt(floatVectorValues.size()), maxNumClusters);
final KMeans.Results kMeans = KMeans.cluster(
floatVectorValues,
desiredClusters,
false,
42L,
KMeans.KmeansInitializationMethod.PLUS_PLUS,
null,
fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE,
1,
15,
desiredClusters * 256
);
float[][] centroids = kMeans.centroids();
// write them
writeCentroids(centroids, fieldInfo, globalCentroid, centroidOutput);
return new OnHeapCentroidAssignmentScorer(centroids);
}
@Override
long[] buildAndWritePostingsLists(
FieldInfo fieldInfo,
InfoStream infoStream,
CentroidAssignmentScorer randomCentroidScorer,
FloatVectorValues floatVectorValues,
IndexOutput postingsOutput
) throws IOException {
IntArrayList[] clusters = new IntArrayList[randomCentroidScorer.size()];
for (int i = 0; i < randomCentroidScorer.size(); i++) {
clusters[i] = new IntArrayList(floatVectorValues.size() / randomCentroidScorer.size() / 4);
}
assignCentroids(randomCentroidScorer, floatVectorValues, clusters);
if (infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
printClusterQualityStatistics(clusters, infoStream);
}
// write the posting lists
final long[] offsets = new long[randomCentroidScorer.size()];
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
BinarizedFloatVectorValues binarizedByteVectorValues = new BinarizedFloatVectorValues(floatVectorValues, quantizer);
DocIdsWriter docIdsWriter = new DocIdsWriter();
for (int i = 0; i < randomCentroidScorer.size(); i++) {
float[] centroid = randomCentroidScorer.centroid(i);
binarizedByteVectorValues.centroid = centroid;
// TODO sort by distance to the centroid
IntArrayList cluster = clusters[i];
// TODO align???
offsets[i] = postingsOutput.getFilePointer();
int size = cluster.size();
postingsOutput.writeVInt(size);
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
// TODO we might want to consider putting the docIds in a separate file
// to aid with only having to fetch vectors from slower storage when they are required
// keeping them in the same file indicates we pull the entire file into cache
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster.get(j)), cluster.size(), postingsOutput);
writePostingList(cluster, postingsOutput, binarizedByteVectorValues);
}
return offsets;
}
private void writePostingList(IntArrayList cluster, IndexOutput postingsOutput, BinarizedFloatVectorValues binarizedByteVectorValues)
throws IOException {
int limit = cluster.size() - ES91OSQVectorsScorer.BULK_SIZE + 1;
int cidx = 0;
OptimizedScalarQuantizer.QuantizationResult[] corrections =
new OptimizedScalarQuantizer.QuantizationResult[ES91OSQVectorsScorer.BULK_SIZE];
// Write vectors in bulks of ES91OSQVectorsScorer.BULK_SIZE.
for (; cidx < limit; cidx += ES91OSQVectorsScorer.BULK_SIZE) {
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
int ord = cluster.get(cidx + j);
byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord);
// write vector
postingsOutput.writeBytes(binaryValue, 0, binaryValue.length);
corrections[j] = binarizedByteVectorValues.getCorrectiveTerms(ord);
}
// write corrections
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
postingsOutput.writeInt(Float.floatToIntBits(corrections[j].lowerInterval()));
}
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
postingsOutput.writeInt(Float.floatToIntBits(corrections[j].upperInterval()));
}
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
int targetComponentSum = corrections[j].quantizedComponentSum();
assert targetComponentSum >= 0 && targetComponentSum <= 0xffff;
postingsOutput.writeShort((short) targetComponentSum);
}
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
postingsOutput.writeInt(Float.floatToIntBits(corrections[j].additionalCorrection()));
}
}
// write tail
for (; cidx < cluster.size(); cidx++) {
int ord = cluster.get(cidx);
// write vector
byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord);
OptimizedScalarQuantizer.QuantizationResult correction = binarizedByteVectorValues.getCorrectiveTerms(ord);
writeQuantizedValue(postingsOutput, binaryValue, correction);
binarizedByteVectorValues.getCorrectiveTerms(ord);
postingsOutput.writeBytes(binaryValue, 0, binaryValue.length);
postingsOutput.writeInt(Float.floatToIntBits(correction.lowerInterval()));
postingsOutput.writeInt(Float.floatToIntBits(correction.upperInterval()));
postingsOutput.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
assert correction.quantizedComponentSum() >= 0 && correction.quantizedComponentSum() <= 0xffff;
postingsOutput.writeShort((short) correction.quantizedComponentSum());
}
}
@Override
CentroidAssignmentScorer createCentroidScorer(
IndexInput centroidsInput,
int numCentroids,
FieldInfo fieldInfo,
float[] globalCentroid
) {
return new OffHeapCentroidAssignmentScorer(centroidsInput, numCentroids, fieldInfo);
}
static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] globalCentroid, IndexOutput centroidOutput)
throws IOException {
final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
byte[] quantizedScratch = new byte[fieldInfo.getVectorDimension()];
float[] centroidScratch = new float[fieldInfo.getVectorDimension()];
// TODO do we want to store these distances as well for future use?
float[] distances = new float[centroids.length];
for (int i = 0; i < centroids.length; i++) {
distances[i] = VectorUtil.squareDistance(centroids[i], globalCentroid);
}
// sort the centroids by distance to globalCentroid, nearest (smallest distance), to furthest
// (largest)
for (int i = 0; i < centroids.length; i++) {
for (int j = i + 1; j < centroids.length; j++) {
if (distances[i] > distances[j]) {
float[] tmp = centroids[i];
centroids[i] = centroids[j];
centroids[j] = tmp;
float tmpDistance = distances[i];
distances[i] = distances[j];
distances[j] = tmpDistance;
}
}
}
for (float[] centroid : centroids) {
System.arraycopy(centroid, 0, centroidScratch, 0, centroid.length);
OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(
centroidScratch,
quantizedScratch,
(byte) 4,
globalCentroid
);
writeQuantizedValue(centroidOutput, quantizedScratch, result);
}
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
for (float[] centroid : centroids) {
buffer.asFloatBuffer().put(centroid);
centroidOutput.writeBytes(buffer.array(), buffer.array().length);
}
}
static float[][] gatherInitCentroids(
List<FloatVectorValues> centroidList,
List<SegmentCentroid> segmentCentroids,
int desiredClusters,
FieldInfo fieldInfo,
MergeState mergeState
) throws IOException {
if (centroidList.size() == 0) {
return null;
}
long startTime = System.nanoTime();
// sort centroid list by floatvector size
FloatVectorValues baseSegment = centroidList.get(0);
for (var l : centroidList) {
if (l.size() > baseSegment.size()) {
baseSegment = l;
}
}
float[] scratch = new float[fieldInfo.getVectorDimension()];
float minimumDistance = Float.MAX_VALUE;
for (int j = 0; j < baseSegment.size(); j++) {
System.arraycopy(baseSegment.vectorValue(j), 0, scratch, 0, baseSegment.dimension());
for (int k = j + 1; k < baseSegment.size(); k++) {
float d = VectorUtil.squareDistance(scratch, baseSegment.vectorValue(k));
if (d < minimumDistance) {
minimumDistance = d;
}
}
}
if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
mergeState.infoStream.message(
IVF_VECTOR_COMPONENT,
"Agglomerative cluster min distance: " + minimumDistance + " From biggest segment: " + baseSegment.size()
);
}
int[] labels = new int[segmentCentroids.size()];
// loop over segments
int clusterIdx = 0;
// keep track of all inter-centroid distances,
// using less than centroid * centroid space (e.g. not keeping track of duplicates)
for (int i = 0; i < segmentCentroids.size(); i++) {
if (labels[i] == 0) {
clusterIdx += 1;
labels[i] = clusterIdx;
}
SegmentCentroid segmentCentroid = segmentCentroids.get(i);
System.arraycopy(
centroidList.get(segmentCentroid.segment()).vectorValue(segmentCentroid.centroid),
0,
scratch,
0,
baseSegment.dimension()
);
for (int j = i + 1; j < segmentCentroids.size(); j++) {
float d = VectorUtil.squareDistance(
scratch,
centroidList.get(segmentCentroids.get(j).segment()).vectorValue(segmentCentroids.get(j).centroid())
);
if (d < minimumDistance / 2) {
if (labels[j] == 0) {
labels[j] = labels[i];
} else {
for (int k = 0; k < labels.length; k++) {
if (labels[k] == labels[j]) {
labels[k] = labels[i];
}
}
}
}
}
}
float[][] initCentroids = new float[clusterIdx][fieldInfo.getVectorDimension()];
int[] sum = new int[clusterIdx];
for (int i = 0; i < segmentCentroids.size(); i++) {
SegmentCentroid segmentCentroid = segmentCentroids.get(i);
int label = labels[i];
FloatVectorValues segment = centroidList.get(segmentCentroid.segment());
float[] vector = segment.vectorValue(segmentCentroid.centroid);
for (int j = 0; j < vector.length; j++) {
initCentroids[label - 1][j] += (vector[j] * segmentCentroid.centroidSize);
}
sum[label - 1] += segmentCentroid.centroidSize;
}
for (int i = 0; i < initCentroids.length; i++) {
if (sum[i] == 0 || sum[i] == 1) {
continue;
}
for (int j = 0; j < initCentroids[i].length; j++) {
initCentroids[i][j] /= sum[i];
}
}
if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
mergeState.infoStream.message(
IVF_VECTOR_COMPONENT,
"Agglomerative cluster time ms: " + ((System.nanoTime() - startTime) / 1000000.0)
);
mergeState.infoStream.message(
IVF_VECTOR_COMPONENT,
"Gathered initCentroids:" + initCentroids.length + " for desired: " + desiredClusters
);
}
return initCentroids;
}
record SegmentCentroid(int segment, int centroid, int centroidSize) {}
/**
* Calculate the centroids for the given field and write them to the given
* temporary centroid output.
* When merging, we first bootstrap the KMeans algorithm with the centroids contained in the merging segments.
* To prevent centroids that are too similar from having an outsized impact, all centroids that are closer than
* the largest segments intra-cluster distance are merged into a single centroid.
* The resulting centroids are then used to initialize the KMeans algorithm.
*
* @param fieldInfo merging field info
* @param floatVectorValues the float vector values to merge
* @param temporaryCentroidOutput the temporary centroid output
* @param mergeState the merge state
* @param globalCentroid the global centroid, calculated by this method and used to quantize the centroids
* @return the number of centroids written
* @throws IOException if an I/O error occurs
*/
@Override
protected int calculateAndWriteCentroids(
FieldInfo fieldInfo,
FloatVectorValues floatVectorValues,
IndexOutput temporaryCentroidOutput,
MergeState mergeState,
float[] globalCentroid
) throws IOException {
if (floatVectorValues.size() == 0) {
return 0;
}
int maxNumClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1;
int desiredClusters = (int) Math.max(Math.sqrt(floatVectorValues.size()), maxNumClusters);
// init centroids from merge state
List<FloatVectorValues> centroidList = new ArrayList<>();
List<SegmentCentroid> segmentCentroids = new ArrayList<>(desiredClusters);
int segmentIdx = 0;
for (var reader : mergeState.knnVectorsReaders) {
IVFVectorsReader ivfVectorsReader = IVFVectorsFormat.getIVFReader(reader, fieldInfo.name);
if (ivfVectorsReader == null) {
continue;
}
FloatVectorValues centroid = ivfVectorsReader.getCentroids(fieldInfo);
if (centroid == null) {
continue;
}
centroidList.add(centroid);
for (int i = 0; i < centroid.size(); i++) {
int size = ivfVectorsReader.centroidSize(fieldInfo.name, i);
if (size == 0) {
continue;
}
segmentCentroids.add(new SegmentCentroid(segmentIdx, i, size));
}
segmentIdx++;
}
float[][] initCentroids = gatherInitCentroids(centroidList, segmentCentroids, desiredClusters, fieldInfo, mergeState);
// FIXME: run a custom version of KMeans that is just better...
long nanoTime = System.nanoTime();
final KMeans.Results kMeans = KMeans.cluster(
floatVectorValues,
desiredClusters,
false,
42L,
KMeans.KmeansInitializationMethod.PLUS_PLUS,
initCentroids,
fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE,
1,
5,
desiredClusters * 64
);
if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
mergeState.infoStream.message(IVF_VECTOR_COMPONENT, "KMeans time ms: " + ((System.nanoTime() - nanoTime) / 1000000.0));
}
float[][] centroids = kMeans.centroids();
// write them
// calculate the global centroid from all the centroids:
for (float[] centroid : centroids) {
for (int j = 0; j < centroid.length; j++) {
globalCentroid[j] += centroid[j];
}
}
for (int j = 0; j < globalCentroid.length; j++) {
globalCentroid[j] /= centroids.length;
}
writeCentroids(centroids, fieldInfo, globalCentroid, temporaryCentroidOutput);
return centroids.length;
}
@Override
long[] buildAndWritePostingsLists(
FieldInfo fieldInfo,
CentroidAssignmentScorer centroidAssignmentScorer,
FloatVectorValues floatVectorValues,
IndexOutput postingsOutput,
MergeState mergeState
) throws IOException {
IntArrayList[] clusters = new IntArrayList[centroidAssignmentScorer.size()];
for (int i = 0; i < centroidAssignmentScorer.size(); i++) {
clusters[i] = new IntArrayList(floatVectorValues.size() / centroidAssignmentScorer.size() / 4);
}
long nanoTime = System.nanoTime();
// Can we do a pre-filter by finding the nearest centroids to the original vector centroids?
// We need to be careful on vecOrd vs. doc as we need random access to the raw vector for posting list writing
assignCentroids(centroidAssignmentScorer, floatVectorValues, clusters);
if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
mergeState.infoStream.message(IVF_VECTOR_COMPONENT, "assignCentroids time ms: " + ((System.nanoTime() - nanoTime) / 1000000.0));
}
if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
printClusterQualityStatistics(clusters, mergeState.infoStream);
}
// write the posting lists
final long[] offsets = new long[centroidAssignmentScorer.size()];
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
BinarizedFloatVectorValues binarizedByteVectorValues = new BinarizedFloatVectorValues(floatVectorValues, quantizer);
DocIdsWriter docIdsWriter = new DocIdsWriter();
for (int i = 0; i < centroidAssignmentScorer.size(); i++) {
float[] centroid = centroidAssignmentScorer.centroid(i);
binarizedByteVectorValues.centroid = centroid;
// TODO: sort by distance to the centroid
IntArrayList cluster = clusters[i];
// TODO align???
offsets[i] = postingsOutput.getFilePointer();
int size = cluster.size();
postingsOutput.writeVInt(size);
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
// TODO we might want to consider putting the docIds in a separate file
// to aid with only having to fetch vectors from slower storage when they are required
// keeping them in the same file indicates we pull the entire file into cache
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster.get(j)), size, postingsOutput);
writePostingList(cluster, postingsOutput, binarizedByteVectorValues);
}
return offsets;
}
private static void printClusterQualityStatistics(IntArrayList[] clusters, InfoStream infoStream) {
float min = Float.MAX_VALUE;
float max = Float.MIN_VALUE;
float mean = 0;
float m2 = 0;
// iteratively compute the variance & mean
int count = 0;
for (IntArrayList cluster : clusters) {
count += 1;
if (cluster == null) {
continue;
}
float delta = cluster.size() - mean;
mean += delta / count;
m2 += delta * (cluster.size() - mean);
min = Math.min(min, cluster.size());
max = Math.max(max, cluster.size());
}
float variance = m2 / (clusters.length - 1);
infoStream.message(
IVF_VECTOR_COMPONENT,
"Centroid count: "
+ clusters.length
+ " min: "
+ min
+ " max: "
+ max
+ " mean: "
+ mean
+ " stdDev: "
+ Math.sqrt(variance)
+ " variance: "
+ variance
);
}
static void assignCentroids(CentroidAssignmentScorer scorer, FloatVectorValues vectors, IntArrayList[] clusters) throws IOException {
int numCentroids = scorer.size();
// we at most will look at the EXT_SOAR_LIMIT_CHECK_RATIO nearest centroids if possible
int soarToCheck = (int) (numCentroids * EXT_SOAR_LIMIT_CHECK_RATIO);
int soarClusterCheckCount = Math.min(numCentroids - 1, soarToCheck);
NeighborQueue neighborsToCheck = new NeighborQueue(soarClusterCheckCount + 1, true);
OrdScoreIterator ordScoreIterator = new OrdScoreIterator(soarClusterCheckCount + 1);
float[] scratch = new float[vectors.dimension()];
for (int docID = 0; docID < vectors.size(); docID++) {
float[] vector = vectors.vectorValue(docID);
scorer.setScoringVector(vector);
int bestCentroid = 0;
float bestScore = Float.MAX_VALUE;
if (numCentroids > 1) {
for (short c = 0; c < numCentroids; c++) {
float squareDist = scorer.score(c);
neighborsToCheck.insertWithOverflow(c, squareDist);
}
// pop the best
int sz = neighborsToCheck.size();
int best = neighborsToCheck.consumeNodesAndScoresMin(ordScoreIterator.ords, ordScoreIterator.scores);
// Set the size to the number of neighbors we actually found
ordScoreIterator.setSize(sz);
bestScore = ordScoreIterator.getScore(best);
bestCentroid = ordScoreIterator.getOrd(best);
}
clusters[bestCentroid].add(docID);
if (soarClusterCheckCount > 0) {
assignCentroidSOAR(
ordScoreIterator,
docID,
bestCentroid,
scorer.centroid(bestCentroid),
bestScore,
scratch,
scorer,
vector,
clusters
);
}
neighborsToCheck.clear();
}
}
static void assignCentroidSOAR(
OrdScoreIterator centroidsToCheck,
int vecOrd,
int bestCentroidId,
float[] bestCentroid,
float bestScore,
float[] scratch,
CentroidAssignmentScorer scorer,
float[] vector,
IntArrayList[] clusters
) throws IOException {
ESVectorUtil.subtract(vector, bestCentroid, scratch);
int bestSecondaryCentroid = -1;
float minDist = Float.MAX_VALUE;
for (int i = 0; i < centroidsToCheck.size(); i++) {
float score = centroidsToCheck.getScore(i);
int centroidOrdinal = centroidsToCheck.getOrd(i);
if (centroidOrdinal == bestCentroidId) {
continue;
}
float proj = ESVectorUtil.soarResidual(vector, scorer.centroid(centroidOrdinal), scratch);
score += SOAR_LAMBDA * proj * proj / bestScore;
if (score < minDist) {
bestSecondaryCentroid = centroidOrdinal;
minDist = score;
}
}
if (bestSecondaryCentroid != -1) {
clusters[bestSecondaryCentroid].add(vecOrd);
}
}
static class OrdScoreIterator {
private final int[] ords;
private final float[] scores;
private int idx = 0;
OrdScoreIterator(int size) {
this.ords = new int[size];
this.scores = new float[size];
}
int setSize(int size) {
if (size > ords.length) {
throw new IllegalArgumentException("size must be <= " + ords.length);
}
this.idx = size;
return size;
}
int getOrd(int idx) {
return ords[idx];
}
float getScore(int idx) {
return scores[idx];
}
int size() {
return idx;
}
}
// TODO unify with OSQ format
static class BinarizedFloatVectorValues {
private OptimizedScalarQuantizer.QuantizationResult corrections;
private final byte[] binarized;
private final byte[] initQuantized;
private float[] centroid;
private final FloatVectorValues values;
private final OptimizedScalarQuantizer quantizer;
private int lastOrd = -1;
BinarizedFloatVectorValues(FloatVectorValues delegate, OptimizedScalarQuantizer quantizer) {
this.values = delegate;
this.quantizer = quantizer;
this.binarized = new byte[discretize(delegate.dimension(), 64) / 8];
this.initQuantized = new byte[delegate.dimension()];
}
public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) {
if (ord != lastOrd) {
throw new IllegalStateException(
"attempt to retrieve corrective terms for different ord " + ord + " than the quantization was done for: " + lastOrd
);
}
return corrections;
}
public byte[] vectorValue(int ord) throws IOException {
if (ord != lastOrd) {
binarize(ord);
lastOrd = ord;
}
return binarized;
}
private void binarize(int ord) throws IOException {
corrections = quantizer.scalarQuantize(values.vectorValue(ord), initQuantized, INDEX_BITS, centroid);
packAsBinary(initQuantized, binarized);
}
}
static class OffHeapCentroidAssignmentScorer implements CentroidAssignmentScorer {
private final IndexInput centroidsInput;
private final int numCentroids;
private final int dimension;
private final float[] scratch;
private float[] q;
private final long rawCentroidOffset;
private int currOrd = -1;
OffHeapCentroidAssignmentScorer(IndexInput centroidsInput, int numCentroids, FieldInfo info) {
this.centroidsInput = centroidsInput;
this.numCentroids = numCentroids;
this.dimension = info.getVectorDimension();
this.scratch = new float[dimension];
this.rawCentroidOffset = (dimension + 3 * Float.BYTES + Short.BYTES) * numCentroids;
}
@Override
public int size() {
return numCentroids;
}
@Override
public float[] centroid(int centroidOrdinal) throws IOException {
if (centroidOrdinal == currOrd) {
return scratch;
}
centroidsInput.seek(rawCentroidOffset + (long) centroidOrdinal * dimension * Float.BYTES);
centroidsInput.readFloats(scratch, 0, dimension);
this.currOrd = centroidOrdinal;
return scratch;
}
@Override
public void setScoringVector(float[] vector) {
q = vector;
}
@Override
public float score(int centroidOrdinal) throws IOException {
return VectorUtil.squareDistance(centroid(centroidOrdinal), q);
}
}
// TODO throw away rawCentroids
static class OnHeapCentroidAssignmentScorer implements CentroidAssignmentScorer {
private final float[][] centroids;
private float[] q;
OnHeapCentroidAssignmentScorer(float[][] centroids) {
this.centroids = centroids;
}
@Override
public int size() {
return centroids.length;
}
@Override
public void setScoringVector(float[] vector) {
q = vector;
}
@Override
public float[] centroid(int centroidOrdinal) throws IOException {
return centroids[centroidOrdinal];
}
@Override
public float score(int centroidOrdinal) throws IOException {
return VectorUtil.squareDistance(centroid(centroidOrdinal), q);
}
}
static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections)
throws IOException {
indexOutput.writeBytes(binaryValue, binaryValue.length);
indexOutput.writeInt(Float.floatToIntBits(corrections.lowerInterval()));
indexOutput.writeInt(Float.floatToIntBits(corrections.upperInterval()));
indexOutput.writeInt(Float.floatToIntBits(corrections.additionalCorrection()));
assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff;
indexOutput.writeShort((short) corrections.quantizedComponentSum());
}
}

View file

@ -0,0 +1,110 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.index.codec.vectors;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import java.io.IOException;
/**
* Codec format for Inverted File Vector indexes. This index expects to break the dimensional space
* into clusters and assign each vector to a cluster generating a posting list of vectors. Clusters
* are represented by centroids.
* The vector quantization format used here is a per-vector optimized scalar quantization. Also see {@link
* OptimizedScalarQuantizer}. Some of key features are:
*
* The format is stored in three files:
*
* <h2>.cenivf (centroid data) file</h2>
* <p> Which stores the raw and quantized centroid vectors.
*
* <h2>.clivf (cluster data) file</h2>
*
* <p> Stores the quantized vectors for each cluster, inline and stored in blocks. Additionally, the docIds of
* each vector is stored.
*
* <h2>.mivf (centroid metadata) file</h2>
*
* <p> Stores metadata including the number of centroids and their offsets in the clivf file</p>
*
*/
public class IVFVectorsFormat extends KnnVectorsFormat {
public static final String IVF_VECTOR_COMPONENT = "IVF";
public static final String NAME = "IVFVectorsFormat";
// centroid ordinals -> centroid values, offsets
public static final String CENTROID_EXTENSION = "cenivf";
// offsets contained in cen_ivf, [vector ordinals, actually just docIds](long varint), quantized
// vectors (OSQ bit)
public static final String CLUSTER_EXTENSION = "clivf";
static final String IVF_META_EXTENSION = "mivf";
public static final int VERSION_START = 0;
public static final int VERSION_CURRENT = VERSION_START;
private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat(
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
);
private static final int DEFAULT_VECTORS_PER_CLUSTER = 1000;
private final int vectorPerCluster;
public IVFVectorsFormat(int vectorPerCluster) {
super(NAME);
if (vectorPerCluster <= 0) {
throw new IllegalArgumentException("vectorPerCluster must be > 0");
}
this.vectorPerCluster = vectorPerCluster;
}
/** Constructs a format using the given graph construction parameters and scalar quantization. */
public IVFVectorsFormat() {
this(DEFAULT_VECTORS_PER_CLUSTER);
}
@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new DefaultIVFVectorsWriter(state, rawVectorFormat.fieldsWriter(state), vectorPerCluster);
}
@Override
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new DefaultIVFVectorsReader(state, rawVectorFormat.fieldsReader(state));
}
@Override
public int getMaxDimensions(String fieldName) {
return 1024;
}
@Override
public String toString() {
return "IVFVectorFormat";
}
static IVFVectorsReader getIVFReader(KnnVectorsReader vectorsReader, String fieldName) {
if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
vectorsReader = candidateReader.getFieldReader(fieldName);
}
if (vectorsReader instanceof IVFVectorsReader reader) {
return reader;
}
return null;
}
}

View file

@ -0,0 +1,354 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.index.codec.vectors;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.elasticsearch.core.IOUtils;
import java.io.IOException;
import java.util.function.IntPredicate;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
/**
* Reader for IVF vectors. This reader is used to read the IVF vectors from the index.
*/
public abstract class IVFVectorsReader extends KnnVectorsReader {
private final IndexInput ivfCentroids, ivfClusters;
private final SegmentReadState state;
private final FieldInfos fieldInfos;
protected final IntObjectHashMap<FieldEntry> fields;
private final FlatVectorsReader rawVectorsReader;
protected IVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException {
this.state = state;
this.fieldInfos = state.fieldInfos;
this.rawVectorsReader = rawVectorsReader;
this.fields = new IntObjectHashMap<>();
String meta = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, IVFVectorsFormat.IVF_META_EXTENSION);
int versionMeta = -1;
boolean success = false;
try (ChecksumIndexInput ivfMeta = state.directory.openChecksumInput(meta)) {
Throwable priorE = null;
try {
versionMeta = CodecUtil.checkIndexHeader(
ivfMeta,
IVFVectorsFormat.NAME,
IVFVectorsFormat.VERSION_START,
IVFVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix
);
readFields(ivfMeta);
} catch (Throwable exception) {
priorE = exception;
} finally {
CodecUtil.checkFooter(ivfMeta, priorE);
}
ivfCentroids = openDataInput(state, versionMeta, IVFVectorsFormat.CENTROID_EXTENSION, IVFVectorsFormat.NAME, state.context);
ivfClusters = openDataInput(state, versionMeta, IVFVectorsFormat.CLUSTER_EXTENSION, IVFVectorsFormat.NAME, state.context);
success = true;
} finally {
if (success == false) {
IOUtils.closeWhileHandlingException(this);
}
}
}
abstract CentroidQueryScorer getCentroidScorer(
FieldInfo fieldInfo,
int numCentroids,
IndexInput centroids,
float[] target,
IndexInput clusters
) throws IOException;
protected abstract FloatVectorValues getCentroids(IndexInput indexInput, int numCentroids, FieldInfo info) throws IOException;
public FloatVectorValues getCentroids(FieldInfo fieldInfo) throws IOException {
FieldEntry entry = fields.get(fieldInfo.number);
if (entry == null) {
return null;
}
return getCentroids(entry.centroidSlice(ivfCentroids), entry.postingListOffsets.length, fieldInfo);
}
int centroidSize(String fieldName, int centroidOrdinal) throws IOException {
FieldInfo fieldInfo = state.fieldInfos.fieldInfo(fieldName);
FieldEntry entry = fields.get(fieldInfo.number);
ivfClusters.seek(entry.postingListOffsets[centroidOrdinal]);
return ivfClusters.readVInt();
}
private static IndexInput openDataInput(
SegmentReadState state,
int versionMeta,
String fileExtension,
String codecName,
IOContext context
) throws IOException {
final String fileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension);
final IndexInput in = state.directory.openInput(fileName, context);
boolean success = false;
try {
final int versionVectorData = CodecUtil.checkIndexHeader(
in,
codecName,
IVFVectorsFormat.VERSION_START,
IVFVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix
);
if (versionMeta != versionVectorData) {
throw new CorruptIndexException(
"Format versions mismatch: meta=" + versionMeta + ", " + codecName + "=" + versionVectorData,
in
);
}
CodecUtil.retrieveChecksum(in);
success = true;
return in;
} finally {
if (success == false) {
IOUtils.closeWhileHandlingException(in);
}
}
}
private void readFields(ChecksumIndexInput meta) throws IOException {
for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) {
final FieldInfo info = fieldInfos.fieldInfo(fieldNumber);
if (info == null) {
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
}
fields.put(info.number, readField(meta, info));
}
}
private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException {
final VectorEncoding vectorEncoding = readVectorEncoding(input);
final VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
final long centroidOffset = input.readLong();
final long centroidLength = input.readLong();
final int numPostingLists = input.readVInt();
final long[] postingListOffsets = new long[numPostingLists];
for (int i = 0; i < numPostingLists; i++) {
postingListOffsets[i] = input.readLong();
}
final float[] globalCentroid = new float[info.getVectorDimension()];
float globalCentroidDp = 0;
if (numPostingLists > 0) {
input.readFloats(globalCentroid, 0, globalCentroid.length);
globalCentroidDp = Float.intBitsToFloat(input.readInt());
}
if (similarityFunction != info.getVectorSimilarityFunction()) {
throw new IllegalStateException(
"Inconsistent vector similarity function for field=\""
+ info.name
+ "\"; "
+ similarityFunction
+ " != "
+ info.getVectorSimilarityFunction()
);
}
return new FieldEntry(
similarityFunction,
vectorEncoding,
centroidOffset,
centroidLength,
postingListOffsets,
globalCentroid,
globalCentroidDp
);
}
private static VectorSimilarityFunction readSimilarityFunction(DataInput input) throws IOException {
final int i = input.readInt();
if (i < 0 || i >= SIMILARITY_FUNCTIONS.size()) {
throw new IllegalArgumentException("invalid distance function: " + i);
}
return SIMILARITY_FUNCTIONS.get(i);
}
private static VectorEncoding readVectorEncoding(DataInput input) throws IOException {
final int encodingId = input.readInt();
if (encodingId < 0 || encodingId >= VectorEncoding.values().length) {
throw new CorruptIndexException("Invalid vector encoding id: " + encodingId, input);
}
return VectorEncoding.values()[encodingId];
}
@Override
public final void checkIntegrity() throws IOException {
rawVectorsReader.checkIntegrity();
CodecUtil.checksumEntireFile(ivfCentroids);
CodecUtil.checksumEntireFile(ivfClusters);
}
@Override
public final FloatVectorValues getFloatVectorValues(String field) throws IOException {
return rawVectorsReader.getFloatVectorValues(field);
}
@Override
public final ByteVectorValues getByteVectorValues(String field) throws IOException {
return rawVectorsReader.getByteVectorValues(field);
}
protected float[] getGlobalCentroid(FieldInfo info) {
if (info == null || info.getVectorEncoding().equals(VectorEncoding.BYTE)) {
return null;
}
FieldEntry entry = fields.get(info.number);
if (entry == null) {
return null;
}
return entry.globalCentroid();
}
@Override
public final void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
final FieldInfo fieldInfo = state.fieldInfos.fieldInfo(field);
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32) == false) {
rawVectorsReader.search(field, target, knnCollector, acceptDocs);
return;
}
// TODO add new ivf search strategy
int nProbe = 10;
float percentFiltered = 1f;
if (acceptDocs instanceof BitSet bitSet) {
percentFiltered = Math.max(0f, Math.min(1f, (float) bitSet.approximateCardinality() / bitSet.length()));
}
int numVectors = rawVectorsReader.getFloatVectorValues(field).size();
BitSet visitedDocs = new FixedBitSet(state.segmentInfo.maxDoc() + 1);
IntPredicate needsScoring = docId -> {
if (acceptDocs != null && acceptDocs.get(docId) == false) {
return false;
}
return visitedDocs.getAndSet(docId) == false;
};
FieldEntry entry = fields.get(fieldInfo.number);
CentroidQueryScorer centroidQueryScorer = getCentroidScorer(
fieldInfo,
entry.postingListOffsets.length,
entry.centroidSlice(ivfCentroids),
target,
ivfClusters
);
final NeighborQueue centroidQueue = scorePostingLists(fieldInfo, knnCollector, centroidQueryScorer, nProbe);
PostingVisitor scorer = getPostingVisitor(fieldInfo, ivfClusters, target, needsScoring);
int centroidsVisited = 0;
long expectedDocs = 0;
long actualDocs = 0;
// initially we visit only the "centroids to search"
while (centroidQueue.size() > 0 && centroidsVisited < nProbe) {
++centroidsVisited;
// todo do we actually need to know the score???
int centroidOrdinal = centroidQueue.pop();
// todo do we need direct access to the raw centroid???
expectedDocs += scorer.resetPostingsScorer(centroidOrdinal, centroidQueryScorer.centroid(centroidOrdinal));
actualDocs += scorer.visit(knnCollector);
}
if (acceptDocs != null) {
float unfilteredRatioVisited = (float) expectedDocs / numVectors;
int filteredVectors = (int) Math.ceil(numVectors * percentFiltered);
float expectedScored = Math.min(2 * filteredVectors * unfilteredRatioVisited, expectedDocs / 2f);
while (centroidQueue.size() > 0 && (actualDocs < expectedScored || actualDocs < knnCollector.k())) {
int centroidOrdinal = centroidQueue.pop();
scorer.resetPostingsScorer(centroidOrdinal, centroidQueryScorer.centroid(centroidOrdinal));
actualDocs += scorer.visit(knnCollector);
}
}
}
@Override
public final void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
final FieldInfo fieldInfo = state.fieldInfos.fieldInfo(field);
final ByteVectorValues values = rawVectorsReader.getByteVectorValues(field);
for (int i = 0; i < values.size(); i++) {
final float score = fieldInfo.getVectorSimilarityFunction().compare(target, values.vectorValue(i));
knnCollector.collect(values.ordToDoc(i), score);
if (knnCollector.earlyTerminated()) {
return;
}
}
}
abstract NeighborQueue scorePostingLists(
FieldInfo fieldInfo,
KnnCollector knnCollector,
CentroidQueryScorer centroidQueryScorer,
int nProbe
) throws IOException;
@Override
public void close() throws IOException {
IOUtils.close(rawVectorsReader, ivfCentroids, ivfClusters);
}
protected record FieldEntry(
VectorSimilarityFunction similarityFunction,
VectorEncoding vectorEncoding,
long centroidOffset,
long centroidLength,
long[] postingListOffsets,
float[] globalCentroid,
float globalCentroidDp
) {
IndexInput centroidSlice(IndexInput centroidFile) throws IOException {
return centroidFile.slice("centroids", centroidOffset, centroidLength);
}
}
abstract PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput postingsLists, float[] target, IntPredicate needsScoring)
throws IOException;
interface CentroidQueryScorer {
int size();
float[] centroid(int centroidOrdinal) throws IOException;
float score(int centroidOrdinal) throws IOException;
}
interface PostingVisitor {
// TODO maybe we can not specifically pass the centroid...
/** returns the number of documents in the posting list */
int resetPostingsScorer(int centroidOrdinal, float[] centroid) throws IOException;
/** returns the number of scored documents */
int visit(KnnCollector collector) throws IOException;
}
}

View file

@ -0,0 +1,486 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.index.codec.vectors;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.SuppressForbidden;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
/**
* Base class for IVF vectors writer.
*/
public abstract class IVFVectorsWriter extends KnnVectorsWriter {
private final List<FieldWriter> fieldWriters = new ArrayList<>();
private final IndexOutput ivfCentroids, ivfClusters;
private final IndexOutput ivfMeta;
private final FlatVectorsWriter rawVectorDelegate;
private final SegmentWriteState segmentWriteState;
protected IVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVectorDelegate) throws IOException {
this.segmentWriteState = state;
this.rawVectorDelegate = rawVectorDelegate;
final String metaFileName = IndexFileNames.segmentFileName(
state.segmentInfo.name,
state.segmentSuffix,
IVFVectorsFormat.IVF_META_EXTENSION
);
final String ivfCentroidsFileName = IndexFileNames.segmentFileName(
state.segmentInfo.name,
state.segmentSuffix,
IVFVectorsFormat.CENTROID_EXTENSION
);
final String ivfClustersFileName = IndexFileNames.segmentFileName(
state.segmentInfo.name,
state.segmentSuffix,
IVFVectorsFormat.CLUSTER_EXTENSION
);
boolean success = false;
try {
ivfMeta = state.directory.createOutput(metaFileName, state.context);
CodecUtil.writeIndexHeader(
ivfMeta,
IVFVectorsFormat.NAME,
IVFVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix
);
ivfCentroids = state.directory.createOutput(ivfCentroidsFileName, state.context);
CodecUtil.writeIndexHeader(
ivfCentroids,
IVFVectorsFormat.NAME,
IVFVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix
);
ivfClusters = state.directory.createOutput(ivfClustersFileName, state.context);
CodecUtil.writeIndexHeader(
ivfClusters,
IVFVectorsFormat.NAME,
IVFVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix
);
success = true;
} finally {
if (success == false) {
IOUtils.closeWhileHandlingException(this);
}
}
}
@Override
public final KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
if (fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE) {
throw new IllegalArgumentException("IVF does not support cosine similarity");
}
final FlatFieldVectorsWriter<?> rawVectorDelegate = this.rawVectorDelegate.addField(fieldInfo);
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
@SuppressWarnings("unchecked")
final FlatFieldVectorsWriter<float[]> floatWriter = (FlatFieldVectorsWriter<float[]>) rawVectorDelegate;
fieldWriters.add(new FieldWriter(fieldInfo, floatWriter));
}
return rawVectorDelegate;
}
protected abstract int calculateAndWriteCentroids(
FieldInfo fieldInfo,
FloatVectorValues floatVectorValues,
IndexOutput temporaryCentroidOutput,
MergeState mergeState,
float[] globalCentroid
) throws IOException;
abstract long[] buildAndWritePostingsLists(
FieldInfo fieldInfo,
CentroidAssignmentScorer scorer,
FloatVectorValues floatVectorValues,
IndexOutput postingsOutput,
MergeState mergeState
) throws IOException;
abstract CentroidAssignmentScorer calculateAndWriteCentroids(
FieldInfo fieldInfo,
FloatVectorValues floatVectorValues,
IndexOutput centroidOutput,
float[] globalCentroid
) throws IOException;
abstract long[] buildAndWritePostingsLists(
FieldInfo fieldInfo,
InfoStream infoStream,
CentroidAssignmentScorer scorer,
FloatVectorValues floatVectorValues,
IndexOutput postingsOutput
) throws IOException;
abstract CentroidAssignmentScorer createCentroidScorer(
IndexInput centroidsInput,
int numCentroids,
FieldInfo fieldInfo,
float[] globalCentroid
) throws IOException;
@Override
public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
rawVectorDelegate.flush(maxDoc, sortMap);
for (FieldWriter fieldWriter : fieldWriters) {
float[] globalCentroid = new float[fieldWriter.fieldInfo.getVectorDimension()];
// calculate global centroid
for (var vector : fieldWriter.delegate.getVectors()) {
for (int i = 0; i < globalCentroid.length; i++) {
globalCentroid[i] += vector[i];
}
}
for (int i = 0; i < globalCentroid.length; i++) {
globalCentroid[i] /= fieldWriter.delegate.getVectors().size();
}
// build a float vector values with random access
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldWriter.fieldInfo, fieldWriter.delegate, maxDoc);
// build centroids
long centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES);
final CentroidAssignmentScorer centroidAssignmentScorer = calculateAndWriteCentroids(
fieldWriter.fieldInfo,
floatVectorValues,
ivfCentroids,
globalCentroid
);
long centroidLength = ivfCentroids.getFilePointer() - centroidOffset;
final long[] offsets = buildAndWritePostingsLists(
fieldWriter.fieldInfo,
segmentWriteState.infoStream,
centroidAssignmentScorer,
floatVectorValues,
ivfClusters
);
writeMeta(fieldWriter.fieldInfo, centroidOffset, centroidLength, offsets, globalCentroid);
}
}
private static FloatVectorValues getFloatVectorValues(
FieldInfo fieldInfo,
FlatFieldVectorsWriter<float[]> fieldVectorsWriter,
int maxDoc
) throws IOException {
List<float[]> vectors = fieldVectorsWriter.getVectors();
if (vectors.size() == maxDoc) {
return FloatVectorValues.fromFloats(vectors, fieldInfo.getVectorDimension());
}
final DocIdSetIterator iterator = fieldVectorsWriter.getDocsWithFieldSet().iterator();
final int[] docIds = new int[vectors.size()];
for (int i = 0; i < docIds.length; i++) {
docIds[i] = iterator.nextDoc();
}
assert iterator.nextDoc() == NO_MORE_DOCS;
return new FloatVectorValues() {
@Override
public float[] vectorValue(int ord) {
return vectors.get(ord);
}
@Override
public FloatVectorValues copy() {
return this;
}
@Override
public int dimension() {
return fieldInfo.getVectorDimension();
}
@Override
public int size() {
return vectors.size();
}
@Override
public int ordToDoc(int ord) {
return docIds[ord];
}
};
}
static IVFVectorsReader getIVFReader(KnnVectorsReader vectorsReader, String fieldName) {
if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
vectorsReader = candidateReader.getFieldReader(fieldName);
}
if (vectorsReader instanceof IVFVectorsReader reader) {
return reader;
}
return null;
}
@Override
@SuppressForbidden(reason = "require usage of Lucene's IOUtils#deleteFilesIgnoringExceptions(...)")
public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
final int numVectors;
String tempRawVectorsFileName = null;
boolean success = false;
// build a float vector values with random access. In order to do that we dump the vectors to
// a temporary file
// and write the docID follow by the vector
try (IndexOutput out = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "ivf_", IOContext.DEFAULT)) {
tempRawVectorsFileName = out.getName();
// TODO do this better, we shouldn't have to write to a temp file, we should be able to
// to just from the merged vector values, the tricky part is the random access.
numVectors = writeFloatVectorValues(fieldInfo, out, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState));
CodecUtil.writeFooter(out);
success = true;
} finally {
if (success == false && tempRawVectorsFileName != null) {
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName);
}
}
try (IndexInput in = mergeState.segmentInfo.dir.openInput(tempRawVectorsFileName, IOContext.DEFAULT)) {
float[] calculatedGlobalCentroid = new float[fieldInfo.getVectorDimension()];
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, in, numVectors);
success = false;
CentroidAssignmentScorer centroidAssignmentScorer;
long centroidOffset;
long centroidLength;
String centroidTempName = null;
int numCentroids;
IndexOutput centroidTemp = null;
try {
centroidTemp = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "civf_", IOContext.DEFAULT);
centroidTempName = centroidTemp.getName();
numCentroids = calculateAndWriteCentroids(
fieldInfo,
floatVectorValues,
centroidTemp,
mergeState,
calculatedGlobalCentroid
);
success = true;
} finally {
if (success == false && centroidTempName != null) {
IOUtils.closeWhileHandlingException(centroidTemp);
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, centroidTempName);
}
}
try {
if (numCentroids == 0) {
centroidOffset = ivfCentroids.getFilePointer();
writeMeta(fieldInfo, centroidOffset, 0, new long[0], null);
CodecUtil.writeFooter(centroidTemp);
IOUtils.close(centroidTemp);
return;
}
CodecUtil.writeFooter(centroidTemp);
IOUtils.close(centroidTemp);
centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES);
try (IndexInput centroidInput = mergeState.segmentInfo.dir.openInput(centroidTempName, IOContext.DEFAULT)) {
ivfCentroids.copyBytes(centroidInput, centroidInput.length() - CodecUtil.footerLength());
centroidLength = ivfCentroids.getFilePointer() - centroidOffset;
centroidAssignmentScorer = createCentroidScorer(centroidInput, numCentroids, fieldInfo, calculatedGlobalCentroid);
assert centroidAssignmentScorer.size() == numCentroids;
// build a float vector values with random access
// build centroids
final long[] offsets = buildAndWritePostingsLists(
fieldInfo,
centroidAssignmentScorer,
floatVectorValues,
ivfClusters,
mergeState
);
assert offsets.length == centroidAssignmentScorer.size();
writeMeta(fieldInfo, centroidOffset, centroidLength, offsets, calculatedGlobalCentroid);
}
} finally {
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(
mergeState.segmentInfo.dir,
tempRawVectorsFileName,
centroidTempName
);
}
} finally {
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName);
}
}
}
private static FloatVectorValues getFloatVectorValues(FieldInfo fieldInfo, IndexInput randomAccessInput, int numVectors) {
if (numVectors == 0) {
return FloatVectorValues.fromFloats(List.of(), fieldInfo.getVectorDimension());
}
final long length = (long) Float.BYTES * fieldInfo.getVectorDimension() + Integer.BYTES;
final float[] vector = new float[fieldInfo.getVectorDimension()];
return new FloatVectorValues() {
@Override
public float[] vectorValue(int ord) throws IOException {
randomAccessInput.seek(ord * length + Integer.BYTES);
randomAccessInput.readFloats(vector, 0, vector.length);
return vector;
}
@Override
public FloatVectorValues copy() {
return this;
}
@Override
public int dimension() {
return fieldInfo.getVectorDimension();
}
@Override
public int size() {
return numVectors;
}
@Override
public int ordToDoc(int ord) {
try {
randomAccessInput.seek(ord * length);
return randomAccessInput.readInt();
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
};
}
private static int writeFloatVectorValues(FieldInfo fieldInfo, IndexOutput out, FloatVectorValues floatVectorValues)
throws IOException {
int numVectors = 0;
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
final KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator();
for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) {
numVectors++;
float[] vector = floatVectorValues.vectorValue(iterator.index());
out.writeInt(iterator.docID());
buffer.asFloatBuffer().put(vector);
out.writeBytes(buffer.array(), buffer.array().length);
}
return numVectors;
}
private void writeMeta(FieldInfo field, long centroidOffset, long centroidLength, long[] offsets, float[] globalCentroid)
throws IOException {
ivfMeta.writeInt(field.number);
ivfMeta.writeInt(field.getVectorEncoding().ordinal());
ivfMeta.writeInt(distFuncToOrd(field.getVectorSimilarityFunction()));
ivfMeta.writeLong(centroidOffset);
ivfMeta.writeLong(centroidLength);
ivfMeta.writeVInt(offsets.length);
for (long offset : offsets) {
ivfMeta.writeLong(offset);
}
if (offsets.length > 0) {
final ByteBuffer buffer = ByteBuffer.allocate(globalCentroid.length * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
buffer.asFloatBuffer().put(globalCentroid);
ivfMeta.writeBytes(buffer.array(), buffer.array().length);
ivfMeta.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(globalCentroid, globalCentroid)));
}
}
private static int distFuncToOrd(VectorSimilarityFunction func) {
for (int i = 0; i < SIMILARITY_FUNCTIONS.size(); i++) {
if (SIMILARITY_FUNCTIONS.get(i).equals(func)) {
return (byte) i;
}
}
throw new IllegalArgumentException("invalid distance function: " + func);
}
@Override
public final void finish() throws IOException {
rawVectorDelegate.finish();
if (ivfMeta != null) {
// write end of fields marker
ivfMeta.writeInt(-1);
CodecUtil.writeFooter(ivfMeta);
}
if (ivfCentroids != null) {
CodecUtil.writeFooter(ivfCentroids);
}
if (ivfClusters != null) {
CodecUtil.writeFooter(ivfClusters);
}
}
@Override
public final void close() throws IOException {
IOUtils.close(rawVectorDelegate, ivfMeta, ivfCentroids, ivfClusters);
}
@Override
public final long ramBytesUsed() {
return rawVectorDelegate.ramBytesUsed();
}
private record FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter<float[]> delegate) {}
interface CentroidAssignmentScorer {
CentroidAssignmentScorer EMPTY = new CentroidAssignmentScorer() {
@Override
public int size() {
return 0;
}
@Override
public float[] centroid(int centroidOrdinal) {
throw new IllegalStateException("No centroids");
}
@Override
public float score(int centroidOrdinal) {
throw new IllegalStateException("No centroids");
}
@Override
public void setScoringVector(float[] vector) {
throw new IllegalStateException("No centroids");
}
};
int size();
float[] centroid(int centroidOrdinal) throws IOException;
void setScoringVector(float[] vector);
float score(int centroidOrdinal) throws IOException;
}
}

View file

@ -0,0 +1,159 @@
/*
* @notice
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* Modifications copyright (C) 2025 Elasticsearch B.V.
*/
package org.elasticsearch.index.codec.vectors;
import org.apache.lucene.util.LongHeap;
import org.apache.lucene.util.NumericUtils;
/**
* Copied from and modified from Apache Lucene.
*/
class NeighborQueue {
private enum Order {
MIN_HEAP {
@Override
long apply(long v) {
return v;
}
},
MAX_HEAP {
@Override
long apply(long v) {
// This cannot be just `-v` since Long.MIN_VALUE doesn't have a positive counterpart. It
// needs a function that returns MAX_VALUE for MIN_VALUE and vice-versa.
return -1 - v;
}
};
abstract long apply(long v);
}
private final LongHeap heap;
private final Order order;
NeighborQueue(int initialSize, boolean maxHeap) {
this.heap = new LongHeap(initialSize);
this.order = maxHeap ? Order.MAX_HEAP : Order.MIN_HEAP;
}
/**
* @return the number of elements in the heap
*/
public int size() {
return heap.size();
}
/**
* Adds a new graph arc, extending the storage as needed.
*
* @param newNode the neighbor node id
* @param newScore the score of the neighbor, relative to some other node
*/
public void add(int newNode, float newScore) {
heap.push(encode(newNode, newScore));
}
/**
* If the heap is not full (size is less than the initialSize provided to the constructor), adds a
* new node-and-score element. If the heap is full, compares the score against the current top
* score, and replaces the top element if newScore is better than (greater than unless the heap is
* reversed), the current top score.
*
* @param newNode the neighbor node id
* @param newScore the score of the neighbor, relative to some other node
*/
public boolean insertWithOverflow(int newNode, float newScore) {
return heap.insertWithOverflow(encode(newNode, newScore));
}
/**
* Encodes the node ID and its similarity score as long, preserving the Lucene tie-breaking rule
* that when two scores are equal, the smaller node ID must win.
* @param node the node ID
* @param score the node score
* @return the encoded score, node ID
*/
private long encode(int node, float score) {
return order.apply((((long) NumericUtils.floatToSortableInt(score)) << 32) | (0xFFFFFFFFL & ~node));
}
/** Returns the top element's node id. */
int topNode() {
return decodeNodeId(heap.top());
}
/**
* Returns the top element's node score. For the min heap this is the minimum score. For the max
* heap this is the maximum score.
*/
float topScore() {
return decodeScore(heap.top());
}
private float decodeScore(long heapValue) {
return NumericUtils.sortableIntToFloat((int) (order.apply(heapValue) >> 32));
}
private int decodeNodeId(long heapValue) {
return (int) ~(order.apply(heapValue));
}
/** Removes the top element and returns its node id. */
public int pop() {
return decodeNodeId(heap.pop());
}
public void consumeNodes(int[] dest) {
if (dest.length < size()) {
throw new IllegalArgumentException("Destination array is too small. Expected at least " + size() + " elements.");
}
for (int i = 0; i < size(); i++) {
dest[i] = decodeNodeId(heap.get(i + 1));
}
}
public int consumeNodesAndScoresMin(int[] dest, float[] scores) {
if (dest.length < size() || scores.length < size()) {
throw new IllegalArgumentException("Destination array is too small. Expected at least " + size() + " elements.");
}
float bestScore = Float.POSITIVE_INFINITY;
int bestIdx = 0;
for (int i = 0; i < size(); i++) {
long heapValue = heap.get(i + 1);
scores[i] = decodeScore(heapValue);
dest[i] = decodeNodeId(heapValue);
if (scores[i] < bestScore) {
bestScore = scores[i];
bestIdx = i;
}
}
return bestIdx;
}
public void clear() {
heap.clear();
}
@Override
public String toString() {
return "Neighbors[" + heap.size() + "]";
}
}

View file

@ -7,3 +7,4 @@ org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat
org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat
org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat
org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat
org.elasticsearch.index.codec.vectors.IVFVectorsFormat

View file

@ -0,0 +1,65 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.index.codec.vectors;
import com.carrotsearch.randomizedtesting.generators.RandomPicks;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
import org.apache.lucene.tests.util.TestUtil;
import org.elasticsearch.common.logging.LogConfigurator;
import org.junit.Before;
import java.util.List;
public class IVFVectorsFormatTests extends BaseKnnVectorsFormatTestCase {
static {
LogConfigurator.loadLog4jPlugins();
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
}
KnnVectorsFormat format;
@Before
@Override
public void setUp() throws Exception {
format = new IVFVectorsFormat(random().nextInt(10, 1000));
super.setUp();
}
@Override
protected VectorSimilarityFunction randomSimilarity() {
return RandomPicks.randomFrom(
random(),
List.of(
VectorSimilarityFunction.DOT_PRODUCT,
VectorSimilarityFunction.EUCLIDEAN,
VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT
)
);
}
@Override
protected VectorEncoding randomVectorEncoding() {
return VectorEncoding.FLOAT32;
}
@Override
public void testSearchWithVisitedLimit() {
// ivf doesn't enforce visitation limit
}
@Override
protected Codec getCodec() {
return TestUtil.alwaysKnnVectorsFormat(format);
}
}

View file

@ -0,0 +1,119 @@
/*
* @notice
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* Modifications copyright (C) 2025 Elasticsearch B.V.
*/
package org.elasticsearch.index.codec.vectors;
import org.elasticsearch.test.ESTestCase;
/**
* copied and modified from Lucene
*/
public class NeighborQueueTests extends ESTestCase {
public void testNeighborsProduct() {
// make sure we have the sign correct
NeighborQueue nn = new NeighborQueue(2, false);
assertTrue(nn.insertWithOverflow(2, 0.5f));
assertTrue(nn.insertWithOverflow(1, 0.2f));
assertTrue(nn.insertWithOverflow(3, 1f));
assertEquals(0.5f, nn.topScore(), 0);
nn.pop();
assertEquals(1f, nn.topScore(), 0);
nn.pop();
}
public void testNeighborsMaxHeap() {
NeighborQueue nn = new NeighborQueue(2, true);
assertTrue(nn.insertWithOverflow(2, 2));
assertTrue(nn.insertWithOverflow(1, 1));
assertFalse(nn.insertWithOverflow(3, 3));
assertEquals(2f, nn.topScore(), 0);
nn.pop();
assertEquals(1f, nn.topScore(), 0);
}
public void testTopMaxHeap() {
NeighborQueue nn = new NeighborQueue(2, true);
nn.add(1, 2);
nn.add(2, 1);
// lower scores are better; highest score on top
assertEquals(2, nn.topScore(), 0);
assertEquals(1, nn.topNode());
}
public void testTopMinHeap() {
NeighborQueue nn = new NeighborQueue(2, false);
nn.add(1, 0.5f);
nn.add(2, -0.5f);
// higher scores are better; lowest score on top
assertEquals(-0.5f, nn.topScore(), 0);
assertEquals(2, nn.topNode());
}
public void testClear() {
NeighborQueue nn = new NeighborQueue(2, false);
nn.add(1, 1.1f);
nn.add(2, -2.2f);
nn.clear();
assertEquals(0, nn.size());
}
public void testMaxSizeQueue() {
NeighborQueue nn = new NeighborQueue(2, false);
nn.add(1, 1);
nn.add(2, 2);
assertEquals(2, nn.size());
assertEquals(1, nn.topNode());
// insertWithOverflow does not extend the queue
nn.insertWithOverflow(3, 3);
assertEquals(2, nn.size());
assertEquals(2, nn.topNode());
// add does extend the queue beyond maxSize
nn.add(4, 1);
assertEquals(3, nn.size());
}
public void testUnboundedQueue() {
NeighborQueue nn = new NeighborQueue(1, true);
float maxScore = -2;
int maxNode = -1;
for (int i = 0; i < 256; i++) {
// initial size is 32
float score = random().nextFloat();
if (score > maxScore) {
maxScore = score;
maxNode = i;
}
nn.add(i, score);
}
assertEquals(maxScore, nn.topScore(), 0);
assertEquals(maxNode, nn.topNode());
}
public void testInvalidArguments() {
expectThrows(IllegalArgumentException.class, () -> new NeighborQueue(0, false));
}
public void testToString() {
assertEquals("Neighbors[0]", new NeighborQueue(2, false).toString());
}
}