mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 09:28:55 -04:00
This reverts commit 8a17a5ed5f
.
reapplying ivf format, but with a fix.
This commit is contained in:
parent
d07ec0cc44
commit
1324ee0115
23 changed files with 2576 additions and 3 deletions
|
@ -17,7 +17,7 @@ import org.apache.lucene.store.MMapDirectory;
|
||||||
import org.apache.lucene.util.VectorUtil;
|
import org.apache.lucene.util.VectorUtil;
|
||||||
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
|
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
|
||||||
import org.elasticsearch.common.logging.LogConfigurator;
|
import org.elasticsearch.common.logging.LogConfigurator;
|
||||||
import org.elasticsearch.simdvec.internal.vectorization.ES91OSQVectorsScorer;
|
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
||||||
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
|
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
|
||||||
import org.openjdk.jmh.annotations.Benchmark;
|
import org.openjdk.jmh.annotations.Benchmark;
|
||||||
import org.openjdk.jmh.annotations.BenchmarkMode;
|
import org.openjdk.jmh.annotations.BenchmarkMode;
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.simdvec.internal.vectorization;
|
package org.elasticsearch.simdvec;
|
||||||
|
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.store.IndexInput;
|
import org.apache.lucene.store.IndexInput;
|
|
@ -9,11 +9,13 @@
|
||||||
|
|
||||||
package org.elasticsearch.simdvec;
|
package org.elasticsearch.simdvec;
|
||||||
|
|
||||||
|
import org.apache.lucene.store.IndexInput;
|
||||||
import org.apache.lucene.util.BitUtil;
|
import org.apache.lucene.util.BitUtil;
|
||||||
import org.apache.lucene.util.Constants;
|
import org.apache.lucene.util.Constants;
|
||||||
import org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport;
|
import org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport;
|
||||||
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
|
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
import java.lang.invoke.MethodHandle;
|
import java.lang.invoke.MethodHandle;
|
||||||
import java.lang.invoke.MethodHandles;
|
import java.lang.invoke.MethodHandles;
|
||||||
import java.lang.invoke.MethodType;
|
import java.lang.invoke.MethodType;
|
||||||
|
@ -41,6 +43,10 @@ public class ESVectorUtil {
|
||||||
|
|
||||||
private static final ESVectorUtilSupport IMPL = ESVectorizationProvider.getInstance().getVectorUtilSupport();
|
private static final ESVectorUtilSupport IMPL = ESVectorizationProvider.getInstance().getVectorUtilSupport();
|
||||||
|
|
||||||
|
public static ES91OSQVectorsScorer getES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException {
|
||||||
|
return ESVectorizationProvider.getInstance().newES91OSQVectorsScorer(input, dimension);
|
||||||
|
}
|
||||||
|
|
||||||
public static long ipByteBinByte(byte[] q, byte[] d) {
|
public static long ipByteBinByte(byte[] q, byte[] d) {
|
||||||
if (q.length != d.length * B_QUERY) {
|
if (q.length != d.length * B_QUERY) {
|
||||||
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + B_QUERY + " x " + d.length);
|
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + B_QUERY + " x " + d.length);
|
||||||
|
@ -211,4 +217,40 @@ public class ESVectorUtil {
|
||||||
assert stats.length == 6;
|
assert stats.length == 6;
|
||||||
IMPL.centerAndCalculateOSQStatsDp(target, centroid, centered, stats);
|
IMPL.centerAndCalculateOSQStatsDp(target, centroid, centered, stats);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculates the difference between two vectors and stores the result in a third vector.
|
||||||
|
* @param v1 the first vector
|
||||||
|
* @param v2 the second vector
|
||||||
|
* @param result the result vector, must be the same length as the input vectors
|
||||||
|
*/
|
||||||
|
public static void subtract(float[] v1, float[] v2, float[] result) {
|
||||||
|
if (v1.length != v2.length) {
|
||||||
|
throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + v2.length);
|
||||||
|
}
|
||||||
|
if (result.length != v1.length) {
|
||||||
|
throw new IllegalArgumentException("vector dimensions differ: " + result.length + "!=" + v1.length);
|
||||||
|
}
|
||||||
|
for (int i = 0; i < v1.length; i++) {
|
||||||
|
result[i] = v1[i] - v2[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* calculates the spill-over score for a vector and a centroid, given its residual with
|
||||||
|
* its actually nearest centroid
|
||||||
|
* @param v1 the vector
|
||||||
|
* @param centroid the centroid
|
||||||
|
* @param originalResidual the residual with the actually nearest centroid
|
||||||
|
* @return the spill-over score (soar)
|
||||||
|
*/
|
||||||
|
public static float soarResidual(float[] v1, float[] centroid, float[] originalResidual) {
|
||||||
|
if (v1.length != centroid.length) {
|
||||||
|
throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + centroid.length);
|
||||||
|
}
|
||||||
|
if (originalResidual.length != v1.length) {
|
||||||
|
throw new IllegalArgumentException("vector dimensions differ: " + originalResidual.length + "!=" + v1.length);
|
||||||
|
}
|
||||||
|
return IMPL.soarResidual(v1, centroid, originalResidual);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -138,6 +138,18 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
|
||||||
stats[5] = centroidDot;
|
stats[5] = centroidDot;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float soarResidual(float[] v1, float[] centroid, float[] originalResidual) {
|
||||||
|
assert v1.length == centroid.length;
|
||||||
|
assert v1.length == originalResidual.length;
|
||||||
|
float proj = 0;
|
||||||
|
for (int i = 0; i < v1.length; i++) {
|
||||||
|
float djk = v1[i] - centroid[i];
|
||||||
|
proj = fma(djk, originalResidual[i], proj);
|
||||||
|
}
|
||||||
|
return proj;
|
||||||
|
}
|
||||||
|
|
||||||
public static int ipByteBitImpl(byte[] q, byte[] d) {
|
public static int ipByteBitImpl(byte[] q, byte[] d) {
|
||||||
return ipByteBitImpl(q, d, 0);
|
return ipByteBitImpl(q, d, 0);
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
package org.elasticsearch.simdvec.internal.vectorization;
|
package org.elasticsearch.simdvec.internal.vectorization;
|
||||||
|
|
||||||
import org.apache.lucene.store.IndexInput;
|
import org.apache.lucene.store.IndexInput;
|
||||||
|
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
||||||
|
|
|
@ -28,4 +28,7 @@ public interface ESVectorUtilSupport {
|
||||||
void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats);
|
void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats);
|
||||||
|
|
||||||
void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats);
|
void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats);
|
||||||
|
|
||||||
|
float soarResidual(float[] v1, float[] centroid, float[] originalResidual);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
package org.elasticsearch.simdvec.internal.vectorization;
|
package org.elasticsearch.simdvec.internal.vectorization;
|
||||||
|
|
||||||
import org.apache.lucene.store.IndexInput;
|
import org.apache.lucene.store.IndexInput;
|
||||||
|
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
|
@ -13,6 +13,7 @@ import org.apache.lucene.store.IndexInput;
|
||||||
import org.apache.lucene.util.Constants;
|
import org.apache.lucene.util.Constants;
|
||||||
import org.elasticsearch.logging.LogManager;
|
import org.elasticsearch.logging.LogManager;
|
||||||
import org.elasticsearch.logging.Logger;
|
import org.elasticsearch.logging.Logger;
|
||||||
|
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Locale;
|
import java.util.Locale;
|
||||||
|
|
|
@ -20,6 +20,7 @@ import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.store.IndexInput;
|
import org.apache.lucene.store.IndexInput;
|
||||||
import org.apache.lucene.util.VectorUtil;
|
import org.apache.lucene.util.VectorUtil;
|
||||||
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
|
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
|
||||||
|
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.lang.foreign.MemorySegment;
|
import java.lang.foreign.MemorySegment;
|
||||||
|
|
|
@ -367,6 +367,49 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
|
||||||
return (1f - lambda) * xe * xe / norm2 + lambda * e;
|
return (1f - lambda) * xe * xe / norm2 + lambda * e;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float soarResidual(float[] v1, float[] centroid, float[] originalResidual) {
|
||||||
|
assert v1.length == centroid.length;
|
||||||
|
assert v1.length == originalResidual.length;
|
||||||
|
float proj = 0;
|
||||||
|
int i = 0;
|
||||||
|
if (v1.length > 2 * FLOAT_SPECIES.length()) {
|
||||||
|
FloatVector projVec1 = FloatVector.zero(FLOAT_SPECIES);
|
||||||
|
FloatVector projVec2 = FloatVector.zero(FLOAT_SPECIES);
|
||||||
|
int unrolledLimit = FLOAT_SPECIES.loopBound(v1.length) - FLOAT_SPECIES.length();
|
||||||
|
for (; i < unrolledLimit; i += 2 * FLOAT_SPECIES.length()) {
|
||||||
|
// one
|
||||||
|
FloatVector v1Vec0 = FloatVector.fromArray(FLOAT_SPECIES, v1, i);
|
||||||
|
FloatVector centroidVec0 = FloatVector.fromArray(FLOAT_SPECIES, centroid, i);
|
||||||
|
FloatVector originalResidualVec0 = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i);
|
||||||
|
FloatVector djkVec0 = v1Vec0.sub(centroidVec0);
|
||||||
|
projVec1 = fma(djkVec0, originalResidualVec0, projVec1);
|
||||||
|
|
||||||
|
// two
|
||||||
|
FloatVector v1Vec1 = FloatVector.fromArray(FLOAT_SPECIES, v1, i + FLOAT_SPECIES.length());
|
||||||
|
FloatVector centroidVec1 = FloatVector.fromArray(FLOAT_SPECIES, centroid, i + FLOAT_SPECIES.length());
|
||||||
|
FloatVector originalResidualVec1 = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i + FLOAT_SPECIES.length());
|
||||||
|
FloatVector djkVec1 = v1Vec1.sub(centroidVec1);
|
||||||
|
projVec2 = fma(djkVec1, originalResidualVec1, projVec2);
|
||||||
|
}
|
||||||
|
// vector tail
|
||||||
|
for (; i < FLOAT_SPECIES.loopBound(v1.length); i += FLOAT_SPECIES.length()) {
|
||||||
|
FloatVector v1Vec = FloatVector.fromArray(FLOAT_SPECIES, v1, i);
|
||||||
|
FloatVector centroidVec = FloatVector.fromArray(FLOAT_SPECIES, centroid, i);
|
||||||
|
FloatVector originalResidualVec = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i);
|
||||||
|
FloatVector djkVec = v1Vec.sub(centroidVec);
|
||||||
|
projVec1 = fma(djkVec, originalResidualVec, projVec1);
|
||||||
|
}
|
||||||
|
proj += projVec1.add(projVec2).reduceLanes(ADD);
|
||||||
|
}
|
||||||
|
// tail
|
||||||
|
for (; i < v1.length; i++) {
|
||||||
|
float djk = v1[i] - centroid[i];
|
||||||
|
proj = fma(djk, originalResidual[i], proj);
|
||||||
|
}
|
||||||
|
return proj;
|
||||||
|
}
|
||||||
|
|
||||||
private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
|
private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
|
||||||
private static final VectorSpecies<Byte> BYTE_SPECIES_256 = ByteVector.SPECIES_256;
|
private static final VectorSpecies<Byte> BYTE_SPECIES_256 = ByteVector.SPECIES_256;
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ package org.elasticsearch.simdvec.internal.vectorization;
|
||||||
|
|
||||||
import org.apache.lucene.store.IndexInput;
|
import org.apache.lucene.store.IndexInput;
|
||||||
import org.apache.lucene.store.MemorySegmentAccessInput;
|
import org.apache.lucene.store.MemorySegmentAccessInput;
|
||||||
|
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.lang.foreign.MemorySegment;
|
import java.lang.foreign.MemorySegment;
|
||||||
|
|
|
@ -268,6 +268,22 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testSoarOverspillScore() {
|
||||||
|
int size = random().nextInt(128, 512);
|
||||||
|
float deltaEps = 1e-5f * size;
|
||||||
|
var vector = new float[size];
|
||||||
|
var centroid = new float[size];
|
||||||
|
var preResidual = new float[size];
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
vector[i] = random().nextFloat();
|
||||||
|
centroid[i] = random().nextFloat();
|
||||||
|
preResidual[i] = random().nextFloat();
|
||||||
|
}
|
||||||
|
var expected = defaultedProvider.getVectorUtilSupport().soarResidual(vector, centroid, preResidual);
|
||||||
|
var result = defOrPanamaProvider.getVectorUtilSupport().soarResidual(vector, centroid, preResidual);
|
||||||
|
assertEquals(expected, result, deltaEps);
|
||||||
|
}
|
||||||
|
|
||||||
void testIpByteBinImpl(ToLongBiFunction<byte[], byte[]> ipByteBinFunc) {
|
void testIpByteBinImpl(ToLongBiFunction<byte[], byte[]> ipByteBinFunc) {
|
||||||
int iterations = atLeast(50);
|
int iterations = atLeast(50);
|
||||||
for (int i = 0; i < iterations; i++) {
|
for (int i = 0; i < iterations; i++) {
|
||||||
|
|
|
@ -16,6 +16,7 @@ import org.apache.lucene.store.IndexInput;
|
||||||
import org.apache.lucene.store.IndexOutput;
|
import org.apache.lucene.store.IndexOutput;
|
||||||
import org.apache.lucene.store.MMapDirectory;
|
import org.apache.lucene.store.MMapDirectory;
|
||||||
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
|
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
|
||||||
|
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
||||||
|
|
||||||
import static org.hamcrest.Matchers.lessThan;
|
import static org.hamcrest.Matchers.lessThan;
|
||||||
|
|
||||||
|
|
|
@ -454,7 +454,8 @@ module org.elasticsearch.server {
|
||||||
org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat,
|
org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat,
|
||||||
org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat,
|
org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat,
|
||||||
org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat,
|
org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat,
|
||||||
org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat;
|
org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat,
|
||||||
|
org.elasticsearch.index.codec.vectors.IVFVectorsFormat;
|
||||||
|
|
||||||
provides org.apache.lucene.codecs.Codec
|
provides org.apache.lucene.codecs.Codec
|
||||||
with
|
with
|
||||||
|
|
|
@ -0,0 +1,420 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the "Elastic License
|
||||||
|
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||||
|
* Public License v 1"; you may not use this file except in compliance with, at
|
||||||
|
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||||
|
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.index.codec.vectors;
|
||||||
|
|
||||||
|
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
|
||||||
|
import org.apache.lucene.index.FieldInfo;
|
||||||
|
import org.apache.lucene.index.FloatVectorValues;
|
||||||
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.apache.lucene.search.KnnCollector;
|
||||||
|
import org.apache.lucene.store.IndexInput;
|
||||||
|
import org.apache.lucene.util.ArrayUtil;
|
||||||
|
import org.apache.lucene.util.VectorUtil;
|
||||||
|
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||||
|
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
|
||||||
|
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
||||||
|
import org.elasticsearch.simdvec.ESVectorUtil;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.function.IntPredicate;
|
||||||
|
|
||||||
|
import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.QUERY_BITS;
|
||||||
|
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
|
||||||
|
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
|
||||||
|
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
|
||||||
|
import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize;
|
||||||
|
import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.transposeHalfByte;
|
||||||
|
import static org.elasticsearch.simdvec.ES91OSQVectorsScorer.BULK_SIZE;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Default implementation of {@link IVFVectorsReader}. It scores the posting lists centroids using
|
||||||
|
* brute force and then scores the top ones using the posting list.
|
||||||
|
*/
|
||||||
|
public class DefaultIVFVectorsReader extends IVFVectorsReader {
|
||||||
|
private static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1);
|
||||||
|
|
||||||
|
public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException {
|
||||||
|
super(state, rawVectorsReader);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
CentroidQueryScorer getCentroidScorer(
|
||||||
|
FieldInfo fieldInfo,
|
||||||
|
int numCentroids,
|
||||||
|
IndexInput centroids,
|
||||||
|
float[] targetQuery,
|
||||||
|
IndexInput clusters
|
||||||
|
) throws IOException {
|
||||||
|
FieldEntry fieldEntry = fields.get(fieldInfo.number);
|
||||||
|
float[] globalCentroid = fieldEntry.globalCentroid();
|
||||||
|
float globalCentroidDp = fieldEntry.globalCentroidDp();
|
||||||
|
OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
|
||||||
|
byte[] quantized = new byte[targetQuery.length];
|
||||||
|
float[] targetScratch = ArrayUtil.copyArray(targetQuery);
|
||||||
|
OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize(
|
||||||
|
targetScratch,
|
||||||
|
quantized,
|
||||||
|
(byte) 4,
|
||||||
|
globalCentroid
|
||||||
|
);
|
||||||
|
return new CentroidQueryScorer() {
|
||||||
|
int currentCentroid = -1;
|
||||||
|
private final byte[] quantizedCentroid = new byte[fieldInfo.getVectorDimension()];
|
||||||
|
private final float[] centroid = new float[fieldInfo.getVectorDimension()];
|
||||||
|
private final float[] centroidCorrectiveValues = new float[3];
|
||||||
|
private int quantizedCentroidComponentSum;
|
||||||
|
private final long centroidByteSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int size() {
|
||||||
|
return numCentroids;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float[] centroid(int centroidOrdinal) throws IOException {
|
||||||
|
readQuantizedAndRawCentroid(centroidOrdinal);
|
||||||
|
return centroid;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void readQuantizedAndRawCentroid(int centroidOrdinal) throws IOException {
|
||||||
|
if (centroidOrdinal == currentCentroid) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
centroids.seek(centroidOrdinal * centroidByteSize);
|
||||||
|
quantizedCentroidComponentSum = readQuantizedValue(centroids, quantizedCentroid, centroidCorrectiveValues);
|
||||||
|
centroids.seek(numCentroids * centroidByteSize + (long) Float.BYTES * quantizedCentroid.length * centroidOrdinal);
|
||||||
|
centroids.readFloats(centroid, 0, centroid.length);
|
||||||
|
currentCentroid = centroidOrdinal;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float score(int centroidOrdinal) throws IOException {
|
||||||
|
readQuantizedAndRawCentroid(centroidOrdinal);
|
||||||
|
return int4QuantizedScore(
|
||||||
|
quantized,
|
||||||
|
queryParams,
|
||||||
|
fieldInfo.getVectorDimension(),
|
||||||
|
quantizedCentroid,
|
||||||
|
centroidCorrectiveValues,
|
||||||
|
quantizedCentroidComponentSum,
|
||||||
|
globalCentroidDp,
|
||||||
|
fieldInfo.getVectorSimilarityFunction()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected FloatVectorValues getCentroids(IndexInput indexInput, int numCentroids, FieldInfo info) {
|
||||||
|
FieldEntry entry = fields.get(info.number);
|
||||||
|
if (entry == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return new OffHeapCentroidFloatVectorValues(numCentroids, indexInput, info.getVectorDimension());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
NeighborQueue scorePostingLists(FieldInfo fieldInfo, KnnCollector knnCollector, CentroidQueryScorer centroidQueryScorer, int nProbe)
|
||||||
|
throws IOException {
|
||||||
|
NeighborQueue neighborQueue = new NeighborQueue(centroidQueryScorer.size(), true);
|
||||||
|
// TODO Off heap scoring for quantized centroids?
|
||||||
|
for (int centroid = 0; centroid < centroidQueryScorer.size(); centroid++) {
|
||||||
|
neighborQueue.add(centroid, centroidQueryScorer.score(centroid));
|
||||||
|
}
|
||||||
|
return neighborQueue;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput indexInput, float[] target, IntPredicate needsScoring)
|
||||||
|
throws IOException {
|
||||||
|
FieldEntry entry = fields.get(fieldInfo.number);
|
||||||
|
return new MemorySegmentPostingsVisitor(target, indexInput, entry, fieldInfo, needsScoring);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO can we do this in off-heap blocks?
|
||||||
|
static float int4QuantizedScore(
|
||||||
|
byte[] quantizedQuery,
|
||||||
|
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
|
||||||
|
int dims,
|
||||||
|
byte[] binaryCode,
|
||||||
|
float[] targetCorrections,
|
||||||
|
int targetComponentSum,
|
||||||
|
float centroidDp,
|
||||||
|
VectorSimilarityFunction similarityFunction
|
||||||
|
) {
|
||||||
|
float qcDist = VectorUtil.int4DotProduct(quantizedQuery, binaryCode);
|
||||||
|
float ax = targetCorrections[0];
|
||||||
|
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
|
||||||
|
float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE;
|
||||||
|
float ay = queryCorrections.lowerInterval();
|
||||||
|
float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
|
||||||
|
float y1 = queryCorrections.quantizedComponentSum();
|
||||||
|
float score = ax * ay * dims + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
|
||||||
|
if (similarityFunction == EUCLIDEAN) {
|
||||||
|
score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2 * score;
|
||||||
|
return Math.max(1 / (1f + score), 0);
|
||||||
|
} else {
|
||||||
|
// For cosine and max inner product, we need to apply the additional correction, which is
|
||||||
|
// assumed to be the non-centered dot-product between the vector and the centroid
|
||||||
|
score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp;
|
||||||
|
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
|
||||||
|
return VectorUtil.scaleMaxInnerProductScore(score);
|
||||||
|
}
|
||||||
|
return Math.max((1f + score) / 2f, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static class OffHeapCentroidFloatVectorValues extends FloatVectorValues {
|
||||||
|
private final int numCentroids;
|
||||||
|
private final IndexInput input;
|
||||||
|
private final int dimension;
|
||||||
|
private final float[] centroid;
|
||||||
|
private final long centroidByteSize;
|
||||||
|
private int ord = -1;
|
||||||
|
|
||||||
|
OffHeapCentroidFloatVectorValues(int numCentroids, IndexInput input, int dimension) {
|
||||||
|
this.numCentroids = numCentroids;
|
||||||
|
this.input = input;
|
||||||
|
this.dimension = dimension;
|
||||||
|
this.centroid = new float[dimension];
|
||||||
|
this.centroidByteSize = dimension + 3 * Float.BYTES + Short.BYTES;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float[] vectorValue(int ord) throws IOException {
|
||||||
|
if (ord < 0 || ord >= numCentroids) {
|
||||||
|
throw new IllegalArgumentException("ord must be in [0, " + numCentroids + "]");
|
||||||
|
}
|
||||||
|
if (ord == this.ord) {
|
||||||
|
return centroid;
|
||||||
|
}
|
||||||
|
readQuantizedCentroid(ord);
|
||||||
|
return centroid;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void readQuantizedCentroid(int centroidOrdinal) throws IOException {
|
||||||
|
if (centroidOrdinal == ord) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
input.seek(numCentroids * centroidByteSize + (long) Float.BYTES * dimension * centroidOrdinal);
|
||||||
|
input.readFloats(centroid, 0, centroid.length);
|
||||||
|
ord = centroidOrdinal;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int dimension() {
|
||||||
|
return dimension;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int size() {
|
||||||
|
return numCentroids;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FloatVectorValues copy() throws IOException {
|
||||||
|
return new OffHeapCentroidFloatVectorValues(numCentroids, input.clone(), dimension);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class MemorySegmentPostingsVisitor implements PostingVisitor {
|
||||||
|
final long quantizedByteLength;
|
||||||
|
final IndexInput indexInput;
|
||||||
|
final float[] target;
|
||||||
|
final FieldEntry entry;
|
||||||
|
final FieldInfo fieldInfo;
|
||||||
|
final IntPredicate needsScoring;
|
||||||
|
private final ES91OSQVectorsScorer osqVectorsScorer;
|
||||||
|
final float[] scores = new float[BULK_SIZE];
|
||||||
|
final float[] correctionsLower = new float[BULK_SIZE];
|
||||||
|
final float[] correctionsUpper = new float[BULK_SIZE];
|
||||||
|
final int[] correctionsSum = new int[BULK_SIZE];
|
||||||
|
final float[] correctionsAdd = new float[BULK_SIZE];
|
||||||
|
|
||||||
|
int[] docIdsScratch = new int[0];
|
||||||
|
int vectors;
|
||||||
|
boolean quantized = false;
|
||||||
|
float centroidDp;
|
||||||
|
float[] centroid;
|
||||||
|
long slicePos;
|
||||||
|
OptimizedScalarQuantizer.QuantizationResult queryCorrections;
|
||||||
|
DocIdsWriter docIdsWriter = new DocIdsWriter();
|
||||||
|
|
||||||
|
final float[] scratch;
|
||||||
|
final byte[] quantizationScratch;
|
||||||
|
final byte[] quantizedQueryScratch;
|
||||||
|
final OptimizedScalarQuantizer quantizer;
|
||||||
|
final float[] correctiveValues = new float[3];
|
||||||
|
final long quantizedVectorByteSize;
|
||||||
|
|
||||||
|
MemorySegmentPostingsVisitor(
|
||||||
|
float[] target,
|
||||||
|
IndexInput indexInput,
|
||||||
|
FieldEntry entry,
|
||||||
|
FieldInfo fieldInfo,
|
||||||
|
IntPredicate needsScoring
|
||||||
|
) throws IOException {
|
||||||
|
this.target = target;
|
||||||
|
this.indexInput = indexInput;
|
||||||
|
this.entry = entry;
|
||||||
|
this.fieldInfo = fieldInfo;
|
||||||
|
this.needsScoring = needsScoring;
|
||||||
|
|
||||||
|
scratch = new float[target.length];
|
||||||
|
quantizationScratch = new byte[target.length];
|
||||||
|
final int discretizedDimensions = discretize(fieldInfo.getVectorDimension(), 64);
|
||||||
|
quantizedQueryScratch = new byte[QUERY_BITS * discretizedDimensions / 8];
|
||||||
|
quantizedByteLength = discretizedDimensions / 8 + (Float.BYTES * 3) + Short.BYTES;
|
||||||
|
quantizedVectorByteSize = (discretizedDimensions / 8);
|
||||||
|
quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
|
||||||
|
osqVectorsScorer = ESVectorUtil.getES91OSQVectorsScorer(indexInput, fieldInfo.getVectorDimension());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int resetPostingsScorer(int centroidOrdinal, float[] centroid) throws IOException {
|
||||||
|
quantized = false;
|
||||||
|
indexInput.seek(entry.postingListOffsets()[centroidOrdinal]);
|
||||||
|
vectors = indexInput.readVInt();
|
||||||
|
centroidDp = Float.intBitsToFloat(indexInput.readInt());
|
||||||
|
this.centroid = centroid;
|
||||||
|
// read the doc ids
|
||||||
|
docIdsScratch = vectors > docIdsScratch.length ? new int[vectors] : docIdsScratch;
|
||||||
|
docIdsWriter.readInts(indexInput, vectors, docIdsScratch);
|
||||||
|
slicePos = indexInput.getFilePointer();
|
||||||
|
return vectors;
|
||||||
|
}
|
||||||
|
|
||||||
|
void scoreIndividually(int offset) throws IOException {
|
||||||
|
// score individually, first the quantized byte chunk
|
||||||
|
for (int j = 0; j < BULK_SIZE; j++) {
|
||||||
|
int doc = docIdsScratch[j + offset];
|
||||||
|
if (doc != -1) {
|
||||||
|
indexInput.seek(slicePos + (offset * quantizedByteLength) + (j * quantizedVectorByteSize));
|
||||||
|
float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch);
|
||||||
|
scores[j] = qcDist;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// read in all corrections
|
||||||
|
indexInput.seek(slicePos + (offset * quantizedByteLength) + (BULK_SIZE * quantizedVectorByteSize));
|
||||||
|
indexInput.readFloats(correctionsLower, 0, BULK_SIZE);
|
||||||
|
indexInput.readFloats(correctionsUpper, 0, BULK_SIZE);
|
||||||
|
for (int j = 0; j < BULK_SIZE; j++) {
|
||||||
|
correctionsSum[j] = Short.toUnsignedInt(indexInput.readShort());
|
||||||
|
}
|
||||||
|
indexInput.readFloats(correctionsAdd, 0, BULK_SIZE);
|
||||||
|
// Now apply corrections
|
||||||
|
for (int j = 0; j < BULK_SIZE; j++) {
|
||||||
|
int doc = docIdsScratch[offset + j];
|
||||||
|
if (doc != -1) {
|
||||||
|
scores[j] = osqVectorsScorer.score(
|
||||||
|
queryCorrections,
|
||||||
|
fieldInfo.getVectorSimilarityFunction(),
|
||||||
|
centroidDp,
|
||||||
|
correctionsLower[j],
|
||||||
|
correctionsUpper[j],
|
||||||
|
correctionsSum[j],
|
||||||
|
correctionsAdd[j],
|
||||||
|
scores[j]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int visit(KnnCollector knnCollector) throws IOException {
|
||||||
|
// block processing
|
||||||
|
int scoredDocs = 0;
|
||||||
|
int limit = vectors - BULK_SIZE + 1;
|
||||||
|
int i = 0;
|
||||||
|
for (; i < limit; i += BULK_SIZE) {
|
||||||
|
int docsToScore = BULK_SIZE;
|
||||||
|
for (int j = 0; j < BULK_SIZE; j++) {
|
||||||
|
int doc = docIdsScratch[i + j];
|
||||||
|
if (needsScoring.test(doc) == false) {
|
||||||
|
docIdsScratch[i + j] = -1;
|
||||||
|
docsToScore--;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (docsToScore == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
quantizeQueryIfNecessary();
|
||||||
|
indexInput.seek(slicePos + i * quantizedByteLength);
|
||||||
|
if (docsToScore < BULK_SIZE / 2) {
|
||||||
|
scoreIndividually(i);
|
||||||
|
} else {
|
||||||
|
osqVectorsScorer.scoreBulk(
|
||||||
|
quantizedQueryScratch,
|
||||||
|
queryCorrections,
|
||||||
|
fieldInfo.getVectorSimilarityFunction(),
|
||||||
|
centroidDp,
|
||||||
|
scores
|
||||||
|
);
|
||||||
|
}
|
||||||
|
for (int j = 0; j < BULK_SIZE; j++) {
|
||||||
|
int doc = docIdsScratch[i + j];
|
||||||
|
if (doc != -1) {
|
||||||
|
scoredDocs++;
|
||||||
|
knnCollector.collect(doc, scores[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// process tail
|
||||||
|
for (; i < vectors; i++) {
|
||||||
|
int doc = docIdsScratch[i];
|
||||||
|
if (needsScoring.test(doc)) {
|
||||||
|
quantizeQueryIfNecessary();
|
||||||
|
indexInput.seek(slicePos + i * quantizedByteLength);
|
||||||
|
float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch);
|
||||||
|
indexInput.readFloats(correctiveValues, 0, 3);
|
||||||
|
final int quantizedComponentSum = Short.toUnsignedInt(indexInput.readShort());
|
||||||
|
float score = osqVectorsScorer.score(
|
||||||
|
queryCorrections,
|
||||||
|
fieldInfo.getVectorSimilarityFunction(),
|
||||||
|
centroidDp,
|
||||||
|
correctiveValues[0],
|
||||||
|
correctiveValues[1],
|
||||||
|
quantizedComponentSum,
|
||||||
|
correctiveValues[2],
|
||||||
|
qcDist
|
||||||
|
);
|
||||||
|
scoredDocs++;
|
||||||
|
knnCollector.collect(doc, score);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (scoredDocs > 0) {
|
||||||
|
knnCollector.incVisitedCount(scoredDocs);
|
||||||
|
}
|
||||||
|
return scoredDocs;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void quantizeQueryIfNecessary() {
|
||||||
|
if (quantized == false) {
|
||||||
|
System.arraycopy(target, 0, scratch, 0, target.length);
|
||||||
|
if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
|
||||||
|
VectorUtil.l2normalize(scratch);
|
||||||
|
}
|
||||||
|
queryCorrections = quantizer.scalarQuantize(scratch, quantizationScratch, (byte) 4, centroid);
|
||||||
|
transposeHalfByte(quantizationScratch, quantizedQueryScratch);
|
||||||
|
quantized = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static int readQuantizedValue(IndexInput indexInput, byte[] binaryValue, float[] corrections) throws IOException {
|
||||||
|
assert corrections.length == 3;
|
||||||
|
indexInput.readBytes(binaryValue, 0, binaryValue.length);
|
||||||
|
corrections[0] = Float.intBitsToFloat(indexInput.readInt());
|
||||||
|
corrections[1] = Float.intBitsToFloat(indexInput.readInt());
|
||||||
|
corrections[2] = Float.intBitsToFloat(indexInput.readInt());
|
||||||
|
return Short.toUnsignedInt(indexInput.readShort());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,736 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the "Elastic License
|
||||||
|
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||||
|
* Public License v 1"; you may not use this file except in compliance with, at
|
||||||
|
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||||
|
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.index.codec.vectors;
|
||||||
|
|
||||||
|
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
|
||||||
|
import org.apache.lucene.index.FieldInfo;
|
||||||
|
import org.apache.lucene.index.FloatVectorValues;
|
||||||
|
import org.apache.lucene.index.MergeState;
|
||||||
|
import org.apache.lucene.index.SegmentWriteState;
|
||||||
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.apache.lucene.internal.hppc.IntArrayList;
|
||||||
|
import org.apache.lucene.store.IndexInput;
|
||||||
|
import org.apache.lucene.store.IndexOutput;
|
||||||
|
import org.apache.lucene.util.InfoStream;
|
||||||
|
import org.apache.lucene.util.VectorUtil;
|
||||||
|
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
|
||||||
|
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
||||||
|
import org.elasticsearch.simdvec.ESVectorUtil;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.nio.ByteOrder;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.INDEX_BITS;
|
||||||
|
import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize;
|
||||||
|
import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.packAsBinary;
|
||||||
|
import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.IVF_VECTOR_COMPONENT;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Default implementation of {@link IVFVectorsWriter}. It uses {@link KMeans} algorithm to
|
||||||
|
* partition the vector space, and then stores the centroids an posting list in a sequential
|
||||||
|
* fashion.
|
||||||
|
*/
|
||||||
|
public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
|
||||||
|
|
||||||
|
static final float SOAR_LAMBDA = 1.0f;
|
||||||
|
// What percentage of the centroids do we do a second check on for SOAR assignment
|
||||||
|
static final float EXT_SOAR_LIMIT_CHECK_RATIO = 0.10f;
|
||||||
|
|
||||||
|
private final int vectorPerCluster;
|
||||||
|
|
||||||
|
public DefaultIVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVectorDelegate, int vectorPerCluster) throws IOException {
|
||||||
|
super(state, rawVectorDelegate);
|
||||||
|
this.vectorPerCluster = vectorPerCluster;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
CentroidAssignmentScorer calculateAndWriteCentroids(
|
||||||
|
FieldInfo fieldInfo,
|
||||||
|
FloatVectorValues floatVectorValues,
|
||||||
|
IndexOutput centroidOutput,
|
||||||
|
float[] globalCentroid
|
||||||
|
) throws IOException {
|
||||||
|
if (floatVectorValues.size() == 0) {
|
||||||
|
return CentroidAssignmentScorer.EMPTY;
|
||||||
|
}
|
||||||
|
// calculate the centroids
|
||||||
|
int maxNumClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1;
|
||||||
|
int desiredClusters = (int) Math.max(Math.sqrt(floatVectorValues.size()), maxNumClusters);
|
||||||
|
final KMeans.Results kMeans = KMeans.cluster(
|
||||||
|
floatVectorValues,
|
||||||
|
desiredClusters,
|
||||||
|
false,
|
||||||
|
42L,
|
||||||
|
KMeans.KmeansInitializationMethod.PLUS_PLUS,
|
||||||
|
null,
|
||||||
|
fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE,
|
||||||
|
1,
|
||||||
|
15,
|
||||||
|
desiredClusters * 256
|
||||||
|
);
|
||||||
|
float[][] centroids = kMeans.centroids();
|
||||||
|
// write them
|
||||||
|
writeCentroids(centroids, fieldInfo, globalCentroid, centroidOutput);
|
||||||
|
return new OnHeapCentroidAssignmentScorer(centroids);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
long[] buildAndWritePostingsLists(
|
||||||
|
FieldInfo fieldInfo,
|
||||||
|
InfoStream infoStream,
|
||||||
|
CentroidAssignmentScorer randomCentroidScorer,
|
||||||
|
FloatVectorValues floatVectorValues,
|
||||||
|
IndexOutput postingsOutput
|
||||||
|
) throws IOException {
|
||||||
|
IntArrayList[] clusters = new IntArrayList[randomCentroidScorer.size()];
|
||||||
|
for (int i = 0; i < randomCentroidScorer.size(); i++) {
|
||||||
|
clusters[i] = new IntArrayList(floatVectorValues.size() / randomCentroidScorer.size() / 4);
|
||||||
|
}
|
||||||
|
assignCentroids(randomCentroidScorer, floatVectorValues, clusters);
|
||||||
|
if (infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
|
||||||
|
printClusterQualityStatistics(clusters, infoStream);
|
||||||
|
}
|
||||||
|
// write the posting lists
|
||||||
|
final long[] offsets = new long[randomCentroidScorer.size()];
|
||||||
|
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
|
||||||
|
BinarizedFloatVectorValues binarizedByteVectorValues = new BinarizedFloatVectorValues(floatVectorValues, quantizer);
|
||||||
|
DocIdsWriter docIdsWriter = new DocIdsWriter();
|
||||||
|
for (int i = 0; i < randomCentroidScorer.size(); i++) {
|
||||||
|
float[] centroid = randomCentroidScorer.centroid(i);
|
||||||
|
binarizedByteVectorValues.centroid = centroid;
|
||||||
|
// TODO sort by distance to the centroid
|
||||||
|
IntArrayList cluster = clusters[i];
|
||||||
|
// TODO align???
|
||||||
|
offsets[i] = postingsOutput.getFilePointer();
|
||||||
|
int size = cluster.size();
|
||||||
|
postingsOutput.writeVInt(size);
|
||||||
|
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
|
||||||
|
// TODO we might want to consider putting the docIds in a separate file
|
||||||
|
// to aid with only having to fetch vectors from slower storage when they are required
|
||||||
|
// keeping them in the same file indicates we pull the entire file into cache
|
||||||
|
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster.get(j)), cluster.size(), postingsOutput);
|
||||||
|
writePostingList(cluster, postingsOutput, binarizedByteVectorValues);
|
||||||
|
}
|
||||||
|
return offsets;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void writePostingList(IntArrayList cluster, IndexOutput postingsOutput, BinarizedFloatVectorValues binarizedByteVectorValues)
|
||||||
|
throws IOException {
|
||||||
|
int limit = cluster.size() - ES91OSQVectorsScorer.BULK_SIZE + 1;
|
||||||
|
int cidx = 0;
|
||||||
|
OptimizedScalarQuantizer.QuantizationResult[] corrections =
|
||||||
|
new OptimizedScalarQuantizer.QuantizationResult[ES91OSQVectorsScorer.BULK_SIZE];
|
||||||
|
// Write vectors in bulks of ES91OSQVectorsScorer.BULK_SIZE.
|
||||||
|
for (; cidx < limit; cidx += ES91OSQVectorsScorer.BULK_SIZE) {
|
||||||
|
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
|
||||||
|
int ord = cluster.get(cidx + j);
|
||||||
|
byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord);
|
||||||
|
// write vector
|
||||||
|
postingsOutput.writeBytes(binaryValue, 0, binaryValue.length);
|
||||||
|
corrections[j] = binarizedByteVectorValues.getCorrectiveTerms(ord);
|
||||||
|
}
|
||||||
|
// write corrections
|
||||||
|
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
|
||||||
|
postingsOutput.writeInt(Float.floatToIntBits(corrections[j].lowerInterval()));
|
||||||
|
}
|
||||||
|
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
|
||||||
|
postingsOutput.writeInt(Float.floatToIntBits(corrections[j].upperInterval()));
|
||||||
|
}
|
||||||
|
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
|
||||||
|
int targetComponentSum = corrections[j].quantizedComponentSum();
|
||||||
|
assert targetComponentSum >= 0 && targetComponentSum <= 0xffff;
|
||||||
|
postingsOutput.writeShort((short) targetComponentSum);
|
||||||
|
}
|
||||||
|
for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
|
||||||
|
postingsOutput.writeInt(Float.floatToIntBits(corrections[j].additionalCorrection()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// write tail
|
||||||
|
for (; cidx < cluster.size(); cidx++) {
|
||||||
|
int ord = cluster.get(cidx);
|
||||||
|
// write vector
|
||||||
|
byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord);
|
||||||
|
OptimizedScalarQuantizer.QuantizationResult correction = binarizedByteVectorValues.getCorrectiveTerms(ord);
|
||||||
|
writeQuantizedValue(postingsOutput, binaryValue, correction);
|
||||||
|
binarizedByteVectorValues.getCorrectiveTerms(ord);
|
||||||
|
postingsOutput.writeBytes(binaryValue, 0, binaryValue.length);
|
||||||
|
postingsOutput.writeInt(Float.floatToIntBits(correction.lowerInterval()));
|
||||||
|
postingsOutput.writeInt(Float.floatToIntBits(correction.upperInterval()));
|
||||||
|
postingsOutput.writeInt(Float.floatToIntBits(correction.additionalCorrection()));
|
||||||
|
assert correction.quantizedComponentSum() >= 0 && correction.quantizedComponentSum() <= 0xffff;
|
||||||
|
postingsOutput.writeShort((short) correction.quantizedComponentSum());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
CentroidAssignmentScorer createCentroidScorer(
|
||||||
|
IndexInput centroidsInput,
|
||||||
|
int numCentroids,
|
||||||
|
FieldInfo fieldInfo,
|
||||||
|
float[] globalCentroid
|
||||||
|
) {
|
||||||
|
return new OffHeapCentroidAssignmentScorer(centroidsInput, numCentroids, fieldInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] globalCentroid, IndexOutput centroidOutput)
|
||||||
|
throws IOException {
|
||||||
|
final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
|
||||||
|
byte[] quantizedScratch = new byte[fieldInfo.getVectorDimension()];
|
||||||
|
float[] centroidScratch = new float[fieldInfo.getVectorDimension()];
|
||||||
|
// TODO do we want to store these distances as well for future use?
|
||||||
|
float[] distances = new float[centroids.length];
|
||||||
|
for (int i = 0; i < centroids.length; i++) {
|
||||||
|
distances[i] = VectorUtil.squareDistance(centroids[i], globalCentroid);
|
||||||
|
}
|
||||||
|
// sort the centroids by distance to globalCentroid, nearest (smallest distance), to furthest
|
||||||
|
// (largest)
|
||||||
|
for (int i = 0; i < centroids.length; i++) {
|
||||||
|
for (int j = i + 1; j < centroids.length; j++) {
|
||||||
|
if (distances[i] > distances[j]) {
|
||||||
|
float[] tmp = centroids[i];
|
||||||
|
centroids[i] = centroids[j];
|
||||||
|
centroids[j] = tmp;
|
||||||
|
float tmpDistance = distances[i];
|
||||||
|
distances[i] = distances[j];
|
||||||
|
distances[j] = tmpDistance;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (float[] centroid : centroids) {
|
||||||
|
System.arraycopy(centroid, 0, centroidScratch, 0, centroid.length);
|
||||||
|
OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(
|
||||||
|
centroidScratch,
|
||||||
|
quantizedScratch,
|
||||||
|
(byte) 4,
|
||||||
|
globalCentroid
|
||||||
|
);
|
||||||
|
writeQuantizedValue(centroidOutput, quantizedScratch, result);
|
||||||
|
}
|
||||||
|
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||||
|
for (float[] centroid : centroids) {
|
||||||
|
buffer.asFloatBuffer().put(centroid);
|
||||||
|
centroidOutput.writeBytes(buffer.array(), buffer.array().length);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static float[][] gatherInitCentroids(
|
||||||
|
List<FloatVectorValues> centroidList,
|
||||||
|
List<SegmentCentroid> segmentCentroids,
|
||||||
|
int desiredClusters,
|
||||||
|
FieldInfo fieldInfo,
|
||||||
|
MergeState mergeState
|
||||||
|
) throws IOException {
|
||||||
|
if (centroidList.size() == 0) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
long startTime = System.nanoTime();
|
||||||
|
// sort centroid list by floatvector size
|
||||||
|
FloatVectorValues baseSegment = centroidList.get(0);
|
||||||
|
for (var l : centroidList) {
|
||||||
|
if (l.size() > baseSegment.size()) {
|
||||||
|
baseSegment = l;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
float[] scratch = new float[fieldInfo.getVectorDimension()];
|
||||||
|
float minimumDistance = Float.MAX_VALUE;
|
||||||
|
for (int j = 0; j < baseSegment.size(); j++) {
|
||||||
|
System.arraycopy(baseSegment.vectorValue(j), 0, scratch, 0, baseSegment.dimension());
|
||||||
|
for (int k = j + 1; k < baseSegment.size(); k++) {
|
||||||
|
float d = VectorUtil.squareDistance(scratch, baseSegment.vectorValue(k));
|
||||||
|
if (d < minimumDistance) {
|
||||||
|
minimumDistance = d;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
|
||||||
|
mergeState.infoStream.message(
|
||||||
|
IVF_VECTOR_COMPONENT,
|
||||||
|
"Agglomerative cluster min distance: " + minimumDistance + " From biggest segment: " + baseSegment.size()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
int[] labels = new int[segmentCentroids.size()];
|
||||||
|
// loop over segments
|
||||||
|
int clusterIdx = 0;
|
||||||
|
// keep track of all inter-centroid distances,
|
||||||
|
// using less than centroid * centroid space (e.g. not keeping track of duplicates)
|
||||||
|
for (int i = 0; i < segmentCentroids.size(); i++) {
|
||||||
|
if (labels[i] == 0) {
|
||||||
|
clusterIdx += 1;
|
||||||
|
labels[i] = clusterIdx;
|
||||||
|
}
|
||||||
|
SegmentCentroid segmentCentroid = segmentCentroids.get(i);
|
||||||
|
System.arraycopy(
|
||||||
|
centroidList.get(segmentCentroid.segment()).vectorValue(segmentCentroid.centroid),
|
||||||
|
0,
|
||||||
|
scratch,
|
||||||
|
0,
|
||||||
|
baseSegment.dimension()
|
||||||
|
);
|
||||||
|
for (int j = i + 1; j < segmentCentroids.size(); j++) {
|
||||||
|
float d = VectorUtil.squareDistance(
|
||||||
|
scratch,
|
||||||
|
centroidList.get(segmentCentroids.get(j).segment()).vectorValue(segmentCentroids.get(j).centroid())
|
||||||
|
);
|
||||||
|
if (d < minimumDistance / 2) {
|
||||||
|
if (labels[j] == 0) {
|
||||||
|
labels[j] = labels[i];
|
||||||
|
} else {
|
||||||
|
for (int k = 0; k < labels.length; k++) {
|
||||||
|
if (labels[k] == labels[j]) {
|
||||||
|
labels[k] = labels[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
float[][] initCentroids = new float[clusterIdx][fieldInfo.getVectorDimension()];
|
||||||
|
int[] sum = new int[clusterIdx];
|
||||||
|
for (int i = 0; i < segmentCentroids.size(); i++) {
|
||||||
|
SegmentCentroid segmentCentroid = segmentCentroids.get(i);
|
||||||
|
int label = labels[i];
|
||||||
|
FloatVectorValues segment = centroidList.get(segmentCentroid.segment());
|
||||||
|
float[] vector = segment.vectorValue(segmentCentroid.centroid);
|
||||||
|
for (int j = 0; j < vector.length; j++) {
|
||||||
|
initCentroids[label - 1][j] += (vector[j] * segmentCentroid.centroidSize);
|
||||||
|
}
|
||||||
|
sum[label - 1] += segmentCentroid.centroidSize;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < initCentroids.length; i++) {
|
||||||
|
if (sum[i] == 0 || sum[i] == 1) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (int j = 0; j < initCentroids[i].length; j++) {
|
||||||
|
initCentroids[i][j] /= sum[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
|
||||||
|
mergeState.infoStream.message(
|
||||||
|
IVF_VECTOR_COMPONENT,
|
||||||
|
"Agglomerative cluster time ms: " + ((System.nanoTime() - startTime) / 1000000.0)
|
||||||
|
);
|
||||||
|
mergeState.infoStream.message(
|
||||||
|
IVF_VECTOR_COMPONENT,
|
||||||
|
"Gathered initCentroids:" + initCentroids.length + " for desired: " + desiredClusters
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return initCentroids;
|
||||||
|
}
|
||||||
|
|
||||||
|
record SegmentCentroid(int segment, int centroid, int centroidSize) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate the centroids for the given field and write them to the given
|
||||||
|
* temporary centroid output.
|
||||||
|
* When merging, we first bootstrap the KMeans algorithm with the centroids contained in the merging segments.
|
||||||
|
* To prevent centroids that are too similar from having an outsized impact, all centroids that are closer than
|
||||||
|
* the largest segments intra-cluster distance are merged into a single centroid.
|
||||||
|
* The resulting centroids are then used to initialize the KMeans algorithm.
|
||||||
|
*
|
||||||
|
* @param fieldInfo merging field info
|
||||||
|
* @param floatVectorValues the float vector values to merge
|
||||||
|
* @param temporaryCentroidOutput the temporary centroid output
|
||||||
|
* @param mergeState the merge state
|
||||||
|
* @param globalCentroid the global centroid, calculated by this method and used to quantize the centroids
|
||||||
|
* @return the number of centroids written
|
||||||
|
* @throws IOException if an I/O error occurs
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
protected int calculateAndWriteCentroids(
|
||||||
|
FieldInfo fieldInfo,
|
||||||
|
FloatVectorValues floatVectorValues,
|
||||||
|
IndexOutput temporaryCentroidOutput,
|
||||||
|
MergeState mergeState,
|
||||||
|
float[] globalCentroid
|
||||||
|
) throws IOException {
|
||||||
|
if (floatVectorValues.size() == 0) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
int maxNumClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1;
|
||||||
|
int desiredClusters = (int) Math.max(Math.sqrt(floatVectorValues.size()), maxNumClusters);
|
||||||
|
// init centroids from merge state
|
||||||
|
List<FloatVectorValues> centroidList = new ArrayList<>();
|
||||||
|
List<SegmentCentroid> segmentCentroids = new ArrayList<>(desiredClusters);
|
||||||
|
|
||||||
|
int segmentIdx = 0;
|
||||||
|
for (var reader : mergeState.knnVectorsReaders) {
|
||||||
|
IVFVectorsReader ivfVectorsReader = IVFVectorsFormat.getIVFReader(reader, fieldInfo.name);
|
||||||
|
if (ivfVectorsReader == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
FloatVectorValues centroid = ivfVectorsReader.getCentroids(fieldInfo);
|
||||||
|
if (centroid == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
centroidList.add(centroid);
|
||||||
|
for (int i = 0; i < centroid.size(); i++) {
|
||||||
|
int size = ivfVectorsReader.centroidSize(fieldInfo.name, i);
|
||||||
|
if (size == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
segmentCentroids.add(new SegmentCentroid(segmentIdx, i, size));
|
||||||
|
}
|
||||||
|
segmentIdx++;
|
||||||
|
}
|
||||||
|
|
||||||
|
float[][] initCentroids = gatherInitCentroids(centroidList, segmentCentroids, desiredClusters, fieldInfo, mergeState);
|
||||||
|
|
||||||
|
// FIXME: run a custom version of KMeans that is just better...
|
||||||
|
long nanoTime = System.nanoTime();
|
||||||
|
final KMeans.Results kMeans = KMeans.cluster(
|
||||||
|
floatVectorValues,
|
||||||
|
desiredClusters,
|
||||||
|
false,
|
||||||
|
42L,
|
||||||
|
KMeans.KmeansInitializationMethod.PLUS_PLUS,
|
||||||
|
initCentroids,
|
||||||
|
fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE,
|
||||||
|
1,
|
||||||
|
5,
|
||||||
|
desiredClusters * 64
|
||||||
|
);
|
||||||
|
if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
|
||||||
|
mergeState.infoStream.message(IVF_VECTOR_COMPONENT, "KMeans time ms: " + ((System.nanoTime() - nanoTime) / 1000000.0));
|
||||||
|
}
|
||||||
|
float[][] centroids = kMeans.centroids();
|
||||||
|
|
||||||
|
// write them
|
||||||
|
// calculate the global centroid from all the centroids:
|
||||||
|
for (float[] centroid : centroids) {
|
||||||
|
for (int j = 0; j < centroid.length; j++) {
|
||||||
|
globalCentroid[j] += centroid[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int j = 0; j < globalCentroid.length; j++) {
|
||||||
|
globalCentroid[j] /= centroids.length;
|
||||||
|
}
|
||||||
|
writeCentroids(centroids, fieldInfo, globalCentroid, temporaryCentroidOutput);
|
||||||
|
return centroids.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
long[] buildAndWritePostingsLists(
|
||||||
|
FieldInfo fieldInfo,
|
||||||
|
CentroidAssignmentScorer centroidAssignmentScorer,
|
||||||
|
FloatVectorValues floatVectorValues,
|
||||||
|
IndexOutput postingsOutput,
|
||||||
|
MergeState mergeState
|
||||||
|
) throws IOException {
|
||||||
|
IntArrayList[] clusters = new IntArrayList[centroidAssignmentScorer.size()];
|
||||||
|
for (int i = 0; i < centroidAssignmentScorer.size(); i++) {
|
||||||
|
clusters[i] = new IntArrayList(floatVectorValues.size() / centroidAssignmentScorer.size() / 4);
|
||||||
|
}
|
||||||
|
long nanoTime = System.nanoTime();
|
||||||
|
// Can we do a pre-filter by finding the nearest centroids to the original vector centroids?
|
||||||
|
// We need to be careful on vecOrd vs. doc as we need random access to the raw vector for posting list writing
|
||||||
|
assignCentroids(centroidAssignmentScorer, floatVectorValues, clusters);
|
||||||
|
if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
|
||||||
|
mergeState.infoStream.message(IVF_VECTOR_COMPONENT, "assignCentroids time ms: " + ((System.nanoTime() - nanoTime) / 1000000.0));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
|
||||||
|
printClusterQualityStatistics(clusters, mergeState.infoStream);
|
||||||
|
}
|
||||||
|
// write the posting lists
|
||||||
|
final long[] offsets = new long[centroidAssignmentScorer.size()];
|
||||||
|
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
|
||||||
|
BinarizedFloatVectorValues binarizedByteVectorValues = new BinarizedFloatVectorValues(floatVectorValues, quantizer);
|
||||||
|
DocIdsWriter docIdsWriter = new DocIdsWriter();
|
||||||
|
for (int i = 0; i < centroidAssignmentScorer.size(); i++) {
|
||||||
|
float[] centroid = centroidAssignmentScorer.centroid(i);
|
||||||
|
binarizedByteVectorValues.centroid = centroid;
|
||||||
|
// TODO: sort by distance to the centroid
|
||||||
|
IntArrayList cluster = clusters[i];
|
||||||
|
// TODO align???
|
||||||
|
offsets[i] = postingsOutput.getFilePointer();
|
||||||
|
int size = cluster.size();
|
||||||
|
postingsOutput.writeVInt(size);
|
||||||
|
postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid)));
|
||||||
|
// TODO we might want to consider putting the docIds in a separate file
|
||||||
|
// to aid with only having to fetch vectors from slower storage when they are required
|
||||||
|
// keeping them in the same file indicates we pull the entire file into cache
|
||||||
|
docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster.get(j)), size, postingsOutput);
|
||||||
|
writePostingList(cluster, postingsOutput, binarizedByteVectorValues);
|
||||||
|
}
|
||||||
|
return offsets;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void printClusterQualityStatistics(IntArrayList[] clusters, InfoStream infoStream) {
|
||||||
|
float min = Float.MAX_VALUE;
|
||||||
|
float max = Float.MIN_VALUE;
|
||||||
|
float mean = 0;
|
||||||
|
float m2 = 0;
|
||||||
|
// iteratively compute the variance & mean
|
||||||
|
int count = 0;
|
||||||
|
for (IntArrayList cluster : clusters) {
|
||||||
|
count += 1;
|
||||||
|
if (cluster == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
float delta = cluster.size() - mean;
|
||||||
|
mean += delta / count;
|
||||||
|
m2 += delta * (cluster.size() - mean);
|
||||||
|
min = Math.min(min, cluster.size());
|
||||||
|
max = Math.max(max, cluster.size());
|
||||||
|
}
|
||||||
|
float variance = m2 / (clusters.length - 1);
|
||||||
|
infoStream.message(
|
||||||
|
IVF_VECTOR_COMPONENT,
|
||||||
|
"Centroid count: "
|
||||||
|
+ clusters.length
|
||||||
|
+ " min: "
|
||||||
|
+ min
|
||||||
|
+ " max: "
|
||||||
|
+ max
|
||||||
|
+ " mean: "
|
||||||
|
+ mean
|
||||||
|
+ " stdDev: "
|
||||||
|
+ Math.sqrt(variance)
|
||||||
|
+ " variance: "
|
||||||
|
+ variance
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void assignCentroids(CentroidAssignmentScorer scorer, FloatVectorValues vectors, IntArrayList[] clusters) throws IOException {
|
||||||
|
int numCentroids = scorer.size();
|
||||||
|
// we at most will look at the EXT_SOAR_LIMIT_CHECK_RATIO nearest centroids if possible
|
||||||
|
int soarToCheck = (int) (numCentroids * EXT_SOAR_LIMIT_CHECK_RATIO);
|
||||||
|
int soarClusterCheckCount = Math.min(numCentroids - 1, soarToCheck);
|
||||||
|
NeighborQueue neighborsToCheck = new NeighborQueue(soarClusterCheckCount + 1, true);
|
||||||
|
OrdScoreIterator ordScoreIterator = new OrdScoreIterator(soarClusterCheckCount + 1);
|
||||||
|
float[] scratch = new float[vectors.dimension()];
|
||||||
|
for (int docID = 0; docID < vectors.size(); docID++) {
|
||||||
|
float[] vector = vectors.vectorValue(docID);
|
||||||
|
scorer.setScoringVector(vector);
|
||||||
|
int bestCentroid = 0;
|
||||||
|
float bestScore = Float.MAX_VALUE;
|
||||||
|
if (numCentroids > 1) {
|
||||||
|
for (short c = 0; c < numCentroids; c++) {
|
||||||
|
float squareDist = scorer.score(c);
|
||||||
|
neighborsToCheck.insertWithOverflow(c, squareDist);
|
||||||
|
}
|
||||||
|
// pop the best
|
||||||
|
int sz = neighborsToCheck.size();
|
||||||
|
int best = neighborsToCheck.consumeNodesAndScoresMin(ordScoreIterator.ords, ordScoreIterator.scores);
|
||||||
|
// Set the size to the number of neighbors we actually found
|
||||||
|
ordScoreIterator.setSize(sz);
|
||||||
|
bestScore = ordScoreIterator.getScore(best);
|
||||||
|
bestCentroid = ordScoreIterator.getOrd(best);
|
||||||
|
}
|
||||||
|
clusters[bestCentroid].add(docID);
|
||||||
|
if (soarClusterCheckCount > 0) {
|
||||||
|
assignCentroidSOAR(
|
||||||
|
ordScoreIterator,
|
||||||
|
docID,
|
||||||
|
bestCentroid,
|
||||||
|
scorer.centroid(bestCentroid),
|
||||||
|
bestScore,
|
||||||
|
scratch,
|
||||||
|
scorer,
|
||||||
|
vector,
|
||||||
|
clusters
|
||||||
|
);
|
||||||
|
}
|
||||||
|
neighborsToCheck.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void assignCentroidSOAR(
|
||||||
|
OrdScoreIterator centroidsToCheck,
|
||||||
|
int vecOrd,
|
||||||
|
int bestCentroidId,
|
||||||
|
float[] bestCentroid,
|
||||||
|
float bestScore,
|
||||||
|
float[] scratch,
|
||||||
|
CentroidAssignmentScorer scorer,
|
||||||
|
float[] vector,
|
||||||
|
IntArrayList[] clusters
|
||||||
|
) throws IOException {
|
||||||
|
ESVectorUtil.subtract(vector, bestCentroid, scratch);
|
||||||
|
int bestSecondaryCentroid = -1;
|
||||||
|
float minDist = Float.MAX_VALUE;
|
||||||
|
for (int i = 0; i < centroidsToCheck.size(); i++) {
|
||||||
|
float score = centroidsToCheck.getScore(i);
|
||||||
|
int centroidOrdinal = centroidsToCheck.getOrd(i);
|
||||||
|
if (centroidOrdinal == bestCentroidId) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
float proj = ESVectorUtil.soarResidual(vector, scorer.centroid(centroidOrdinal), scratch);
|
||||||
|
score += SOAR_LAMBDA * proj * proj / bestScore;
|
||||||
|
if (score < minDist) {
|
||||||
|
bestSecondaryCentroid = centroidOrdinal;
|
||||||
|
minDist = score;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (bestSecondaryCentroid != -1) {
|
||||||
|
clusters[bestSecondaryCentroid].add(vecOrd);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static class OrdScoreIterator {
|
||||||
|
private final int[] ords;
|
||||||
|
private final float[] scores;
|
||||||
|
private int idx = 0;
|
||||||
|
|
||||||
|
OrdScoreIterator(int size) {
|
||||||
|
this.ords = new int[size];
|
||||||
|
this.scores = new float[size];
|
||||||
|
}
|
||||||
|
|
||||||
|
int setSize(int size) {
|
||||||
|
if (size > ords.length) {
|
||||||
|
throw new IllegalArgumentException("size must be <= " + ords.length);
|
||||||
|
}
|
||||||
|
this.idx = size;
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
int getOrd(int idx) {
|
||||||
|
return ords[idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
float getScore(int idx) {
|
||||||
|
return scores[idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
int size() {
|
||||||
|
return idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO unify with OSQ format
|
||||||
|
static class BinarizedFloatVectorValues {
|
||||||
|
private OptimizedScalarQuantizer.QuantizationResult corrections;
|
||||||
|
private final byte[] binarized;
|
||||||
|
private final byte[] initQuantized;
|
||||||
|
private float[] centroid;
|
||||||
|
private final FloatVectorValues values;
|
||||||
|
private final OptimizedScalarQuantizer quantizer;
|
||||||
|
|
||||||
|
private int lastOrd = -1;
|
||||||
|
|
||||||
|
BinarizedFloatVectorValues(FloatVectorValues delegate, OptimizedScalarQuantizer quantizer) {
|
||||||
|
this.values = delegate;
|
||||||
|
this.quantizer = quantizer;
|
||||||
|
this.binarized = new byte[discretize(delegate.dimension(), 64) / 8];
|
||||||
|
this.initQuantized = new byte[delegate.dimension()];
|
||||||
|
}
|
||||||
|
|
||||||
|
public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) {
|
||||||
|
if (ord != lastOrd) {
|
||||||
|
throw new IllegalStateException(
|
||||||
|
"attempt to retrieve corrective terms for different ord " + ord + " than the quantization was done for: " + lastOrd
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return corrections;
|
||||||
|
}
|
||||||
|
|
||||||
|
public byte[] vectorValue(int ord) throws IOException {
|
||||||
|
if (ord != lastOrd) {
|
||||||
|
binarize(ord);
|
||||||
|
lastOrd = ord;
|
||||||
|
}
|
||||||
|
return binarized;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void binarize(int ord) throws IOException {
|
||||||
|
corrections = quantizer.scalarQuantize(values.vectorValue(ord), initQuantized, INDEX_BITS, centroid);
|
||||||
|
packAsBinary(initQuantized, binarized);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static class OffHeapCentroidAssignmentScorer implements CentroidAssignmentScorer {
|
||||||
|
private final IndexInput centroidsInput;
|
||||||
|
private final int numCentroids;
|
||||||
|
private final int dimension;
|
||||||
|
private final float[] scratch;
|
||||||
|
private float[] q;
|
||||||
|
private final long rawCentroidOffset;
|
||||||
|
private int currOrd = -1;
|
||||||
|
|
||||||
|
OffHeapCentroidAssignmentScorer(IndexInput centroidsInput, int numCentroids, FieldInfo info) {
|
||||||
|
this.centroidsInput = centroidsInput;
|
||||||
|
this.numCentroids = numCentroids;
|
||||||
|
this.dimension = info.getVectorDimension();
|
||||||
|
this.scratch = new float[dimension];
|
||||||
|
this.rawCentroidOffset = (dimension + 3 * Float.BYTES + Short.BYTES) * numCentroids;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int size() {
|
||||||
|
return numCentroids;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float[] centroid(int centroidOrdinal) throws IOException {
|
||||||
|
if (centroidOrdinal == currOrd) {
|
||||||
|
return scratch;
|
||||||
|
}
|
||||||
|
centroidsInput.seek(rawCentroidOffset + (long) centroidOrdinal * dimension * Float.BYTES);
|
||||||
|
centroidsInput.readFloats(scratch, 0, dimension);
|
||||||
|
this.currOrd = centroidOrdinal;
|
||||||
|
return scratch;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setScoringVector(float[] vector) {
|
||||||
|
q = vector;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float score(int centroidOrdinal) throws IOException {
|
||||||
|
return VectorUtil.squareDistance(centroid(centroidOrdinal), q);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO throw away rawCentroids
|
||||||
|
static class OnHeapCentroidAssignmentScorer implements CentroidAssignmentScorer {
|
||||||
|
private final float[][] centroids;
|
||||||
|
private float[] q;
|
||||||
|
|
||||||
|
OnHeapCentroidAssignmentScorer(float[][] centroids) {
|
||||||
|
this.centroids = centroids;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int size() {
|
||||||
|
return centroids.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setScoringVector(float[] vector) {
|
||||||
|
q = vector;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float[] centroid(int centroidOrdinal) throws IOException {
|
||||||
|
return centroids[centroidOrdinal];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float score(int centroidOrdinal) throws IOException {
|
||||||
|
return VectorUtil.squareDistance(centroid(centroidOrdinal), q);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections)
|
||||||
|
throws IOException {
|
||||||
|
indexOutput.writeBytes(binaryValue, binaryValue.length);
|
||||||
|
indexOutput.writeInt(Float.floatToIntBits(corrections.lowerInterval()));
|
||||||
|
indexOutput.writeInt(Float.floatToIntBits(corrections.upperInterval()));
|
||||||
|
indexOutput.writeInt(Float.floatToIntBits(corrections.additionalCorrection()));
|
||||||
|
assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff;
|
||||||
|
indexOutput.writeShort((short) corrections.quantizedComponentSum());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,110 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the "Elastic License
|
||||||
|
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||||
|
* Public License v 1"; you may not use this file except in compliance with, at
|
||||||
|
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||||
|
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.index.codec.vectors;
|
||||||
|
|
||||||
|
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||||
|
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
|
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||||
|
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
|
||||||
|
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
|
||||||
|
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
|
||||||
|
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
||||||
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
|
import org.apache.lucene.index.SegmentWriteState;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Codec format for Inverted File Vector indexes. This index expects to break the dimensional space
|
||||||
|
* into clusters and assign each vector to a cluster generating a posting list of vectors. Clusters
|
||||||
|
* are represented by centroids.
|
||||||
|
* The vector quantization format used here is a per-vector optimized scalar quantization. Also see {@link
|
||||||
|
* OptimizedScalarQuantizer}. Some of key features are:
|
||||||
|
*
|
||||||
|
* The format is stored in three files:
|
||||||
|
*
|
||||||
|
* <h2>.cenivf (centroid data) file</h2>
|
||||||
|
* <p> Which stores the raw and quantized centroid vectors.
|
||||||
|
*
|
||||||
|
* <h2>.clivf (cluster data) file</h2>
|
||||||
|
*
|
||||||
|
* <p> Stores the quantized vectors for each cluster, inline and stored in blocks. Additionally, the docIds of
|
||||||
|
* each vector is stored.
|
||||||
|
*
|
||||||
|
* <h2>.mivf (centroid metadata) file</h2>
|
||||||
|
*
|
||||||
|
* <p> Stores metadata including the number of centroids and their offsets in the clivf file</p>
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public class IVFVectorsFormat extends KnnVectorsFormat {
|
||||||
|
|
||||||
|
public static final String IVF_VECTOR_COMPONENT = "IVF";
|
||||||
|
public static final String NAME = "IVFVectorsFormat";
|
||||||
|
// centroid ordinals -> centroid values, offsets
|
||||||
|
public static final String CENTROID_EXTENSION = "cenivf";
|
||||||
|
// offsets contained in cen_ivf, [vector ordinals, actually just docIds](long varint), quantized
|
||||||
|
// vectors (OSQ bit)
|
||||||
|
public static final String CLUSTER_EXTENSION = "clivf";
|
||||||
|
static final String IVF_META_EXTENSION = "mivf";
|
||||||
|
|
||||||
|
public static final int VERSION_START = 0;
|
||||||
|
public static final int VERSION_CURRENT = VERSION_START;
|
||||||
|
|
||||||
|
private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat(
|
||||||
|
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
|
||||||
|
);
|
||||||
|
|
||||||
|
private static final int DEFAULT_VECTORS_PER_CLUSTER = 1000;
|
||||||
|
|
||||||
|
private final int vectorPerCluster;
|
||||||
|
|
||||||
|
public IVFVectorsFormat(int vectorPerCluster) {
|
||||||
|
super(NAME);
|
||||||
|
if (vectorPerCluster <= 0) {
|
||||||
|
throw new IllegalArgumentException("vectorPerCluster must be > 0");
|
||||||
|
}
|
||||||
|
this.vectorPerCluster = vectorPerCluster;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Constructs a format using the given graph construction parameters and scalar quantization. */
|
||||||
|
public IVFVectorsFormat() {
|
||||||
|
this(DEFAULT_VECTORS_PER_CLUSTER);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
|
||||||
|
return new DefaultIVFVectorsWriter(state, rawVectorFormat.fieldsWriter(state), vectorPerCluster);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
|
||||||
|
return new DefaultIVFVectorsReader(state, rawVectorFormat.fieldsReader(state));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getMaxDimensions(String fieldName) {
|
||||||
|
return 1024;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "IVFVectorFormat";
|
||||||
|
}
|
||||||
|
|
||||||
|
static IVFVectorsReader getIVFReader(KnnVectorsReader vectorsReader, String fieldName) {
|
||||||
|
if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
|
||||||
|
vectorsReader = candidateReader.getFieldReader(fieldName);
|
||||||
|
}
|
||||||
|
if (vectorsReader instanceof IVFVectorsReader reader) {
|
||||||
|
return reader;
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,354 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the "Elastic License
|
||||||
|
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||||
|
* Public License v 1"; you may not use this file except in compliance with, at
|
||||||
|
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||||
|
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.index.codec.vectors;
|
||||||
|
|
||||||
|
import org.apache.lucene.codecs.CodecUtil;
|
||||||
|
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
|
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
|
import org.apache.lucene.index.CorruptIndexException;
|
||||||
|
import org.apache.lucene.index.FieldInfo;
|
||||||
|
import org.apache.lucene.index.FieldInfos;
|
||||||
|
import org.apache.lucene.index.FloatVectorValues;
|
||||||
|
import org.apache.lucene.index.IndexFileNames;
|
||||||
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.apache.lucene.internal.hppc.IntObjectHashMap;
|
||||||
|
import org.apache.lucene.search.KnnCollector;
|
||||||
|
import org.apache.lucene.store.ChecksumIndexInput;
|
||||||
|
import org.apache.lucene.store.DataInput;
|
||||||
|
import org.apache.lucene.store.IOContext;
|
||||||
|
import org.apache.lucene.store.IndexInput;
|
||||||
|
import org.apache.lucene.util.BitSet;
|
||||||
|
import org.apache.lucene.util.Bits;
|
||||||
|
import org.apache.lucene.util.FixedBitSet;
|
||||||
|
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||||
|
import org.elasticsearch.core.IOUtils;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.function.IntPredicate;
|
||||||
|
|
||||||
|
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reader for IVF vectors. This reader is used to read the IVF vectors from the index.
|
||||||
|
*/
|
||||||
|
public abstract class IVFVectorsReader extends KnnVectorsReader {
|
||||||
|
|
||||||
|
private final IndexInput ivfCentroids, ivfClusters;
|
||||||
|
private final SegmentReadState state;
|
||||||
|
private final FieldInfos fieldInfos;
|
||||||
|
protected final IntObjectHashMap<FieldEntry> fields;
|
||||||
|
private final FlatVectorsReader rawVectorsReader;
|
||||||
|
|
||||||
|
protected IVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException {
|
||||||
|
this.state = state;
|
||||||
|
this.fieldInfos = state.fieldInfos;
|
||||||
|
this.rawVectorsReader = rawVectorsReader;
|
||||||
|
this.fields = new IntObjectHashMap<>();
|
||||||
|
String meta = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, IVFVectorsFormat.IVF_META_EXTENSION);
|
||||||
|
|
||||||
|
int versionMeta = -1;
|
||||||
|
boolean success = false;
|
||||||
|
try (ChecksumIndexInput ivfMeta = state.directory.openChecksumInput(meta)) {
|
||||||
|
Throwable priorE = null;
|
||||||
|
try {
|
||||||
|
versionMeta = CodecUtil.checkIndexHeader(
|
||||||
|
ivfMeta,
|
||||||
|
IVFVectorsFormat.NAME,
|
||||||
|
IVFVectorsFormat.VERSION_START,
|
||||||
|
IVFVectorsFormat.VERSION_CURRENT,
|
||||||
|
state.segmentInfo.getId(),
|
||||||
|
state.segmentSuffix
|
||||||
|
);
|
||||||
|
readFields(ivfMeta);
|
||||||
|
} catch (Throwable exception) {
|
||||||
|
priorE = exception;
|
||||||
|
} finally {
|
||||||
|
CodecUtil.checkFooter(ivfMeta, priorE);
|
||||||
|
}
|
||||||
|
ivfCentroids = openDataInput(state, versionMeta, IVFVectorsFormat.CENTROID_EXTENSION, IVFVectorsFormat.NAME, state.context);
|
||||||
|
ivfClusters = openDataInput(state, versionMeta, IVFVectorsFormat.CLUSTER_EXTENSION, IVFVectorsFormat.NAME, state.context);
|
||||||
|
success = true;
|
||||||
|
} finally {
|
||||||
|
if (success == false) {
|
||||||
|
IOUtils.closeWhileHandlingException(this);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract CentroidQueryScorer getCentroidScorer(
|
||||||
|
FieldInfo fieldInfo,
|
||||||
|
int numCentroids,
|
||||||
|
IndexInput centroids,
|
||||||
|
float[] target,
|
||||||
|
IndexInput clusters
|
||||||
|
) throws IOException;
|
||||||
|
|
||||||
|
protected abstract FloatVectorValues getCentroids(IndexInput indexInput, int numCentroids, FieldInfo info) throws IOException;
|
||||||
|
|
||||||
|
public FloatVectorValues getCentroids(FieldInfo fieldInfo) throws IOException {
|
||||||
|
FieldEntry entry = fields.get(fieldInfo.number);
|
||||||
|
if (entry == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return getCentroids(entry.centroidSlice(ivfCentroids), entry.postingListOffsets.length, fieldInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
int centroidSize(String fieldName, int centroidOrdinal) throws IOException {
|
||||||
|
FieldInfo fieldInfo = state.fieldInfos.fieldInfo(fieldName);
|
||||||
|
FieldEntry entry = fields.get(fieldInfo.number);
|
||||||
|
ivfClusters.seek(entry.postingListOffsets[centroidOrdinal]);
|
||||||
|
return ivfClusters.readVInt();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static IndexInput openDataInput(
|
||||||
|
SegmentReadState state,
|
||||||
|
int versionMeta,
|
||||||
|
String fileExtension,
|
||||||
|
String codecName,
|
||||||
|
IOContext context
|
||||||
|
) throws IOException {
|
||||||
|
final String fileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension);
|
||||||
|
final IndexInput in = state.directory.openInput(fileName, context);
|
||||||
|
boolean success = false;
|
||||||
|
try {
|
||||||
|
final int versionVectorData = CodecUtil.checkIndexHeader(
|
||||||
|
in,
|
||||||
|
codecName,
|
||||||
|
IVFVectorsFormat.VERSION_START,
|
||||||
|
IVFVectorsFormat.VERSION_CURRENT,
|
||||||
|
state.segmentInfo.getId(),
|
||||||
|
state.segmentSuffix
|
||||||
|
);
|
||||||
|
if (versionMeta != versionVectorData) {
|
||||||
|
throw new CorruptIndexException(
|
||||||
|
"Format versions mismatch: meta=" + versionMeta + ", " + codecName + "=" + versionVectorData,
|
||||||
|
in
|
||||||
|
);
|
||||||
|
}
|
||||||
|
CodecUtil.retrieveChecksum(in);
|
||||||
|
success = true;
|
||||||
|
return in;
|
||||||
|
} finally {
|
||||||
|
if (success == false) {
|
||||||
|
IOUtils.closeWhileHandlingException(in);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void readFields(ChecksumIndexInput meta) throws IOException {
|
||||||
|
for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) {
|
||||||
|
final FieldInfo info = fieldInfos.fieldInfo(fieldNumber);
|
||||||
|
if (info == null) {
|
||||||
|
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
|
||||||
|
}
|
||||||
|
fields.put(info.number, readField(meta, info));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException {
|
||||||
|
final VectorEncoding vectorEncoding = readVectorEncoding(input);
|
||||||
|
final VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
|
||||||
|
final long centroidOffset = input.readLong();
|
||||||
|
final long centroidLength = input.readLong();
|
||||||
|
final int numPostingLists = input.readVInt();
|
||||||
|
final long[] postingListOffsets = new long[numPostingLists];
|
||||||
|
for (int i = 0; i < numPostingLists; i++) {
|
||||||
|
postingListOffsets[i] = input.readLong();
|
||||||
|
}
|
||||||
|
final float[] globalCentroid = new float[info.getVectorDimension()];
|
||||||
|
float globalCentroidDp = 0;
|
||||||
|
if (numPostingLists > 0) {
|
||||||
|
input.readFloats(globalCentroid, 0, globalCentroid.length);
|
||||||
|
globalCentroidDp = Float.intBitsToFloat(input.readInt());
|
||||||
|
}
|
||||||
|
if (similarityFunction != info.getVectorSimilarityFunction()) {
|
||||||
|
throw new IllegalStateException(
|
||||||
|
"Inconsistent vector similarity function for field=\""
|
||||||
|
+ info.name
|
||||||
|
+ "\"; "
|
||||||
|
+ similarityFunction
|
||||||
|
+ " != "
|
||||||
|
+ info.getVectorSimilarityFunction()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return new FieldEntry(
|
||||||
|
similarityFunction,
|
||||||
|
vectorEncoding,
|
||||||
|
centroidOffset,
|
||||||
|
centroidLength,
|
||||||
|
postingListOffsets,
|
||||||
|
globalCentroid,
|
||||||
|
globalCentroidDp
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static VectorSimilarityFunction readSimilarityFunction(DataInput input) throws IOException {
|
||||||
|
final int i = input.readInt();
|
||||||
|
if (i < 0 || i >= SIMILARITY_FUNCTIONS.size()) {
|
||||||
|
throw new IllegalArgumentException("invalid distance function: " + i);
|
||||||
|
}
|
||||||
|
return SIMILARITY_FUNCTIONS.get(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static VectorEncoding readVectorEncoding(DataInput input) throws IOException {
|
||||||
|
final int encodingId = input.readInt();
|
||||||
|
if (encodingId < 0 || encodingId >= VectorEncoding.values().length) {
|
||||||
|
throw new CorruptIndexException("Invalid vector encoding id: " + encodingId, input);
|
||||||
|
}
|
||||||
|
return VectorEncoding.values()[encodingId];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final void checkIntegrity() throws IOException {
|
||||||
|
rawVectorsReader.checkIntegrity();
|
||||||
|
CodecUtil.checksumEntireFile(ivfCentroids);
|
||||||
|
CodecUtil.checksumEntireFile(ivfClusters);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||||
|
return rawVectorsReader.getFloatVectorValues(field);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||||
|
return rawVectorsReader.getByteVectorValues(field);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected float[] getGlobalCentroid(FieldInfo info) {
|
||||||
|
if (info == null || info.getVectorEncoding().equals(VectorEncoding.BYTE)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
FieldEntry entry = fields.get(info.number);
|
||||||
|
if (entry == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return entry.globalCentroid();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
|
||||||
|
final FieldInfo fieldInfo = state.fieldInfos.fieldInfo(field);
|
||||||
|
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32) == false) {
|
||||||
|
rawVectorsReader.search(field, target, knnCollector, acceptDocs);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// TODO add new ivf search strategy
|
||||||
|
int nProbe = 10;
|
||||||
|
float percentFiltered = 1f;
|
||||||
|
if (acceptDocs instanceof BitSet bitSet) {
|
||||||
|
percentFiltered = Math.max(0f, Math.min(1f, (float) bitSet.approximateCardinality() / bitSet.length()));
|
||||||
|
}
|
||||||
|
int numVectors = rawVectorsReader.getFloatVectorValues(field).size();
|
||||||
|
BitSet visitedDocs = new FixedBitSet(state.segmentInfo.maxDoc() + 1);
|
||||||
|
IntPredicate needsScoring = docId -> {
|
||||||
|
if (acceptDocs != null && acceptDocs.get(docId) == false) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return visitedDocs.getAndSet(docId) == false;
|
||||||
|
};
|
||||||
|
|
||||||
|
FieldEntry entry = fields.get(fieldInfo.number);
|
||||||
|
CentroidQueryScorer centroidQueryScorer = getCentroidScorer(
|
||||||
|
fieldInfo,
|
||||||
|
entry.postingListOffsets.length,
|
||||||
|
entry.centroidSlice(ivfCentroids),
|
||||||
|
target,
|
||||||
|
ivfClusters
|
||||||
|
);
|
||||||
|
final NeighborQueue centroidQueue = scorePostingLists(fieldInfo, knnCollector, centroidQueryScorer, nProbe);
|
||||||
|
PostingVisitor scorer = getPostingVisitor(fieldInfo, ivfClusters, target, needsScoring);
|
||||||
|
int centroidsVisited = 0;
|
||||||
|
long expectedDocs = 0;
|
||||||
|
long actualDocs = 0;
|
||||||
|
// initially we visit only the "centroids to search"
|
||||||
|
while (centroidQueue.size() > 0 && centroidsVisited < nProbe) {
|
||||||
|
++centroidsVisited;
|
||||||
|
// todo do we actually need to know the score???
|
||||||
|
int centroidOrdinal = centroidQueue.pop();
|
||||||
|
// todo do we need direct access to the raw centroid???
|
||||||
|
expectedDocs += scorer.resetPostingsScorer(centroidOrdinal, centroidQueryScorer.centroid(centroidOrdinal));
|
||||||
|
actualDocs += scorer.visit(knnCollector);
|
||||||
|
}
|
||||||
|
if (acceptDocs != null) {
|
||||||
|
float unfilteredRatioVisited = (float) expectedDocs / numVectors;
|
||||||
|
int filteredVectors = (int) Math.ceil(numVectors * percentFiltered);
|
||||||
|
float expectedScored = Math.min(2 * filteredVectors * unfilteredRatioVisited, expectedDocs / 2f);
|
||||||
|
while (centroidQueue.size() > 0 && (actualDocs < expectedScored || actualDocs < knnCollector.k())) {
|
||||||
|
int centroidOrdinal = centroidQueue.pop();
|
||||||
|
scorer.resetPostingsScorer(centroidOrdinal, centroidQueryScorer.centroid(centroidOrdinal));
|
||||||
|
actualDocs += scorer.visit(knnCollector);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
|
||||||
|
final FieldInfo fieldInfo = state.fieldInfos.fieldInfo(field);
|
||||||
|
final ByteVectorValues values = rawVectorsReader.getByteVectorValues(field);
|
||||||
|
for (int i = 0; i < values.size(); i++) {
|
||||||
|
final float score = fieldInfo.getVectorSimilarityFunction().compare(target, values.vectorValue(i));
|
||||||
|
knnCollector.collect(values.ordToDoc(i), score);
|
||||||
|
if (knnCollector.earlyTerminated()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract NeighborQueue scorePostingLists(
|
||||||
|
FieldInfo fieldInfo,
|
||||||
|
KnnCollector knnCollector,
|
||||||
|
CentroidQueryScorer centroidQueryScorer,
|
||||||
|
int nProbe
|
||||||
|
) throws IOException;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() throws IOException {
|
||||||
|
IOUtils.close(rawVectorsReader, ivfCentroids, ivfClusters);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected record FieldEntry(
|
||||||
|
VectorSimilarityFunction similarityFunction,
|
||||||
|
VectorEncoding vectorEncoding,
|
||||||
|
long centroidOffset,
|
||||||
|
long centroidLength,
|
||||||
|
long[] postingListOffsets,
|
||||||
|
float[] globalCentroid,
|
||||||
|
float globalCentroidDp
|
||||||
|
) {
|
||||||
|
IndexInput centroidSlice(IndexInput centroidFile) throws IOException {
|
||||||
|
return centroidFile.slice("centroids", centroidOffset, centroidLength);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput postingsLists, float[] target, IntPredicate needsScoring)
|
||||||
|
throws IOException;
|
||||||
|
|
||||||
|
interface CentroidQueryScorer {
|
||||||
|
int size();
|
||||||
|
|
||||||
|
float[] centroid(int centroidOrdinal) throws IOException;
|
||||||
|
|
||||||
|
float score(int centroidOrdinal) throws IOException;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface PostingVisitor {
|
||||||
|
// TODO maybe we can not specifically pass the centroid...
|
||||||
|
|
||||||
|
/** returns the number of documents in the posting list */
|
||||||
|
int resetPostingsScorer(int centroidOrdinal, float[] centroid) throws IOException;
|
||||||
|
|
||||||
|
/** returns the number of scored documents */
|
||||||
|
int visit(KnnCollector collector) throws IOException;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,486 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the "Elastic License
|
||||||
|
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||||
|
* Public License v 1"; you may not use this file except in compliance with, at
|
||||||
|
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||||
|
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.index.codec.vectors;
|
||||||
|
|
||||||
|
import org.apache.lucene.codecs.CodecUtil;
|
||||||
|
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||||
|
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
|
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||||
|
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
|
||||||
|
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
|
||||||
|
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
||||||
|
import org.apache.lucene.index.FieldInfo;
|
||||||
|
import org.apache.lucene.index.FloatVectorValues;
|
||||||
|
import org.apache.lucene.index.IndexFileNames;
|
||||||
|
import org.apache.lucene.index.KnnVectorValues;
|
||||||
|
import org.apache.lucene.index.MergeState;
|
||||||
|
import org.apache.lucene.index.SegmentWriteState;
|
||||||
|
import org.apache.lucene.index.Sorter;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
|
import org.apache.lucene.store.IOContext;
|
||||||
|
import org.apache.lucene.store.IndexInput;
|
||||||
|
import org.apache.lucene.store.IndexOutput;
|
||||||
|
import org.apache.lucene.util.InfoStream;
|
||||||
|
import org.apache.lucene.util.VectorUtil;
|
||||||
|
import org.elasticsearch.core.IOUtils;
|
||||||
|
import org.elasticsearch.core.SuppressForbidden;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.UncheckedIOException;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.nio.ByteOrder;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
|
||||||
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Base class for IVF vectors writer.
|
||||||
|
*/
|
||||||
|
public abstract class IVFVectorsWriter extends KnnVectorsWriter {
|
||||||
|
|
||||||
|
private final List<FieldWriter> fieldWriters = new ArrayList<>();
|
||||||
|
private final IndexOutput ivfCentroids, ivfClusters;
|
||||||
|
private final IndexOutput ivfMeta;
|
||||||
|
private final FlatVectorsWriter rawVectorDelegate;
|
||||||
|
private final SegmentWriteState segmentWriteState;
|
||||||
|
|
||||||
|
protected IVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVectorDelegate) throws IOException {
|
||||||
|
this.segmentWriteState = state;
|
||||||
|
this.rawVectorDelegate = rawVectorDelegate;
|
||||||
|
final String metaFileName = IndexFileNames.segmentFileName(
|
||||||
|
state.segmentInfo.name,
|
||||||
|
state.segmentSuffix,
|
||||||
|
IVFVectorsFormat.IVF_META_EXTENSION
|
||||||
|
);
|
||||||
|
|
||||||
|
final String ivfCentroidsFileName = IndexFileNames.segmentFileName(
|
||||||
|
state.segmentInfo.name,
|
||||||
|
state.segmentSuffix,
|
||||||
|
IVFVectorsFormat.CENTROID_EXTENSION
|
||||||
|
);
|
||||||
|
final String ivfClustersFileName = IndexFileNames.segmentFileName(
|
||||||
|
state.segmentInfo.name,
|
||||||
|
state.segmentSuffix,
|
||||||
|
IVFVectorsFormat.CLUSTER_EXTENSION
|
||||||
|
);
|
||||||
|
boolean success = false;
|
||||||
|
try {
|
||||||
|
ivfMeta = state.directory.createOutput(metaFileName, state.context);
|
||||||
|
CodecUtil.writeIndexHeader(
|
||||||
|
ivfMeta,
|
||||||
|
IVFVectorsFormat.NAME,
|
||||||
|
IVFVectorsFormat.VERSION_CURRENT,
|
||||||
|
state.segmentInfo.getId(),
|
||||||
|
state.segmentSuffix
|
||||||
|
);
|
||||||
|
ivfCentroids = state.directory.createOutput(ivfCentroidsFileName, state.context);
|
||||||
|
CodecUtil.writeIndexHeader(
|
||||||
|
ivfCentroids,
|
||||||
|
IVFVectorsFormat.NAME,
|
||||||
|
IVFVectorsFormat.VERSION_CURRENT,
|
||||||
|
state.segmentInfo.getId(),
|
||||||
|
state.segmentSuffix
|
||||||
|
);
|
||||||
|
ivfClusters = state.directory.createOutput(ivfClustersFileName, state.context);
|
||||||
|
CodecUtil.writeIndexHeader(
|
||||||
|
ivfClusters,
|
||||||
|
IVFVectorsFormat.NAME,
|
||||||
|
IVFVectorsFormat.VERSION_CURRENT,
|
||||||
|
state.segmentInfo.getId(),
|
||||||
|
state.segmentSuffix
|
||||||
|
);
|
||||||
|
success = true;
|
||||||
|
} finally {
|
||||||
|
if (success == false) {
|
||||||
|
IOUtils.closeWhileHandlingException(this);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
|
||||||
|
if (fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE) {
|
||||||
|
throw new IllegalArgumentException("IVF does not support cosine similarity");
|
||||||
|
}
|
||||||
|
final FlatFieldVectorsWriter<?> rawVectorDelegate = this.rawVectorDelegate.addField(fieldInfo);
|
||||||
|
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
final FlatFieldVectorsWriter<float[]> floatWriter = (FlatFieldVectorsWriter<float[]>) rawVectorDelegate;
|
||||||
|
fieldWriters.add(new FieldWriter(fieldInfo, floatWriter));
|
||||||
|
}
|
||||||
|
return rawVectorDelegate;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected abstract int calculateAndWriteCentroids(
|
||||||
|
FieldInfo fieldInfo,
|
||||||
|
FloatVectorValues floatVectorValues,
|
||||||
|
IndexOutput temporaryCentroidOutput,
|
||||||
|
MergeState mergeState,
|
||||||
|
float[] globalCentroid
|
||||||
|
) throws IOException;
|
||||||
|
|
||||||
|
abstract long[] buildAndWritePostingsLists(
|
||||||
|
FieldInfo fieldInfo,
|
||||||
|
CentroidAssignmentScorer scorer,
|
||||||
|
FloatVectorValues floatVectorValues,
|
||||||
|
IndexOutput postingsOutput,
|
||||||
|
MergeState mergeState
|
||||||
|
) throws IOException;
|
||||||
|
|
||||||
|
abstract CentroidAssignmentScorer calculateAndWriteCentroids(
|
||||||
|
FieldInfo fieldInfo,
|
||||||
|
FloatVectorValues floatVectorValues,
|
||||||
|
IndexOutput centroidOutput,
|
||||||
|
float[] globalCentroid
|
||||||
|
) throws IOException;
|
||||||
|
|
||||||
|
abstract long[] buildAndWritePostingsLists(
|
||||||
|
FieldInfo fieldInfo,
|
||||||
|
InfoStream infoStream,
|
||||||
|
CentroidAssignmentScorer scorer,
|
||||||
|
FloatVectorValues floatVectorValues,
|
||||||
|
IndexOutput postingsOutput
|
||||||
|
) throws IOException;
|
||||||
|
|
||||||
|
abstract CentroidAssignmentScorer createCentroidScorer(
|
||||||
|
IndexInput centroidsInput,
|
||||||
|
int numCentroids,
|
||||||
|
FieldInfo fieldInfo,
|
||||||
|
float[] globalCentroid
|
||||||
|
) throws IOException;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
|
||||||
|
rawVectorDelegate.flush(maxDoc, sortMap);
|
||||||
|
for (FieldWriter fieldWriter : fieldWriters) {
|
||||||
|
float[] globalCentroid = new float[fieldWriter.fieldInfo.getVectorDimension()];
|
||||||
|
// calculate global centroid
|
||||||
|
for (var vector : fieldWriter.delegate.getVectors()) {
|
||||||
|
for (int i = 0; i < globalCentroid.length; i++) {
|
||||||
|
globalCentroid[i] += vector[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < globalCentroid.length; i++) {
|
||||||
|
globalCentroid[i] /= fieldWriter.delegate.getVectors().size();
|
||||||
|
}
|
||||||
|
// build a float vector values with random access
|
||||||
|
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldWriter.fieldInfo, fieldWriter.delegate, maxDoc);
|
||||||
|
// build centroids
|
||||||
|
long centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES);
|
||||||
|
final CentroidAssignmentScorer centroidAssignmentScorer = calculateAndWriteCentroids(
|
||||||
|
fieldWriter.fieldInfo,
|
||||||
|
floatVectorValues,
|
||||||
|
ivfCentroids,
|
||||||
|
globalCentroid
|
||||||
|
);
|
||||||
|
long centroidLength = ivfCentroids.getFilePointer() - centroidOffset;
|
||||||
|
final long[] offsets = buildAndWritePostingsLists(
|
||||||
|
fieldWriter.fieldInfo,
|
||||||
|
segmentWriteState.infoStream,
|
||||||
|
centroidAssignmentScorer,
|
||||||
|
floatVectorValues,
|
||||||
|
ivfClusters
|
||||||
|
);
|
||||||
|
writeMeta(fieldWriter.fieldInfo, centroidOffset, centroidLength, offsets, globalCentroid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static FloatVectorValues getFloatVectorValues(
|
||||||
|
FieldInfo fieldInfo,
|
||||||
|
FlatFieldVectorsWriter<float[]> fieldVectorsWriter,
|
||||||
|
int maxDoc
|
||||||
|
) throws IOException {
|
||||||
|
List<float[]> vectors = fieldVectorsWriter.getVectors();
|
||||||
|
if (vectors.size() == maxDoc) {
|
||||||
|
return FloatVectorValues.fromFloats(vectors, fieldInfo.getVectorDimension());
|
||||||
|
}
|
||||||
|
final DocIdSetIterator iterator = fieldVectorsWriter.getDocsWithFieldSet().iterator();
|
||||||
|
final int[] docIds = new int[vectors.size()];
|
||||||
|
for (int i = 0; i < docIds.length; i++) {
|
||||||
|
docIds[i] = iterator.nextDoc();
|
||||||
|
}
|
||||||
|
assert iterator.nextDoc() == NO_MORE_DOCS;
|
||||||
|
return new FloatVectorValues() {
|
||||||
|
@Override
|
||||||
|
public float[] vectorValue(int ord) {
|
||||||
|
return vectors.get(ord);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FloatVectorValues copy() {
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int dimension() {
|
||||||
|
return fieldInfo.getVectorDimension();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int size() {
|
||||||
|
return vectors.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int ordToDoc(int ord) {
|
||||||
|
return docIds[ord];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
static IVFVectorsReader getIVFReader(KnnVectorsReader vectorsReader, String fieldName) {
|
||||||
|
if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
|
||||||
|
vectorsReader = candidateReader.getFieldReader(fieldName);
|
||||||
|
}
|
||||||
|
if (vectorsReader instanceof IVFVectorsReader reader) {
|
||||||
|
return reader;
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
@SuppressForbidden(reason = "require usage of Lucene's IOUtils#deleteFilesIgnoringExceptions(...)")
|
||||||
|
public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||||
|
rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
|
||||||
|
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
|
||||||
|
final int numVectors;
|
||||||
|
String tempRawVectorsFileName = null;
|
||||||
|
boolean success = false;
|
||||||
|
// build a float vector values with random access. In order to do that we dump the vectors to
|
||||||
|
// a temporary file
|
||||||
|
// and write the docID follow by the vector
|
||||||
|
try (IndexOutput out = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "ivf_", IOContext.DEFAULT)) {
|
||||||
|
tempRawVectorsFileName = out.getName();
|
||||||
|
// TODO do this better, we shouldn't have to write to a temp file, we should be able to
|
||||||
|
// to just from the merged vector values, the tricky part is the random access.
|
||||||
|
numVectors = writeFloatVectorValues(fieldInfo, out, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState));
|
||||||
|
CodecUtil.writeFooter(out);
|
||||||
|
success = true;
|
||||||
|
} finally {
|
||||||
|
if (success == false && tempRawVectorsFileName != null) {
|
||||||
|
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
try (IndexInput in = mergeState.segmentInfo.dir.openInput(tempRawVectorsFileName, IOContext.DEFAULT)) {
|
||||||
|
float[] calculatedGlobalCentroid = new float[fieldInfo.getVectorDimension()];
|
||||||
|
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, in, numVectors);
|
||||||
|
success = false;
|
||||||
|
CentroidAssignmentScorer centroidAssignmentScorer;
|
||||||
|
long centroidOffset;
|
||||||
|
long centroidLength;
|
||||||
|
String centroidTempName = null;
|
||||||
|
int numCentroids;
|
||||||
|
IndexOutput centroidTemp = null;
|
||||||
|
try {
|
||||||
|
centroidTemp = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "civf_", IOContext.DEFAULT);
|
||||||
|
centroidTempName = centroidTemp.getName();
|
||||||
|
numCentroids = calculateAndWriteCentroids(
|
||||||
|
fieldInfo,
|
||||||
|
floatVectorValues,
|
||||||
|
centroidTemp,
|
||||||
|
mergeState,
|
||||||
|
calculatedGlobalCentroid
|
||||||
|
);
|
||||||
|
success = true;
|
||||||
|
} finally {
|
||||||
|
if (success == false && centroidTempName != null) {
|
||||||
|
IOUtils.closeWhileHandlingException(centroidTemp);
|
||||||
|
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, centroidTempName);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
if (numCentroids == 0) {
|
||||||
|
centroidOffset = ivfCentroids.getFilePointer();
|
||||||
|
writeMeta(fieldInfo, centroidOffset, 0, new long[0], null);
|
||||||
|
CodecUtil.writeFooter(centroidTemp);
|
||||||
|
IOUtils.close(centroidTemp);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
CodecUtil.writeFooter(centroidTemp);
|
||||||
|
IOUtils.close(centroidTemp);
|
||||||
|
centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES);
|
||||||
|
try (IndexInput centroidInput = mergeState.segmentInfo.dir.openInput(centroidTempName, IOContext.DEFAULT)) {
|
||||||
|
ivfCentroids.copyBytes(centroidInput, centroidInput.length() - CodecUtil.footerLength());
|
||||||
|
centroidLength = ivfCentroids.getFilePointer() - centroidOffset;
|
||||||
|
centroidAssignmentScorer = createCentroidScorer(centroidInput, numCentroids, fieldInfo, calculatedGlobalCentroid);
|
||||||
|
assert centroidAssignmentScorer.size() == numCentroids;
|
||||||
|
// build a float vector values with random access
|
||||||
|
// build centroids
|
||||||
|
final long[] offsets = buildAndWritePostingsLists(
|
||||||
|
fieldInfo,
|
||||||
|
centroidAssignmentScorer,
|
||||||
|
floatVectorValues,
|
||||||
|
ivfClusters,
|
||||||
|
mergeState
|
||||||
|
);
|
||||||
|
assert offsets.length == centroidAssignmentScorer.size();
|
||||||
|
writeMeta(fieldInfo, centroidOffset, centroidLength, offsets, calculatedGlobalCentroid);
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(
|
||||||
|
mergeState.segmentInfo.dir,
|
||||||
|
tempRawVectorsFileName,
|
||||||
|
centroidTempName
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static FloatVectorValues getFloatVectorValues(FieldInfo fieldInfo, IndexInput randomAccessInput, int numVectors) {
|
||||||
|
if (numVectors == 0) {
|
||||||
|
return FloatVectorValues.fromFloats(List.of(), fieldInfo.getVectorDimension());
|
||||||
|
}
|
||||||
|
final long length = (long) Float.BYTES * fieldInfo.getVectorDimension() + Integer.BYTES;
|
||||||
|
final float[] vector = new float[fieldInfo.getVectorDimension()];
|
||||||
|
return new FloatVectorValues() {
|
||||||
|
@Override
|
||||||
|
public float[] vectorValue(int ord) throws IOException {
|
||||||
|
randomAccessInput.seek(ord * length + Integer.BYTES);
|
||||||
|
randomAccessInput.readFloats(vector, 0, vector.length);
|
||||||
|
return vector;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FloatVectorValues copy() {
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int dimension() {
|
||||||
|
return fieldInfo.getVectorDimension();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int size() {
|
||||||
|
return numVectors;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int ordToDoc(int ord) {
|
||||||
|
try {
|
||||||
|
randomAccessInput.seek(ord * length);
|
||||||
|
return randomAccessInput.readInt();
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new UncheckedIOException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
private static int writeFloatVectorValues(FieldInfo fieldInfo, IndexOutput out, FloatVectorValues floatVectorValues)
|
||||||
|
throws IOException {
|
||||||
|
int numVectors = 0;
|
||||||
|
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||||
|
final KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator();
|
||||||
|
for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) {
|
||||||
|
numVectors++;
|
||||||
|
float[] vector = floatVectorValues.vectorValue(iterator.index());
|
||||||
|
out.writeInt(iterator.docID());
|
||||||
|
buffer.asFloatBuffer().put(vector);
|
||||||
|
out.writeBytes(buffer.array(), buffer.array().length);
|
||||||
|
}
|
||||||
|
return numVectors;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void writeMeta(FieldInfo field, long centroidOffset, long centroidLength, long[] offsets, float[] globalCentroid)
|
||||||
|
throws IOException {
|
||||||
|
ivfMeta.writeInt(field.number);
|
||||||
|
ivfMeta.writeInt(field.getVectorEncoding().ordinal());
|
||||||
|
ivfMeta.writeInt(distFuncToOrd(field.getVectorSimilarityFunction()));
|
||||||
|
ivfMeta.writeLong(centroidOffset);
|
||||||
|
ivfMeta.writeLong(centroidLength);
|
||||||
|
ivfMeta.writeVInt(offsets.length);
|
||||||
|
for (long offset : offsets) {
|
||||||
|
ivfMeta.writeLong(offset);
|
||||||
|
}
|
||||||
|
if (offsets.length > 0) {
|
||||||
|
final ByteBuffer buffer = ByteBuffer.allocate(globalCentroid.length * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||||
|
buffer.asFloatBuffer().put(globalCentroid);
|
||||||
|
ivfMeta.writeBytes(buffer.array(), buffer.array().length);
|
||||||
|
ivfMeta.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(globalCentroid, globalCentroid)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static int distFuncToOrd(VectorSimilarityFunction func) {
|
||||||
|
for (int i = 0; i < SIMILARITY_FUNCTIONS.size(); i++) {
|
||||||
|
if (SIMILARITY_FUNCTIONS.get(i).equals(func)) {
|
||||||
|
return (byte) i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw new IllegalArgumentException("invalid distance function: " + func);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final void finish() throws IOException {
|
||||||
|
rawVectorDelegate.finish();
|
||||||
|
if (ivfMeta != null) {
|
||||||
|
// write end of fields marker
|
||||||
|
ivfMeta.writeInt(-1);
|
||||||
|
CodecUtil.writeFooter(ivfMeta);
|
||||||
|
}
|
||||||
|
if (ivfCentroids != null) {
|
||||||
|
CodecUtil.writeFooter(ivfCentroids);
|
||||||
|
}
|
||||||
|
if (ivfClusters != null) {
|
||||||
|
CodecUtil.writeFooter(ivfClusters);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final void close() throws IOException {
|
||||||
|
IOUtils.close(rawVectorDelegate, ivfMeta, ivfCentroids, ivfClusters);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final long ramBytesUsed() {
|
||||||
|
return rawVectorDelegate.ramBytesUsed();
|
||||||
|
}
|
||||||
|
|
||||||
|
private record FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter<float[]> delegate) {}
|
||||||
|
|
||||||
|
interface CentroidAssignmentScorer {
|
||||||
|
CentroidAssignmentScorer EMPTY = new CentroidAssignmentScorer() {
|
||||||
|
@Override
|
||||||
|
public int size() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float[] centroid(int centroidOrdinal) {
|
||||||
|
throw new IllegalStateException("No centroids");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float score(int centroidOrdinal) {
|
||||||
|
throw new IllegalStateException("No centroids");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setScoringVector(float[] vector) {
|
||||||
|
throw new IllegalStateException("No centroids");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
int size();
|
||||||
|
|
||||||
|
float[] centroid(int centroidOrdinal) throws IOException;
|
||||||
|
|
||||||
|
void setScoringVector(float[] vector);
|
||||||
|
|
||||||
|
float score(int centroidOrdinal) throws IOException;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,159 @@
|
||||||
|
/*
|
||||||
|
* @notice
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
* Modifications copyright (C) 2025 Elasticsearch B.V.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.index.codec.vectors;
|
||||||
|
|
||||||
|
import org.apache.lucene.util.LongHeap;
|
||||||
|
import org.apache.lucene.util.NumericUtils;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Copied from and modified from Apache Lucene.
|
||||||
|
*/
|
||||||
|
class NeighborQueue {
|
||||||
|
|
||||||
|
private enum Order {
|
||||||
|
MIN_HEAP {
|
||||||
|
@Override
|
||||||
|
long apply(long v) {
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
MAX_HEAP {
|
||||||
|
@Override
|
||||||
|
long apply(long v) {
|
||||||
|
// This cannot be just `-v` since Long.MIN_VALUE doesn't have a positive counterpart. It
|
||||||
|
// needs a function that returns MAX_VALUE for MIN_VALUE and vice-versa.
|
||||||
|
return -1 - v;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
abstract long apply(long v);
|
||||||
|
}
|
||||||
|
|
||||||
|
private final LongHeap heap;
|
||||||
|
private final Order order;
|
||||||
|
|
||||||
|
NeighborQueue(int initialSize, boolean maxHeap) {
|
||||||
|
this.heap = new LongHeap(initialSize);
|
||||||
|
this.order = maxHeap ? Order.MAX_HEAP : Order.MIN_HEAP;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return the number of elements in the heap
|
||||||
|
*/
|
||||||
|
public int size() {
|
||||||
|
return heap.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds a new graph arc, extending the storage as needed.
|
||||||
|
*
|
||||||
|
* @param newNode the neighbor node id
|
||||||
|
* @param newScore the score of the neighbor, relative to some other node
|
||||||
|
*/
|
||||||
|
public void add(int newNode, float newScore) {
|
||||||
|
heap.push(encode(newNode, newScore));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* If the heap is not full (size is less than the initialSize provided to the constructor), adds a
|
||||||
|
* new node-and-score element. If the heap is full, compares the score against the current top
|
||||||
|
* score, and replaces the top element if newScore is better than (greater than unless the heap is
|
||||||
|
* reversed), the current top score.
|
||||||
|
*
|
||||||
|
* @param newNode the neighbor node id
|
||||||
|
* @param newScore the score of the neighbor, relative to some other node
|
||||||
|
*/
|
||||||
|
public boolean insertWithOverflow(int newNode, float newScore) {
|
||||||
|
return heap.insertWithOverflow(encode(newNode, newScore));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Encodes the node ID and its similarity score as long, preserving the Lucene tie-breaking rule
|
||||||
|
* that when two scores are equal, the smaller node ID must win.
|
||||||
|
* @param node the node ID
|
||||||
|
* @param score the node score
|
||||||
|
* @return the encoded score, node ID
|
||||||
|
*/
|
||||||
|
private long encode(int node, float score) {
|
||||||
|
return order.apply((((long) NumericUtils.floatToSortableInt(score)) << 32) | (0xFFFFFFFFL & ~node));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns the top element's node id. */
|
||||||
|
int topNode() {
|
||||||
|
return decodeNodeId(heap.top());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the top element's node score. For the min heap this is the minimum score. For the max
|
||||||
|
* heap this is the maximum score.
|
||||||
|
*/
|
||||||
|
float topScore() {
|
||||||
|
return decodeScore(heap.top());
|
||||||
|
}
|
||||||
|
|
||||||
|
private float decodeScore(long heapValue) {
|
||||||
|
return NumericUtils.sortableIntToFloat((int) (order.apply(heapValue) >> 32));
|
||||||
|
}
|
||||||
|
|
||||||
|
private int decodeNodeId(long heapValue) {
|
||||||
|
return (int) ~(order.apply(heapValue));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Removes the top element and returns its node id. */
|
||||||
|
public int pop() {
|
||||||
|
return decodeNodeId(heap.pop());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void consumeNodes(int[] dest) {
|
||||||
|
if (dest.length < size()) {
|
||||||
|
throw new IllegalArgumentException("Destination array is too small. Expected at least " + size() + " elements.");
|
||||||
|
}
|
||||||
|
for (int i = 0; i < size(); i++) {
|
||||||
|
dest[i] = decodeNodeId(heap.get(i + 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public int consumeNodesAndScoresMin(int[] dest, float[] scores) {
|
||||||
|
if (dest.length < size() || scores.length < size()) {
|
||||||
|
throw new IllegalArgumentException("Destination array is too small. Expected at least " + size() + " elements.");
|
||||||
|
}
|
||||||
|
float bestScore = Float.POSITIVE_INFINITY;
|
||||||
|
int bestIdx = 0;
|
||||||
|
for (int i = 0; i < size(); i++) {
|
||||||
|
long heapValue = heap.get(i + 1);
|
||||||
|
scores[i] = decodeScore(heapValue);
|
||||||
|
dest[i] = decodeNodeId(heapValue);
|
||||||
|
if (scores[i] < bestScore) {
|
||||||
|
bestScore = scores[i];
|
||||||
|
bestIdx = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return bestIdx;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void clear() {
|
||||||
|
heap.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "Neighbors[" + heap.size() + "]";
|
||||||
|
}
|
||||||
|
}
|
|
@ -7,3 +7,4 @@ org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat
|
||||||
org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat
|
org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat
|
||||||
org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat
|
org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat
|
||||||
org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat
|
org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat
|
||||||
|
org.elasticsearch.index.codec.vectors.IVFVectorsFormat
|
||||||
|
|
|
@ -0,0 +1,65 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the "Elastic License
|
||||||
|
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||||
|
* Public License v 1"; you may not use this file except in compliance with, at
|
||||||
|
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||||
|
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.index.codec.vectors;
|
||||||
|
|
||||||
|
import com.carrotsearch.randomizedtesting.generators.RandomPicks;
|
||||||
|
|
||||||
|
import org.apache.lucene.codecs.Codec;
|
||||||
|
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
|
||||||
|
import org.apache.lucene.tests.util.TestUtil;
|
||||||
|
import org.elasticsearch.common.logging.LogConfigurator;
|
||||||
|
import org.junit.Before;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class IVFVectorsFormatTests extends BaseKnnVectorsFormatTestCase {
|
||||||
|
|
||||||
|
static {
|
||||||
|
LogConfigurator.loadLog4jPlugins();
|
||||||
|
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
|
||||||
|
}
|
||||||
|
KnnVectorsFormat format;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
@Override
|
||||||
|
public void setUp() throws Exception {
|
||||||
|
format = new IVFVectorsFormat(random().nextInt(10, 1000));
|
||||||
|
super.setUp();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected VectorSimilarityFunction randomSimilarity() {
|
||||||
|
return RandomPicks.randomFrom(
|
||||||
|
random(),
|
||||||
|
List.of(
|
||||||
|
VectorSimilarityFunction.DOT_PRODUCT,
|
||||||
|
VectorSimilarityFunction.EUCLIDEAN,
|
||||||
|
VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected VectorEncoding randomVectorEncoding() {
|
||||||
|
return VectorEncoding.FLOAT32;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void testSearchWithVisitedLimit() {
|
||||||
|
// ivf doesn't enforce visitation limit
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Codec getCodec() {
|
||||||
|
return TestUtil.alwaysKnnVectorsFormat(format);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,119 @@
|
||||||
|
/*
|
||||||
|
* @notice
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
* Modifications copyright (C) 2025 Elasticsearch B.V.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.index.codec.vectors;
|
||||||
|
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* copied and modified from Lucene
|
||||||
|
*/
|
||||||
|
public class NeighborQueueTests extends ESTestCase {
|
||||||
|
public void testNeighborsProduct() {
|
||||||
|
// make sure we have the sign correct
|
||||||
|
NeighborQueue nn = new NeighborQueue(2, false);
|
||||||
|
assertTrue(nn.insertWithOverflow(2, 0.5f));
|
||||||
|
assertTrue(nn.insertWithOverflow(1, 0.2f));
|
||||||
|
assertTrue(nn.insertWithOverflow(3, 1f));
|
||||||
|
assertEquals(0.5f, nn.topScore(), 0);
|
||||||
|
nn.pop();
|
||||||
|
assertEquals(1f, nn.topScore(), 0);
|
||||||
|
nn.pop();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testNeighborsMaxHeap() {
|
||||||
|
NeighborQueue nn = new NeighborQueue(2, true);
|
||||||
|
assertTrue(nn.insertWithOverflow(2, 2));
|
||||||
|
assertTrue(nn.insertWithOverflow(1, 1));
|
||||||
|
assertFalse(nn.insertWithOverflow(3, 3));
|
||||||
|
assertEquals(2f, nn.topScore(), 0);
|
||||||
|
nn.pop();
|
||||||
|
assertEquals(1f, nn.topScore(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testTopMaxHeap() {
|
||||||
|
NeighborQueue nn = new NeighborQueue(2, true);
|
||||||
|
nn.add(1, 2);
|
||||||
|
nn.add(2, 1);
|
||||||
|
// lower scores are better; highest score on top
|
||||||
|
assertEquals(2, nn.topScore(), 0);
|
||||||
|
assertEquals(1, nn.topNode());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testTopMinHeap() {
|
||||||
|
NeighborQueue nn = new NeighborQueue(2, false);
|
||||||
|
nn.add(1, 0.5f);
|
||||||
|
nn.add(2, -0.5f);
|
||||||
|
// higher scores are better; lowest score on top
|
||||||
|
assertEquals(-0.5f, nn.topScore(), 0);
|
||||||
|
assertEquals(2, nn.topNode());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testClear() {
|
||||||
|
NeighborQueue nn = new NeighborQueue(2, false);
|
||||||
|
nn.add(1, 1.1f);
|
||||||
|
nn.add(2, -2.2f);
|
||||||
|
nn.clear();
|
||||||
|
|
||||||
|
assertEquals(0, nn.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMaxSizeQueue() {
|
||||||
|
NeighborQueue nn = new NeighborQueue(2, false);
|
||||||
|
nn.add(1, 1);
|
||||||
|
nn.add(2, 2);
|
||||||
|
assertEquals(2, nn.size());
|
||||||
|
assertEquals(1, nn.topNode());
|
||||||
|
|
||||||
|
// insertWithOverflow does not extend the queue
|
||||||
|
nn.insertWithOverflow(3, 3);
|
||||||
|
assertEquals(2, nn.size());
|
||||||
|
assertEquals(2, nn.topNode());
|
||||||
|
|
||||||
|
// add does extend the queue beyond maxSize
|
||||||
|
nn.add(4, 1);
|
||||||
|
assertEquals(3, nn.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testUnboundedQueue() {
|
||||||
|
NeighborQueue nn = new NeighborQueue(1, true);
|
||||||
|
float maxScore = -2;
|
||||||
|
int maxNode = -1;
|
||||||
|
for (int i = 0; i < 256; i++) {
|
||||||
|
// initial size is 32
|
||||||
|
float score = random().nextFloat();
|
||||||
|
if (score > maxScore) {
|
||||||
|
maxScore = score;
|
||||||
|
maxNode = i;
|
||||||
|
}
|
||||||
|
nn.add(i, score);
|
||||||
|
}
|
||||||
|
assertEquals(maxScore, nn.topScore(), 0);
|
||||||
|
assertEquals(maxNode, nn.topNode());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testInvalidArguments() {
|
||||||
|
expectThrows(IllegalArgumentException.class, () -> new NeighborQueue(0, false));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testToString() {
|
||||||
|
assertEquals("Neighbors[0]", new NeighborQueue(2, false).toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue