diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OptimizedScalarQuantizerBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OptimizedScalarQuantizerBenchmark.java new file mode 100644 index 000000000000..4fa0a1f95495 --- /dev/null +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OptimizedScalarQuantizerBenchmark.java @@ -0,0 +1,78 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.benchmark.vector; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; + +@BenchmarkMode(Mode.Throughput) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Benchmark) +@Warmup(iterations = 3, time = 1) +@Measurement(iterations = 5, time = 1) +@Fork(value = 3) +public class OptimizedScalarQuantizerBenchmark { + static { + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + @Param({ "384", "702", "1024" }) + int dims; + + float[] vector; + float[] centroid; + byte[] destination; + + @Param({ "1", "4", "7" }) + byte bits; + + OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(VectorSimilarityFunction.DOT_PRODUCT); + + @Setup(Level.Iteration) + public void init() { + ThreadLocalRandom random = ThreadLocalRandom.current(); + // random byte arrays for binary methods + destination = new byte[dims]; + vector = new float[dims]; + centroid = new float[dims]; + for (int i = 0; i < dims; ++i) { + vector[i] = random.nextFloat(); + centroid[i] = random.nextFloat(); + } + } + + @Benchmark + public byte[] scalar() { + osq.scalarQuantize(vector, destination, bits, centroid); + return destination; + } + + @Benchmark + @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public byte[] vector() { + osq.scalarQuantize(vector, destination, bits, centroid); + return destination; + } +} diff --git a/docs/changelog/127118.yaml b/docs/changelog/127118.yaml new file mode 100644 index 000000000000..cf3bd807d4a2 --- /dev/null +++ b/docs/changelog/127118.yaml @@ -0,0 +1,5 @@ +pr: 127118 +summary: Panama vector accelerated optimized scalar quantization +area: Vector Search +type: enhancement +issues: [] diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java index c5b6791724c6..41bf6ff58d14 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java @@ -144,4 +144,71 @@ public class ESVectorUtil { } return distance; } + + /** + * Calculate the loss for optimized-scalar quantization for the given parameteres + * @param target The vector being quantized, assumed to be centered + * @param interval The interval for which to calculate the loss + * @param points the quantization points + * @param norm2 The norm squared of the target vector + * @param lambda The lambda parameter for controlling anisotropic loss calculation + * @return The loss for the given parameters + */ + public static float calculateOSQLoss(float[] target, float[] interval, int points, float norm2, float lambda) { + assert interval.length == 2; + float step = ((interval[1] - interval[0]) / (points - 1.0F)); + float invStep = 1f / step; + return IMPL.calculateOSQLoss(target, interval, step, invStep, norm2, lambda); + } + + /** + * Calculate the grid points for optimized-scalar quantization + * @param target The vector being quantized, assumed to be centered + * @param interval The interval for which to calculate the grid points + * @param points the quantization points + * @param pts The array to store the grid points, must be of length 5 + */ + public static void calculateOSQGridPoints(float[] target, float[] interval, int points, float[] pts) { + assert interval.length == 2; + assert pts.length == 5; + float invStep = (points - 1.0F) / (interval[1] - interval[0]); + IMPL.calculateOSQGridPoints(target, interval, points, invStep, pts); + } + + /** + * Center the target vector and calculate the optimized-scalar quantization statistics + * @param target The vector being quantized + * @param centroid The centroid of the target vector + * @param centered The destination of the centered vector, will be overwritten + * @param stats The array to store the statistics, must be of length 5 + */ + public static void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats) { + assert target.length == centroid.length; + assert stats.length == 5; + if (target.length != centroid.length) { + throw new IllegalArgumentException("vector dimensions differ: " + target.length + "!=" + centroid.length); + } + if (centered.length != target.length) { + throw new IllegalArgumentException("vector dimensions differ: " + centered.length + "!=" + target.length); + } + IMPL.centerAndCalculateOSQStatsEuclidean(target, centroid, centered, stats); + } + + /** + * Center the target vector and calculate the optimized-scalar quantization statistics + * @param target The vector being quantized + * @param centroid The centroid of the target vector + * @param centered The destination of the centered vector, will be overwritten + * @param stats The array to store the statistics, must be of length 6 + */ + public static void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats) { + if (target.length != centroid.length) { + throw new IllegalArgumentException("vector dimensions differ: " + target.length + "!=" + centroid.length); + } + if (centered.length != target.length) { + throw new IllegalArgumentException("vector dimensions differ: " + centered.length + "!=" + target.length); + } + assert stats.length == 6; + IMPL.centerAndCalculateOSQStatsDp(target, centroid, centered, stats); + } } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java index 90226c3b0e94..ce8fce7e68b7 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java @@ -44,6 +44,100 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport { return ipFloatByteImpl(q, d); } + @Override + public float calculateOSQLoss(float[] target, float[] interval, float step, float invStep, float norm2, float lambda) { + float a = interval[0]; + float b = interval[1]; + float xe = 0f; + float e = 0f; + for (float xi : target) { + // this is quantizing and then dequantizing the vector + float xiq = fma(step, Math.round((Math.min(Math.max(xi, a), b) - a) * invStep), a); + // how much does the de-quantized value differ from the original value + float xiiq = xi - xiq; + e = fma(xiiq, xiiq, e); + xe = fma(xi, xiiq, xe); + } + return (1f - lambda) * xe * xe / norm2 + lambda * e; + } + + @Override + public void calculateOSQGridPoints(float[] target, float[] interval, int points, float invStep, float[] pts) { + float a = interval[0]; + float b = interval[1]; + float daa = 0; + float dab = 0; + float dbb = 0; + float dax = 0; + float dbx = 0; + for (float v : target) { + float k = Math.round((Math.min(Math.max(v, a), b) - a) * invStep); + float s = k / (points - 1); + float ms = 1f - s; + daa = fma(ms, ms, daa); + dab = fma(ms, s, dab); + dbb = fma(s, s, dbb); + dax = fma(ms, v, dax); + dbx = fma(s, v, dbx); + } + pts[0] = daa; + pts[1] = dab; + pts[2] = dbb; + pts[3] = dax; + pts[4] = dbx; + } + + @Override + public void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats) { + float vecMean = 0; + float vecVar = 0; + float norm2 = 0; + float min = Float.MAX_VALUE; + float max = -Float.MAX_VALUE; + for (int i = 0; i < target.length; i++) { + centered[i] = target[i] - centroid[i]; + min = Math.min(min, centered[i]); + max = Math.max(max, centered[i]); + norm2 = fma(centered[i], centered[i], norm2); + float delta = centered[i] - vecMean; + vecMean += delta / (i + 1); + float delta2 = centered[i] - vecMean; + vecVar = fma(delta, delta2, vecVar); + } + stats[0] = vecMean; + stats[1] = vecVar / target.length; + stats[2] = norm2; + stats[3] = min; + stats[4] = max; + } + + @Override + public void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats) { + float vecMean = 0; + float vecVar = 0; + float norm2 = 0; + float centroidDot = 0; + float min = Float.MAX_VALUE; + float max = -Float.MAX_VALUE; + for (int i = 0; i < target.length; i++) { + centroidDot = fma(target[i], centroid[i], centroidDot); + centered[i] = target[i] - centroid[i]; + min = Math.min(min, centered[i]); + max = Math.max(max, centered[i]); + norm2 = fma(centered[i], centered[i], norm2); + float delta = centered[i] - vecMean; + vecMean += delta / (i + 1); + float delta2 = centered[i] - vecMean; + vecVar = fma(delta, delta2, vecVar); + } + stats[0] = vecMean; + stats[1] = vecVar / target.length; + stats[2] = norm2; + stats[3] = min; + stats[4] = max; + stats[5] = centroidDot; + } + public static int ipByteBitImpl(byte[] q, byte[] d) { return ipByteBitImpl(q, d, 0); } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java index fec073f04383..b2615c55e64e 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java @@ -20,4 +20,12 @@ public interface ESVectorUtilSupport { float ipFloatBit(float[] q, byte[] d); float ipFloatByte(float[] q, byte[] d); + + float calculateOSQLoss(float[] target, float[] interval, float step, float invStep, float norm2, float lambda); + + void calculateOSQGridPoints(float[] target, float[] interval, int points, float invStep, float[] pts); + + void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats); + + void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats); } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java index 53f5924cbe72..cf856c5322f0 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java @@ -13,6 +13,7 @@ import jdk.incubator.vector.ByteVector; import jdk.incubator.vector.FloatVector; import jdk.incubator.vector.IntVector; import jdk.incubator.vector.LongVector; +import jdk.incubator.vector.Vector; import jdk.incubator.vector.VectorMask; import jdk.incubator.vector.VectorOperators; import jdk.incubator.vector.VectorShape; @@ -21,16 +22,22 @@ import jdk.incubator.vector.VectorSpecies; import org.apache.lucene.util.BitUtil; import org.apache.lucene.util.Constants; +import static jdk.incubator.vector.VectorOperators.ADD; +import static jdk.incubator.vector.VectorOperators.MAX; +import static jdk.incubator.vector.VectorOperators.MIN; + public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport { static final int VECTOR_BITSIZE; + private static final VectorSpecies FLOAT_SPECIES; /** Whether integer vectors can be trusted to actually be fast. */ static final boolean HAS_FAST_INTEGER_VECTORS; static { // default to platform supported bitsize VECTOR_BITSIZE = VectorShape.preferredShape().vectorBitSize(); + FLOAT_SPECIES = VectorSpecies.of(float.class, VectorShape.forBitSize(VECTOR_BITSIZE)); // hotspot misses some SSE intrinsics, workaround it // to be fair, they do document this thing only works well with AVX2/AVX3 and Neon @@ -38,6 +45,22 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport { HAS_FAST_INTEGER_VECTORS = isAMD64withoutAVX2 == false; } + private static FloatVector fma(FloatVector a, FloatVector b, FloatVector c) { + if (Constants.HAS_FAST_VECTOR_FMA) { + return a.fma(b, c); + } else { + return a.mul(b).add(c); + } + } + + private static float fma(float a, float b, float c) { + if (Constants.HAS_FAST_SCALAR_FMA) { + return Math.fma(a, b, c); + } else { + return a * b + c; + } + } + @Override public long ipByteBinByte(byte[] q, byte[] d) { // 128 / 8 == 16 @@ -83,6 +106,267 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport { return DefaultESVectorUtilSupport.ipFloatByteImpl(q, d); } + @Override + public void centerAndCalculateOSQStatsEuclidean(float[] vector, float[] centroid, float[] centered, float[] stats) { + assert vector.length == centroid.length; + assert vector.length == centered.length; + float vecMean = 0; + float vecVar = 0; + float norm2 = 0; + float min = Float.MAX_VALUE; + float max = -Float.MAX_VALUE; + int i = 0; + int vectCount = 0; + if (vector.length > 2 * FLOAT_SPECIES.length()) { + FloatVector vecMeanVec = FloatVector.zero(FLOAT_SPECIES); + FloatVector m2Vec = FloatVector.zero(FLOAT_SPECIES); + FloatVector norm2Vec = FloatVector.zero(FLOAT_SPECIES); + FloatVector minVec = FloatVector.broadcast(FLOAT_SPECIES, Float.MAX_VALUE); + FloatVector maxVec = FloatVector.broadcast(FLOAT_SPECIES, -Float.MAX_VALUE); + int count = 0; + for (; i < FLOAT_SPECIES.loopBound(vector.length); i += FLOAT_SPECIES.length()) { + ++count; + FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, vector, i); + FloatVector c = FloatVector.fromArray(FLOAT_SPECIES, centroid, i); + FloatVector centeredVec = v.sub(c); + FloatVector deltaVec = centeredVec.sub(vecMeanVec); + norm2Vec = fma(centeredVec, centeredVec, norm2Vec); + vecMeanVec = vecMeanVec.add(deltaVec.div(count)); + FloatVector delta2Vec = centeredVec.sub(vecMeanVec); + m2Vec = fma(deltaVec, delta2Vec, m2Vec); + minVec = minVec.min(centeredVec); + maxVec = maxVec.max(centeredVec); + centeredVec.intoArray(centered, i); + } + min = minVec.reduceLanes(MIN); + max = maxVec.reduceLanes(MAX); + norm2 = norm2Vec.reduceLanes(ADD); + vecMean = vecMeanVec.reduceLanes(ADD) / FLOAT_SPECIES.length(); + FloatVector d2Mean = vecMeanVec.sub(vecMean); + m2Vec = fma(d2Mean, d2Mean, m2Vec); + vectCount = count * FLOAT_SPECIES.length(); + vecVar = m2Vec.reduceLanes(ADD); + } + + float tailMean = 0; + float tailM2 = 0; + int tailCount = 0; + // handle the tail + for (; i < vector.length; i++) { + centered[i] = vector[i] - centroid[i]; + float delta = centered[i] - tailMean; + ++tailCount; + tailMean += delta / tailCount; + float delta2 = centered[i] - tailMean; + tailM2 = fma(delta, delta2, tailM2); + min = Math.min(min, centered[i]); + max = Math.max(max, centered[i]); + norm2 = fma(centered[i], centered[i], norm2); + } + if (vectCount == 0) { + vecMean = tailMean; + vecVar = tailM2; + } else if (tailCount > 0) { + int totalCount = tailCount + vectCount; + assert totalCount == vector.length; + float alpha = (float) vectCount / totalCount; + float beta = 1f - alpha; + float completeMean = alpha * vecMean + beta * tailMean; + float dMean2Lhs = (vecMean - completeMean) * (vecMean - completeMean); + float dMean2Rhs = (tailMean - completeMean) * (tailMean - completeMean); + vecVar = (vecVar + dMean2Lhs) + beta * (tailM2 + dMean2Rhs); + vecMean = completeMean; + } + stats[0] = vecMean; + stats[1] = vecVar / vector.length; + stats[2] = norm2; + stats[3] = min; + stats[4] = max; + } + + @Override + public void centerAndCalculateOSQStatsDp(float[] vector, float[] centroid, float[] centered, float[] stats) { + assert vector.length == centroid.length; + assert vector.length == centered.length; + float vecMean = 0; + float vecVar = 0; + float norm2 = 0; + float min = Float.MAX_VALUE; + float max = -Float.MAX_VALUE; + float centroidDot = 0; + int i = 0; + int vectCount = 0; + int loopBound = FLOAT_SPECIES.loopBound(vector.length); + if (vector.length > 2 * FLOAT_SPECIES.length()) { + FloatVector vecMeanVec = FloatVector.zero(FLOAT_SPECIES); + FloatVector m2Vec = FloatVector.zero(FLOAT_SPECIES); + FloatVector norm2Vec = FloatVector.zero(FLOAT_SPECIES); + FloatVector minVec = FloatVector.broadcast(FLOAT_SPECIES, Float.MAX_VALUE); + FloatVector maxVec = FloatVector.broadcast(FLOAT_SPECIES, -Float.MAX_VALUE); + FloatVector centroidDotVec = FloatVector.zero(FLOAT_SPECIES); + int count = 0; + for (; i < loopBound; i += FLOAT_SPECIES.length()) { + ++count; + FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, vector, i); + FloatVector c = FloatVector.fromArray(FLOAT_SPECIES, centroid, i); + centroidDotVec = fma(v, c, centroidDotVec); + FloatVector centeredVec = v.sub(c); + FloatVector deltaVec = centeredVec.sub(vecMeanVec); + norm2Vec = fma(centeredVec, centeredVec, norm2Vec); + vecMeanVec = vecMeanVec.add(deltaVec.div(count)); + FloatVector delta2Vec = centeredVec.sub(vecMeanVec); + m2Vec = fma(deltaVec, delta2Vec, m2Vec); + minVec = minVec.min(centeredVec); + maxVec = maxVec.max(centeredVec); + centeredVec.intoArray(centered, i); + } + min = minVec.reduceLanes(MIN); + max = maxVec.reduceLanes(MAX); + norm2 = norm2Vec.reduceLanes(ADD); + centroidDot = centroidDotVec.reduceLanes(ADD); + vecMean = vecMeanVec.reduceLanes(ADD) / FLOAT_SPECIES.length(); + FloatVector d2Mean = vecMeanVec.sub(vecMean); + m2Vec = fma(d2Mean, d2Mean, m2Vec); + vectCount = count * FLOAT_SPECIES.length(); + vecVar = m2Vec.reduceLanes(ADD); + } + + float tailMean = 0; + float tailM2 = 0; + int tailCount = 0; + // handle the tail + for (; i < vector.length; i++) { + centroidDot = fma(vector[i], centroid[i], centroidDot); + centered[i] = vector[i] - centroid[i]; + float delta = centered[i] - tailMean; + ++tailCount; + tailMean += delta / tailCount; + float delta2 = centered[i] - tailMean; + tailM2 = fma(delta, delta2, tailM2); + min = Math.min(min, centered[i]); + max = Math.max(max, centered[i]); + norm2 = fma(centered[i], centered[i], norm2); + } + if (vectCount == 0) { + vecMean = tailMean; + vecVar = tailM2; + } else if (tailCount > 0) { + int totalCount = tailCount + vectCount; + assert totalCount == vector.length; + float alpha = (float) vectCount / totalCount; + float beta = 1f - alpha; + float completeMean = alpha * vecMean + beta * tailMean; + float dMean2Lhs = (vecMean - completeMean) * (vecMean - completeMean); + float dMean2Rhs = (tailMean - completeMean) * (tailMean - completeMean); + vecVar = (vecVar + dMean2Lhs) + beta * (tailM2 + dMean2Rhs); + vecMean = completeMean; + } + stats[0] = vecMean; + stats[1] = vecVar / vector.length; + stats[2] = norm2; + stats[3] = min; + stats[4] = max; + stats[5] = centroidDot; + } + + @Override + public void calculateOSQGridPoints(float[] target, float[] interval, int points, float invStep, float[] pts) { + float a = interval[0]; + float b = interval[1]; + int i = 0; + float daa = 0; + float dab = 0; + float dbb = 0; + float dax = 0; + float dbx = 0; + + FloatVector daaVec = FloatVector.zero(FLOAT_SPECIES); + FloatVector dabVec = FloatVector.zero(FLOAT_SPECIES); + FloatVector dbbVec = FloatVector.zero(FLOAT_SPECIES); + FloatVector daxVec = FloatVector.zero(FLOAT_SPECIES); + FloatVector dbxVec = FloatVector.zero(FLOAT_SPECIES); + + // if the array size is large (> 2x platform vector size), it's worth the overhead to vectorize + if (target.length > 2 * FLOAT_SPECIES.length()) { + FloatVector ones = FloatVector.broadcast(FLOAT_SPECIES, 1f); + FloatVector pmOnes = FloatVector.broadcast(FLOAT_SPECIES, points - 1f); + for (; i < FLOAT_SPECIES.loopBound(target.length); i += FLOAT_SPECIES.length()) { + FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, target, i); + FloatVector vClamped = v.max(a).min(b); + Vector xiqint = vClamped.sub(a) + .mul(invStep) + // round + .add(0.5f) + .convert(VectorOperators.F2I, 0); + FloatVector kVec = xiqint.convert(VectorOperators.I2F, 0).reinterpretAsFloats(); + FloatVector sVec = kVec.div(pmOnes); + FloatVector smVec = ones.sub(sVec); + daaVec = fma(smVec, smVec, daaVec); + dabVec = fma(smVec, sVec, dabVec); + dbbVec = fma(sVec, sVec, dbbVec); + daxVec = fma(v, smVec, daxVec); + dbxVec = fma(v, sVec, dbxVec); + } + daa = daaVec.reduceLanes(ADD); + dab = dabVec.reduceLanes(ADD); + dbb = dbbVec.reduceLanes(ADD); + dax = daxVec.reduceLanes(ADD); + dbx = dbxVec.reduceLanes(ADD); + } + + for (; i < target.length; i++) { + float k = Math.round((Math.min(Math.max(target[i], a), b) - a) * invStep); + float s = k / (points - 1); + float ms = 1f - s; + daa = fma(ms, ms, daa); + dab = fma(ms, s, dab); + dbb = fma(s, s, dbb); + dax = fma(ms, target[i], dax); + dbx = fma(s, target[i], dbx); + } + + pts[0] = daa; + pts[1] = dab; + pts[2] = dbb; + pts[3] = dax; + pts[4] = dbx; + } + + @Override + public float calculateOSQLoss(float[] target, float[] interval, float step, float invStep, float norm2, float lambda) { + float a = interval[0]; + float b = interval[1]; + float xe = 0f; + float e = 0f; + FloatVector xeVec = FloatVector.zero(FLOAT_SPECIES); + FloatVector eVec = FloatVector.zero(FLOAT_SPECIES); + int i = 0; + // if the array size is large (> 2x platform vector size), it's worth the overhead to vectorize + if (target.length > 2 * FLOAT_SPECIES.length()) { + for (; i < FLOAT_SPECIES.loopBound(target.length); i += FLOAT_SPECIES.length()) { + FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, target, i); + FloatVector vClamped = v.max(a).min(b); + Vector xiqint = vClamped.sub(a).mul(invStep).add(0.5f).convert(VectorOperators.F2I, 0); + FloatVector xiq = xiqint.convert(VectorOperators.I2F, 0).reinterpretAsFloats().mul(step).add(a); + FloatVector xiiq = v.sub(xiq); + xeVec = fma(v, xiiq, xeVec); + eVec = fma(xiiq, xiiq, eVec); + } + e = eVec.reduceLanes(ADD); + xe = xeVec.reduceLanes(ADD); + } + + for (; i < target.length; i++) { + // this is quantizing and then dequantizing the vector + float xiq = fma(step, Math.round((Math.min(Math.max(target[i], a), b) - a) * invStep), a); + // how much does the de-quantized value differ from the original value + float xiiq = target[i] - xiq; + e = fma(xiiq, xiiq, e); + xe = fma(target[i], xiiq, xe); + } + return (1f - lambda) * xe * xe / norm2 + lambda * e; + } + private static final VectorSpecies BYTE_SPECIES_128 = ByteVector.SPECIES_128; private static final VectorSpecies BYTE_SPECIES_256 = ByteVector.SPECIES_256; diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java index 3b60df19109b..0c99fad2d3d5 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java @@ -9,6 +9,7 @@ package org.elasticsearch.simdvec; +import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; import org.elasticsearch.simdvec.internal.vectorization.BaseVectorizationTests; import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider; @@ -161,6 +162,112 @@ public class ESVectorUtilTests extends BaseVectorizationTests { testIpByteBinImpl(defOrPanamaProvider.getVectorUtilSupport()::ipByteBinByte); } + public void testCenterAndCalculateOSQStatsDp() { + int size = random().nextInt(128, 512); + float delta = 1e-3f * size; + var vector = new float[size]; + var centroid = new float[size]; + for (int i = 0; i < size; ++i) { + vector[i] = random().nextFloat(); + centroid[i] = random().nextFloat(); + } + var centeredLucene = new float[size]; + var statsLucene = new float[6]; + defaultedProvider.getVectorUtilSupport().centerAndCalculateOSQStatsDp(vector, centroid, centeredLucene, statsLucene); + var centeredPanama = new float[size]; + var statsPanama = new float[6]; + defOrPanamaProvider.getVectorUtilSupport().centerAndCalculateOSQStatsDp(vector, centroid, centeredPanama, statsPanama); + assertArrayEquals(centeredLucene, centeredPanama, delta); + assertArrayEquals(statsLucene, statsPanama, delta); + } + + public void testCenterAndCalculateOSQStatsEuclidean() { + int size = random().nextInt(128, 512); + float delta = 1e-3f * size; + var vector = new float[size]; + var centroid = new float[size]; + for (int i = 0; i < size; ++i) { + vector[i] = random().nextFloat(); + centroid[i] = random().nextFloat(); + } + var centeredLucene = new float[size]; + var statsLucene = new float[5]; + defaultedProvider.getVectorUtilSupport().centerAndCalculateOSQStatsEuclidean(vector, centroid, centeredLucene, statsLucene); + var centeredPanama = new float[size]; + var statsPanama = new float[5]; + defOrPanamaProvider.getVectorUtilSupport().centerAndCalculateOSQStatsEuclidean(vector, centroid, centeredPanama, statsPanama); + assertArrayEquals(centeredLucene, centeredPanama, delta); + assertArrayEquals(statsLucene, statsPanama, delta); + } + + public void testOsqLoss() { + int size = random().nextInt(128, 512); + float deltaEps = 1e-5f * size; + var vector = new float[size]; + var min = Float.MAX_VALUE; + var max = -Float.MAX_VALUE; + float vecMean = 0; + float vecVar = 0; + float norm2 = 0; + for (int i = 0; i < size; ++i) { + vector[i] = random().nextFloat(); + min = Math.min(min, vector[i]); + max = Math.max(max, vector[i]); + float delta = vector[i] - vecMean; + vecMean += delta / (i + 1); + float delta2 = vector[i] - vecMean; + vecVar += delta * delta2; + norm2 += vector[i] * vector[i]; + } + vecVar /= size; + float vecStd = (float) Math.sqrt(vecVar); + + for (byte bits : new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 }) { + int points = 1 << bits; + float[] initInterval = new float[2]; + OptimizedScalarQuantizer.initInterval(bits, vecStd, vecMean, min, max, initInterval); + float step = ((initInterval[1] - initInterval[0]) / (points - 1f)); + float stepInv = 1f / step; + float expected = defaultedProvider.getVectorUtilSupport().calculateOSQLoss(vector, initInterval, step, stepInv, norm2, 0.1f); + float result = defOrPanamaProvider.getVectorUtilSupport().calculateOSQLoss(vector, initInterval, step, stepInv, norm2, 0.1f); + assertEquals(expected, result, deltaEps); + } + } + + public void testOsqGridPoints() { + int size = random().nextInt(128, 512); + float deltaEps = 1e-5f * size; + var vector = new float[size]; + var min = Float.MAX_VALUE; + var max = -Float.MAX_VALUE; + float vecMean = 0; + float vecVar = 0; + for (int i = 0; i < size; ++i) { + vector[i] = random().nextFloat(); + min = Math.min(min, vector[i]); + max = Math.max(max, vector[i]); + float delta = vector[i] - vecMean; + vecMean += delta / (i + 1); + float delta2 = vector[i] - vecMean; + vecVar += delta * delta2; + } + vecVar /= size; + float vecStd = (float) Math.sqrt(vecVar); + for (byte bits : new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 }) { + int points = 1 << bits; + float[] initInterval = new float[2]; + OptimizedScalarQuantizer.initInterval(bits, vecStd, vecMean, min, max, initInterval); + float step = ((initInterval[1] - initInterval[0]) / (points - 1f)); + float stepInv = 1f / step; + float[] expected = new float[5]; + defaultedProvider.getVectorUtilSupport().calculateOSQGridPoints(vector, initInterval, points, stepInv, expected); + + float[] result = new float[5]; + defOrPanamaProvider.getVectorUtilSupport().calculateOSQGridPoints(vector, initInterval, points, stepInv, result); + assertArrayEquals(expected, result, deltaEps); + } + } + void testIpByteBinImpl(ToLongBiFunction ipByteBinFunc) { int iterations = atLeast(50); for (int i = 0; i < iterations; i++) { diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizer.java similarity index 60% rename from server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizer.java rename to server/src/main/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizer.java index 31b254b5de56..565e8116edc2 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizer.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizer.java @@ -7,15 +7,21 @@ * License v3.0 only", or the "Server Side Public License, v 1". */ -package org.elasticsearch.index.codec.vectors.es818; +package org.elasticsearch.index.codec.vectors; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.simdvec.ESVectorUtil; import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; -class OptimizedScalarQuantizer { +public class OptimizedScalarQuantizer { + public static void initInterval(byte bits, float vecStd, float vecMean, float min, float max, float[] initInterval) { + initInterval[0] = (float) clamp(MINIMUM_MSE_GRID[bits - 1][0] * vecStd + vecMean, min, max); + initInterval[1] = (float) clamp(MINIMUM_MSE_GRID[bits - 1][1] * vecStd + vecMean, min, max); + } + // The initial interval is set to the minimum MSE grid for each number of bits // these starting points are derived from the optimal MSE grid for a uniform distribution static final float[][] MINIMUM_MSE_GRID = new float[][] { @@ -27,19 +33,25 @@ class OptimizedScalarQuantizer { { -3.278f, 3.278f }, { -3.611f, 3.611f }, { -3.922f, 3.922f } }; - private static final float DEFAULT_LAMBDA = 0.1f; + public static final float DEFAULT_LAMBDA = 0.1f; private static final int DEFAULT_ITERS = 5; private final VectorSimilarityFunction similarityFunction; private final float lambda; private final int iters; + private final float[] statsScratch; + private final float[] gridScratch; + private final float[] intervalScratch; - OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction, float lambda, int iters) { + public OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction, float lambda, int iters) { this.similarityFunction = similarityFunction; this.lambda = lambda; this.iters = iters; + this.statsScratch = new float[similarityFunction == EUCLIDEAN ? 5 : 6]; + this.gridScratch = new float[5]; + this.intervalScratch = new float[2]; } - OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction) { + public OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction) { this(similarityFunction, DEFAULT_LAMBDA, DEFAULT_ITERS); } @@ -49,34 +61,23 @@ class OptimizedScalarQuantizer { assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector); assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid); assert bits.length == destinations.length; - float[] intervalScratch = new float[2]; - double vecMean = 0; - double vecVar = 0; - float norm2 = 0; - float centroidDot = 0; - float min = Float.MAX_VALUE; - float max = -Float.MAX_VALUE; - for (int i = 0; i < vector.length; ++i) { - if (similarityFunction != EUCLIDEAN) { - centroidDot += vector[i] * centroid[i]; - } - vector[i] = vector[i] - centroid[i]; - min = Math.min(min, vector[i]); - max = Math.max(max, vector[i]); - norm2 += (vector[i] * vector[i]); - double delta = vector[i] - vecMean; - vecMean += delta / (i + 1); - vecVar += delta * (vector[i] - vecMean); + if (similarityFunction == EUCLIDEAN) { + ESVectorUtil.centerAndCalculateOSQStatsEuclidean(vector, centroid, vector, statsScratch); + } else { + ESVectorUtil.centerAndCalculateOSQStatsDp(vector, centroid, vector, statsScratch); } - vecVar /= vector.length; - double vecStd = Math.sqrt(vecVar); + float vecMean = statsScratch[0]; + float vecVar = statsScratch[1]; + float norm2 = statsScratch[2]; + float min = statsScratch[3]; + float max = statsScratch[4]; + float vecStd = (float) Math.sqrt(vecVar); QuantizationResult[] results = new QuantizationResult[bits.length]; for (int i = 0; i < bits.length; ++i) { assert bits[i] > 0 && bits[i] <= 8; int points = (1 << bits[i]); // Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds - intervalScratch[0] = (float) clamp(MINIMUM_MSE_GRID[bits[i] - 1][0] * vecStd + vecMean, min, max); - intervalScratch[1] = (float) clamp(MINIMUM_MSE_GRID[bits[i] - 1][1] * vecStd + vecMean, min, max); + initInterval(bits[i], vecStd, vecMean, min, max, intervalScratch); optimizeIntervals(intervalScratch, vector, norm2, points); float nSteps = ((1 << bits[i]) - 1); float a = intervalScratch[0]; @@ -93,7 +94,7 @@ class OptimizedScalarQuantizer { results[i] = new QuantizationResult( intervalScratch[0], intervalScratch[1], - similarityFunction == EUCLIDEAN ? norm2 : centroidDot, + similarityFunction == EUCLIDEAN ? norm2 : statsScratch[5], sumQuery ); } @@ -105,31 +106,20 @@ class OptimizedScalarQuantizer { assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid); assert vector.length <= destination.length; assert bits > 0 && bits <= 8; - float[] intervalScratch = new float[2]; int points = 1 << bits; - double vecMean = 0; - double vecVar = 0; - float norm2 = 0; - float centroidDot = 0; - float min = Float.MAX_VALUE; - float max = -Float.MAX_VALUE; - for (int i = 0; i < vector.length; ++i) { - if (similarityFunction != EUCLIDEAN) { - centroidDot += vector[i] * centroid[i]; - } - vector[i] = vector[i] - centroid[i]; - min = Math.min(min, vector[i]); - max = Math.max(max, vector[i]); - norm2 += (vector[i] * vector[i]); - double delta = vector[i] - vecMean; - vecMean += delta / (i + 1); - vecVar += delta * (vector[i] - vecMean); + if (similarityFunction == EUCLIDEAN) { + ESVectorUtil.centerAndCalculateOSQStatsEuclidean(vector, centroid, vector, statsScratch); + } else { + ESVectorUtil.centerAndCalculateOSQStatsDp(vector, centroid, vector, statsScratch); } - vecVar /= vector.length; - double vecStd = Math.sqrt(vecVar); + float vecMean = statsScratch[0]; + float vecVar = statsScratch[1]; + float norm2 = statsScratch[2]; + float min = statsScratch[3]; + float max = statsScratch[4]; + float vecStd = (float) Math.sqrt(vecVar); // Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds - intervalScratch[0] = (float) clamp(MINIMUM_MSE_GRID[bits - 1][0] * vecStd + vecMean, min, max); - intervalScratch[1] = (float) clamp(MINIMUM_MSE_GRID[bits - 1][1] * vecStd + vecMean, min, max); + initInterval(bits, vecStd, vecMean, min, max, intervalScratch); optimizeIntervals(intervalScratch, vector, norm2, points); float nSteps = ((1 << bits) - 1); // Now we have the optimized intervals, quantize the vector @@ -146,37 +136,11 @@ class OptimizedScalarQuantizer { return new QuantizationResult( intervalScratch[0], intervalScratch[1], - similarityFunction == EUCLIDEAN ? norm2 : centroidDot, + similarityFunction == EUCLIDEAN ? norm2 : statsScratch[5], sumQuery ); } - /** - * Compute the loss of the vector given the interval. Effectively, we are computing the MSE of a dequantized vector with the raw - * vector. - * @param vector raw vector - * @param interval interval to quantize the vector - * @param points number of quantization points - * @param norm2 squared norm of the vector - * @return the loss - */ - private double loss(float[] vector, float[] interval, int points, float norm2) { - double a = interval[0]; - double b = interval[1]; - double step = ((b - a) / (points - 1.0F)); - double stepInv = 1.0 / step; - double xe = 0.0; - double e = 0.0; - for (double xi : vector) { - // this is quantizing and then dequantizing the vector - double xiq = (a + step * Math.round((clamp(xi, a, b) - a) * stepInv)); - // how much does the de-quantized value differ from the original value - xe += xi * (xi - xiq); - e += (xi - xiq) * (xi - xiq); - } - return (1.0 - lambda) * xe * xe / norm2 + lambda * e; - } - /** * Optimize the quantization interval for the given vector. This is done via a coordinate descent trying to minimize the quantization * loss. Note, the loss is not always guaranteed to decrease, so we have a maximum number of iterations and will exit early if the @@ -187,30 +151,19 @@ class OptimizedScalarQuantizer { * @param points number of quantization points */ private void optimizeIntervals(float[] initInterval, float[] vector, float norm2, int points) { - double initialLoss = loss(vector, initInterval, points, norm2); + double initialLoss = ESVectorUtil.calculateOSQLoss(vector, initInterval, points, norm2, lambda); final float scale = (1.0f - lambda) / norm2; if (Float.isFinite(scale) == false) { return; } for (int i = 0; i < iters; ++i) { - float a = initInterval[0]; - float b = initInterval[1]; - float stepInv = (points - 1.0f) / (b - a); // calculate the grid points for coordinate descent - double daa = 0; - double dab = 0; - double dbb = 0; - double dax = 0; - double dbx = 0; - for (float xi : vector) { - float k = Math.round((clamp(xi, a, b) - a) * stepInv); - float s = k / (points - 1); - daa += (1.0 - s) * (1.0 - s); - dab += (1.0 - s) * s; - dbb += s * s; - dax += xi * (1.0 - s); - dbx += xi * s; - } + ESVectorUtil.calculateOSQGridPoints(vector, initInterval, points, gridScratch); + float daa = gridScratch[0]; + float dab = gridScratch[1]; + float dbb = gridScratch[2]; + float dax = gridScratch[3]; + float dbx = gridScratch[4]; double m0 = scale * dax * dax + lambda * daa; double m1 = scale * dax * dbx + lambda * dab; double m2 = scale * dbx * dbx + lambda * dbb; @@ -225,7 +178,7 @@ class OptimizedScalarQuantizer { if ((Math.abs(initInterval[0] - aOpt) < 1e-8 && Math.abs(initInterval[1] - bOpt) < 1e-8)) { return; } - double newLoss = loss(vector, new float[] { aOpt, bOpt }, points, norm2); + double newLoss = ESVectorUtil.calculateOSQLoss(vector, new float[] { aOpt, bOpt }, points, norm2, lambda); // If the new loss is worse, don't update the interval and exit // This optimization, unlike kMeans, does not always converge to better loss // So exit if we are getting worse diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/BinarizedByteVectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/BinarizedByteVectorValues.java index cc1f7b85e0f7..ca80ba52e2c2 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/BinarizedByteVectorValues.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/BinarizedByteVectorValues.java @@ -23,6 +23,7 @@ import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.search.VectorScorer; import org.apache.lucene.util.VectorUtil; import org.elasticsearch.index.codec.vectors.BQVectorUtils; +import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; import java.io.IOException; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java index 87b744b4e4ee..19cd66550699 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java @@ -28,6 +28,7 @@ import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; import org.elasticsearch.index.codec.vectors.BQSpaceUtils; +import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; import org.elasticsearch.simdvec.ESVectorUtil; import java.io.IOException; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormat.java index 1dee9599f985..893640dfb484 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormat.java @@ -26,6 +26,7 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; +import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; import java.io.IOException; @@ -35,7 +36,7 @@ import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_ * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 * Codec for encoding/decoding binary quantized vectors The binary quantization format used here * is a per-vector optimized scalar quantization. Also see {@link - * org.elasticsearch.index.codec.vectors.es818.OptimizedScalarQuantizer}. Some of key features are: + * OptimizedScalarQuantizer}. Some of key features are: * *
    *
  • Estimating the distance between two vectors using their centroid normalized distance. This diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java index 6571b2dfa35b..ac707d155ea3 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java @@ -44,6 +44,7 @@ import org.apache.lucene.util.SuppressForbidden; import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.elasticsearch.index.codec.vectors.BQVectorUtils; +import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; import org.elasticsearch.index.codec.vectors.reflect.OffHeapByteSizeUtils; import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java index 8dba2dbee9f5..7cfa755c2610 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java @@ -49,6 +49,7 @@ import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.index.codec.vectors.BQSpaceUtils; import org.elasticsearch.index.codec.vectors.BQVectorUtils; +import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; import java.io.Closeable; import java.io.IOException; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OffHeapBinarizedVectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OffHeapBinarizedVectorValues.java index 72333169b39b..0357468c6864 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OffHeapBinarizedVectorValues.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OffHeapBinarizedVectorValues.java @@ -30,6 +30,7 @@ import org.apache.lucene.util.Bits; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.packed.DirectMonotonicReader; import org.elasticsearch.index.codec.vectors.BQVectorUtils; +import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; import java.io.IOException; import java.nio.ByteBuffer; diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizerTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizerTests.java similarity index 98% rename from server/src/test/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizerTests.java rename to server/src/test/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizerTests.java index 73405ecc6d4f..55171f48f373 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizerTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizerTests.java @@ -7,13 +7,13 @@ * License v3.0 only", or the "Server Side Public License, v 1". */ -package org.elasticsearch.index.codec.vectors.es818; +package org.elasticsearch.index.codec.vectors; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.VectorUtil; import org.elasticsearch.test.ESTestCase; -import static org.elasticsearch.index.codec.vectors.es818.OptimizedScalarQuantizer.MINIMUM_MSE_GRID; +import static org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer.MINIMUM_MSE_GRID; public class OptimizedScalarQuantizerTests extends ESTestCase { diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java index 308f712371c6..4e431cbe71ea 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java @@ -45,6 +45,7 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; import org.elasticsearch.common.logging.LogConfigurator; import org.elasticsearch.index.codec.vectors.BQVectorUtils; +import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; import org.elasticsearch.index.codec.vectors.reflect.OffHeapByteSizeUtils; import java.io.IOException;