Panama vector accelerated optimized scalar quantization (#127118)

* Adds accelerates optimized scalar quantization with vectorized functions

* Adding benchmark

* Update docs/changelog/127118.yaml

* adjusting benchmark and delta
This commit is contained in:
Benjamin Trent 2025-04-23 12:51:04 -04:00 committed by GitHub
parent ad0fe78e3e
commit 059f91c90c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 702 additions and 99 deletions

View file

@ -0,0 +1,78 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.benchmark.vector;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.elasticsearch.common.logging.LogConfigurator;
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Benchmark)
@Warmup(iterations = 3, time = 1)
@Measurement(iterations = 5, time = 1)
@Fork(value = 3)
public class OptimizedScalarQuantizerBenchmark {
static {
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
}
@Param({ "384", "702", "1024" })
int dims;
float[] vector;
float[] centroid;
byte[] destination;
@Param({ "1", "4", "7" })
byte bits;
OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(VectorSimilarityFunction.DOT_PRODUCT);
@Setup(Level.Iteration)
public void init() {
ThreadLocalRandom random = ThreadLocalRandom.current();
// random byte arrays for binary methods
destination = new byte[dims];
vector = new float[dims];
centroid = new float[dims];
for (int i = 0; i < dims; ++i) {
vector[i] = random.nextFloat();
centroid[i] = random.nextFloat();
}
}
@Benchmark
public byte[] scalar() {
osq.scalarQuantize(vector, destination, bits, centroid);
return destination;
}
@Benchmark
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public byte[] vector() {
osq.scalarQuantize(vector, destination, bits, centroid);
return destination;
}
}

View file

@ -0,0 +1,5 @@
pr: 127118
summary: Panama vector accelerated optimized scalar quantization
area: Vector Search
type: enhancement
issues: []

View file

@ -144,4 +144,71 @@ public class ESVectorUtil {
}
return distance;
}
/**
* Calculate the loss for optimized-scalar quantization for the given parameteres
* @param target The vector being quantized, assumed to be centered
* @param interval The interval for which to calculate the loss
* @param points the quantization points
* @param norm2 The norm squared of the target vector
* @param lambda The lambda parameter for controlling anisotropic loss calculation
* @return The loss for the given parameters
*/
public static float calculateOSQLoss(float[] target, float[] interval, int points, float norm2, float lambda) {
assert interval.length == 2;
float step = ((interval[1] - interval[0]) / (points - 1.0F));
float invStep = 1f / step;
return IMPL.calculateOSQLoss(target, interval, step, invStep, norm2, lambda);
}
/**
* Calculate the grid points for optimized-scalar quantization
* @param target The vector being quantized, assumed to be centered
* @param interval The interval for which to calculate the grid points
* @param points the quantization points
* @param pts The array to store the grid points, must be of length 5
*/
public static void calculateOSQGridPoints(float[] target, float[] interval, int points, float[] pts) {
assert interval.length == 2;
assert pts.length == 5;
float invStep = (points - 1.0F) / (interval[1] - interval[0]);
IMPL.calculateOSQGridPoints(target, interval, points, invStep, pts);
}
/**
* Center the target vector and calculate the optimized-scalar quantization statistics
* @param target The vector being quantized
* @param centroid The centroid of the target vector
* @param centered The destination of the centered vector, will be overwritten
* @param stats The array to store the statistics, must be of length 5
*/
public static void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats) {
assert target.length == centroid.length;
assert stats.length == 5;
if (target.length != centroid.length) {
throw new IllegalArgumentException("vector dimensions differ: " + target.length + "!=" + centroid.length);
}
if (centered.length != target.length) {
throw new IllegalArgumentException("vector dimensions differ: " + centered.length + "!=" + target.length);
}
IMPL.centerAndCalculateOSQStatsEuclidean(target, centroid, centered, stats);
}
/**
* Center the target vector and calculate the optimized-scalar quantization statistics
* @param target The vector being quantized
* @param centroid The centroid of the target vector
* @param centered The destination of the centered vector, will be overwritten
* @param stats The array to store the statistics, must be of length 6
*/
public static void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats) {
if (target.length != centroid.length) {
throw new IllegalArgumentException("vector dimensions differ: " + target.length + "!=" + centroid.length);
}
if (centered.length != target.length) {
throw new IllegalArgumentException("vector dimensions differ: " + centered.length + "!=" + target.length);
}
assert stats.length == 6;
IMPL.centerAndCalculateOSQStatsDp(target, centroid, centered, stats);
}
}

View file

@ -44,6 +44,100 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
return ipFloatByteImpl(q, d);
}
@Override
public float calculateOSQLoss(float[] target, float[] interval, float step, float invStep, float norm2, float lambda) {
float a = interval[0];
float b = interval[1];
float xe = 0f;
float e = 0f;
for (float xi : target) {
// this is quantizing and then dequantizing the vector
float xiq = fma(step, Math.round((Math.min(Math.max(xi, a), b) - a) * invStep), a);
// how much does the de-quantized value differ from the original value
float xiiq = xi - xiq;
e = fma(xiiq, xiiq, e);
xe = fma(xi, xiiq, xe);
}
return (1f - lambda) * xe * xe / norm2 + lambda * e;
}
@Override
public void calculateOSQGridPoints(float[] target, float[] interval, int points, float invStep, float[] pts) {
float a = interval[0];
float b = interval[1];
float daa = 0;
float dab = 0;
float dbb = 0;
float dax = 0;
float dbx = 0;
for (float v : target) {
float k = Math.round((Math.min(Math.max(v, a), b) - a) * invStep);
float s = k / (points - 1);
float ms = 1f - s;
daa = fma(ms, ms, daa);
dab = fma(ms, s, dab);
dbb = fma(s, s, dbb);
dax = fma(ms, v, dax);
dbx = fma(s, v, dbx);
}
pts[0] = daa;
pts[1] = dab;
pts[2] = dbb;
pts[3] = dax;
pts[4] = dbx;
}
@Override
public void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats) {
float vecMean = 0;
float vecVar = 0;
float norm2 = 0;
float min = Float.MAX_VALUE;
float max = -Float.MAX_VALUE;
for (int i = 0; i < target.length; i++) {
centered[i] = target[i] - centroid[i];
min = Math.min(min, centered[i]);
max = Math.max(max, centered[i]);
norm2 = fma(centered[i], centered[i], norm2);
float delta = centered[i] - vecMean;
vecMean += delta / (i + 1);
float delta2 = centered[i] - vecMean;
vecVar = fma(delta, delta2, vecVar);
}
stats[0] = vecMean;
stats[1] = vecVar / target.length;
stats[2] = norm2;
stats[3] = min;
stats[4] = max;
}
@Override
public void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats) {
float vecMean = 0;
float vecVar = 0;
float norm2 = 0;
float centroidDot = 0;
float min = Float.MAX_VALUE;
float max = -Float.MAX_VALUE;
for (int i = 0; i < target.length; i++) {
centroidDot = fma(target[i], centroid[i], centroidDot);
centered[i] = target[i] - centroid[i];
min = Math.min(min, centered[i]);
max = Math.max(max, centered[i]);
norm2 = fma(centered[i], centered[i], norm2);
float delta = centered[i] - vecMean;
vecMean += delta / (i + 1);
float delta2 = centered[i] - vecMean;
vecVar = fma(delta, delta2, vecVar);
}
stats[0] = vecMean;
stats[1] = vecVar / target.length;
stats[2] = norm2;
stats[3] = min;
stats[4] = max;
stats[5] = centroidDot;
}
public static int ipByteBitImpl(byte[] q, byte[] d) {
return ipByteBitImpl(q, d, 0);
}

View file

@ -20,4 +20,12 @@ public interface ESVectorUtilSupport {
float ipFloatBit(float[] q, byte[] d);
float ipFloatByte(float[] q, byte[] d);
float calculateOSQLoss(float[] target, float[] interval, float step, float invStep, float norm2, float lambda);
void calculateOSQGridPoints(float[] target, float[] interval, int points, float invStep, float[] pts);
void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats);
void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats);
}

View file

@ -13,6 +13,7 @@ import jdk.incubator.vector.ByteVector;
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.IntVector;
import jdk.incubator.vector.LongVector;
import jdk.incubator.vector.Vector;
import jdk.incubator.vector.VectorMask;
import jdk.incubator.vector.VectorOperators;
import jdk.incubator.vector.VectorShape;
@ -21,16 +22,22 @@ import jdk.incubator.vector.VectorSpecies;
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.Constants;
import static jdk.incubator.vector.VectorOperators.ADD;
import static jdk.incubator.vector.VectorOperators.MAX;
import static jdk.incubator.vector.VectorOperators.MIN;
public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
static final int VECTOR_BITSIZE;
private static final VectorSpecies<Float> FLOAT_SPECIES;
/** Whether integer vectors can be trusted to actually be fast. */
static final boolean HAS_FAST_INTEGER_VECTORS;
static {
// default to platform supported bitsize
VECTOR_BITSIZE = VectorShape.preferredShape().vectorBitSize();
FLOAT_SPECIES = VectorSpecies.of(float.class, VectorShape.forBitSize(VECTOR_BITSIZE));
// hotspot misses some SSE intrinsics, workaround it
// to be fair, they do document this thing only works well with AVX2/AVX3 and Neon
@ -38,6 +45,22 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
HAS_FAST_INTEGER_VECTORS = isAMD64withoutAVX2 == false;
}
private static FloatVector fma(FloatVector a, FloatVector b, FloatVector c) {
if (Constants.HAS_FAST_VECTOR_FMA) {
return a.fma(b, c);
} else {
return a.mul(b).add(c);
}
}
private static float fma(float a, float b, float c) {
if (Constants.HAS_FAST_SCALAR_FMA) {
return Math.fma(a, b, c);
} else {
return a * b + c;
}
}
@Override
public long ipByteBinByte(byte[] q, byte[] d) {
// 128 / 8 == 16
@ -83,6 +106,267 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
return DefaultESVectorUtilSupport.ipFloatByteImpl(q, d);
}
@Override
public void centerAndCalculateOSQStatsEuclidean(float[] vector, float[] centroid, float[] centered, float[] stats) {
assert vector.length == centroid.length;
assert vector.length == centered.length;
float vecMean = 0;
float vecVar = 0;
float norm2 = 0;
float min = Float.MAX_VALUE;
float max = -Float.MAX_VALUE;
int i = 0;
int vectCount = 0;
if (vector.length > 2 * FLOAT_SPECIES.length()) {
FloatVector vecMeanVec = FloatVector.zero(FLOAT_SPECIES);
FloatVector m2Vec = FloatVector.zero(FLOAT_SPECIES);
FloatVector norm2Vec = FloatVector.zero(FLOAT_SPECIES);
FloatVector minVec = FloatVector.broadcast(FLOAT_SPECIES, Float.MAX_VALUE);
FloatVector maxVec = FloatVector.broadcast(FLOAT_SPECIES, -Float.MAX_VALUE);
int count = 0;
for (; i < FLOAT_SPECIES.loopBound(vector.length); i += FLOAT_SPECIES.length()) {
++count;
FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, vector, i);
FloatVector c = FloatVector.fromArray(FLOAT_SPECIES, centroid, i);
FloatVector centeredVec = v.sub(c);
FloatVector deltaVec = centeredVec.sub(vecMeanVec);
norm2Vec = fma(centeredVec, centeredVec, norm2Vec);
vecMeanVec = vecMeanVec.add(deltaVec.div(count));
FloatVector delta2Vec = centeredVec.sub(vecMeanVec);
m2Vec = fma(deltaVec, delta2Vec, m2Vec);
minVec = minVec.min(centeredVec);
maxVec = maxVec.max(centeredVec);
centeredVec.intoArray(centered, i);
}
min = minVec.reduceLanes(MIN);
max = maxVec.reduceLanes(MAX);
norm2 = norm2Vec.reduceLanes(ADD);
vecMean = vecMeanVec.reduceLanes(ADD) / FLOAT_SPECIES.length();
FloatVector d2Mean = vecMeanVec.sub(vecMean);
m2Vec = fma(d2Mean, d2Mean, m2Vec);
vectCount = count * FLOAT_SPECIES.length();
vecVar = m2Vec.reduceLanes(ADD);
}
float tailMean = 0;
float tailM2 = 0;
int tailCount = 0;
// handle the tail
for (; i < vector.length; i++) {
centered[i] = vector[i] - centroid[i];
float delta = centered[i] - tailMean;
++tailCount;
tailMean += delta / tailCount;
float delta2 = centered[i] - tailMean;
tailM2 = fma(delta, delta2, tailM2);
min = Math.min(min, centered[i]);
max = Math.max(max, centered[i]);
norm2 = fma(centered[i], centered[i], norm2);
}
if (vectCount == 0) {
vecMean = tailMean;
vecVar = tailM2;
} else if (tailCount > 0) {
int totalCount = tailCount + vectCount;
assert totalCount == vector.length;
float alpha = (float) vectCount / totalCount;
float beta = 1f - alpha;
float completeMean = alpha * vecMean + beta * tailMean;
float dMean2Lhs = (vecMean - completeMean) * (vecMean - completeMean);
float dMean2Rhs = (tailMean - completeMean) * (tailMean - completeMean);
vecVar = (vecVar + dMean2Lhs) + beta * (tailM2 + dMean2Rhs);
vecMean = completeMean;
}
stats[0] = vecMean;
stats[1] = vecVar / vector.length;
stats[2] = norm2;
stats[3] = min;
stats[4] = max;
}
@Override
public void centerAndCalculateOSQStatsDp(float[] vector, float[] centroid, float[] centered, float[] stats) {
assert vector.length == centroid.length;
assert vector.length == centered.length;
float vecMean = 0;
float vecVar = 0;
float norm2 = 0;
float min = Float.MAX_VALUE;
float max = -Float.MAX_VALUE;
float centroidDot = 0;
int i = 0;
int vectCount = 0;
int loopBound = FLOAT_SPECIES.loopBound(vector.length);
if (vector.length > 2 * FLOAT_SPECIES.length()) {
FloatVector vecMeanVec = FloatVector.zero(FLOAT_SPECIES);
FloatVector m2Vec = FloatVector.zero(FLOAT_SPECIES);
FloatVector norm2Vec = FloatVector.zero(FLOAT_SPECIES);
FloatVector minVec = FloatVector.broadcast(FLOAT_SPECIES, Float.MAX_VALUE);
FloatVector maxVec = FloatVector.broadcast(FLOAT_SPECIES, -Float.MAX_VALUE);
FloatVector centroidDotVec = FloatVector.zero(FLOAT_SPECIES);
int count = 0;
for (; i < loopBound; i += FLOAT_SPECIES.length()) {
++count;
FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, vector, i);
FloatVector c = FloatVector.fromArray(FLOAT_SPECIES, centroid, i);
centroidDotVec = fma(v, c, centroidDotVec);
FloatVector centeredVec = v.sub(c);
FloatVector deltaVec = centeredVec.sub(vecMeanVec);
norm2Vec = fma(centeredVec, centeredVec, norm2Vec);
vecMeanVec = vecMeanVec.add(deltaVec.div(count));
FloatVector delta2Vec = centeredVec.sub(vecMeanVec);
m2Vec = fma(deltaVec, delta2Vec, m2Vec);
minVec = minVec.min(centeredVec);
maxVec = maxVec.max(centeredVec);
centeredVec.intoArray(centered, i);
}
min = minVec.reduceLanes(MIN);
max = maxVec.reduceLanes(MAX);
norm2 = norm2Vec.reduceLanes(ADD);
centroidDot = centroidDotVec.reduceLanes(ADD);
vecMean = vecMeanVec.reduceLanes(ADD) / FLOAT_SPECIES.length();
FloatVector d2Mean = vecMeanVec.sub(vecMean);
m2Vec = fma(d2Mean, d2Mean, m2Vec);
vectCount = count * FLOAT_SPECIES.length();
vecVar = m2Vec.reduceLanes(ADD);
}
float tailMean = 0;
float tailM2 = 0;
int tailCount = 0;
// handle the tail
for (; i < vector.length; i++) {
centroidDot = fma(vector[i], centroid[i], centroidDot);
centered[i] = vector[i] - centroid[i];
float delta = centered[i] - tailMean;
++tailCount;
tailMean += delta / tailCount;
float delta2 = centered[i] - tailMean;
tailM2 = fma(delta, delta2, tailM2);
min = Math.min(min, centered[i]);
max = Math.max(max, centered[i]);
norm2 = fma(centered[i], centered[i], norm2);
}
if (vectCount == 0) {
vecMean = tailMean;
vecVar = tailM2;
} else if (tailCount > 0) {
int totalCount = tailCount + vectCount;
assert totalCount == vector.length;
float alpha = (float) vectCount / totalCount;
float beta = 1f - alpha;
float completeMean = alpha * vecMean + beta * tailMean;
float dMean2Lhs = (vecMean - completeMean) * (vecMean - completeMean);
float dMean2Rhs = (tailMean - completeMean) * (tailMean - completeMean);
vecVar = (vecVar + dMean2Lhs) + beta * (tailM2 + dMean2Rhs);
vecMean = completeMean;
}
stats[0] = vecMean;
stats[1] = vecVar / vector.length;
stats[2] = norm2;
stats[3] = min;
stats[4] = max;
stats[5] = centroidDot;
}
@Override
public void calculateOSQGridPoints(float[] target, float[] interval, int points, float invStep, float[] pts) {
float a = interval[0];
float b = interval[1];
int i = 0;
float daa = 0;
float dab = 0;
float dbb = 0;
float dax = 0;
float dbx = 0;
FloatVector daaVec = FloatVector.zero(FLOAT_SPECIES);
FloatVector dabVec = FloatVector.zero(FLOAT_SPECIES);
FloatVector dbbVec = FloatVector.zero(FLOAT_SPECIES);
FloatVector daxVec = FloatVector.zero(FLOAT_SPECIES);
FloatVector dbxVec = FloatVector.zero(FLOAT_SPECIES);
// if the array size is large (> 2x platform vector size), it's worth the overhead to vectorize
if (target.length > 2 * FLOAT_SPECIES.length()) {
FloatVector ones = FloatVector.broadcast(FLOAT_SPECIES, 1f);
FloatVector pmOnes = FloatVector.broadcast(FLOAT_SPECIES, points - 1f);
for (; i < FLOAT_SPECIES.loopBound(target.length); i += FLOAT_SPECIES.length()) {
FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, target, i);
FloatVector vClamped = v.max(a).min(b);
Vector<Integer> xiqint = vClamped.sub(a)
.mul(invStep)
// round
.add(0.5f)
.convert(VectorOperators.F2I, 0);
FloatVector kVec = xiqint.convert(VectorOperators.I2F, 0).reinterpretAsFloats();
FloatVector sVec = kVec.div(pmOnes);
FloatVector smVec = ones.sub(sVec);
daaVec = fma(smVec, smVec, daaVec);
dabVec = fma(smVec, sVec, dabVec);
dbbVec = fma(sVec, sVec, dbbVec);
daxVec = fma(v, smVec, daxVec);
dbxVec = fma(v, sVec, dbxVec);
}
daa = daaVec.reduceLanes(ADD);
dab = dabVec.reduceLanes(ADD);
dbb = dbbVec.reduceLanes(ADD);
dax = daxVec.reduceLanes(ADD);
dbx = dbxVec.reduceLanes(ADD);
}
for (; i < target.length; i++) {
float k = Math.round((Math.min(Math.max(target[i], a), b) - a) * invStep);
float s = k / (points - 1);
float ms = 1f - s;
daa = fma(ms, ms, daa);
dab = fma(ms, s, dab);
dbb = fma(s, s, dbb);
dax = fma(ms, target[i], dax);
dbx = fma(s, target[i], dbx);
}
pts[0] = daa;
pts[1] = dab;
pts[2] = dbb;
pts[3] = dax;
pts[4] = dbx;
}
@Override
public float calculateOSQLoss(float[] target, float[] interval, float step, float invStep, float norm2, float lambda) {
float a = interval[0];
float b = interval[1];
float xe = 0f;
float e = 0f;
FloatVector xeVec = FloatVector.zero(FLOAT_SPECIES);
FloatVector eVec = FloatVector.zero(FLOAT_SPECIES);
int i = 0;
// if the array size is large (> 2x platform vector size), it's worth the overhead to vectorize
if (target.length > 2 * FLOAT_SPECIES.length()) {
for (; i < FLOAT_SPECIES.loopBound(target.length); i += FLOAT_SPECIES.length()) {
FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, target, i);
FloatVector vClamped = v.max(a).min(b);
Vector<Integer> xiqint = vClamped.sub(a).mul(invStep).add(0.5f).convert(VectorOperators.F2I, 0);
FloatVector xiq = xiqint.convert(VectorOperators.I2F, 0).reinterpretAsFloats().mul(step).add(a);
FloatVector xiiq = v.sub(xiq);
xeVec = fma(v, xiiq, xeVec);
eVec = fma(xiiq, xiiq, eVec);
}
e = eVec.reduceLanes(ADD);
xe = xeVec.reduceLanes(ADD);
}
for (; i < target.length; i++) {
// this is quantizing and then dequantizing the vector
float xiq = fma(step, Math.round((Math.min(Math.max(target[i], a), b) - a) * invStep), a);
// how much does the de-quantized value differ from the original value
float xiiq = target[i] - xiq;
e = fma(xiiq, xiiq, e);
xe = fma(target[i], xiiq, xe);
}
return (1f - lambda) * xe * xe / norm2 + lambda * e;
}
private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
private static final VectorSpecies<Byte> BYTE_SPECIES_256 = ByteVector.SPECIES_256;

View file

@ -9,6 +9,7 @@
package org.elasticsearch.simdvec;
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
import org.elasticsearch.simdvec.internal.vectorization.BaseVectorizationTests;
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
@ -161,6 +162,112 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
testIpByteBinImpl(defOrPanamaProvider.getVectorUtilSupport()::ipByteBinByte);
}
public void testCenterAndCalculateOSQStatsDp() {
int size = random().nextInt(128, 512);
float delta = 1e-3f * size;
var vector = new float[size];
var centroid = new float[size];
for (int i = 0; i < size; ++i) {
vector[i] = random().nextFloat();
centroid[i] = random().nextFloat();
}
var centeredLucene = new float[size];
var statsLucene = new float[6];
defaultedProvider.getVectorUtilSupport().centerAndCalculateOSQStatsDp(vector, centroid, centeredLucene, statsLucene);
var centeredPanama = new float[size];
var statsPanama = new float[6];
defOrPanamaProvider.getVectorUtilSupport().centerAndCalculateOSQStatsDp(vector, centroid, centeredPanama, statsPanama);
assertArrayEquals(centeredLucene, centeredPanama, delta);
assertArrayEquals(statsLucene, statsPanama, delta);
}
public void testCenterAndCalculateOSQStatsEuclidean() {
int size = random().nextInt(128, 512);
float delta = 1e-3f * size;
var vector = new float[size];
var centroid = new float[size];
for (int i = 0; i < size; ++i) {
vector[i] = random().nextFloat();
centroid[i] = random().nextFloat();
}
var centeredLucene = new float[size];
var statsLucene = new float[5];
defaultedProvider.getVectorUtilSupport().centerAndCalculateOSQStatsEuclidean(vector, centroid, centeredLucene, statsLucene);
var centeredPanama = new float[size];
var statsPanama = new float[5];
defOrPanamaProvider.getVectorUtilSupport().centerAndCalculateOSQStatsEuclidean(vector, centroid, centeredPanama, statsPanama);
assertArrayEquals(centeredLucene, centeredPanama, delta);
assertArrayEquals(statsLucene, statsPanama, delta);
}
public void testOsqLoss() {
int size = random().nextInt(128, 512);
float deltaEps = 1e-5f * size;
var vector = new float[size];
var min = Float.MAX_VALUE;
var max = -Float.MAX_VALUE;
float vecMean = 0;
float vecVar = 0;
float norm2 = 0;
for (int i = 0; i < size; ++i) {
vector[i] = random().nextFloat();
min = Math.min(min, vector[i]);
max = Math.max(max, vector[i]);
float delta = vector[i] - vecMean;
vecMean += delta / (i + 1);
float delta2 = vector[i] - vecMean;
vecVar += delta * delta2;
norm2 += vector[i] * vector[i];
}
vecVar /= size;
float vecStd = (float) Math.sqrt(vecVar);
for (byte bits : new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 }) {
int points = 1 << bits;
float[] initInterval = new float[2];
OptimizedScalarQuantizer.initInterval(bits, vecStd, vecMean, min, max, initInterval);
float step = ((initInterval[1] - initInterval[0]) / (points - 1f));
float stepInv = 1f / step;
float expected = defaultedProvider.getVectorUtilSupport().calculateOSQLoss(vector, initInterval, step, stepInv, norm2, 0.1f);
float result = defOrPanamaProvider.getVectorUtilSupport().calculateOSQLoss(vector, initInterval, step, stepInv, norm2, 0.1f);
assertEquals(expected, result, deltaEps);
}
}
public void testOsqGridPoints() {
int size = random().nextInt(128, 512);
float deltaEps = 1e-5f * size;
var vector = new float[size];
var min = Float.MAX_VALUE;
var max = -Float.MAX_VALUE;
float vecMean = 0;
float vecVar = 0;
for (int i = 0; i < size; ++i) {
vector[i] = random().nextFloat();
min = Math.min(min, vector[i]);
max = Math.max(max, vector[i]);
float delta = vector[i] - vecMean;
vecMean += delta / (i + 1);
float delta2 = vector[i] - vecMean;
vecVar += delta * delta2;
}
vecVar /= size;
float vecStd = (float) Math.sqrt(vecVar);
for (byte bits : new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 }) {
int points = 1 << bits;
float[] initInterval = new float[2];
OptimizedScalarQuantizer.initInterval(bits, vecStd, vecMean, min, max, initInterval);
float step = ((initInterval[1] - initInterval[0]) / (points - 1f));
float stepInv = 1f / step;
float[] expected = new float[5];
defaultedProvider.getVectorUtilSupport().calculateOSQGridPoints(vector, initInterval, points, stepInv, expected);
float[] result = new float[5];
defOrPanamaProvider.getVectorUtilSupport().calculateOSQGridPoints(vector, initInterval, points, stepInv, result);
assertArrayEquals(expected, result, deltaEps);
}
}
void testIpByteBinImpl(ToLongBiFunction<byte[], byte[]> ipByteBinFunc) {
int iterations = atLeast(50);
for (int i = 0; i < iterations; i++) {

View file

@ -7,15 +7,21 @@
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.index.codec.vectors.es818;
package org.elasticsearch.index.codec.vectors;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.simdvec.ESVectorUtil;
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
class OptimizedScalarQuantizer {
public class OptimizedScalarQuantizer {
public static void initInterval(byte bits, float vecStd, float vecMean, float min, float max, float[] initInterval) {
initInterval[0] = (float) clamp(MINIMUM_MSE_GRID[bits - 1][0] * vecStd + vecMean, min, max);
initInterval[1] = (float) clamp(MINIMUM_MSE_GRID[bits - 1][1] * vecStd + vecMean, min, max);
}
// The initial interval is set to the minimum MSE grid for each number of bits
// these starting points are derived from the optimal MSE grid for a uniform distribution
static final float[][] MINIMUM_MSE_GRID = new float[][] {
@ -27,19 +33,25 @@ class OptimizedScalarQuantizer {
{ -3.278f, 3.278f },
{ -3.611f, 3.611f },
{ -3.922f, 3.922f } };
private static final float DEFAULT_LAMBDA = 0.1f;
public static final float DEFAULT_LAMBDA = 0.1f;
private static final int DEFAULT_ITERS = 5;
private final VectorSimilarityFunction similarityFunction;
private final float lambda;
private final int iters;
private final float[] statsScratch;
private final float[] gridScratch;
private final float[] intervalScratch;
OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction, float lambda, int iters) {
public OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction, float lambda, int iters) {
this.similarityFunction = similarityFunction;
this.lambda = lambda;
this.iters = iters;
this.statsScratch = new float[similarityFunction == EUCLIDEAN ? 5 : 6];
this.gridScratch = new float[5];
this.intervalScratch = new float[2];
}
OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction) {
public OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction) {
this(similarityFunction, DEFAULT_LAMBDA, DEFAULT_ITERS);
}
@ -49,34 +61,23 @@ class OptimizedScalarQuantizer {
assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector);
assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
assert bits.length == destinations.length;
float[] intervalScratch = new float[2];
double vecMean = 0;
double vecVar = 0;
float norm2 = 0;
float centroidDot = 0;
float min = Float.MAX_VALUE;
float max = -Float.MAX_VALUE;
for (int i = 0; i < vector.length; ++i) {
if (similarityFunction != EUCLIDEAN) {
centroidDot += vector[i] * centroid[i];
}
vector[i] = vector[i] - centroid[i];
min = Math.min(min, vector[i]);
max = Math.max(max, vector[i]);
norm2 += (vector[i] * vector[i]);
double delta = vector[i] - vecMean;
vecMean += delta / (i + 1);
vecVar += delta * (vector[i] - vecMean);
if (similarityFunction == EUCLIDEAN) {
ESVectorUtil.centerAndCalculateOSQStatsEuclidean(vector, centroid, vector, statsScratch);
} else {
ESVectorUtil.centerAndCalculateOSQStatsDp(vector, centroid, vector, statsScratch);
}
vecVar /= vector.length;
double vecStd = Math.sqrt(vecVar);
float vecMean = statsScratch[0];
float vecVar = statsScratch[1];
float norm2 = statsScratch[2];
float min = statsScratch[3];
float max = statsScratch[4];
float vecStd = (float) Math.sqrt(vecVar);
QuantizationResult[] results = new QuantizationResult[bits.length];
for (int i = 0; i < bits.length; ++i) {
assert bits[i] > 0 && bits[i] <= 8;
int points = (1 << bits[i]);
// Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds
intervalScratch[0] = (float) clamp(MINIMUM_MSE_GRID[bits[i] - 1][0] * vecStd + vecMean, min, max);
intervalScratch[1] = (float) clamp(MINIMUM_MSE_GRID[bits[i] - 1][1] * vecStd + vecMean, min, max);
initInterval(bits[i], vecStd, vecMean, min, max, intervalScratch);
optimizeIntervals(intervalScratch, vector, norm2, points);
float nSteps = ((1 << bits[i]) - 1);
float a = intervalScratch[0];
@ -93,7 +94,7 @@ class OptimizedScalarQuantizer {
results[i] = new QuantizationResult(
intervalScratch[0],
intervalScratch[1],
similarityFunction == EUCLIDEAN ? norm2 : centroidDot,
similarityFunction == EUCLIDEAN ? norm2 : statsScratch[5],
sumQuery
);
}
@ -105,31 +106,20 @@ class OptimizedScalarQuantizer {
assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
assert vector.length <= destination.length;
assert bits > 0 && bits <= 8;
float[] intervalScratch = new float[2];
int points = 1 << bits;
double vecMean = 0;
double vecVar = 0;
float norm2 = 0;
float centroidDot = 0;
float min = Float.MAX_VALUE;
float max = -Float.MAX_VALUE;
for (int i = 0; i < vector.length; ++i) {
if (similarityFunction != EUCLIDEAN) {
centroidDot += vector[i] * centroid[i];
}
vector[i] = vector[i] - centroid[i];
min = Math.min(min, vector[i]);
max = Math.max(max, vector[i]);
norm2 += (vector[i] * vector[i]);
double delta = vector[i] - vecMean;
vecMean += delta / (i + 1);
vecVar += delta * (vector[i] - vecMean);
if (similarityFunction == EUCLIDEAN) {
ESVectorUtil.centerAndCalculateOSQStatsEuclidean(vector, centroid, vector, statsScratch);
} else {
ESVectorUtil.centerAndCalculateOSQStatsDp(vector, centroid, vector, statsScratch);
}
vecVar /= vector.length;
double vecStd = Math.sqrt(vecVar);
float vecMean = statsScratch[0];
float vecVar = statsScratch[1];
float norm2 = statsScratch[2];
float min = statsScratch[3];
float max = statsScratch[4];
float vecStd = (float) Math.sqrt(vecVar);
// Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds
intervalScratch[0] = (float) clamp(MINIMUM_MSE_GRID[bits - 1][0] * vecStd + vecMean, min, max);
intervalScratch[1] = (float) clamp(MINIMUM_MSE_GRID[bits - 1][1] * vecStd + vecMean, min, max);
initInterval(bits, vecStd, vecMean, min, max, intervalScratch);
optimizeIntervals(intervalScratch, vector, norm2, points);
float nSteps = ((1 << bits) - 1);
// Now we have the optimized intervals, quantize the vector
@ -146,37 +136,11 @@ class OptimizedScalarQuantizer {
return new QuantizationResult(
intervalScratch[0],
intervalScratch[1],
similarityFunction == EUCLIDEAN ? norm2 : centroidDot,
similarityFunction == EUCLIDEAN ? norm2 : statsScratch[5],
sumQuery
);
}
/**
* Compute the loss of the vector given the interval. Effectively, we are computing the MSE of a dequantized vector with the raw
* vector.
* @param vector raw vector
* @param interval interval to quantize the vector
* @param points number of quantization points
* @param norm2 squared norm of the vector
* @return the loss
*/
private double loss(float[] vector, float[] interval, int points, float norm2) {
double a = interval[0];
double b = interval[1];
double step = ((b - a) / (points - 1.0F));
double stepInv = 1.0 / step;
double xe = 0.0;
double e = 0.0;
for (double xi : vector) {
// this is quantizing and then dequantizing the vector
double xiq = (a + step * Math.round((clamp(xi, a, b) - a) * stepInv));
// how much does the de-quantized value differ from the original value
xe += xi * (xi - xiq);
e += (xi - xiq) * (xi - xiq);
}
return (1.0 - lambda) * xe * xe / norm2 + lambda * e;
}
/**
* Optimize the quantization interval for the given vector. This is done via a coordinate descent trying to minimize the quantization
* loss. Note, the loss is not always guaranteed to decrease, so we have a maximum number of iterations and will exit early if the
@ -187,30 +151,19 @@ class OptimizedScalarQuantizer {
* @param points number of quantization points
*/
private void optimizeIntervals(float[] initInterval, float[] vector, float norm2, int points) {
double initialLoss = loss(vector, initInterval, points, norm2);
double initialLoss = ESVectorUtil.calculateOSQLoss(vector, initInterval, points, norm2, lambda);
final float scale = (1.0f - lambda) / norm2;
if (Float.isFinite(scale) == false) {
return;
}
for (int i = 0; i < iters; ++i) {
float a = initInterval[0];
float b = initInterval[1];
float stepInv = (points - 1.0f) / (b - a);
// calculate the grid points for coordinate descent
double daa = 0;
double dab = 0;
double dbb = 0;
double dax = 0;
double dbx = 0;
for (float xi : vector) {
float k = Math.round((clamp(xi, a, b) - a) * stepInv);
float s = k / (points - 1);
daa += (1.0 - s) * (1.0 - s);
dab += (1.0 - s) * s;
dbb += s * s;
dax += xi * (1.0 - s);
dbx += xi * s;
}
ESVectorUtil.calculateOSQGridPoints(vector, initInterval, points, gridScratch);
float daa = gridScratch[0];
float dab = gridScratch[1];
float dbb = gridScratch[2];
float dax = gridScratch[3];
float dbx = gridScratch[4];
double m0 = scale * dax * dax + lambda * daa;
double m1 = scale * dax * dbx + lambda * dab;
double m2 = scale * dbx * dbx + lambda * dbb;
@ -225,7 +178,7 @@ class OptimizedScalarQuantizer {
if ((Math.abs(initInterval[0] - aOpt) < 1e-8 && Math.abs(initInterval[1] - bOpt) < 1e-8)) {
return;
}
double newLoss = loss(vector, new float[] { aOpt, bOpt }, points, norm2);
double newLoss = ESVectorUtil.calculateOSQLoss(vector, new float[] { aOpt, bOpt }, points, norm2, lambda);
// If the new loss is worse, don't update the interval and exit
// This optimization, unlike kMeans, does not always converge to better loss
// So exit if we are getting worse

View file

@ -23,6 +23,7 @@ import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
import java.io.IOException;

View file

@ -28,6 +28,7 @@ import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.elasticsearch.index.codec.vectors.BQSpaceUtils;
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
import org.elasticsearch.simdvec.ESVectorUtil;
import java.io.IOException;

View file

@ -26,6 +26,7 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
import java.io.IOException;
@ -35,7 +36,7 @@ import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_
* Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10
* Codec for encoding/decoding binary quantized vectors The binary quantization format used here
* is a per-vector optimized scalar quantization. Also see {@link
* org.elasticsearch.index.codec.vectors.es818.OptimizedScalarQuantizer}. Some of key features are:
* OptimizedScalarQuantizer}. Some of key features are:
*
* <ul>
* <li>Estimating the distance between two vectors using their centroid normalized distance. This

View file

@ -44,6 +44,7 @@ import org.apache.lucene.util.SuppressForbidden;
import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
import org.elasticsearch.index.codec.vectors.reflect.OffHeapByteSizeUtils;
import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats;

View file

@ -49,6 +49,7 @@ import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.index.codec.vectors.BQSpaceUtils;
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
import java.io.Closeable;
import java.io.IOException;

View file

@ -30,6 +30,7 @@ import org.apache.lucene.util.Bits;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.packed.DirectMonotonicReader;
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
import java.io.IOException;
import java.nio.ByteBuffer;

View file

@ -7,13 +7,13 @@
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.index.codec.vectors.es818;
package org.elasticsearch.index.codec.vectors;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.test.ESTestCase;
import static org.elasticsearch.index.codec.vectors.es818.OptimizedScalarQuantizer.MINIMUM_MSE_GRID;
import static org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer.MINIMUM_MSE_GRID;
public class OptimizedScalarQuantizerTests extends ESTestCase {

View file

@ -45,6 +45,7 @@ import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
import org.elasticsearch.common.logging.LogConfigurator;
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
import org.elasticsearch.index.codec.vectors.reflect.OffHeapByteSizeUtils;
import java.io.IOException;