mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 09:28:55 -04:00
This reverts commit 8a17a5ed5f
.
reapplying ivf format, but with a fix.
This commit is contained in:
parent
d07ec0cc44
commit
1324ee0115
23 changed files with 2576 additions and 3 deletions
|
@ -17,7 +17,7 @@ import org.apache.lucene.store.MMapDirectory;
|
|||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
|
||||
import org.elasticsearch.common.logging.LogConfigurator;
|
||||
import org.elasticsearch.simdvec.internal.vectorization.ES91OSQVectorsScorer;
|
||||
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
||||
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
|
||||
import org.openjdk.jmh.annotations.Benchmark;
|
||||
import org.openjdk.jmh.annotations.BenchmarkMode;
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
package org.elasticsearch.simdvec.internal.vectorization;
|
||||
package org.elasticsearch.simdvec;
|
||||
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.IndexInput;
|
|
@ -9,11 +9,13 @@
|
|||
|
||||
package org.elasticsearch.simdvec;
|
||||
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.util.BitUtil;
|
||||
import org.apache.lucene.util.Constants;
|
||||
import org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport;
|
||||
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.lang.invoke.MethodHandle;
|
||||
import java.lang.invoke.MethodHandles;
|
||||
import java.lang.invoke.MethodType;
|
||||
|
@ -41,6 +43,10 @@ public class ESVectorUtil {
|
|||
|
||||
private static final ESVectorUtilSupport IMPL = ESVectorizationProvider.getInstance().getVectorUtilSupport();
|
||||
|
||||
public static ES91OSQVectorsScorer getES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException {
|
||||
return ESVectorizationProvider.getInstance().newES91OSQVectorsScorer(input, dimension);
|
||||
}
|
||||
|
||||
public static long ipByteBinByte(byte[] q, byte[] d) {
|
||||
if (q.length != d.length * B_QUERY) {
|
||||
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + B_QUERY + " x " + d.length);
|
||||
|
@ -211,4 +217,40 @@ public class ESVectorUtil {
|
|||
assert stats.length == 6;
|
||||
IMPL.centerAndCalculateOSQStatsDp(target, centroid, centered, stats);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates the difference between two vectors and stores the result in a third vector.
|
||||
* @param v1 the first vector
|
||||
* @param v2 the second vector
|
||||
* @param result the result vector, must be the same length as the input vectors
|
||||
*/
|
||||
public static void subtract(float[] v1, float[] v2, float[] result) {
|
||||
if (v1.length != v2.length) {
|
||||
throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + v2.length);
|
||||
}
|
||||
if (result.length != v1.length) {
|
||||
throw new IllegalArgumentException("vector dimensions differ: " + result.length + "!=" + v1.length);
|
||||
}
|
||||
for (int i = 0; i < v1.length; i++) {
|
||||
result[i] = v1[i] - v2[i];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* calculates the spill-over score for a vector and a centroid, given its residual with
|
||||
* its actually nearest centroid
|
||||
* @param v1 the vector
|
||||
* @param centroid the centroid
|
||||
* @param originalResidual the residual with the actually nearest centroid
|
||||
* @return the spill-over score (soar)
|
||||
*/
|
||||
public static float soarResidual(float[] v1, float[] centroid, float[] originalResidual) {
|
||||
if (v1.length != centroid.length) {
|
||||
throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + centroid.length);
|
||||
}
|
||||
if (originalResidual.length != v1.length) {
|
||||
throw new IllegalArgumentException("vector dimensions differ: " + originalResidual.length + "!=" + v1.length);
|
||||
}
|
||||
return IMPL.soarResidual(v1, centroid, originalResidual);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -138,6 +138,18 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
|
|||
stats[5] = centroidDot;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float soarResidual(float[] v1, float[] centroid, float[] originalResidual) {
|
||||
assert v1.length == centroid.length;
|
||||
assert v1.length == originalResidual.length;
|
||||
float proj = 0;
|
||||
for (int i = 0; i < v1.length; i++) {
|
||||
float djk = v1[i] - centroid[i];
|
||||
proj = fma(djk, originalResidual[i], proj);
|
||||
}
|
||||
return proj;
|
||||
}
|
||||
|
||||
public static int ipByteBitImpl(byte[] q, byte[] d) {
|
||||
return ipByteBitImpl(q, d, 0);
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
package org.elasticsearch.simdvec.internal.vectorization;
|
||||
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
|
|
|
@ -28,4 +28,7 @@ public interface ESVectorUtilSupport {
|
|||
void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats);
|
||||
|
||||
void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats);
|
||||
|
||||
float soarResidual(float[] v1, float[] centroid, float[] originalResidual);
|
||||
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
package org.elasticsearch.simdvec.internal.vectorization;
|
||||
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
|
|
@ -13,6 +13,7 @@ import org.apache.lucene.store.IndexInput;
|
|||
import org.apache.lucene.util.Constants;
|
||||
import org.elasticsearch.logging.LogManager;
|
||||
import org.elasticsearch.logging.Logger;
|
||||
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Locale;
|
||||
|
|
|
@ -20,6 +20,7 @@ import org.apache.lucene.index.VectorSimilarityFunction;
|
|||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
|
||||
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.lang.foreign.MemorySegment;
|
||||
|
|
|
@ -367,6 +367,49 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
|
|||
return (1f - lambda) * xe * xe / norm2 + lambda * e;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float soarResidual(float[] v1, float[] centroid, float[] originalResidual) {
|
||||
assert v1.length == centroid.length;
|
||||
assert v1.length == originalResidual.length;
|
||||
float proj = 0;
|
||||
int i = 0;
|
||||
if (v1.length > 2 * FLOAT_SPECIES.length()) {
|
||||
FloatVector projVec1 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector projVec2 = FloatVector.zero(FLOAT_SPECIES);
|
||||
int unrolledLimit = FLOAT_SPECIES.loopBound(v1.length) - FLOAT_SPECIES.length();
|
||||
for (; i < unrolledLimit; i += 2 * FLOAT_SPECIES.length()) {
|
||||
// one
|
||||
FloatVector v1Vec0 = FloatVector.fromArray(FLOAT_SPECIES, v1, i);
|
||||
FloatVector centroidVec0 = FloatVector.fromArray(FLOAT_SPECIES, centroid, i);
|
||||
FloatVector originalResidualVec0 = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i);
|
||||
FloatVector djkVec0 = v1Vec0.sub(centroidVec0);
|
||||
projVec1 = fma(djkVec0, originalResidualVec0, projVec1);
|
||||
|
||||
// two
|
||||
FloatVector v1Vec1 = FloatVector.fromArray(FLOAT_SPECIES, v1, i + FLOAT_SPECIES.length());
|
||||
FloatVector centroidVec1 = FloatVector.fromArray(FLOAT_SPECIES, centroid, i + FLOAT_SPECIES.length());
|
||||
FloatVector originalResidualVec1 = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i + FLOAT_SPECIES.length());
|
||||
FloatVector djkVec1 = v1Vec1.sub(centroidVec1);
|
||||
projVec2 = fma(djkVec1, originalResidualVec1, projVec2);
|
||||
}
|
||||
// vector tail
|
||||
for (; i < FLOAT_SPECIES.loopBound(v1.length); i += FLOAT_SPECIES.length()) {
|
||||
FloatVector v1Vec = FloatVector.fromArray(FLOAT_SPECIES, v1, i);
|
||||
FloatVector centroidVec = FloatVector.fromArray(FLOAT_SPECIES, centroid, i);
|
||||
FloatVector originalResidualVec = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i);
|
||||
FloatVector djkVec = v1Vec.sub(centroidVec);
|
||||
projVec1 = fma(djkVec, originalResidualVec, projVec1);
|
||||
}
|
||||
proj += projVec1.add(projVec2).reduceLanes(ADD);
|
||||
}
|
||||
// tail
|
||||
for (; i < v1.length; i++) {
|
||||
float djk = v1[i] - centroid[i];
|
||||
proj = fma(djk, originalResidual[i], proj);
|
||||
}
|
||||
return proj;
|
||||
}
|
||||
|
||||
private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
|
||||
private static final VectorSpecies<Byte> BYTE_SPECIES_256 = ByteVector.SPECIES_256;
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ package org.elasticsearch.simdvec.internal.vectorization;
|
|||
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.MemorySegmentAccessInput;
|
||||
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.lang.foreign.MemorySegment;
|
||||
|
|
|
@ -268,6 +268,22 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
|
|||
}
|
||||
}
|
||||
|
||||
public void testSoarOverspillScore() {
|
||||
int size = random().nextInt(128, 512);
|
||||
float deltaEps = 1e-5f * size;
|
||||
var vector = new float[size];
|
||||
var centroid = new float[size];
|
||||
var preResidual = new float[size];
|
||||
for (int i = 0; i < size; ++i) {
|
||||
vector[i] = random().nextFloat();
|
||||
centroid[i] = random().nextFloat();
|
||||
preResidual[i] = random().nextFloat();
|
||||
}
|
||||
var expected = defaultedProvider.getVectorUtilSupport().soarResidual(vector, centroid, preResidual);
|
||||
var result = defOrPanamaProvider.getVectorUtilSupport().soarResidual(vector, centroid, preResidual);
|
||||
assertEquals(expected, result, deltaEps);
|
||||
}
|
||||
|
||||
void testIpByteBinImpl(ToLongBiFunction<byte[], byte[]> ipByteBinFunc) {
|
||||
int iterations = atLeast(50);
|
||||
for (int i = 0; i < iterations; i++) {
|
||||
|
|
|
@ -16,6 +16,7 @@ import org.apache.lucene.store.IndexInput;
|
|||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.store.MMapDirectory;
|
||||
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
|
||||
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
||||
|
||||
import static org.hamcrest.Matchers.lessThan;
|
||||
|
||||
|
|
|
@ -454,7 +454,8 @@ module org.elasticsearch.server {
|
|||
org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat,
|
||||
org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat,
|
||||
org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat,
|
||||
org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat;
|
||||
org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat,
|
||||
org.elasticsearch.index.codec.vectors.IVFVectorsFormat;
|
||||
|
||||
provides org.apache.lucene.codecs.Codec
|
||||
with
|
||||
|
|
|
@ -0,0 +1,420 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.codec.vectors;
|
||||
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
|
||||
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
||||
import org.elasticsearch.simdvec.ESVectorUtil;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.function.IntPredicate;
|
||||
|
||||
import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.QUERY_BITS;
|
||||
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
|
||||
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
|
||||
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
|
||||
import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize;
|
||||
import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.transposeHalfByte;
|
||||
import static org.elasticsearch.simdvec.ES91OSQVectorsScorer.BULK_SIZE;
|
||||
|
||||
/**
|
||||
* Default implementation of {@link IVFVectorsReader}. It scores the posting lists centroids using
|
||||
* brute force and then scores the top ones using the posting list.
|
||||
*/
|
||||
public class DefaultIVFVectorsReader extends IVFVectorsReader {
|
||||
private static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1);
|
||||
|
||||
public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException {
|
||||
super(state, rawVectorsReader);
|
||||
}
|
||||
|
||||
@Override
|
||||
CentroidQueryScorer getCentroidScorer(
|
||||
FieldInfo fieldInfo,
|
||||
int numCentroids,
|
||||
IndexInput centroids,
|
||||
float[] targetQuery,
|
||||
IndexInput clusters
|
||||
) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(fieldInfo.number);
|
||||
float[] globalCentroid = fieldEntry.globalCentroid();
|
||||
float globalCentroidDp = fieldEntry.globalCentroidDp();
|
||||
OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
|
||||
byte[] quantized = new byte[targetQuery.length];
|
||||
float[] targetScratch = ArrayUtil.copyArray(targetQuery);
|
||||
OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize(
|
||||
targetScratch,
|
||||
quantized,
|
||||
(byte) 4,
|
||||
globalCentroid
|
||||
);
|
||||
return new CentroidQueryScorer() {
|
||||
int currentCentroid = -1;
|
||||
private final byte[] quantizedCentroid = new byte[fieldInfo.getVectorDimension()];
|
||||
private final float[] centroid = new float[fieldInfo.getVectorDimension()];
|
||||
private final float[] centroidCorrectiveValues = new float[3];
|
||||
private int quantizedCentroidComponentSum;
|
||||
private final long centroidByteSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES;
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return numCentroids;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] centroid(int centroidOrdinal) throws IOException {
|
||||
readQuantizedAndRawCentroid(centroidOrdinal);
|
||||
return centroid;
|
||||
}
|
||||
|
||||
private void readQuantizedAndRawCentroid(int centroidOrdinal) throws IOException {
|
||||
if (centroidOrdinal == currentCentroid) {
|
||||
return;
|
||||
}
|
||||
centroids.seek(centroidOrdinal * centroidByteSize);
|
||||
quantizedCentroidComponentSum = readQuantizedValue(centroids, quantizedCentroid, centroidCorrectiveValues);
|
||||
centroids.seek(numCentroids * centroidByteSize + (long) Float.BYTES * quantizedCentroid.length * centroidOrdinal);
|
||||
centroids.readFloats(centroid, 0, centroid.length);
|
||||
currentCentroid = centroidOrdinal;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score(int centroidOrdinal) throws IOException {
|
||||
readQuantizedAndRawCentroid(centroidOrdinal);
|
||||
return int4QuantizedScore(
|
||||
quantized,
|
||||
queryParams,
|
||||
fieldInfo.getVectorDimension(),
|
||||
quantizedCentroid,
|
||||
centroidCorrectiveValues,
|
||||
quantizedCentroidComponentSum,
|
||||
globalCentroidDp,
|
||||
fieldInfo.getVectorSimilarityFunction()
|
||||
);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
protected FloatVectorValues getCentroids(IndexInput indexInput, int numCentroids, FieldInfo info) {
|
||||
FieldEntry entry = fields.get(info.number);
|
||||
if (entry == null) {
|
||||
return null;
|
||||
}
|
||||
return new OffHeapCentroidFloatVectorValues(numCentroids, indexInput, info.getVectorDimension());
|
||||
}
|
||||
|
||||
@Override
|
||||
NeighborQueue scorePostingLists(FieldInfo fieldInfo, KnnCollector knnCollector, CentroidQueryScorer centroidQueryScorer, int nProbe)
|
||||
throws IOException {
|
||||
NeighborQueue neighborQueue = new NeighborQueue(centroidQueryScorer.size(), true);
|
||||
// TODO Off heap scoring for quantized centroids?
|
||||
for (int centroid = 0; centroid < centroidQueryScorer.size(); centroid++) {
|
||||
neighborQueue.add(centroid, centroidQueryScorer.score(centroid));
|
||||
}
|
||||
return neighborQueue;
|
||||
}
|
||||
|
||||
@Override
|
||||
PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput indexInput, float[] target, IntPredicate needsScoring)
|
||||
throws IOException {
|
||||
FieldEntry entry = fields.get(fieldInfo.number);
|
||||
return new MemorySegmentPostingsVisitor(target, indexInput, entry, fieldInfo, needsScoring);
|
||||
}
|
||||
|
||||
// TODO can we do this in off-heap blocks?
|
||||
static float int4QuantizedScore(
|
||||
byte[] quantizedQuery,
|
||||
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
|
||||
int dims,
|
||||
byte[] binaryCode,
|
||||
float[] targetCorrections,
|
||||
int targetComponentSum,
|
||||
float centroidDp,
|
||||
VectorSimilarityFunction similarityFunction
|
||||
) {
|
||||
float qcDist = VectorUtil.int4DotProduct(quantizedQuery, binaryCode);
|
||||
float ax = targetCorrections[0];
|
||||
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
|
||||
float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE;
|
||||
float ay = queryCorrections.lowerInterval();
|
||||
float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
|
||||
float y1 = queryCorrections.quantizedComponentSum();
|
||||
float score = ax * ay * dims + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
|
||||
if (similarityFunction == EUCLIDEAN) {
|
||||
score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2 * score;
|
||||
return Math.max(1 / (1f + score), 0);
|
||||
} else {
|
||||
// For cosine and max inner product, we need to apply the additional correction, which is
|
||||
// assumed to be the non-centered dot-product between the vector and the centroid
|
||||
score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp;
|
||||
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
|
||||
return VectorUtil.scaleMaxInnerProductScore(score);
|
||||
}
|
||||
return Math.max((1f + score) / 2f, 0);
|
||||
}
|
||||
}
|
||||
|
||||
static class OffHeapCentroidFloatVectorValues extends FloatVectorValues {
|
||||
private final int numCentroids;
|
||||
private final IndexInput input;
|
||||
private final int dimension;
|
||||
private final float[] centroid;
|
||||
private final long centroidByteSize;
|
||||
private int ord = -1;
|
||||
|
||||
OffHeapCentroidFloatVectorValues(int numCentroids, IndexInput input, int dimension) {
|
||||
this.numCentroids = numCentroids;
|
||||
this.input = input;
|
||||
this.dimension = dimension;
|
||||
this.centroid = new float[dimension];
|
||||
this.centroidByteSize = dimension + 3 * Float.BYTES + Short.BYTES;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue(int ord) throws IOException {
|
||||
if (ord < 0 || ord >= numCentroids) {
|
||||
throw new IllegalArgumentException("ord must be in [0, " + numCentroids + "]");
|
||||
}
|
||||
if (ord == this.ord) {
|
||||
return centroid;
|
||||
}
|
||||
readQuantizedCentroid(ord);
|
||||
return centroid;
|
||||
}
|
||||
|
||||
private void readQuantizedCentroid(int centroidOrdinal) throws IOException {
|
||||
if (centroidOrdinal == ord) {
|
||||
return;
|
||||
}
|
||||
input.seek(numCentroids * centroidByteSize + (long) Float.BYTES * dimension * centroidOrdinal);
|
||||
input.readFloats(centroid, 0, centroid.length);
|
||||
ord = centroidOrdinal;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return dimension;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return numCentroids;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FloatVectorValues copy() throws IOException {
|
||||
return new OffHeapCentroidFloatVectorValues(numCentroids, input.clone(), dimension);
|
||||
}
|
||||
}
|
||||
|
||||
private static class MemorySegmentPostingsVisitor implements PostingVisitor {
|
||||
final long quantizedByteLength;
|
||||
final IndexInput indexInput;
|
||||
final float[] target;
|
||||
final FieldEntry entry;
|
||||
final FieldInfo fieldInfo;
|
||||
final IntPredicate needsScoring;
|
||||
private final ES91OSQVectorsScorer osqVectorsScorer;
|
||||
final float[] scores = new float[BULK_SIZE];
|
||||
final float[] correctionsLower = new float[BULK_SIZE];
|
||||
final float[] correctionsUpper = new float[BULK_SIZE];
|
||||
final int[] correctionsSum = new int[BULK_SIZE];
|
||||
final float[] correctionsAdd = new float[BULK_SIZE];
|
||||
|
||||
int[] docIdsScratch = new int[0];
|
||||
int vectors;
|
||||
boolean quantized = false;
|
||||
float centroidDp;
|
||||
float[] centroid;
|
||||
long slicePos;
|
||||
OptimizedScalarQuantizer.QuantizationResult queryCorrections;
|
||||
DocIdsWriter docIdsWriter = new DocIdsWriter();
|
||||
|
||||
final float[] scratch;
|
||||
final byte[] quantizationScratch;
|
||||
final byte[] quantizedQueryScratch;
|
||||
final OptimizedScalarQuantizer quantizer;
|
||||
final float[] correctiveValues = new float[3];
|
||||
final long quantizedVectorByteSize;
|
||||
|
||||
MemorySegmentPostingsVisitor(
|
||||
float[] target,
|
||||
IndexInput indexInput,
|
||||
FieldEntry entry,
|
||||
FieldInfo fieldInfo,
|
||||
IntPredicate needsScoring
|
||||
) throws IOException {
|
||||
this.target = target;
|
||||
this.indexInput = indexInput;
|
||||
this.entry = entry;
|
||||
this.fieldInfo = fieldInfo;
|
||||
this.needsScoring = needsScoring;
|
||||
|
||||
scratch = new float[target.length];
|
||||
quantizationScratch = new byte[target.length];
|
||||
final int discretizedDimensions = discretize(fieldInfo.getVectorDimension(), 64);
|
||||
quantizedQueryScratch = new byte[QUERY_BITS * discretizedDimensions / 8];
|
||||
quantizedByteLength = discretizedDimensions / 8 + (Float.BYTES * 3) + Short.BYTES;
|
||||
quantizedVectorByteSize = (discretizedDimensions / 8);
|
||||
quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
|
||||
osqVectorsScorer = ESVectorUtil.getES91OSQVectorsScorer(indexInput, fieldInfo.getVectorDimension());
|
||||
}
|
||||
|
||||
@Override
|
||||
public int resetPostingsScorer(int centroidOrdinal, float[] centroid) throws IOException {
|
||||
quantized = false;
|
||||
indexInput.seek(entry.postingListOffsets()[centroidOrdinal]);
|
||||
vectors = indexInput.readVInt();
|
||||
centroidDp = Float.intBitsToFloat(indexInput.readInt());
|
||||
this.centroid = centroid;
|
||||
// read the doc ids
|
||||
docIdsScratch = vectors > docIdsScratch.length ? new int[vectors] : docIdsScratch;
|
||||
docIdsWriter.readInts(indexInput, vectors, docIdsScratch);
|
||||
slicePos = indexInput.getFilePointer();
|
||||
return vectors;
|
||||
}
|
||||
|
||||
void scoreIndividually(int offset) throws IOException {
|
||||
// score individually, first the quantized byte chunk
|
||||
for (int j = 0; j < BULK_SIZE; j++) {
|
||||
int doc = docIdsScratch[j + offset];
|
||||
if (doc != -1) {
|
||||
indexInput.seek(slicePos + (offset * quantizedByteLength) + (j * quantizedVectorByteSize));
|
||||
float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch);
|
||||
scores[j] = qcDist;
|
||||
}
|
||||
}
|
||||
// read in all corrections
|
||||
indexInput.seek(slicePos + (offset * quantizedByteLength) + (BULK_SIZE * quantizedVectorByteSize));
|
||||
indexInput.readFloats(correctionsLower, 0, BULK_SIZE);
|
||||
indexInput.readFloats(correctionsUpper, 0, BULK_SIZE);
|
||||
for (int j = 0; j < BULK_SIZE; j++) {
|
||||
correctionsSum[j] = Short.toUnsignedInt(indexInput.readShort());
|
||||
}
|
||||
indexInput.readFloats(correctionsAdd, 0, BULK_SIZE);
|
||||
// Now apply corrections
|
||||
for (int j = 0; j < BULK_SIZE; j++) {
|
||||
int doc = docIdsScratch[offset + j];
|
||||
if (doc != -1) {
|
||||
scores[j] = osqVectorsScorer.score(
|
||||
queryCorrections,
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
centroidDp,
|
||||
correctionsLower[j],
|
||||
correctionsUpper[j],
|
||||
correctionsSum[j],
|
||||
correctionsAdd[j],
|
||||
scores[j]
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int visit(KnnCollector knnCollector) throws IOException {
|
||||
// block processing
|
||||
int scoredDocs = 0;
|
||||
int limit = vectors - BULK_SIZE + 1;
|
||||
int i = 0;
|
||||
for (; i < limit; i += BULK_SIZE) {
|
||||
int docsToScore = BULK_SIZE;
|
||||
for (int j = 0; j < BULK_SIZE; j++) {
|
||||
int doc = docIdsScratch[i + j];
|
||||
if (needsScoring.test(doc) == false) {
|
||||
docIdsScratch[i + j] = -1;
|
||||
docsToScore--;
|
||||
}
|
||||
}
|
||||
if (docsToScore == 0) {
|
||||
continue;
|
||||
}
|
||||
quantizeQueryIfNecessary();
|
||||
indexInput.seek(slicePos + i * quantizedByteLength);
|
||||
if (docsToScore < BULK_SIZE / 2) {
|
||||
scoreIndividually(i);
|
||||
} else {
|
||||
osqVectorsScorer.scoreBulk(
|
||||
quantizedQueryScratch,
|
||||
queryCorrections,
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
centroidDp,
|
||||
scores
|
||||
);
|
||||
}
|
||||
for (int j = 0; j < BULK_SIZE; j++) {
|
||||
int doc = docIdsScratch[i + j];
|
||||
if (doc != -1) {
|
||||
scoredDocs++;
|
||||
knnCollector.collect(doc, scores[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
// process tail
|
||||
for (; i < vectors; i++) {
|
||||
int doc = docIdsScratch[i];
|
||||
if (needsScoring.test(doc)) {
|
||||
quantizeQueryIfNecessary();
|
||||
indexInput.seek(slicePos + i * quantizedByteLength);
|
||||
float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch);
|
||||
indexInput.readFloats(correctiveValues, 0, 3);
|
||||
final int quantizedComponentSum = Short.toUnsignedInt(indexInput.readShort());
|
||||
float score = osqVectorsScorer.score(
|
||||
queryCorrections,
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
centroidDp,
|
||||
correctiveValues[0],
|
||||
correctiveValues[1],
|
||||
quantizedComponentSum,
|
||||
correctiveValues[2],
|
||||
qcDist
|
||||
);
|
||||
scoredDocs++;
|
||||
knnCollector.collect(doc, score);
|
||||
}
|
||||
}
|
||||
if (scoredDocs > 0) {
|
||||
knnCollector.incVisitedCount(scoredDocs);
|
||||
}
|
||||
return scoredDocs;
|
||||
}
|
||||
|
||||
private void quantizeQueryIfNecessary() {
|
||||
if (quantized == false) {
|
||||
System.arraycopy(target, 0, scratch, 0, target.length);
|
||||
if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
|
||||
VectorUtil.l2normalize(scratch);
|
||||
}
|
||||
queryCorrections = quantizer.scalarQuantize(scratch, quantizationScratch, (byte) 4, centroid);
|
||||
transposeHalfByte(quantizationScratch, quantizedQueryScratch);
|
||||
quantized = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static int readQuantizedValue(IndexInput indexInput, byte[] binaryValue, float[] corrections) throws IOException {
|
||||
assert corrections.length == 3;
|
||||
indexInput.readBytes(binaryValue, 0, binaryValue.length);
|
||||
corrections[0] = Float.intBitsToFloat(indexInput.readInt());
|
||||
corrections[1] = Float.intBitsToFloat(indexInput.readInt());
|
||||
corrections[2] = Float.intBitsToFloat(indexInput.readInt());
|
||||
return Short.toUnsignedInt(indexInput.readShort());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,736 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.codec.vectors;
|
||||
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.internal.hppc.IntArrayList;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.util.InfoStream;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
|
||||
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
||||
import org.elasticsearch.simdvec.ESVectorUtil;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.INDEX_BITS;
|
||||
import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize;
|
||||
import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.packAsBinary;
|
||||
import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.IVF_VECTOR_COMPONENT;
|
||||
|
||||
/**
|
||||
* Default implementation of {@link IVFVectorsWriter}. It uses {@link KMeans} algorithm to
|
||||
* partition the vector space, and then stores the centroids an posting list in a sequential
|
||||
* fashion.
|
||||
*/
|
||||
public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
||||
|
||||
static final float SOAR_LAMBDA = 1.0f;
|
||||
// What percentage of the centroids do we do a second check on for SOAR assignment
|
||||
static final float EXT_SOAR_LIMIT_CHECK_RATIO = 0.10f;
|
||||
|
||||
private final int vectorPerCluster;
|
||||
|
||||
public DefaultIVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVectorDelegate, int vectorPerCluster) throws IOException {
|
||||
super(state, rawVectorDelegate);
|
||||
this.vectorPerCluster = vectorPerCluster;
|
||||
}
|
||||
|
||||
@Override
|
||||
CentroidAssignmentScorer calculateAndWriteCentroids(
|
||||
FieldInfo fieldInfo,
|
||||
FloatVectorValues floatVectorValues,
|
||||
IndexOutput centroidOutput,
|
||||
float[] globalCentroid
|
||||
) throws IOException {
|
||||
if (floatVectorValues.size() == 0) {
|
||||
return CentroidAssignmentScorer.EMPTY;
|
||||
}
|
||||
// calculate the centroids
|
||||
int maxNumClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1;
|
||||
int desiredClusters = (int) Math.max(Math.sqrt(floatVectorValues.size()), maxNumClusters);
|
||||
final KMeans.Results kMeans = KMeans.cluster(
|
||||
floatVectorValues,
|
||||
desiredClusters,
|
||||
false,
|
||||
42L,
|
||||
KMeans.KmeansInitializationMethod.PLUS_PLUS,
|
||||
null,
|
||||
fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE,
|
||||
1,
|
||||
15,
|
||||
desiredClusters * 256
|
||||
);
|
||||
float[][] centroids = kMeans.centroids();
|
||||
// write them
|
||||
writeCentroids(centroids, fieldInfo, globalCentroid, centroidOutput);
|
||||
return new OnHeapCentroidAssignmentScorer(centroids);
|
||||
}
|
||||
|
||||
@Override
|
||||
long[] buildAndWritePostingsLists(
|
||||
FieldInfo fieldInfo,
|
||||
InfoStream infoStream,
|
||||
CentroidAssignmentScorer randomCentroidScorer,
|
||||
FloatVectorValues floatVectorValues,
|
||||
IndexOutput postingsOutput
|
||||
) throws IOException {
|
||||
IntArrayList[] clusters = new IntArrayList[randomCentroidScorer.size()];
|
||||
for (int i = 0; i < randomCentroidScorer.size(); i++) {
|
||||
clusters[i] = new IntArrayList(floatVectorValues.size() / randomCentroidScorer.size() / 4);
|
||||
}
|
||||
assignCentroids(randomCentroidScorer, floatVectorValues, clusters);
|
||||
if (infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
|
||||
printClusterQualityStatistics(clusters, infoStream);
|
||||
}
|
||||
// write the posting lists
|
||||
final long[] offsets = new long[randomCentroidScorer.size()];
|
||||
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
|
||||
BinarizedFloatVectorValues binarizedByteVectorValues = new BinarizedFloatVectorValues(floatVectorValues, quantizer);
|
||||
DocIdsWriter docIdsWriter = new DocIdsWriter();
|
||||
for (int i = 0; i < randomCentroidScorer.size(); i++) {
|
||||
float[] centroid = randomCentroidScorer.centroid(i);
|
||||
binarizedByteVectorValues.centroid = centroid;
|
||||
// TODO sort by distance to the centroid
|
||||
IntArrayList cluster = clusters[i];
|
||||
// TODO align???
|
||||
offsets[i] = postingsOutput.getFilePointer();
|
||||
int size = cluster.size();
|
||||
postingsOutput.writeVInt(size);
|
||||
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
|
||||
// TODO we might want to consider putting the docIds in a separate file
|
||||
// to aid with only having to fetch vectors from slower storage when they are required
|
||||
// keeping them in the same file indicates we pull the entire file into cache
|
||||
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster.get(j)), cluster.size(), postingsOutput);
|
||||
writePostingList(cluster, postingsOutput, binarizedByteVectorValues);
|
||||
}
|
||||
return offsets;
|
||||
}
|
||||
|
||||
private void writePostingList(IntArrayList cluster, IndexOutput postingsOutput, BinarizedFloatVectorValues binarizedByteVectorValues)
|
||||
throws IOException {
|
||||
int limit = cluster.size() - ES91OSQVectorsScorer.BULK_SIZE + 1;
|
||||
int cidx = 0;
|
||||
OptimizedScalarQuantizer.QuantizationResult[] corrections =
|
||||
new OptimizedScalarQuantizer.QuantizationResult[ES91OSQVectorsScorer.BULK_SIZE];
|
||||
// Write vectors in bulks of ES91OSQVectorsScorer.BULK_SIZE.
|
||||
for (; cidx < limit; cidx += ES91OSQVectorsScorer.BULK_SIZE) {
|
||||
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
|
||||
int ord = cluster.get(cidx + j);
|
||||
byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord);
|
||||
// write vector
|
||||
postingsOutput.writeBytes(binaryValue, 0, binaryValue.length);
|
||||
corrections[j] = binarizedByteVectorValues.getCorrectiveTerms(ord);
|
||||
}
|
||||
// write corrections
|
||||
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
|
||||
postingsOutput.writeInt(Float.floatToIntBits(corrections[j].lowerInterval()));
|
||||
}
|
||||
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
|
||||
postingsOutput.writeInt(Float.floatToIntBits(corrections[j].upperInterval()));
|
||||
}
|
||||
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
|
||||
int targetComponentSum = corrections[j].quantizedComponentSum();
|
||||
assert targetComponentSum >= 0 && targetComponentSum <= 0xffff;
|
||||
postingsOutput.writeShort((short) targetComponentSum);
|
||||
}
|
||||
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
|
||||
postingsOutput.writeInt(Float.floatToIntBits(corrections[j].additionalCorrection()));
|
||||
}
|
||||
}
|
||||
// write tail
|
||||
for (; cidx < cluster.size(); cidx++) {
|
||||
int ord = cluster.get(cidx);
|
||||
// write vector
|
||||
byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord);
|
||||
OptimizedScalarQuantizer.QuantizationResult correction = binarizedByteVectorValues.getCorrectiveTerms(ord);
|
||||
writeQuantizedValue(postingsOutput, binaryValue, correction);
|
||||
binarizedByteVectorValues.getCorrectiveTerms(ord);
|
||||
postingsOutput.writeBytes(binaryValue, 0, binaryValue.length);
|
||||
postingsOutput.writeInt(Float.floatToIntBits(correction.lowerInterval()));
|
||||
postingsOutput.writeInt(Float.floatToIntBits(correction.upperInterval()));
|
||||
postingsOutput.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
|
||||
assert correction.quantizedComponentSum() >= 0 && correction.quantizedComponentSum() <= 0xffff;
|
||||
postingsOutput.writeShort((short) correction.quantizedComponentSum());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
CentroidAssignmentScorer createCentroidScorer(
|
||||
IndexInput centroidsInput,
|
||||
int numCentroids,
|
||||
FieldInfo fieldInfo,
|
||||
float[] globalCentroid
|
||||
) {
|
||||
return new OffHeapCentroidAssignmentScorer(centroidsInput, numCentroids, fieldInfo);
|
||||
}
|
||||
|
||||
static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] globalCentroid, IndexOutput centroidOutput)
|
||||
throws IOException {
|
||||
final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
|
||||
byte[] quantizedScratch = new byte[fieldInfo.getVectorDimension()];
|
||||
float[] centroidScratch = new float[fieldInfo.getVectorDimension()];
|
||||
// TODO do we want to store these distances as well for future use?
|
||||
float[] distances = new float[centroids.length];
|
||||
for (int i = 0; i < centroids.length; i++) {
|
||||
distances[i] = VectorUtil.squareDistance(centroids[i], globalCentroid);
|
||||
}
|
||||
// sort the centroids by distance to globalCentroid, nearest (smallest distance), to furthest
|
||||
// (largest)
|
||||
for (int i = 0; i < centroids.length; i++) {
|
||||
for (int j = i + 1; j < centroids.length; j++) {
|
||||
if (distances[i] > distances[j]) {
|
||||
float[] tmp = centroids[i];
|
||||
centroids[i] = centroids[j];
|
||||
centroids[j] = tmp;
|
||||
float tmpDistance = distances[i];
|
||||
distances[i] = distances[j];
|
||||
distances[j] = tmpDistance;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (float[] centroid : centroids) {
|
||||
System.arraycopy(centroid, 0, centroidScratch, 0, centroid.length);
|
||||
OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(
|
||||
centroidScratch,
|
||||
quantizedScratch,
|
||||
(byte) 4,
|
||||
globalCentroid
|
||||
);
|
||||
writeQuantizedValue(centroidOutput, quantizedScratch, result);
|
||||
}
|
||||
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
for (float[] centroid : centroids) {
|
||||
buffer.asFloatBuffer().put(centroid);
|
||||
centroidOutput.writeBytes(buffer.array(), buffer.array().length);
|
||||
}
|
||||
}
|
||||
|
||||
static float[][] gatherInitCentroids(
|
||||
List<FloatVectorValues> centroidList,
|
||||
List<SegmentCentroid> segmentCentroids,
|
||||
int desiredClusters,
|
||||
FieldInfo fieldInfo,
|
||||
MergeState mergeState
|
||||
) throws IOException {
|
||||
if (centroidList.size() == 0) {
|
||||
return null;
|
||||
}
|
||||
long startTime = System.nanoTime();
|
||||
// sort centroid list by floatvector size
|
||||
FloatVectorValues baseSegment = centroidList.get(0);
|
||||
for (var l : centroidList) {
|
||||
if (l.size() > baseSegment.size()) {
|
||||
baseSegment = l;
|
||||
}
|
||||
}
|
||||
float[] scratch = new float[fieldInfo.getVectorDimension()];
|
||||
float minimumDistance = Float.MAX_VALUE;
|
||||
for (int j = 0; j < baseSegment.size(); j++) {
|
||||
System.arraycopy(baseSegment.vectorValue(j), 0, scratch, 0, baseSegment.dimension());
|
||||
for (int k = j + 1; k < baseSegment.size(); k++) {
|
||||
float d = VectorUtil.squareDistance(scratch, baseSegment.vectorValue(k));
|
||||
if (d < minimumDistance) {
|
||||
minimumDistance = d;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
|
||||
mergeState.infoStream.message(
|
||||
IVF_VECTOR_COMPONENT,
|
||||
"Agglomerative cluster min distance: " + minimumDistance + " From biggest segment: " + baseSegment.size()
|
||||
);
|
||||
}
|
||||
int[] labels = new int[segmentCentroids.size()];
|
||||
// loop over segments
|
||||
int clusterIdx = 0;
|
||||
// keep track of all inter-centroid distances,
|
||||
// using less than centroid * centroid space (e.g. not keeping track of duplicates)
|
||||
for (int i = 0; i < segmentCentroids.size(); i++) {
|
||||
if (labels[i] == 0) {
|
||||
clusterIdx += 1;
|
||||
labels[i] = clusterIdx;
|
||||
}
|
||||
SegmentCentroid segmentCentroid = segmentCentroids.get(i);
|
||||
System.arraycopy(
|
||||
centroidList.get(segmentCentroid.segment()).vectorValue(segmentCentroid.centroid),
|
||||
0,
|
||||
scratch,
|
||||
0,
|
||||
baseSegment.dimension()
|
||||
);
|
||||
for (int j = i + 1; j < segmentCentroids.size(); j++) {
|
||||
float d = VectorUtil.squareDistance(
|
||||
scratch,
|
||||
centroidList.get(segmentCentroids.get(j).segment()).vectorValue(segmentCentroids.get(j).centroid())
|
||||
);
|
||||
if (d < minimumDistance / 2) {
|
||||
if (labels[j] == 0) {
|
||||
labels[j] = labels[i];
|
||||
} else {
|
||||
for (int k = 0; k < labels.length; k++) {
|
||||
if (labels[k] == labels[j]) {
|
||||
labels[k] = labels[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
float[][] initCentroids = new float[clusterIdx][fieldInfo.getVectorDimension()];
|
||||
int[] sum = new int[clusterIdx];
|
||||
for (int i = 0; i < segmentCentroids.size(); i++) {
|
||||
SegmentCentroid segmentCentroid = segmentCentroids.get(i);
|
||||
int label = labels[i];
|
||||
FloatVectorValues segment = centroidList.get(segmentCentroid.segment());
|
||||
float[] vector = segment.vectorValue(segmentCentroid.centroid);
|
||||
for (int j = 0; j < vector.length; j++) {
|
||||
initCentroids[label - 1][j] += (vector[j] * segmentCentroid.centroidSize);
|
||||
}
|
||||
sum[label - 1] += segmentCentroid.centroidSize;
|
||||
}
|
||||
for (int i = 0; i < initCentroids.length; i++) {
|
||||
if (sum[i] == 0 || sum[i] == 1) {
|
||||
continue;
|
||||
}
|
||||
for (int j = 0; j < initCentroids[i].length; j++) {
|
||||
initCentroids[i][j] /= sum[i];
|
||||
}
|
||||
}
|
||||
if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
|
||||
mergeState.infoStream.message(
|
||||
IVF_VECTOR_COMPONENT,
|
||||
"Agglomerative cluster time ms: " + ((System.nanoTime() - startTime) / 1000000.0)
|
||||
);
|
||||
mergeState.infoStream.message(
|
||||
IVF_VECTOR_COMPONENT,
|
||||
"Gathered initCentroids:" + initCentroids.length + " for desired: " + desiredClusters
|
||||
);
|
||||
}
|
||||
return initCentroids;
|
||||
}
|
||||
|
||||
record SegmentCentroid(int segment, int centroid, int centroidSize) {}
|
||||
|
||||
/**
|
||||
* Calculate the centroids for the given field and write them to the given
|
||||
* temporary centroid output.
|
||||
* When merging, we first bootstrap the KMeans algorithm with the centroids contained in the merging segments.
|
||||
* To prevent centroids that are too similar from having an outsized impact, all centroids that are closer than
|
||||
* the largest segments intra-cluster distance are merged into a single centroid.
|
||||
* The resulting centroids are then used to initialize the KMeans algorithm.
|
||||
*
|
||||
* @param fieldInfo merging field info
|
||||
* @param floatVectorValues the float vector values to merge
|
||||
* @param temporaryCentroidOutput the temporary centroid output
|
||||
* @param mergeState the merge state
|
||||
* @param globalCentroid the global centroid, calculated by this method and used to quantize the centroids
|
||||
* @return the number of centroids written
|
||||
* @throws IOException if an I/O error occurs
|
||||
*/
|
||||
@Override
|
||||
protected int calculateAndWriteCentroids(
|
||||
FieldInfo fieldInfo,
|
||||
FloatVectorValues floatVectorValues,
|
||||
IndexOutput temporaryCentroidOutput,
|
||||
MergeState mergeState,
|
||||
float[] globalCentroid
|
||||
) throws IOException {
|
||||
if (floatVectorValues.size() == 0) {
|
||||
return 0;
|
||||
}
|
||||
int maxNumClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1;
|
||||
int desiredClusters = (int) Math.max(Math.sqrt(floatVectorValues.size()), maxNumClusters);
|
||||
// init centroids from merge state
|
||||
List<FloatVectorValues> centroidList = new ArrayList<>();
|
||||
List<SegmentCentroid> segmentCentroids = new ArrayList<>(desiredClusters);
|
||||
|
||||
int segmentIdx = 0;
|
||||
for (var reader : mergeState.knnVectorsReaders) {
|
||||
IVFVectorsReader ivfVectorsReader = IVFVectorsFormat.getIVFReader(reader, fieldInfo.name);
|
||||
if (ivfVectorsReader == null) {
|
||||
continue;
|
||||
}
|
||||
|
||||
FloatVectorValues centroid = ivfVectorsReader.getCentroids(fieldInfo);
|
||||
if (centroid == null) {
|
||||
continue;
|
||||
}
|
||||
centroidList.add(centroid);
|
||||
for (int i = 0; i < centroid.size(); i++) {
|
||||
int size = ivfVectorsReader.centroidSize(fieldInfo.name, i);
|
||||
if (size == 0) {
|
||||
continue;
|
||||
}
|
||||
segmentCentroids.add(new SegmentCentroid(segmentIdx, i, size));
|
||||
}
|
||||
segmentIdx++;
|
||||
}
|
||||
|
||||
float[][] initCentroids = gatherInitCentroids(centroidList, segmentCentroids, desiredClusters, fieldInfo, mergeState);
|
||||
|
||||
// FIXME: run a custom version of KMeans that is just better...
|
||||
long nanoTime = System.nanoTime();
|
||||
final KMeans.Results kMeans = KMeans.cluster(
|
||||
floatVectorValues,
|
||||
desiredClusters,
|
||||
false,
|
||||
42L,
|
||||
KMeans.KmeansInitializationMethod.PLUS_PLUS,
|
||||
initCentroids,
|
||||
fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE,
|
||||
1,
|
||||
5,
|
||||
desiredClusters * 64
|
||||
);
|
||||
if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
|
||||
mergeState.infoStream.message(IVF_VECTOR_COMPONENT, "KMeans time ms: " + ((System.nanoTime() - nanoTime) / 1000000.0));
|
||||
}
|
||||
float[][] centroids = kMeans.centroids();
|
||||
|
||||
// write them
|
||||
// calculate the global centroid from all the centroids:
|
||||
for (float[] centroid : centroids) {
|
||||
for (int j = 0; j < centroid.length; j++) {
|
||||
globalCentroid[j] += centroid[j];
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < globalCentroid.length; j++) {
|
||||
globalCentroid[j] /= centroids.length;
|
||||
}
|
||||
writeCentroids(centroids, fieldInfo, globalCentroid, temporaryCentroidOutput);
|
||||
return centroids.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
long[] buildAndWritePostingsLists(
|
||||
FieldInfo fieldInfo,
|
||||
CentroidAssignmentScorer centroidAssignmentScorer,
|
||||
FloatVectorValues floatVectorValues,
|
||||
IndexOutput postingsOutput,
|
||||
MergeState mergeState
|
||||
) throws IOException {
|
||||
IntArrayList[] clusters = new IntArrayList[centroidAssignmentScorer.size()];
|
||||
for (int i = 0; i < centroidAssignmentScorer.size(); i++) {
|
||||
clusters[i] = new IntArrayList(floatVectorValues.size() / centroidAssignmentScorer.size() / 4);
|
||||
}
|
||||
long nanoTime = System.nanoTime();
|
||||
// Can we do a pre-filter by finding the nearest centroids to the original vector centroids?
|
||||
// We need to be careful on vecOrd vs. doc as we need random access to the raw vector for posting list writing
|
||||
assignCentroids(centroidAssignmentScorer, floatVectorValues, clusters);
|
||||
if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
|
||||
mergeState.infoStream.message(IVF_VECTOR_COMPONENT, "assignCentroids time ms: " + ((System.nanoTime() - nanoTime) / 1000000.0));
|
||||
}
|
||||
|
||||
if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
|
||||
printClusterQualityStatistics(clusters, mergeState.infoStream);
|
||||
}
|
||||
// write the posting lists
|
||||
final long[] offsets = new long[centroidAssignmentScorer.size()];
|
||||
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
|
||||
BinarizedFloatVectorValues binarizedByteVectorValues = new BinarizedFloatVectorValues(floatVectorValues, quantizer);
|
||||
DocIdsWriter docIdsWriter = new DocIdsWriter();
|
||||
for (int i = 0; i < centroidAssignmentScorer.size(); i++) {
|
||||
float[] centroid = centroidAssignmentScorer.centroid(i);
|
||||
binarizedByteVectorValues.centroid = centroid;
|
||||
// TODO: sort by distance to the centroid
|
||||
IntArrayList cluster = clusters[i];
|
||||
// TODO align???
|
||||
offsets[i] = postingsOutput.getFilePointer();
|
||||
int size = cluster.size();
|
||||
postingsOutput.writeVInt(size);
|
||||
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
|
||||
// TODO we might want to consider putting the docIds in a separate file
|
||||
// to aid with only having to fetch vectors from slower storage when they are required
|
||||
// keeping them in the same file indicates we pull the entire file into cache
|
||||
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster.get(j)), size, postingsOutput);
|
||||
writePostingList(cluster, postingsOutput, binarizedByteVectorValues);
|
||||
}
|
||||
return offsets;
|
||||
}
|
||||
|
||||
private static void printClusterQualityStatistics(IntArrayList[] clusters, InfoStream infoStream) {
|
||||
float min = Float.MAX_VALUE;
|
||||
float max = Float.MIN_VALUE;
|
||||
float mean = 0;
|
||||
float m2 = 0;
|
||||
// iteratively compute the variance & mean
|
||||
int count = 0;
|
||||
for (IntArrayList cluster : clusters) {
|
||||
count += 1;
|
||||
if (cluster == null) {
|
||||
continue;
|
||||
}
|
||||
float delta = cluster.size() - mean;
|
||||
mean += delta / count;
|
||||
m2 += delta * (cluster.size() - mean);
|
||||
min = Math.min(min, cluster.size());
|
||||
max = Math.max(max, cluster.size());
|
||||
}
|
||||
float variance = m2 / (clusters.length - 1);
|
||||
infoStream.message(
|
||||
IVF_VECTOR_COMPONENT,
|
||||
"Centroid count: "
|
||||
+ clusters.length
|
||||
+ " min: "
|
||||
+ min
|
||||
+ " max: "
|
||||
+ max
|
||||
+ " mean: "
|
||||
+ mean
|
||||
+ " stdDev: "
|
||||
+ Math.sqrt(variance)
|
||||
+ " variance: "
|
||||
+ variance
|
||||
);
|
||||
}
|
||||
|
||||
static void assignCentroids(CentroidAssignmentScorer scorer, FloatVectorValues vectors, IntArrayList[] clusters) throws IOException {
|
||||
int numCentroids = scorer.size();
|
||||
// we at most will look at the EXT_SOAR_LIMIT_CHECK_RATIO nearest centroids if possible
|
||||
int soarToCheck = (int) (numCentroids * EXT_SOAR_LIMIT_CHECK_RATIO);
|
||||
int soarClusterCheckCount = Math.min(numCentroids - 1, soarToCheck);
|
||||
NeighborQueue neighborsToCheck = new NeighborQueue(soarClusterCheckCount + 1, true);
|
||||
OrdScoreIterator ordScoreIterator = new OrdScoreIterator(soarClusterCheckCount + 1);
|
||||
float[] scratch = new float[vectors.dimension()];
|
||||
for (int docID = 0; docID < vectors.size(); docID++) {
|
||||
float[] vector = vectors.vectorValue(docID);
|
||||
scorer.setScoringVector(vector);
|
||||
int bestCentroid = 0;
|
||||
float bestScore = Float.MAX_VALUE;
|
||||
if (numCentroids > 1) {
|
||||
for (short c = 0; c < numCentroids; c++) {
|
||||
float squareDist = scorer.score(c);
|
||||
neighborsToCheck.insertWithOverflow(c, squareDist);
|
||||
}
|
||||
// pop the best
|
||||
int sz = neighborsToCheck.size();
|
||||
int best = neighborsToCheck.consumeNodesAndScoresMin(ordScoreIterator.ords, ordScoreIterator.scores);
|
||||
// Set the size to the number of neighbors we actually found
|
||||
ordScoreIterator.setSize(sz);
|
||||
bestScore = ordScoreIterator.getScore(best);
|
||||
bestCentroid = ordScoreIterator.getOrd(best);
|
||||
}
|
||||
clusters[bestCentroid].add(docID);
|
||||
if (soarClusterCheckCount > 0) {
|
||||
assignCentroidSOAR(
|
||||
ordScoreIterator,
|
||||
docID,
|
||||
bestCentroid,
|
||||
scorer.centroid(bestCentroid),
|
||||
bestScore,
|
||||
scratch,
|
||||
scorer,
|
||||
vector,
|
||||
clusters
|
||||
);
|
||||
}
|
||||
neighborsToCheck.clear();
|
||||
}
|
||||
}
|
||||
|
||||
static void assignCentroidSOAR(
|
||||
OrdScoreIterator centroidsToCheck,
|
||||
int vecOrd,
|
||||
int bestCentroidId,
|
||||
float[] bestCentroid,
|
||||
float bestScore,
|
||||
float[] scratch,
|
||||
CentroidAssignmentScorer scorer,
|
||||
float[] vector,
|
||||
IntArrayList[] clusters
|
||||
) throws IOException {
|
||||
ESVectorUtil.subtract(vector, bestCentroid, scratch);
|
||||
int bestSecondaryCentroid = -1;
|
||||
float minDist = Float.MAX_VALUE;
|
||||
for (int i = 0; i < centroidsToCheck.size(); i++) {
|
||||
float score = centroidsToCheck.getScore(i);
|
||||
int centroidOrdinal = centroidsToCheck.getOrd(i);
|
||||
if (centroidOrdinal == bestCentroidId) {
|
||||
continue;
|
||||
}
|
||||
float proj = ESVectorUtil.soarResidual(vector, scorer.centroid(centroidOrdinal), scratch);
|
||||
score += SOAR_LAMBDA * proj * proj / bestScore;
|
||||
if (score < minDist) {
|
||||
bestSecondaryCentroid = centroidOrdinal;
|
||||
minDist = score;
|
||||
}
|
||||
}
|
||||
if (bestSecondaryCentroid != -1) {
|
||||
clusters[bestSecondaryCentroid].add(vecOrd);
|
||||
}
|
||||
}
|
||||
|
||||
static class OrdScoreIterator {
|
||||
private final int[] ords;
|
||||
private final float[] scores;
|
||||
private int idx = 0;
|
||||
|
||||
OrdScoreIterator(int size) {
|
||||
this.ords = new int[size];
|
||||
this.scores = new float[size];
|
||||
}
|
||||
|
||||
int setSize(int size) {
|
||||
if (size > ords.length) {
|
||||
throw new IllegalArgumentException("size must be <= " + ords.length);
|
||||
}
|
||||
this.idx = size;
|
||||
return size;
|
||||
}
|
||||
|
||||
int getOrd(int idx) {
|
||||
return ords[idx];
|
||||
}
|
||||
|
||||
float getScore(int idx) {
|
||||
return scores[idx];
|
||||
}
|
||||
|
||||
int size() {
|
||||
return idx;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO unify with OSQ format
|
||||
static class BinarizedFloatVectorValues {
|
||||
private OptimizedScalarQuantizer.QuantizationResult corrections;
|
||||
private final byte[] binarized;
|
||||
private final byte[] initQuantized;
|
||||
private float[] centroid;
|
||||
private final FloatVectorValues values;
|
||||
private final OptimizedScalarQuantizer quantizer;
|
||||
|
||||
private int lastOrd = -1;
|
||||
|
||||
BinarizedFloatVectorValues(FloatVectorValues delegate, OptimizedScalarQuantizer quantizer) {
|
||||
this.values = delegate;
|
||||
this.quantizer = quantizer;
|
||||
this.binarized = new byte[discretize(delegate.dimension(), 64) / 8];
|
||||
this.initQuantized = new byte[delegate.dimension()];
|
||||
}
|
||||
|
||||
public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) {
|
||||
if (ord != lastOrd) {
|
||||
throw new IllegalStateException(
|
||||
"attempt to retrieve corrective terms for different ord " + ord + " than the quantization was done for: " + lastOrd
|
||||
);
|
||||
}
|
||||
return corrections;
|
||||
}
|
||||
|
||||
public byte[] vectorValue(int ord) throws IOException {
|
||||
if (ord != lastOrd) {
|
||||
binarize(ord);
|
||||
lastOrd = ord;
|
||||
}
|
||||
return binarized;
|
||||
}
|
||||
|
||||
private void binarize(int ord) throws IOException {
|
||||
corrections = quantizer.scalarQuantize(values.vectorValue(ord), initQuantized, INDEX_BITS, centroid);
|
||||
packAsBinary(initQuantized, binarized);
|
||||
}
|
||||
}
|
||||
|
||||
static class OffHeapCentroidAssignmentScorer implements CentroidAssignmentScorer {
|
||||
private final IndexInput centroidsInput;
|
||||
private final int numCentroids;
|
||||
private final int dimension;
|
||||
private final float[] scratch;
|
||||
private float[] q;
|
||||
private final long rawCentroidOffset;
|
||||
private int currOrd = -1;
|
||||
|
||||
OffHeapCentroidAssignmentScorer(IndexInput centroidsInput, int numCentroids, FieldInfo info) {
|
||||
this.centroidsInput = centroidsInput;
|
||||
this.numCentroids = numCentroids;
|
||||
this.dimension = info.getVectorDimension();
|
||||
this.scratch = new float[dimension];
|
||||
this.rawCentroidOffset = (dimension + 3 * Float.BYTES + Short.BYTES) * numCentroids;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return numCentroids;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] centroid(int centroidOrdinal) throws IOException {
|
||||
if (centroidOrdinal == currOrd) {
|
||||
return scratch;
|
||||
}
|
||||
centroidsInput.seek(rawCentroidOffset + (long) centroidOrdinal * dimension * Float.BYTES);
|
||||
centroidsInput.readFloats(scratch, 0, dimension);
|
||||
this.currOrd = centroidOrdinal;
|
||||
return scratch;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setScoringVector(float[] vector) {
|
||||
q = vector;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score(int centroidOrdinal) throws IOException {
|
||||
return VectorUtil.squareDistance(centroid(centroidOrdinal), q);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO throw away rawCentroids
|
||||
static class OnHeapCentroidAssignmentScorer implements CentroidAssignmentScorer {
|
||||
private final float[][] centroids;
|
||||
private float[] q;
|
||||
|
||||
OnHeapCentroidAssignmentScorer(float[][] centroids) {
|
||||
this.centroids = centroids;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return centroids.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setScoringVector(float[] vector) {
|
||||
q = vector;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] centroid(int centroidOrdinal) throws IOException {
|
||||
return centroids[centroidOrdinal];
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score(int centroidOrdinal) throws IOException {
|
||||
return VectorUtil.squareDistance(centroid(centroidOrdinal), q);
|
||||
}
|
||||
}
|
||||
|
||||
static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections)
|
||||
throws IOException {
|
||||
indexOutput.writeBytes(binaryValue, binaryValue.length);
|
||||
indexOutput.writeInt(Float.floatToIntBits(corrections.lowerInterval()));
|
||||
indexOutput.writeInt(Float.floatToIntBits(corrections.upperInterval()));
|
||||
indexOutput.writeInt(Float.floatToIntBits(corrections.additionalCorrection()));
|
||||
assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff;
|
||||
indexOutput.writeShort((short) corrections.quantizedComponentSum());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,110 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.codec.vectors;
|
||||
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
|
||||
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Codec format for Inverted File Vector indexes. This index expects to break the dimensional space
|
||||
* into clusters and assign each vector to a cluster generating a posting list of vectors. Clusters
|
||||
* are represented by centroids.
|
||||
* The vector quantization format used here is a per-vector optimized scalar quantization. Also see {@link
|
||||
* OptimizedScalarQuantizer}. Some of key features are:
|
||||
*
|
||||
* The format is stored in three files:
|
||||
*
|
||||
* <h2>.cenivf (centroid data) file</h2>
|
||||
* <p> Which stores the raw and quantized centroid vectors.
|
||||
*
|
||||
* <h2>.clivf (cluster data) file</h2>
|
||||
*
|
||||
* <p> Stores the quantized vectors for each cluster, inline and stored in blocks. Additionally, the docIds of
|
||||
* each vector is stored.
|
||||
*
|
||||
* <h2>.mivf (centroid metadata) file</h2>
|
||||
*
|
||||
* <p> Stores metadata including the number of centroids and their offsets in the clivf file</p>
|
||||
*
|
||||
*/
|
||||
public class IVFVectorsFormat extends KnnVectorsFormat {
|
||||
|
||||
public static final String IVF_VECTOR_COMPONENT = "IVF";
|
||||
public static final String NAME = "IVFVectorsFormat";
|
||||
// centroid ordinals -> centroid values, offsets
|
||||
public static final String CENTROID_EXTENSION = "cenivf";
|
||||
// offsets contained in cen_ivf, [vector ordinals, actually just docIds](long varint), quantized
|
||||
// vectors (OSQ bit)
|
||||
public static final String CLUSTER_EXTENSION = "clivf";
|
||||
static final String IVF_META_EXTENSION = "mivf";
|
||||
|
||||
public static final int VERSION_START = 0;
|
||||
public static final int VERSION_CURRENT = VERSION_START;
|
||||
|
||||
private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat(
|
||||
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
|
||||
);
|
||||
|
||||
private static final int DEFAULT_VECTORS_PER_CLUSTER = 1000;
|
||||
|
||||
private final int vectorPerCluster;
|
||||
|
||||
public IVFVectorsFormat(int vectorPerCluster) {
|
||||
super(NAME);
|
||||
if (vectorPerCluster <= 0) {
|
||||
throw new IllegalArgumentException("vectorPerCluster must be > 0");
|
||||
}
|
||||
this.vectorPerCluster = vectorPerCluster;
|
||||
}
|
||||
|
||||
/** Constructs a format using the given graph construction parameters and scalar quantization. */
|
||||
public IVFVectorsFormat() {
|
||||
this(DEFAULT_VECTORS_PER_CLUSTER);
|
||||
}
|
||||
|
||||
@Override
|
||||
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
|
||||
return new DefaultIVFVectorsWriter(state, rawVectorFormat.fieldsWriter(state), vectorPerCluster);
|
||||
}
|
||||
|
||||
@Override
|
||||
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
|
||||
return new DefaultIVFVectorsReader(state, rawVectorFormat.fieldsReader(state));
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getMaxDimensions(String fieldName) {
|
||||
return 1024;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "IVFVectorFormat";
|
||||
}
|
||||
|
||||
static IVFVectorsReader getIVFReader(KnnVectorsReader vectorsReader, String fieldName) {
|
||||
if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
|
||||
vectorsReader = candidateReader.getFieldReader(fieldName);
|
||||
}
|
||||
if (vectorsReader instanceof IVFVectorsReader reader) {
|
||||
return reader;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,354 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.codec.vectors;
|
||||
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.CorruptIndexException;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FieldInfos;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.internal.hppc.IntObjectHashMap;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.store.ChecksumIndexInput;
|
||||
import org.apache.lucene.store.DataInput;
|
||||
import org.apache.lucene.store.IOContext;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||
import org.elasticsearch.core.IOUtils;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.function.IntPredicate;
|
||||
|
||||
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
|
||||
|
||||
/**
|
||||
* Reader for IVF vectors. This reader is used to read the IVF vectors from the index.
|
||||
*/
|
||||
public abstract class IVFVectorsReader extends KnnVectorsReader {
|
||||
|
||||
private final IndexInput ivfCentroids, ivfClusters;
|
||||
private final SegmentReadState state;
|
||||
private final FieldInfos fieldInfos;
|
||||
protected final IntObjectHashMap<FieldEntry> fields;
|
||||
private final FlatVectorsReader rawVectorsReader;
|
||||
|
||||
protected IVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException {
|
||||
this.state = state;
|
||||
this.fieldInfos = state.fieldInfos;
|
||||
this.rawVectorsReader = rawVectorsReader;
|
||||
this.fields = new IntObjectHashMap<>();
|
||||
String meta = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, IVFVectorsFormat.IVF_META_EXTENSION);
|
||||
|
||||
int versionMeta = -1;
|
||||
boolean success = false;
|
||||
try (ChecksumIndexInput ivfMeta = state.directory.openChecksumInput(meta)) {
|
||||
Throwable priorE = null;
|
||||
try {
|
||||
versionMeta = CodecUtil.checkIndexHeader(
|
||||
ivfMeta,
|
||||
IVFVectorsFormat.NAME,
|
||||
IVFVectorsFormat.VERSION_START,
|
||||
IVFVectorsFormat.VERSION_CURRENT,
|
||||
state.segmentInfo.getId(),
|
||||
state.segmentSuffix
|
||||
);
|
||||
readFields(ivfMeta);
|
||||
} catch (Throwable exception) {
|
||||
priorE = exception;
|
||||
} finally {
|
||||
CodecUtil.checkFooter(ivfMeta, priorE);
|
||||
}
|
||||
ivfCentroids = openDataInput(state, versionMeta, IVFVectorsFormat.CENTROID_EXTENSION, IVFVectorsFormat.NAME, state.context);
|
||||
ivfClusters = openDataInput(state, versionMeta, IVFVectorsFormat.CLUSTER_EXTENSION, IVFVectorsFormat.NAME, state.context);
|
||||
success = true;
|
||||
} finally {
|
||||
if (success == false) {
|
||||
IOUtils.closeWhileHandlingException(this);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
abstract CentroidQueryScorer getCentroidScorer(
|
||||
FieldInfo fieldInfo,
|
||||
int numCentroids,
|
||||
IndexInput centroids,
|
||||
float[] target,
|
||||
IndexInput clusters
|
||||
) throws IOException;
|
||||
|
||||
protected abstract FloatVectorValues getCentroids(IndexInput indexInput, int numCentroids, FieldInfo info) throws IOException;
|
||||
|
||||
public FloatVectorValues getCentroids(FieldInfo fieldInfo) throws IOException {
|
||||
FieldEntry entry = fields.get(fieldInfo.number);
|
||||
if (entry == null) {
|
||||
return null;
|
||||
}
|
||||
return getCentroids(entry.centroidSlice(ivfCentroids), entry.postingListOffsets.length, fieldInfo);
|
||||
}
|
||||
|
||||
int centroidSize(String fieldName, int centroidOrdinal) throws IOException {
|
||||
FieldInfo fieldInfo = state.fieldInfos.fieldInfo(fieldName);
|
||||
FieldEntry entry = fields.get(fieldInfo.number);
|
||||
ivfClusters.seek(entry.postingListOffsets[centroidOrdinal]);
|
||||
return ivfClusters.readVInt();
|
||||
}
|
||||
|
||||
private static IndexInput openDataInput(
|
||||
SegmentReadState state,
|
||||
int versionMeta,
|
||||
String fileExtension,
|
||||
String codecName,
|
||||
IOContext context
|
||||
) throws IOException {
|
||||
final String fileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension);
|
||||
final IndexInput in = state.directory.openInput(fileName, context);
|
||||
boolean success = false;
|
||||
try {
|
||||
final int versionVectorData = CodecUtil.checkIndexHeader(
|
||||
in,
|
||||
codecName,
|
||||
IVFVectorsFormat.VERSION_START,
|
||||
IVFVectorsFormat.VERSION_CURRENT,
|
||||
state.segmentInfo.getId(),
|
||||
state.segmentSuffix
|
||||
);
|
||||
if (versionMeta != versionVectorData) {
|
||||
throw new CorruptIndexException(
|
||||
"Format versions mismatch: meta=" + versionMeta + ", " + codecName + "=" + versionVectorData,
|
||||
in
|
||||
);
|
||||
}
|
||||
CodecUtil.retrieveChecksum(in);
|
||||
success = true;
|
||||
return in;
|
||||
} finally {
|
||||
if (success == false) {
|
||||
IOUtils.closeWhileHandlingException(in);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void readFields(ChecksumIndexInput meta) throws IOException {
|
||||
for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) {
|
||||
final FieldInfo info = fieldInfos.fieldInfo(fieldNumber);
|
||||
if (info == null) {
|
||||
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
|
||||
}
|
||||
fields.put(info.number, readField(meta, info));
|
||||
}
|
||||
}
|
||||
|
||||
private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException {
|
||||
final VectorEncoding vectorEncoding = readVectorEncoding(input);
|
||||
final VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
|
||||
final long centroidOffset = input.readLong();
|
||||
final long centroidLength = input.readLong();
|
||||
final int numPostingLists = input.readVInt();
|
||||
final long[] postingListOffsets = new long[numPostingLists];
|
||||
for (int i = 0; i < numPostingLists; i++) {
|
||||
postingListOffsets[i] = input.readLong();
|
||||
}
|
||||
final float[] globalCentroid = new float[info.getVectorDimension()];
|
||||
float globalCentroidDp = 0;
|
||||
if (numPostingLists > 0) {
|
||||
input.readFloats(globalCentroid, 0, globalCentroid.length);
|
||||
globalCentroidDp = Float.intBitsToFloat(input.readInt());
|
||||
}
|
||||
if (similarityFunction != info.getVectorSimilarityFunction()) {
|
||||
throw new IllegalStateException(
|
||||
"Inconsistent vector similarity function for field=\""
|
||||
+ info.name
|
||||
+ "\"; "
|
||||
+ similarityFunction
|
||||
+ " != "
|
||||
+ info.getVectorSimilarityFunction()
|
||||
);
|
||||
}
|
||||
return new FieldEntry(
|
||||
similarityFunction,
|
||||
vectorEncoding,
|
||||
centroidOffset,
|
||||
centroidLength,
|
||||
postingListOffsets,
|
||||
globalCentroid,
|
||||
globalCentroidDp
|
||||
);
|
||||
}
|
||||
|
||||
private static VectorSimilarityFunction readSimilarityFunction(DataInput input) throws IOException {
|
||||
final int i = input.readInt();
|
||||
if (i < 0 || i >= SIMILARITY_FUNCTIONS.size()) {
|
||||
throw new IllegalArgumentException("invalid distance function: " + i);
|
||||
}
|
||||
return SIMILARITY_FUNCTIONS.get(i);
|
||||
}
|
||||
|
||||
private static VectorEncoding readVectorEncoding(DataInput input) throws IOException {
|
||||
final int encodingId = input.readInt();
|
||||
if (encodingId < 0 || encodingId >= VectorEncoding.values().length) {
|
||||
throw new CorruptIndexException("Invalid vector encoding id: " + encodingId, input);
|
||||
}
|
||||
return VectorEncoding.values()[encodingId];
|
||||
}
|
||||
|
||||
@Override
|
||||
public final void checkIntegrity() throws IOException {
|
||||
rawVectorsReader.checkIntegrity();
|
||||
CodecUtil.checksumEntireFile(ivfCentroids);
|
||||
CodecUtil.checksumEntireFile(ivfClusters);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||
return rawVectorsReader.getFloatVectorValues(field);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||
return rawVectorsReader.getByteVectorValues(field);
|
||||
}
|
||||
|
||||
protected float[] getGlobalCentroid(FieldInfo info) {
|
||||
if (info == null || info.getVectorEncoding().equals(VectorEncoding.BYTE)) {
|
||||
return null;
|
||||
}
|
||||
FieldEntry entry = fields.get(info.number);
|
||||
if (entry == null) {
|
||||
return null;
|
||||
}
|
||||
return entry.globalCentroid();
|
||||
}
|
||||
|
||||
@Override
|
||||
public final void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
|
||||
final FieldInfo fieldInfo = state.fieldInfos.fieldInfo(field);
|
||||
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32) == false) {
|
||||
rawVectorsReader.search(field, target, knnCollector, acceptDocs);
|
||||
return;
|
||||
}
|
||||
// TODO add new ivf search strategy
|
||||
int nProbe = 10;
|
||||
float percentFiltered = 1f;
|
||||
if (acceptDocs instanceof BitSet bitSet) {
|
||||
percentFiltered = Math.max(0f, Math.min(1f, (float) bitSet.approximateCardinality() / bitSet.length()));
|
||||
}
|
||||
int numVectors = rawVectorsReader.getFloatVectorValues(field).size();
|
||||
BitSet visitedDocs = new FixedBitSet(state.segmentInfo.maxDoc() + 1);
|
||||
IntPredicate needsScoring = docId -> {
|
||||
if (acceptDocs != null && acceptDocs.get(docId) == false) {
|
||||
return false;
|
||||
}
|
||||
return visitedDocs.getAndSet(docId) == false;
|
||||
};
|
||||
|
||||
FieldEntry entry = fields.get(fieldInfo.number);
|
||||
CentroidQueryScorer centroidQueryScorer = getCentroidScorer(
|
||||
fieldInfo,
|
||||
entry.postingListOffsets.length,
|
||||
entry.centroidSlice(ivfCentroids),
|
||||
target,
|
||||
ivfClusters
|
||||
);
|
||||
final NeighborQueue centroidQueue = scorePostingLists(fieldInfo, knnCollector, centroidQueryScorer, nProbe);
|
||||
PostingVisitor scorer = getPostingVisitor(fieldInfo, ivfClusters, target, needsScoring);
|
||||
int centroidsVisited = 0;
|
||||
long expectedDocs = 0;
|
||||
long actualDocs = 0;
|
||||
// initially we visit only the "centroids to search"
|
||||
while (centroidQueue.size() > 0 && centroidsVisited < nProbe) {
|
||||
++centroidsVisited;
|
||||
// todo do we actually need to know the score???
|
||||
int centroidOrdinal = centroidQueue.pop();
|
||||
// todo do we need direct access to the raw centroid???
|
||||
expectedDocs += scorer.resetPostingsScorer(centroidOrdinal, centroidQueryScorer.centroid(centroidOrdinal));
|
||||
actualDocs += scorer.visit(knnCollector);
|
||||
}
|
||||
if (acceptDocs != null) {
|
||||
float unfilteredRatioVisited = (float) expectedDocs / numVectors;
|
||||
int filteredVectors = (int) Math.ceil(numVectors * percentFiltered);
|
||||
float expectedScored = Math.min(2 * filteredVectors * unfilteredRatioVisited, expectedDocs / 2f);
|
||||
while (centroidQueue.size() > 0 && (actualDocs < expectedScored || actualDocs < knnCollector.k())) {
|
||||
int centroidOrdinal = centroidQueue.pop();
|
||||
scorer.resetPostingsScorer(centroidOrdinal, centroidQueryScorer.centroid(centroidOrdinal));
|
||||
actualDocs += scorer.visit(knnCollector);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public final void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
|
||||
final FieldInfo fieldInfo = state.fieldInfos.fieldInfo(field);
|
||||
final ByteVectorValues values = rawVectorsReader.getByteVectorValues(field);
|
||||
for (int i = 0; i < values.size(); i++) {
|
||||
final float score = fieldInfo.getVectorSimilarityFunction().compare(target, values.vectorValue(i));
|
||||
knnCollector.collect(values.ordToDoc(i), score);
|
||||
if (knnCollector.earlyTerminated()) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
abstract NeighborQueue scorePostingLists(
|
||||
FieldInfo fieldInfo,
|
||||
KnnCollector knnCollector,
|
||||
CentroidQueryScorer centroidQueryScorer,
|
||||
int nProbe
|
||||
) throws IOException;
|
||||
|
||||
@Override
|
||||
public void close() throws IOException {
|
||||
IOUtils.close(rawVectorsReader, ivfCentroids, ivfClusters);
|
||||
}
|
||||
|
||||
protected record FieldEntry(
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
VectorEncoding vectorEncoding,
|
||||
long centroidOffset,
|
||||
long centroidLength,
|
||||
long[] postingListOffsets,
|
||||
float[] globalCentroid,
|
||||
float globalCentroidDp
|
||||
) {
|
||||
IndexInput centroidSlice(IndexInput centroidFile) throws IOException {
|
||||
return centroidFile.slice("centroids", centroidOffset, centroidLength);
|
||||
}
|
||||
}
|
||||
|
||||
abstract PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput postingsLists, float[] target, IntPredicate needsScoring)
|
||||
throws IOException;
|
||||
|
||||
interface CentroidQueryScorer {
|
||||
int size();
|
||||
|
||||
float[] centroid(int centroidOrdinal) throws IOException;
|
||||
|
||||
float score(int centroidOrdinal) throws IOException;
|
||||
}
|
||||
|
||||
interface PostingVisitor {
|
||||
// TODO maybe we can not specifically pass the centroid...
|
||||
|
||||
/** returns the number of documents in the posting list */
|
||||
int resetPostingsScorer(int centroidOrdinal, float[] centroid) throws IOException;
|
||||
|
||||
/** returns the number of scored documents */
|
||||
int visit(KnnCollector collector) throws IOException;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,486 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.codec.vectors;
|
||||
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
|
||||
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.store.IOContext;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.util.InfoStream;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.elasticsearch.core.IOUtils;
|
||||
import org.elasticsearch.core.SuppressForbidden;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.UncheckedIOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
/**
|
||||
* Base class for IVF vectors writer.
|
||||
*/
|
||||
public abstract class IVFVectorsWriter extends KnnVectorsWriter {
|
||||
|
||||
private final List<FieldWriter> fieldWriters = new ArrayList<>();
|
||||
private final IndexOutput ivfCentroids, ivfClusters;
|
||||
private final IndexOutput ivfMeta;
|
||||
private final FlatVectorsWriter rawVectorDelegate;
|
||||
private final SegmentWriteState segmentWriteState;
|
||||
|
||||
protected IVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVectorDelegate) throws IOException {
|
||||
this.segmentWriteState = state;
|
||||
this.rawVectorDelegate = rawVectorDelegate;
|
||||
final String metaFileName = IndexFileNames.segmentFileName(
|
||||
state.segmentInfo.name,
|
||||
state.segmentSuffix,
|
||||
IVFVectorsFormat.IVF_META_EXTENSION
|
||||
);
|
||||
|
||||
final String ivfCentroidsFileName = IndexFileNames.segmentFileName(
|
||||
state.segmentInfo.name,
|
||||
state.segmentSuffix,
|
||||
IVFVectorsFormat.CENTROID_EXTENSION
|
||||
);
|
||||
final String ivfClustersFileName = IndexFileNames.segmentFileName(
|
||||
state.segmentInfo.name,
|
||||
state.segmentSuffix,
|
||||
IVFVectorsFormat.CLUSTER_EXTENSION
|
||||
);
|
||||
boolean success = false;
|
||||
try {
|
||||
ivfMeta = state.directory.createOutput(metaFileName, state.context);
|
||||
CodecUtil.writeIndexHeader(
|
||||
ivfMeta,
|
||||
IVFVectorsFormat.NAME,
|
||||
IVFVectorsFormat.VERSION_CURRENT,
|
||||
state.segmentInfo.getId(),
|
||||
state.segmentSuffix
|
||||
);
|
||||
ivfCentroids = state.directory.createOutput(ivfCentroidsFileName, state.context);
|
||||
CodecUtil.writeIndexHeader(
|
||||
ivfCentroids,
|
||||
IVFVectorsFormat.NAME,
|
||||
IVFVectorsFormat.VERSION_CURRENT,
|
||||
state.segmentInfo.getId(),
|
||||
state.segmentSuffix
|
||||
);
|
||||
ivfClusters = state.directory.createOutput(ivfClustersFileName, state.context);
|
||||
CodecUtil.writeIndexHeader(
|
||||
ivfClusters,
|
||||
IVFVectorsFormat.NAME,
|
||||
IVFVectorsFormat.VERSION_CURRENT,
|
||||
state.segmentInfo.getId(),
|
||||
state.segmentSuffix
|
||||
);
|
||||
success = true;
|
||||
} finally {
|
||||
if (success == false) {
|
||||
IOUtils.closeWhileHandlingException(this);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public final KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
|
||||
if (fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE) {
|
||||
throw new IllegalArgumentException("IVF does not support cosine similarity");
|
||||
}
|
||||
final FlatFieldVectorsWriter<?> rawVectorDelegate = this.rawVectorDelegate.addField(fieldInfo);
|
||||
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
|
||||
@SuppressWarnings("unchecked")
|
||||
final FlatFieldVectorsWriter<float[]> floatWriter = (FlatFieldVectorsWriter<float[]>) rawVectorDelegate;
|
||||
fieldWriters.add(new FieldWriter(fieldInfo, floatWriter));
|
||||
}
|
||||
return rawVectorDelegate;
|
||||
}
|
||||
|
||||
protected abstract int calculateAndWriteCentroids(
|
||||
FieldInfo fieldInfo,
|
||||
FloatVectorValues floatVectorValues,
|
||||
IndexOutput temporaryCentroidOutput,
|
||||
MergeState mergeState,
|
||||
float[] globalCentroid
|
||||
) throws IOException;
|
||||
|
||||
abstract long[] buildAndWritePostingsLists(
|
||||
FieldInfo fieldInfo,
|
||||
CentroidAssignmentScorer scorer,
|
||||
FloatVectorValues floatVectorValues,
|
||||
IndexOutput postingsOutput,
|
||||
MergeState mergeState
|
||||
) throws IOException;
|
||||
|
||||
abstract CentroidAssignmentScorer calculateAndWriteCentroids(
|
||||
FieldInfo fieldInfo,
|
||||
FloatVectorValues floatVectorValues,
|
||||
IndexOutput centroidOutput,
|
||||
float[] globalCentroid
|
||||
) throws IOException;
|
||||
|
||||
abstract long[] buildAndWritePostingsLists(
|
||||
FieldInfo fieldInfo,
|
||||
InfoStream infoStream,
|
||||
CentroidAssignmentScorer scorer,
|
||||
FloatVectorValues floatVectorValues,
|
||||
IndexOutput postingsOutput
|
||||
) throws IOException;
|
||||
|
||||
abstract CentroidAssignmentScorer createCentroidScorer(
|
||||
IndexInput centroidsInput,
|
||||
int numCentroids,
|
||||
FieldInfo fieldInfo,
|
||||
float[] globalCentroid
|
||||
) throws IOException;
|
||||
|
||||
@Override
|
||||
public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
|
||||
rawVectorDelegate.flush(maxDoc, sortMap);
|
||||
for (FieldWriter fieldWriter : fieldWriters) {
|
||||
float[] globalCentroid = new float[fieldWriter.fieldInfo.getVectorDimension()];
|
||||
// calculate global centroid
|
||||
for (var vector : fieldWriter.delegate.getVectors()) {
|
||||
for (int i = 0; i < globalCentroid.length; i++) {
|
||||
globalCentroid[i] += vector[i];
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < globalCentroid.length; i++) {
|
||||
globalCentroid[i] /= fieldWriter.delegate.getVectors().size();
|
||||
}
|
||||
// build a float vector values with random access
|
||||
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldWriter.fieldInfo, fieldWriter.delegate, maxDoc);
|
||||
// build centroids
|
||||
long centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES);
|
||||
final CentroidAssignmentScorer centroidAssignmentScorer = calculateAndWriteCentroids(
|
||||
fieldWriter.fieldInfo,
|
||||
floatVectorValues,
|
||||
ivfCentroids,
|
||||
globalCentroid
|
||||
);
|
||||
long centroidLength = ivfCentroids.getFilePointer() - centroidOffset;
|
||||
final long[] offsets = buildAndWritePostingsLists(
|
||||
fieldWriter.fieldInfo,
|
||||
segmentWriteState.infoStream,
|
||||
centroidAssignmentScorer,
|
||||
floatVectorValues,
|
||||
ivfClusters
|
||||
);
|
||||
writeMeta(fieldWriter.fieldInfo, centroidOffset, centroidLength, offsets, globalCentroid);
|
||||
}
|
||||
}
|
||||
|
||||
private static FloatVectorValues getFloatVectorValues(
|
||||
FieldInfo fieldInfo,
|
||||
FlatFieldVectorsWriter<float[]> fieldVectorsWriter,
|
||||
int maxDoc
|
||||
) throws IOException {
|
||||
List<float[]> vectors = fieldVectorsWriter.getVectors();
|
||||
if (vectors.size() == maxDoc) {
|
||||
return FloatVectorValues.fromFloats(vectors, fieldInfo.getVectorDimension());
|
||||
}
|
||||
final DocIdSetIterator iterator = fieldVectorsWriter.getDocsWithFieldSet().iterator();
|
||||
final int[] docIds = new int[vectors.size()];
|
||||
for (int i = 0; i < docIds.length; i++) {
|
||||
docIds[i] = iterator.nextDoc();
|
||||
}
|
||||
assert iterator.nextDoc() == NO_MORE_DOCS;
|
||||
return new FloatVectorValues() {
|
||||
@Override
|
||||
public float[] vectorValue(int ord) {
|
||||
return vectors.get(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
public FloatVectorValues copy() {
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return fieldInfo.getVectorDimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return vectors.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ordToDoc(int ord) {
|
||||
return docIds[ord];
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
static IVFVectorsReader getIVFReader(KnnVectorsReader vectorsReader, String fieldName) {
|
||||
if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
|
||||
vectorsReader = candidateReader.getFieldReader(fieldName);
|
||||
}
|
||||
if (vectorsReader instanceof IVFVectorsReader reader) {
|
||||
return reader;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
@SuppressForbidden(reason = "require usage of Lucene's IOUtils#deleteFilesIgnoringExceptions(...)")
|
||||
public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||
rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
|
||||
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
|
||||
final int numVectors;
|
||||
String tempRawVectorsFileName = null;
|
||||
boolean success = false;
|
||||
// build a float vector values with random access. In order to do that we dump the vectors to
|
||||
// a temporary file
|
||||
// and write the docID follow by the vector
|
||||
try (IndexOutput out = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "ivf_", IOContext.DEFAULT)) {
|
||||
tempRawVectorsFileName = out.getName();
|
||||
// TODO do this better, we shouldn't have to write to a temp file, we should be able to
|
||||
// to just from the merged vector values, the tricky part is the random access.
|
||||
numVectors = writeFloatVectorValues(fieldInfo, out, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState));
|
||||
CodecUtil.writeFooter(out);
|
||||
success = true;
|
||||
} finally {
|
||||
if (success == false && tempRawVectorsFileName != null) {
|
||||
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName);
|
||||
}
|
||||
}
|
||||
try (IndexInput in = mergeState.segmentInfo.dir.openInput(tempRawVectorsFileName, IOContext.DEFAULT)) {
|
||||
float[] calculatedGlobalCentroid = new float[fieldInfo.getVectorDimension()];
|
||||
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, in, numVectors);
|
||||
success = false;
|
||||
CentroidAssignmentScorer centroidAssignmentScorer;
|
||||
long centroidOffset;
|
||||
long centroidLength;
|
||||
String centroidTempName = null;
|
||||
int numCentroids;
|
||||
IndexOutput centroidTemp = null;
|
||||
try {
|
||||
centroidTemp = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "civf_", IOContext.DEFAULT);
|
||||
centroidTempName = centroidTemp.getName();
|
||||
numCentroids = calculateAndWriteCentroids(
|
||||
fieldInfo,
|
||||
floatVectorValues,
|
||||
centroidTemp,
|
||||
mergeState,
|
||||
calculatedGlobalCentroid
|
||||
);
|
||||
success = true;
|
||||
} finally {
|
||||
if (success == false && centroidTempName != null) {
|
||||
IOUtils.closeWhileHandlingException(centroidTemp);
|
||||
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, centroidTempName);
|
||||
}
|
||||
}
|
||||
try {
|
||||
if (numCentroids == 0) {
|
||||
centroidOffset = ivfCentroids.getFilePointer();
|
||||
writeMeta(fieldInfo, centroidOffset, 0, new long[0], null);
|
||||
CodecUtil.writeFooter(centroidTemp);
|
||||
IOUtils.close(centroidTemp);
|
||||
return;
|
||||
}
|
||||
CodecUtil.writeFooter(centroidTemp);
|
||||
IOUtils.close(centroidTemp);
|
||||
centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES);
|
||||
try (IndexInput centroidInput = mergeState.segmentInfo.dir.openInput(centroidTempName, IOContext.DEFAULT)) {
|
||||
ivfCentroids.copyBytes(centroidInput, centroidInput.length() - CodecUtil.footerLength());
|
||||
centroidLength = ivfCentroids.getFilePointer() - centroidOffset;
|
||||
centroidAssignmentScorer = createCentroidScorer(centroidInput, numCentroids, fieldInfo, calculatedGlobalCentroid);
|
||||
assert centroidAssignmentScorer.size() == numCentroids;
|
||||
// build a float vector values with random access
|
||||
// build centroids
|
||||
final long[] offsets = buildAndWritePostingsLists(
|
||||
fieldInfo,
|
||||
centroidAssignmentScorer,
|
||||
floatVectorValues,
|
||||
ivfClusters,
|
||||
mergeState
|
||||
);
|
||||
assert offsets.length == centroidAssignmentScorer.size();
|
||||
writeMeta(fieldInfo, centroidOffset, centroidLength, offsets, calculatedGlobalCentroid);
|
||||
}
|
||||
} finally {
|
||||
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(
|
||||
mergeState.segmentInfo.dir,
|
||||
tempRawVectorsFileName,
|
||||
centroidTempName
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static FloatVectorValues getFloatVectorValues(FieldInfo fieldInfo, IndexInput randomAccessInput, int numVectors) {
|
||||
if (numVectors == 0) {
|
||||
return FloatVectorValues.fromFloats(List.of(), fieldInfo.getVectorDimension());
|
||||
}
|
||||
final long length = (long) Float.BYTES * fieldInfo.getVectorDimension() + Integer.BYTES;
|
||||
final float[] vector = new float[fieldInfo.getVectorDimension()];
|
||||
return new FloatVectorValues() {
|
||||
@Override
|
||||
public float[] vectorValue(int ord) throws IOException {
|
||||
randomAccessInput.seek(ord * length + Integer.BYTES);
|
||||
randomAccessInput.readFloats(vector, 0, vector.length);
|
||||
return vector;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FloatVectorValues copy() {
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return fieldInfo.getVectorDimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return numVectors;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ordToDoc(int ord) {
|
||||
try {
|
||||
randomAccessInput.seek(ord * length);
|
||||
return randomAccessInput.readInt();
|
||||
} catch (IOException e) {
|
||||
throw new UncheckedIOException(e);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private static int writeFloatVectorValues(FieldInfo fieldInfo, IndexOutput out, FloatVectorValues floatVectorValues)
|
||||
throws IOException {
|
||||
int numVectors = 0;
|
||||
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
final KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator();
|
||||
for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) {
|
||||
numVectors++;
|
||||
float[] vector = floatVectorValues.vectorValue(iterator.index());
|
||||
out.writeInt(iterator.docID());
|
||||
buffer.asFloatBuffer().put(vector);
|
||||
out.writeBytes(buffer.array(), buffer.array().length);
|
||||
}
|
||||
return numVectors;
|
||||
}
|
||||
|
||||
private void writeMeta(FieldInfo field, long centroidOffset, long centroidLength, long[] offsets, float[] globalCentroid)
|
||||
throws IOException {
|
||||
ivfMeta.writeInt(field.number);
|
||||
ivfMeta.writeInt(field.getVectorEncoding().ordinal());
|
||||
ivfMeta.writeInt(distFuncToOrd(field.getVectorSimilarityFunction()));
|
||||
ivfMeta.writeLong(centroidOffset);
|
||||
ivfMeta.writeLong(centroidLength);
|
||||
ivfMeta.writeVInt(offsets.length);
|
||||
for (long offset : offsets) {
|
||||
ivfMeta.writeLong(offset);
|
||||
}
|
||||
if (offsets.length > 0) {
|
||||
final ByteBuffer buffer = ByteBuffer.allocate(globalCentroid.length * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
buffer.asFloatBuffer().put(globalCentroid);
|
||||
ivfMeta.writeBytes(buffer.array(), buffer.array().length);
|
||||
ivfMeta.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(globalCentroid, globalCentroid)));
|
||||
}
|
||||
}
|
||||
|
||||
private static int distFuncToOrd(VectorSimilarityFunction func) {
|
||||
for (int i = 0; i < SIMILARITY_FUNCTIONS.size(); i++) {
|
||||
if (SIMILARITY_FUNCTIONS.get(i).equals(func)) {
|
||||
return (byte) i;
|
||||
}
|
||||
}
|
||||
throw new IllegalArgumentException("invalid distance function: " + func);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final void finish() throws IOException {
|
||||
rawVectorDelegate.finish();
|
||||
if (ivfMeta != null) {
|
||||
// write end of fields marker
|
||||
ivfMeta.writeInt(-1);
|
||||
CodecUtil.writeFooter(ivfMeta);
|
||||
}
|
||||
if (ivfCentroids != null) {
|
||||
CodecUtil.writeFooter(ivfCentroids);
|
||||
}
|
||||
if (ivfClusters != null) {
|
||||
CodecUtil.writeFooter(ivfClusters);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public final void close() throws IOException {
|
||||
IOUtils.close(rawVectorDelegate, ivfMeta, ivfCentroids, ivfClusters);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final long ramBytesUsed() {
|
||||
return rawVectorDelegate.ramBytesUsed();
|
||||
}
|
||||
|
||||
private record FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter<float[]> delegate) {}
|
||||
|
||||
interface CentroidAssignmentScorer {
|
||||
CentroidAssignmentScorer EMPTY = new CentroidAssignmentScorer() {
|
||||
@Override
|
||||
public int size() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] centroid(int centroidOrdinal) {
|
||||
throw new IllegalStateException("No centroids");
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score(int centroidOrdinal) {
|
||||
throw new IllegalStateException("No centroids");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setScoringVector(float[] vector) {
|
||||
throw new IllegalStateException("No centroids");
|
||||
}
|
||||
};
|
||||
|
||||
int size();
|
||||
|
||||
float[] centroid(int centroidOrdinal) throws IOException;
|
||||
|
||||
void setScoringVector(float[] vector);
|
||||
|
||||
float score(int centroidOrdinal) throws IOException;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,159 @@
|
|||
/*
|
||||
* @notice
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
* Modifications copyright (C) 2025 Elasticsearch B.V.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.codec.vectors;
|
||||
|
||||
import org.apache.lucene.util.LongHeap;
|
||||
import org.apache.lucene.util.NumericUtils;
|
||||
|
||||
/**
|
||||
* Copied from and modified from Apache Lucene.
|
||||
*/
|
||||
class NeighborQueue {
|
||||
|
||||
private enum Order {
|
||||
MIN_HEAP {
|
||||
@Override
|
||||
long apply(long v) {
|
||||
return v;
|
||||
}
|
||||
},
|
||||
MAX_HEAP {
|
||||
@Override
|
||||
long apply(long v) {
|
||||
// This cannot be just `-v` since Long.MIN_VALUE doesn't have a positive counterpart. It
|
||||
// needs a function that returns MAX_VALUE for MIN_VALUE and vice-versa.
|
||||
return -1 - v;
|
||||
}
|
||||
};
|
||||
|
||||
abstract long apply(long v);
|
||||
}
|
||||
|
||||
private final LongHeap heap;
|
||||
private final Order order;
|
||||
|
||||
NeighborQueue(int initialSize, boolean maxHeap) {
|
||||
this.heap = new LongHeap(initialSize);
|
||||
this.order = maxHeap ? Order.MAX_HEAP : Order.MIN_HEAP;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the number of elements in the heap
|
||||
*/
|
||||
public int size() {
|
||||
return heap.size();
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a new graph arc, extending the storage as needed.
|
||||
*
|
||||
* @param newNode the neighbor node id
|
||||
* @param newScore the score of the neighbor, relative to some other node
|
||||
*/
|
||||
public void add(int newNode, float newScore) {
|
||||
heap.push(encode(newNode, newScore));
|
||||
}
|
||||
|
||||
/**
|
||||
* If the heap is not full (size is less than the initialSize provided to the constructor), adds a
|
||||
* new node-and-score element. If the heap is full, compares the score against the current top
|
||||
* score, and replaces the top element if newScore is better than (greater than unless the heap is
|
||||
* reversed), the current top score.
|
||||
*
|
||||
* @param newNode the neighbor node id
|
||||
* @param newScore the score of the neighbor, relative to some other node
|
||||
*/
|
||||
public boolean insertWithOverflow(int newNode, float newScore) {
|
||||
return heap.insertWithOverflow(encode(newNode, newScore));
|
||||
}
|
||||
|
||||
/**
|
||||
* Encodes the node ID and its similarity score as long, preserving the Lucene tie-breaking rule
|
||||
* that when two scores are equal, the smaller node ID must win.
|
||||
* @param node the node ID
|
||||
* @param score the node score
|
||||
* @return the encoded score, node ID
|
||||
*/
|
||||
private long encode(int node, float score) {
|
||||
return order.apply((((long) NumericUtils.floatToSortableInt(score)) << 32) | (0xFFFFFFFFL & ~node));
|
||||
}
|
||||
|
||||
/** Returns the top element's node id. */
|
||||
int topNode() {
|
||||
return decodeNodeId(heap.top());
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the top element's node score. For the min heap this is the minimum score. For the max
|
||||
* heap this is the maximum score.
|
||||
*/
|
||||
float topScore() {
|
||||
return decodeScore(heap.top());
|
||||
}
|
||||
|
||||
private float decodeScore(long heapValue) {
|
||||
return NumericUtils.sortableIntToFloat((int) (order.apply(heapValue) >> 32));
|
||||
}
|
||||
|
||||
private int decodeNodeId(long heapValue) {
|
||||
return (int) ~(order.apply(heapValue));
|
||||
}
|
||||
|
||||
/** Removes the top element and returns its node id. */
|
||||
public int pop() {
|
||||
return decodeNodeId(heap.pop());
|
||||
}
|
||||
|
||||
public void consumeNodes(int[] dest) {
|
||||
if (dest.length < size()) {
|
||||
throw new IllegalArgumentException("Destination array is too small. Expected at least " + size() + " elements.");
|
||||
}
|
||||
for (int i = 0; i < size(); i++) {
|
||||
dest[i] = decodeNodeId(heap.get(i + 1));
|
||||
}
|
||||
}
|
||||
|
||||
public int consumeNodesAndScoresMin(int[] dest, float[] scores) {
|
||||
if (dest.length < size() || scores.length < size()) {
|
||||
throw new IllegalArgumentException("Destination array is too small. Expected at least " + size() + " elements.");
|
||||
}
|
||||
float bestScore = Float.POSITIVE_INFINITY;
|
||||
int bestIdx = 0;
|
||||
for (int i = 0; i < size(); i++) {
|
||||
long heapValue = heap.get(i + 1);
|
||||
scores[i] = decodeScore(heapValue);
|
||||
dest[i] = decodeNodeId(heapValue);
|
||||
if (scores[i] < bestScore) {
|
||||
bestScore = scores[i];
|
||||
bestIdx = i;
|
||||
}
|
||||
}
|
||||
return bestIdx;
|
||||
}
|
||||
|
||||
public void clear() {
|
||||
heap.clear();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "Neighbors[" + heap.size() + "]";
|
||||
}
|
||||
}
|
|
@ -7,3 +7,4 @@ org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat
|
|||
org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat
|
||||
org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat
|
||||
org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat
|
||||
org.elasticsearch.index.codec.vectors.IVFVectorsFormat
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
package org.elasticsearch.index.codec.vectors;
|
||||
|
||||
import com.carrotsearch.randomizedtesting.generators.RandomPicks;
|
||||
|
||||
import org.apache.lucene.codecs.Codec;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
|
||||
import org.apache.lucene.tests.util.TestUtil;
|
||||
import org.elasticsearch.common.logging.LogConfigurator;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class IVFVectorsFormatTests extends BaseKnnVectorsFormatTestCase {
|
||||
|
||||
static {
|
||||
LogConfigurator.loadLog4jPlugins();
|
||||
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
|
||||
}
|
||||
KnnVectorsFormat format;
|
||||
|
||||
@Before
|
||||
@Override
|
||||
public void setUp() throws Exception {
|
||||
format = new IVFVectorsFormat(random().nextInt(10, 1000));
|
||||
super.setUp();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected VectorSimilarityFunction randomSimilarity() {
|
||||
return RandomPicks.randomFrom(
|
||||
random(),
|
||||
List.of(
|
||||
VectorSimilarityFunction.DOT_PRODUCT,
|
||||
VectorSimilarityFunction.EUCLIDEAN,
|
||||
VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected VectorEncoding randomVectorEncoding() {
|
||||
return VectorEncoding.FLOAT32;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void testSearchWithVisitedLimit() {
|
||||
// ivf doesn't enforce visitation limit
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Codec getCodec() {
|
||||
return TestUtil.alwaysKnnVectorsFormat(format);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,119 @@
|
|||
/*
|
||||
* @notice
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
* Modifications copyright (C) 2025 Elasticsearch B.V.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.codec.vectors;
|
||||
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
/**
|
||||
* copied and modified from Lucene
|
||||
*/
|
||||
public class NeighborQueueTests extends ESTestCase {
|
||||
public void testNeighborsProduct() {
|
||||
// make sure we have the sign correct
|
||||
NeighborQueue nn = new NeighborQueue(2, false);
|
||||
assertTrue(nn.insertWithOverflow(2, 0.5f));
|
||||
assertTrue(nn.insertWithOverflow(1, 0.2f));
|
||||
assertTrue(nn.insertWithOverflow(3, 1f));
|
||||
assertEquals(0.5f, nn.topScore(), 0);
|
||||
nn.pop();
|
||||
assertEquals(1f, nn.topScore(), 0);
|
||||
nn.pop();
|
||||
}
|
||||
|
||||
public void testNeighborsMaxHeap() {
|
||||
NeighborQueue nn = new NeighborQueue(2, true);
|
||||
assertTrue(nn.insertWithOverflow(2, 2));
|
||||
assertTrue(nn.insertWithOverflow(1, 1));
|
||||
assertFalse(nn.insertWithOverflow(3, 3));
|
||||
assertEquals(2f, nn.topScore(), 0);
|
||||
nn.pop();
|
||||
assertEquals(1f, nn.topScore(), 0);
|
||||
}
|
||||
|
||||
public void testTopMaxHeap() {
|
||||
NeighborQueue nn = new NeighborQueue(2, true);
|
||||
nn.add(1, 2);
|
||||
nn.add(2, 1);
|
||||
// lower scores are better; highest score on top
|
||||
assertEquals(2, nn.topScore(), 0);
|
||||
assertEquals(1, nn.topNode());
|
||||
}
|
||||
|
||||
public void testTopMinHeap() {
|
||||
NeighborQueue nn = new NeighborQueue(2, false);
|
||||
nn.add(1, 0.5f);
|
||||
nn.add(2, -0.5f);
|
||||
// higher scores are better; lowest score on top
|
||||
assertEquals(-0.5f, nn.topScore(), 0);
|
||||
assertEquals(2, nn.topNode());
|
||||
}
|
||||
|
||||
public void testClear() {
|
||||
NeighborQueue nn = new NeighborQueue(2, false);
|
||||
nn.add(1, 1.1f);
|
||||
nn.add(2, -2.2f);
|
||||
nn.clear();
|
||||
|
||||
assertEquals(0, nn.size());
|
||||
}
|
||||
|
||||
public void testMaxSizeQueue() {
|
||||
NeighborQueue nn = new NeighborQueue(2, false);
|
||||
nn.add(1, 1);
|
||||
nn.add(2, 2);
|
||||
assertEquals(2, nn.size());
|
||||
assertEquals(1, nn.topNode());
|
||||
|
||||
// insertWithOverflow does not extend the queue
|
||||
nn.insertWithOverflow(3, 3);
|
||||
assertEquals(2, nn.size());
|
||||
assertEquals(2, nn.topNode());
|
||||
|
||||
// add does extend the queue beyond maxSize
|
||||
nn.add(4, 1);
|
||||
assertEquals(3, nn.size());
|
||||
}
|
||||
|
||||
public void testUnboundedQueue() {
|
||||
NeighborQueue nn = new NeighborQueue(1, true);
|
||||
float maxScore = -2;
|
||||
int maxNode = -1;
|
||||
for (int i = 0; i < 256; i++) {
|
||||
// initial size is 32
|
||||
float score = random().nextFloat();
|
||||
if (score > maxScore) {
|
||||
maxScore = score;
|
||||
maxNode = i;
|
||||
}
|
||||
nn.add(i, score);
|
||||
}
|
||||
assertEquals(maxScore, nn.topScore(), 0);
|
||||
assertEquals(maxNode, nn.topNode());
|
||||
}
|
||||
|
||||
public void testInvalidArguments() {
|
||||
expectThrows(IllegalArgumentException.class, () -> new NeighborQueue(0, false));
|
||||
}
|
||||
|
||||
public void testToString() {
|
||||
assertEquals("Neighbors[0]", new NeighborQueue(2, false).toString());
|
||||
}
|
||||
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue