mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 09:28:55 -04:00
Panama vector accelerated optimized scalar quantization (#127118)
* Adds accelerates optimized scalar quantization with vectorized functions * Adding benchmark * Update docs/changelog/127118.yaml * adjusting benchmark and delta
This commit is contained in:
parent
ad0fe78e3e
commit
059f91c90c
16 changed files with 702 additions and 99 deletions
|
@ -0,0 +1,78 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the "Elastic License
|
||||||
|
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||||
|
* Public License v 1"; you may not use this file except in compliance with, at
|
||||||
|
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||||
|
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.benchmark.vector;
|
||||||
|
|
||||||
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.elasticsearch.common.logging.LogConfigurator;
|
||||||
|
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
|
||||||
|
import org.openjdk.jmh.annotations.Benchmark;
|
||||||
|
import org.openjdk.jmh.annotations.BenchmarkMode;
|
||||||
|
import org.openjdk.jmh.annotations.Fork;
|
||||||
|
import org.openjdk.jmh.annotations.Level;
|
||||||
|
import org.openjdk.jmh.annotations.Measurement;
|
||||||
|
import org.openjdk.jmh.annotations.Mode;
|
||||||
|
import org.openjdk.jmh.annotations.OutputTimeUnit;
|
||||||
|
import org.openjdk.jmh.annotations.Param;
|
||||||
|
import org.openjdk.jmh.annotations.Scope;
|
||||||
|
import org.openjdk.jmh.annotations.Setup;
|
||||||
|
import org.openjdk.jmh.annotations.State;
|
||||||
|
import org.openjdk.jmh.annotations.Warmup;
|
||||||
|
|
||||||
|
import java.util.concurrent.ThreadLocalRandom;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
|
@BenchmarkMode(Mode.Throughput)
|
||||||
|
@OutputTimeUnit(TimeUnit.MILLISECONDS)
|
||||||
|
@State(Scope.Benchmark)
|
||||||
|
@Warmup(iterations = 3, time = 1)
|
||||||
|
@Measurement(iterations = 5, time = 1)
|
||||||
|
@Fork(value = 3)
|
||||||
|
public class OptimizedScalarQuantizerBenchmark {
|
||||||
|
static {
|
||||||
|
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
|
||||||
|
}
|
||||||
|
@Param({ "384", "702", "1024" })
|
||||||
|
int dims;
|
||||||
|
|
||||||
|
float[] vector;
|
||||||
|
float[] centroid;
|
||||||
|
byte[] destination;
|
||||||
|
|
||||||
|
@Param({ "1", "4", "7" })
|
||||||
|
byte bits;
|
||||||
|
|
||||||
|
OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(VectorSimilarityFunction.DOT_PRODUCT);
|
||||||
|
|
||||||
|
@Setup(Level.Iteration)
|
||||||
|
public void init() {
|
||||||
|
ThreadLocalRandom random = ThreadLocalRandom.current();
|
||||||
|
// random byte arrays for binary methods
|
||||||
|
destination = new byte[dims];
|
||||||
|
vector = new float[dims];
|
||||||
|
centroid = new float[dims];
|
||||||
|
for (int i = 0; i < dims; ++i) {
|
||||||
|
vector[i] = random.nextFloat();
|
||||||
|
centroid[i] = random.nextFloat();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Benchmark
|
||||||
|
public byte[] scalar() {
|
||||||
|
osq.scalarQuantize(vector, destination, bits, centroid);
|
||||||
|
return destination;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Benchmark
|
||||||
|
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
|
||||||
|
public byte[] vector() {
|
||||||
|
osq.scalarQuantize(vector, destination, bits, centroid);
|
||||||
|
return destination;
|
||||||
|
}
|
||||||
|
}
|
5
docs/changelog/127118.yaml
Normal file
5
docs/changelog/127118.yaml
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
pr: 127118
|
||||||
|
summary: Panama vector accelerated optimized scalar quantization
|
||||||
|
area: Vector Search
|
||||||
|
type: enhancement
|
||||||
|
issues: []
|
|
@ -144,4 +144,71 @@ public class ESVectorUtil {
|
||||||
}
|
}
|
||||||
return distance;
|
return distance;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate the loss for optimized-scalar quantization for the given parameteres
|
||||||
|
* @param target The vector being quantized, assumed to be centered
|
||||||
|
* @param interval The interval for which to calculate the loss
|
||||||
|
* @param points the quantization points
|
||||||
|
* @param norm2 The norm squared of the target vector
|
||||||
|
* @param lambda The lambda parameter for controlling anisotropic loss calculation
|
||||||
|
* @return The loss for the given parameters
|
||||||
|
*/
|
||||||
|
public static float calculateOSQLoss(float[] target, float[] interval, int points, float norm2, float lambda) {
|
||||||
|
assert interval.length == 2;
|
||||||
|
float step = ((interval[1] - interval[0]) / (points - 1.0F));
|
||||||
|
float invStep = 1f / step;
|
||||||
|
return IMPL.calculateOSQLoss(target, interval, step, invStep, norm2, lambda);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate the grid points for optimized-scalar quantization
|
||||||
|
* @param target The vector being quantized, assumed to be centered
|
||||||
|
* @param interval The interval for which to calculate the grid points
|
||||||
|
* @param points the quantization points
|
||||||
|
* @param pts The array to store the grid points, must be of length 5
|
||||||
|
*/
|
||||||
|
public static void calculateOSQGridPoints(float[] target, float[] interval, int points, float[] pts) {
|
||||||
|
assert interval.length == 2;
|
||||||
|
assert pts.length == 5;
|
||||||
|
float invStep = (points - 1.0F) / (interval[1] - interval[0]);
|
||||||
|
IMPL.calculateOSQGridPoints(target, interval, points, invStep, pts);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Center the target vector and calculate the optimized-scalar quantization statistics
|
||||||
|
* @param target The vector being quantized
|
||||||
|
* @param centroid The centroid of the target vector
|
||||||
|
* @param centered The destination of the centered vector, will be overwritten
|
||||||
|
* @param stats The array to store the statistics, must be of length 5
|
||||||
|
*/
|
||||||
|
public static void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats) {
|
||||||
|
assert target.length == centroid.length;
|
||||||
|
assert stats.length == 5;
|
||||||
|
if (target.length != centroid.length) {
|
||||||
|
throw new IllegalArgumentException("vector dimensions differ: " + target.length + "!=" + centroid.length);
|
||||||
|
}
|
||||||
|
if (centered.length != target.length) {
|
||||||
|
throw new IllegalArgumentException("vector dimensions differ: " + centered.length + "!=" + target.length);
|
||||||
|
}
|
||||||
|
IMPL.centerAndCalculateOSQStatsEuclidean(target, centroid, centered, stats);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Center the target vector and calculate the optimized-scalar quantization statistics
|
||||||
|
* @param target The vector being quantized
|
||||||
|
* @param centroid The centroid of the target vector
|
||||||
|
* @param centered The destination of the centered vector, will be overwritten
|
||||||
|
* @param stats The array to store the statistics, must be of length 6
|
||||||
|
*/
|
||||||
|
public static void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats) {
|
||||||
|
if (target.length != centroid.length) {
|
||||||
|
throw new IllegalArgumentException("vector dimensions differ: " + target.length + "!=" + centroid.length);
|
||||||
|
}
|
||||||
|
if (centered.length != target.length) {
|
||||||
|
throw new IllegalArgumentException("vector dimensions differ: " + centered.length + "!=" + target.length);
|
||||||
|
}
|
||||||
|
assert stats.length == 6;
|
||||||
|
IMPL.centerAndCalculateOSQStatsDp(target, centroid, centered, stats);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,6 +44,100 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
|
||||||
return ipFloatByteImpl(q, d);
|
return ipFloatByteImpl(q, d);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float calculateOSQLoss(float[] target, float[] interval, float step, float invStep, float norm2, float lambda) {
|
||||||
|
float a = interval[0];
|
||||||
|
float b = interval[1];
|
||||||
|
float xe = 0f;
|
||||||
|
float e = 0f;
|
||||||
|
for (float xi : target) {
|
||||||
|
// this is quantizing and then dequantizing the vector
|
||||||
|
float xiq = fma(step, Math.round((Math.min(Math.max(xi, a), b) - a) * invStep), a);
|
||||||
|
// how much does the de-quantized value differ from the original value
|
||||||
|
float xiiq = xi - xiq;
|
||||||
|
e = fma(xiiq, xiiq, e);
|
||||||
|
xe = fma(xi, xiiq, xe);
|
||||||
|
}
|
||||||
|
return (1f - lambda) * xe * xe / norm2 + lambda * e;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void calculateOSQGridPoints(float[] target, float[] interval, int points, float invStep, float[] pts) {
|
||||||
|
float a = interval[0];
|
||||||
|
float b = interval[1];
|
||||||
|
float daa = 0;
|
||||||
|
float dab = 0;
|
||||||
|
float dbb = 0;
|
||||||
|
float dax = 0;
|
||||||
|
float dbx = 0;
|
||||||
|
for (float v : target) {
|
||||||
|
float k = Math.round((Math.min(Math.max(v, a), b) - a) * invStep);
|
||||||
|
float s = k / (points - 1);
|
||||||
|
float ms = 1f - s;
|
||||||
|
daa = fma(ms, ms, daa);
|
||||||
|
dab = fma(ms, s, dab);
|
||||||
|
dbb = fma(s, s, dbb);
|
||||||
|
dax = fma(ms, v, dax);
|
||||||
|
dbx = fma(s, v, dbx);
|
||||||
|
}
|
||||||
|
pts[0] = daa;
|
||||||
|
pts[1] = dab;
|
||||||
|
pts[2] = dbb;
|
||||||
|
pts[3] = dax;
|
||||||
|
pts[4] = dbx;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats) {
|
||||||
|
float vecMean = 0;
|
||||||
|
float vecVar = 0;
|
||||||
|
float norm2 = 0;
|
||||||
|
float min = Float.MAX_VALUE;
|
||||||
|
float max = -Float.MAX_VALUE;
|
||||||
|
for (int i = 0; i < target.length; i++) {
|
||||||
|
centered[i] = target[i] - centroid[i];
|
||||||
|
min = Math.min(min, centered[i]);
|
||||||
|
max = Math.max(max, centered[i]);
|
||||||
|
norm2 = fma(centered[i], centered[i], norm2);
|
||||||
|
float delta = centered[i] - vecMean;
|
||||||
|
vecMean += delta / (i + 1);
|
||||||
|
float delta2 = centered[i] - vecMean;
|
||||||
|
vecVar = fma(delta, delta2, vecVar);
|
||||||
|
}
|
||||||
|
stats[0] = vecMean;
|
||||||
|
stats[1] = vecVar / target.length;
|
||||||
|
stats[2] = norm2;
|
||||||
|
stats[3] = min;
|
||||||
|
stats[4] = max;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats) {
|
||||||
|
float vecMean = 0;
|
||||||
|
float vecVar = 0;
|
||||||
|
float norm2 = 0;
|
||||||
|
float centroidDot = 0;
|
||||||
|
float min = Float.MAX_VALUE;
|
||||||
|
float max = -Float.MAX_VALUE;
|
||||||
|
for (int i = 0; i < target.length; i++) {
|
||||||
|
centroidDot = fma(target[i], centroid[i], centroidDot);
|
||||||
|
centered[i] = target[i] - centroid[i];
|
||||||
|
min = Math.min(min, centered[i]);
|
||||||
|
max = Math.max(max, centered[i]);
|
||||||
|
norm2 = fma(centered[i], centered[i], norm2);
|
||||||
|
float delta = centered[i] - vecMean;
|
||||||
|
vecMean += delta / (i + 1);
|
||||||
|
float delta2 = centered[i] - vecMean;
|
||||||
|
vecVar = fma(delta, delta2, vecVar);
|
||||||
|
}
|
||||||
|
stats[0] = vecMean;
|
||||||
|
stats[1] = vecVar / target.length;
|
||||||
|
stats[2] = norm2;
|
||||||
|
stats[3] = min;
|
||||||
|
stats[4] = max;
|
||||||
|
stats[5] = centroidDot;
|
||||||
|
}
|
||||||
|
|
||||||
public static int ipByteBitImpl(byte[] q, byte[] d) {
|
public static int ipByteBitImpl(byte[] q, byte[] d) {
|
||||||
return ipByteBitImpl(q, d, 0);
|
return ipByteBitImpl(q, d, 0);
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,4 +20,12 @@ public interface ESVectorUtilSupport {
|
||||||
float ipFloatBit(float[] q, byte[] d);
|
float ipFloatBit(float[] q, byte[] d);
|
||||||
|
|
||||||
float ipFloatByte(float[] q, byte[] d);
|
float ipFloatByte(float[] q, byte[] d);
|
||||||
|
|
||||||
|
float calculateOSQLoss(float[] target, float[] interval, float step, float invStep, float norm2, float lambda);
|
||||||
|
|
||||||
|
void calculateOSQGridPoints(float[] target, float[] interval, int points, float invStep, float[] pts);
|
||||||
|
|
||||||
|
void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats);
|
||||||
|
|
||||||
|
void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats);
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@ import jdk.incubator.vector.ByteVector;
|
||||||
import jdk.incubator.vector.FloatVector;
|
import jdk.incubator.vector.FloatVector;
|
||||||
import jdk.incubator.vector.IntVector;
|
import jdk.incubator.vector.IntVector;
|
||||||
import jdk.incubator.vector.LongVector;
|
import jdk.incubator.vector.LongVector;
|
||||||
|
import jdk.incubator.vector.Vector;
|
||||||
import jdk.incubator.vector.VectorMask;
|
import jdk.incubator.vector.VectorMask;
|
||||||
import jdk.incubator.vector.VectorOperators;
|
import jdk.incubator.vector.VectorOperators;
|
||||||
import jdk.incubator.vector.VectorShape;
|
import jdk.incubator.vector.VectorShape;
|
||||||
|
@ -21,16 +22,22 @@ import jdk.incubator.vector.VectorSpecies;
|
||||||
import org.apache.lucene.util.BitUtil;
|
import org.apache.lucene.util.BitUtil;
|
||||||
import org.apache.lucene.util.Constants;
|
import org.apache.lucene.util.Constants;
|
||||||
|
|
||||||
|
import static jdk.incubator.vector.VectorOperators.ADD;
|
||||||
|
import static jdk.incubator.vector.VectorOperators.MAX;
|
||||||
|
import static jdk.incubator.vector.VectorOperators.MIN;
|
||||||
|
|
||||||
public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
|
public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
|
||||||
|
|
||||||
static final int VECTOR_BITSIZE;
|
static final int VECTOR_BITSIZE;
|
||||||
|
|
||||||
|
private static final VectorSpecies<Float> FLOAT_SPECIES;
|
||||||
/** Whether integer vectors can be trusted to actually be fast. */
|
/** Whether integer vectors can be trusted to actually be fast. */
|
||||||
static final boolean HAS_FAST_INTEGER_VECTORS;
|
static final boolean HAS_FAST_INTEGER_VECTORS;
|
||||||
|
|
||||||
static {
|
static {
|
||||||
// default to platform supported bitsize
|
// default to platform supported bitsize
|
||||||
VECTOR_BITSIZE = VectorShape.preferredShape().vectorBitSize();
|
VECTOR_BITSIZE = VectorShape.preferredShape().vectorBitSize();
|
||||||
|
FLOAT_SPECIES = VectorSpecies.of(float.class, VectorShape.forBitSize(VECTOR_BITSIZE));
|
||||||
|
|
||||||
// hotspot misses some SSE intrinsics, workaround it
|
// hotspot misses some SSE intrinsics, workaround it
|
||||||
// to be fair, they do document this thing only works well with AVX2/AVX3 and Neon
|
// to be fair, they do document this thing only works well with AVX2/AVX3 and Neon
|
||||||
|
@ -38,6 +45,22 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
|
||||||
HAS_FAST_INTEGER_VECTORS = isAMD64withoutAVX2 == false;
|
HAS_FAST_INTEGER_VECTORS = isAMD64withoutAVX2 == false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static FloatVector fma(FloatVector a, FloatVector b, FloatVector c) {
|
||||||
|
if (Constants.HAS_FAST_VECTOR_FMA) {
|
||||||
|
return a.fma(b, c);
|
||||||
|
} else {
|
||||||
|
return a.mul(b).add(c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static float fma(float a, float b, float c) {
|
||||||
|
if (Constants.HAS_FAST_SCALAR_FMA) {
|
||||||
|
return Math.fma(a, b, c);
|
||||||
|
} else {
|
||||||
|
return a * b + c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long ipByteBinByte(byte[] q, byte[] d) {
|
public long ipByteBinByte(byte[] q, byte[] d) {
|
||||||
// 128 / 8 == 16
|
// 128 / 8 == 16
|
||||||
|
@ -83,6 +106,267 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
|
||||||
return DefaultESVectorUtilSupport.ipFloatByteImpl(q, d);
|
return DefaultESVectorUtilSupport.ipFloatByteImpl(q, d);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void centerAndCalculateOSQStatsEuclidean(float[] vector, float[] centroid, float[] centered, float[] stats) {
|
||||||
|
assert vector.length == centroid.length;
|
||||||
|
assert vector.length == centered.length;
|
||||||
|
float vecMean = 0;
|
||||||
|
float vecVar = 0;
|
||||||
|
float norm2 = 0;
|
||||||
|
float min = Float.MAX_VALUE;
|
||||||
|
float max = -Float.MAX_VALUE;
|
||||||
|
int i = 0;
|
||||||
|
int vectCount = 0;
|
||||||
|
if (vector.length > 2 * FLOAT_SPECIES.length()) {
|
||||||
|
FloatVector vecMeanVec = FloatVector.zero(FLOAT_SPECIES);
|
||||||
|
FloatVector m2Vec = FloatVector.zero(FLOAT_SPECIES);
|
||||||
|
FloatVector norm2Vec = FloatVector.zero(FLOAT_SPECIES);
|
||||||
|
FloatVector minVec = FloatVector.broadcast(FLOAT_SPECIES, Float.MAX_VALUE);
|
||||||
|
FloatVector maxVec = FloatVector.broadcast(FLOAT_SPECIES, -Float.MAX_VALUE);
|
||||||
|
int count = 0;
|
||||||
|
for (; i < FLOAT_SPECIES.loopBound(vector.length); i += FLOAT_SPECIES.length()) {
|
||||||
|
++count;
|
||||||
|
FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, vector, i);
|
||||||
|
FloatVector c = FloatVector.fromArray(FLOAT_SPECIES, centroid, i);
|
||||||
|
FloatVector centeredVec = v.sub(c);
|
||||||
|
FloatVector deltaVec = centeredVec.sub(vecMeanVec);
|
||||||
|
norm2Vec = fma(centeredVec, centeredVec, norm2Vec);
|
||||||
|
vecMeanVec = vecMeanVec.add(deltaVec.div(count));
|
||||||
|
FloatVector delta2Vec = centeredVec.sub(vecMeanVec);
|
||||||
|
m2Vec = fma(deltaVec, delta2Vec, m2Vec);
|
||||||
|
minVec = minVec.min(centeredVec);
|
||||||
|
maxVec = maxVec.max(centeredVec);
|
||||||
|
centeredVec.intoArray(centered, i);
|
||||||
|
}
|
||||||
|
min = minVec.reduceLanes(MIN);
|
||||||
|
max = maxVec.reduceLanes(MAX);
|
||||||
|
norm2 = norm2Vec.reduceLanes(ADD);
|
||||||
|
vecMean = vecMeanVec.reduceLanes(ADD) / FLOAT_SPECIES.length();
|
||||||
|
FloatVector d2Mean = vecMeanVec.sub(vecMean);
|
||||||
|
m2Vec = fma(d2Mean, d2Mean, m2Vec);
|
||||||
|
vectCount = count * FLOAT_SPECIES.length();
|
||||||
|
vecVar = m2Vec.reduceLanes(ADD);
|
||||||
|
}
|
||||||
|
|
||||||
|
float tailMean = 0;
|
||||||
|
float tailM2 = 0;
|
||||||
|
int tailCount = 0;
|
||||||
|
// handle the tail
|
||||||
|
for (; i < vector.length; i++) {
|
||||||
|
centered[i] = vector[i] - centroid[i];
|
||||||
|
float delta = centered[i] - tailMean;
|
||||||
|
++tailCount;
|
||||||
|
tailMean += delta / tailCount;
|
||||||
|
float delta2 = centered[i] - tailMean;
|
||||||
|
tailM2 = fma(delta, delta2, tailM2);
|
||||||
|
min = Math.min(min, centered[i]);
|
||||||
|
max = Math.max(max, centered[i]);
|
||||||
|
norm2 = fma(centered[i], centered[i], norm2);
|
||||||
|
}
|
||||||
|
if (vectCount == 0) {
|
||||||
|
vecMean = tailMean;
|
||||||
|
vecVar = tailM2;
|
||||||
|
} else if (tailCount > 0) {
|
||||||
|
int totalCount = tailCount + vectCount;
|
||||||
|
assert totalCount == vector.length;
|
||||||
|
float alpha = (float) vectCount / totalCount;
|
||||||
|
float beta = 1f - alpha;
|
||||||
|
float completeMean = alpha * vecMean + beta * tailMean;
|
||||||
|
float dMean2Lhs = (vecMean - completeMean) * (vecMean - completeMean);
|
||||||
|
float dMean2Rhs = (tailMean - completeMean) * (tailMean - completeMean);
|
||||||
|
vecVar = (vecVar + dMean2Lhs) + beta * (tailM2 + dMean2Rhs);
|
||||||
|
vecMean = completeMean;
|
||||||
|
}
|
||||||
|
stats[0] = vecMean;
|
||||||
|
stats[1] = vecVar / vector.length;
|
||||||
|
stats[2] = norm2;
|
||||||
|
stats[3] = min;
|
||||||
|
stats[4] = max;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void centerAndCalculateOSQStatsDp(float[] vector, float[] centroid, float[] centered, float[] stats) {
|
||||||
|
assert vector.length == centroid.length;
|
||||||
|
assert vector.length == centered.length;
|
||||||
|
float vecMean = 0;
|
||||||
|
float vecVar = 0;
|
||||||
|
float norm2 = 0;
|
||||||
|
float min = Float.MAX_VALUE;
|
||||||
|
float max = -Float.MAX_VALUE;
|
||||||
|
float centroidDot = 0;
|
||||||
|
int i = 0;
|
||||||
|
int vectCount = 0;
|
||||||
|
int loopBound = FLOAT_SPECIES.loopBound(vector.length);
|
||||||
|
if (vector.length > 2 * FLOAT_SPECIES.length()) {
|
||||||
|
FloatVector vecMeanVec = FloatVector.zero(FLOAT_SPECIES);
|
||||||
|
FloatVector m2Vec = FloatVector.zero(FLOAT_SPECIES);
|
||||||
|
FloatVector norm2Vec = FloatVector.zero(FLOAT_SPECIES);
|
||||||
|
FloatVector minVec = FloatVector.broadcast(FLOAT_SPECIES, Float.MAX_VALUE);
|
||||||
|
FloatVector maxVec = FloatVector.broadcast(FLOAT_SPECIES, -Float.MAX_VALUE);
|
||||||
|
FloatVector centroidDotVec = FloatVector.zero(FLOAT_SPECIES);
|
||||||
|
int count = 0;
|
||||||
|
for (; i < loopBound; i += FLOAT_SPECIES.length()) {
|
||||||
|
++count;
|
||||||
|
FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, vector, i);
|
||||||
|
FloatVector c = FloatVector.fromArray(FLOAT_SPECIES, centroid, i);
|
||||||
|
centroidDotVec = fma(v, c, centroidDotVec);
|
||||||
|
FloatVector centeredVec = v.sub(c);
|
||||||
|
FloatVector deltaVec = centeredVec.sub(vecMeanVec);
|
||||||
|
norm2Vec = fma(centeredVec, centeredVec, norm2Vec);
|
||||||
|
vecMeanVec = vecMeanVec.add(deltaVec.div(count));
|
||||||
|
FloatVector delta2Vec = centeredVec.sub(vecMeanVec);
|
||||||
|
m2Vec = fma(deltaVec, delta2Vec, m2Vec);
|
||||||
|
minVec = minVec.min(centeredVec);
|
||||||
|
maxVec = maxVec.max(centeredVec);
|
||||||
|
centeredVec.intoArray(centered, i);
|
||||||
|
}
|
||||||
|
min = minVec.reduceLanes(MIN);
|
||||||
|
max = maxVec.reduceLanes(MAX);
|
||||||
|
norm2 = norm2Vec.reduceLanes(ADD);
|
||||||
|
centroidDot = centroidDotVec.reduceLanes(ADD);
|
||||||
|
vecMean = vecMeanVec.reduceLanes(ADD) / FLOAT_SPECIES.length();
|
||||||
|
FloatVector d2Mean = vecMeanVec.sub(vecMean);
|
||||||
|
m2Vec = fma(d2Mean, d2Mean, m2Vec);
|
||||||
|
vectCount = count * FLOAT_SPECIES.length();
|
||||||
|
vecVar = m2Vec.reduceLanes(ADD);
|
||||||
|
}
|
||||||
|
|
||||||
|
float tailMean = 0;
|
||||||
|
float tailM2 = 0;
|
||||||
|
int tailCount = 0;
|
||||||
|
// handle the tail
|
||||||
|
for (; i < vector.length; i++) {
|
||||||
|
centroidDot = fma(vector[i], centroid[i], centroidDot);
|
||||||
|
centered[i] = vector[i] - centroid[i];
|
||||||
|
float delta = centered[i] - tailMean;
|
||||||
|
++tailCount;
|
||||||
|
tailMean += delta / tailCount;
|
||||||
|
float delta2 = centered[i] - tailMean;
|
||||||
|
tailM2 = fma(delta, delta2, tailM2);
|
||||||
|
min = Math.min(min, centered[i]);
|
||||||
|
max = Math.max(max, centered[i]);
|
||||||
|
norm2 = fma(centered[i], centered[i], norm2);
|
||||||
|
}
|
||||||
|
if (vectCount == 0) {
|
||||||
|
vecMean = tailMean;
|
||||||
|
vecVar = tailM2;
|
||||||
|
} else if (tailCount > 0) {
|
||||||
|
int totalCount = tailCount + vectCount;
|
||||||
|
assert totalCount == vector.length;
|
||||||
|
float alpha = (float) vectCount / totalCount;
|
||||||
|
float beta = 1f - alpha;
|
||||||
|
float completeMean = alpha * vecMean + beta * tailMean;
|
||||||
|
float dMean2Lhs = (vecMean - completeMean) * (vecMean - completeMean);
|
||||||
|
float dMean2Rhs = (tailMean - completeMean) * (tailMean - completeMean);
|
||||||
|
vecVar = (vecVar + dMean2Lhs) + beta * (tailM2 + dMean2Rhs);
|
||||||
|
vecMean = completeMean;
|
||||||
|
}
|
||||||
|
stats[0] = vecMean;
|
||||||
|
stats[1] = vecVar / vector.length;
|
||||||
|
stats[2] = norm2;
|
||||||
|
stats[3] = min;
|
||||||
|
stats[4] = max;
|
||||||
|
stats[5] = centroidDot;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void calculateOSQGridPoints(float[] target, float[] interval, int points, float invStep, float[] pts) {
|
||||||
|
float a = interval[0];
|
||||||
|
float b = interval[1];
|
||||||
|
int i = 0;
|
||||||
|
float daa = 0;
|
||||||
|
float dab = 0;
|
||||||
|
float dbb = 0;
|
||||||
|
float dax = 0;
|
||||||
|
float dbx = 0;
|
||||||
|
|
||||||
|
FloatVector daaVec = FloatVector.zero(FLOAT_SPECIES);
|
||||||
|
FloatVector dabVec = FloatVector.zero(FLOAT_SPECIES);
|
||||||
|
FloatVector dbbVec = FloatVector.zero(FLOAT_SPECIES);
|
||||||
|
FloatVector daxVec = FloatVector.zero(FLOAT_SPECIES);
|
||||||
|
FloatVector dbxVec = FloatVector.zero(FLOAT_SPECIES);
|
||||||
|
|
||||||
|
// if the array size is large (> 2x platform vector size), it's worth the overhead to vectorize
|
||||||
|
if (target.length > 2 * FLOAT_SPECIES.length()) {
|
||||||
|
FloatVector ones = FloatVector.broadcast(FLOAT_SPECIES, 1f);
|
||||||
|
FloatVector pmOnes = FloatVector.broadcast(FLOAT_SPECIES, points - 1f);
|
||||||
|
for (; i < FLOAT_SPECIES.loopBound(target.length); i += FLOAT_SPECIES.length()) {
|
||||||
|
FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, target, i);
|
||||||
|
FloatVector vClamped = v.max(a).min(b);
|
||||||
|
Vector<Integer> xiqint = vClamped.sub(a)
|
||||||
|
.mul(invStep)
|
||||||
|
// round
|
||||||
|
.add(0.5f)
|
||||||
|
.convert(VectorOperators.F2I, 0);
|
||||||
|
FloatVector kVec = xiqint.convert(VectorOperators.I2F, 0).reinterpretAsFloats();
|
||||||
|
FloatVector sVec = kVec.div(pmOnes);
|
||||||
|
FloatVector smVec = ones.sub(sVec);
|
||||||
|
daaVec = fma(smVec, smVec, daaVec);
|
||||||
|
dabVec = fma(smVec, sVec, dabVec);
|
||||||
|
dbbVec = fma(sVec, sVec, dbbVec);
|
||||||
|
daxVec = fma(v, smVec, daxVec);
|
||||||
|
dbxVec = fma(v, sVec, dbxVec);
|
||||||
|
}
|
||||||
|
daa = daaVec.reduceLanes(ADD);
|
||||||
|
dab = dabVec.reduceLanes(ADD);
|
||||||
|
dbb = dbbVec.reduceLanes(ADD);
|
||||||
|
dax = daxVec.reduceLanes(ADD);
|
||||||
|
dbx = dbxVec.reduceLanes(ADD);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (; i < target.length; i++) {
|
||||||
|
float k = Math.round((Math.min(Math.max(target[i], a), b) - a) * invStep);
|
||||||
|
float s = k / (points - 1);
|
||||||
|
float ms = 1f - s;
|
||||||
|
daa = fma(ms, ms, daa);
|
||||||
|
dab = fma(ms, s, dab);
|
||||||
|
dbb = fma(s, s, dbb);
|
||||||
|
dax = fma(ms, target[i], dax);
|
||||||
|
dbx = fma(s, target[i], dbx);
|
||||||
|
}
|
||||||
|
|
||||||
|
pts[0] = daa;
|
||||||
|
pts[1] = dab;
|
||||||
|
pts[2] = dbb;
|
||||||
|
pts[3] = dax;
|
||||||
|
pts[4] = dbx;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float calculateOSQLoss(float[] target, float[] interval, float step, float invStep, float norm2, float lambda) {
|
||||||
|
float a = interval[0];
|
||||||
|
float b = interval[1];
|
||||||
|
float xe = 0f;
|
||||||
|
float e = 0f;
|
||||||
|
FloatVector xeVec = FloatVector.zero(FLOAT_SPECIES);
|
||||||
|
FloatVector eVec = FloatVector.zero(FLOAT_SPECIES);
|
||||||
|
int i = 0;
|
||||||
|
// if the array size is large (> 2x platform vector size), it's worth the overhead to vectorize
|
||||||
|
if (target.length > 2 * FLOAT_SPECIES.length()) {
|
||||||
|
for (; i < FLOAT_SPECIES.loopBound(target.length); i += FLOAT_SPECIES.length()) {
|
||||||
|
FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, target, i);
|
||||||
|
FloatVector vClamped = v.max(a).min(b);
|
||||||
|
Vector<Integer> xiqint = vClamped.sub(a).mul(invStep).add(0.5f).convert(VectorOperators.F2I, 0);
|
||||||
|
FloatVector xiq = xiqint.convert(VectorOperators.I2F, 0).reinterpretAsFloats().mul(step).add(a);
|
||||||
|
FloatVector xiiq = v.sub(xiq);
|
||||||
|
xeVec = fma(v, xiiq, xeVec);
|
||||||
|
eVec = fma(xiiq, xiiq, eVec);
|
||||||
|
}
|
||||||
|
e = eVec.reduceLanes(ADD);
|
||||||
|
xe = xeVec.reduceLanes(ADD);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (; i < target.length; i++) {
|
||||||
|
// this is quantizing and then dequantizing the vector
|
||||||
|
float xiq = fma(step, Math.round((Math.min(Math.max(target[i], a), b) - a) * invStep), a);
|
||||||
|
// how much does the de-quantized value differ from the original value
|
||||||
|
float xiiq = target[i] - xiq;
|
||||||
|
e = fma(xiiq, xiiq, e);
|
||||||
|
xe = fma(target[i], xiiq, xe);
|
||||||
|
}
|
||||||
|
return (1f - lambda) * xe * xe / norm2 + lambda * e;
|
||||||
|
}
|
||||||
|
|
||||||
private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
|
private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
|
||||||
private static final VectorSpecies<Byte> BYTE_SPECIES_256 = ByteVector.SPECIES_256;
|
private static final VectorSpecies<Byte> BYTE_SPECIES_256 = ByteVector.SPECIES_256;
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@
|
||||||
|
|
||||||
package org.elasticsearch.simdvec;
|
package org.elasticsearch.simdvec;
|
||||||
|
|
||||||
|
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
|
||||||
import org.elasticsearch.simdvec.internal.vectorization.BaseVectorizationTests;
|
import org.elasticsearch.simdvec.internal.vectorization.BaseVectorizationTests;
|
||||||
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
|
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
|
||||||
|
|
||||||
|
@ -161,6 +162,112 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
|
||||||
testIpByteBinImpl(defOrPanamaProvider.getVectorUtilSupport()::ipByteBinByte);
|
testIpByteBinImpl(defOrPanamaProvider.getVectorUtilSupport()::ipByteBinByte);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testCenterAndCalculateOSQStatsDp() {
|
||||||
|
int size = random().nextInt(128, 512);
|
||||||
|
float delta = 1e-3f * size;
|
||||||
|
var vector = new float[size];
|
||||||
|
var centroid = new float[size];
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
vector[i] = random().nextFloat();
|
||||||
|
centroid[i] = random().nextFloat();
|
||||||
|
}
|
||||||
|
var centeredLucene = new float[size];
|
||||||
|
var statsLucene = new float[6];
|
||||||
|
defaultedProvider.getVectorUtilSupport().centerAndCalculateOSQStatsDp(vector, centroid, centeredLucene, statsLucene);
|
||||||
|
var centeredPanama = new float[size];
|
||||||
|
var statsPanama = new float[6];
|
||||||
|
defOrPanamaProvider.getVectorUtilSupport().centerAndCalculateOSQStatsDp(vector, centroid, centeredPanama, statsPanama);
|
||||||
|
assertArrayEquals(centeredLucene, centeredPanama, delta);
|
||||||
|
assertArrayEquals(statsLucene, statsPanama, delta);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testCenterAndCalculateOSQStatsEuclidean() {
|
||||||
|
int size = random().nextInt(128, 512);
|
||||||
|
float delta = 1e-3f * size;
|
||||||
|
var vector = new float[size];
|
||||||
|
var centroid = new float[size];
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
vector[i] = random().nextFloat();
|
||||||
|
centroid[i] = random().nextFloat();
|
||||||
|
}
|
||||||
|
var centeredLucene = new float[size];
|
||||||
|
var statsLucene = new float[5];
|
||||||
|
defaultedProvider.getVectorUtilSupport().centerAndCalculateOSQStatsEuclidean(vector, centroid, centeredLucene, statsLucene);
|
||||||
|
var centeredPanama = new float[size];
|
||||||
|
var statsPanama = new float[5];
|
||||||
|
defOrPanamaProvider.getVectorUtilSupport().centerAndCalculateOSQStatsEuclidean(vector, centroid, centeredPanama, statsPanama);
|
||||||
|
assertArrayEquals(centeredLucene, centeredPanama, delta);
|
||||||
|
assertArrayEquals(statsLucene, statsPanama, delta);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testOsqLoss() {
|
||||||
|
int size = random().nextInt(128, 512);
|
||||||
|
float deltaEps = 1e-5f * size;
|
||||||
|
var vector = new float[size];
|
||||||
|
var min = Float.MAX_VALUE;
|
||||||
|
var max = -Float.MAX_VALUE;
|
||||||
|
float vecMean = 0;
|
||||||
|
float vecVar = 0;
|
||||||
|
float norm2 = 0;
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
vector[i] = random().nextFloat();
|
||||||
|
min = Math.min(min, vector[i]);
|
||||||
|
max = Math.max(max, vector[i]);
|
||||||
|
float delta = vector[i] - vecMean;
|
||||||
|
vecMean += delta / (i + 1);
|
||||||
|
float delta2 = vector[i] - vecMean;
|
||||||
|
vecVar += delta * delta2;
|
||||||
|
norm2 += vector[i] * vector[i];
|
||||||
|
}
|
||||||
|
vecVar /= size;
|
||||||
|
float vecStd = (float) Math.sqrt(vecVar);
|
||||||
|
|
||||||
|
for (byte bits : new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 }) {
|
||||||
|
int points = 1 << bits;
|
||||||
|
float[] initInterval = new float[2];
|
||||||
|
OptimizedScalarQuantizer.initInterval(bits, vecStd, vecMean, min, max, initInterval);
|
||||||
|
float step = ((initInterval[1] - initInterval[0]) / (points - 1f));
|
||||||
|
float stepInv = 1f / step;
|
||||||
|
float expected = defaultedProvider.getVectorUtilSupport().calculateOSQLoss(vector, initInterval, step, stepInv, norm2, 0.1f);
|
||||||
|
float result = defOrPanamaProvider.getVectorUtilSupport().calculateOSQLoss(vector, initInterval, step, stepInv, norm2, 0.1f);
|
||||||
|
assertEquals(expected, result, deltaEps);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testOsqGridPoints() {
|
||||||
|
int size = random().nextInt(128, 512);
|
||||||
|
float deltaEps = 1e-5f * size;
|
||||||
|
var vector = new float[size];
|
||||||
|
var min = Float.MAX_VALUE;
|
||||||
|
var max = -Float.MAX_VALUE;
|
||||||
|
float vecMean = 0;
|
||||||
|
float vecVar = 0;
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
vector[i] = random().nextFloat();
|
||||||
|
min = Math.min(min, vector[i]);
|
||||||
|
max = Math.max(max, vector[i]);
|
||||||
|
float delta = vector[i] - vecMean;
|
||||||
|
vecMean += delta / (i + 1);
|
||||||
|
float delta2 = vector[i] - vecMean;
|
||||||
|
vecVar += delta * delta2;
|
||||||
|
}
|
||||||
|
vecVar /= size;
|
||||||
|
float vecStd = (float) Math.sqrt(vecVar);
|
||||||
|
for (byte bits : new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 }) {
|
||||||
|
int points = 1 << bits;
|
||||||
|
float[] initInterval = new float[2];
|
||||||
|
OptimizedScalarQuantizer.initInterval(bits, vecStd, vecMean, min, max, initInterval);
|
||||||
|
float step = ((initInterval[1] - initInterval[0]) / (points - 1f));
|
||||||
|
float stepInv = 1f / step;
|
||||||
|
float[] expected = new float[5];
|
||||||
|
defaultedProvider.getVectorUtilSupport().calculateOSQGridPoints(vector, initInterval, points, stepInv, expected);
|
||||||
|
|
||||||
|
float[] result = new float[5];
|
||||||
|
defOrPanamaProvider.getVectorUtilSupport().calculateOSQGridPoints(vector, initInterval, points, stepInv, result);
|
||||||
|
assertArrayEquals(expected, result, deltaEps);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void testIpByteBinImpl(ToLongBiFunction<byte[], byte[]> ipByteBinFunc) {
|
void testIpByteBinImpl(ToLongBiFunction<byte[], byte[]> ipByteBinFunc) {
|
||||||
int iterations = atLeast(50);
|
int iterations = atLeast(50);
|
||||||
for (int i = 0; i < iterations; i++) {
|
for (int i = 0; i < iterations; i++) {
|
||||||
|
|
|
@ -7,15 +7,21 @@
|
||||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.elasticsearch.index.codec.vectors.es818;
|
package org.elasticsearch.index.codec.vectors;
|
||||||
|
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.util.VectorUtil;
|
import org.apache.lucene.util.VectorUtil;
|
||||||
|
import org.elasticsearch.simdvec.ESVectorUtil;
|
||||||
|
|
||||||
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
|
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
|
||||||
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
|
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
|
||||||
|
|
||||||
class OptimizedScalarQuantizer {
|
public class OptimizedScalarQuantizer {
|
||||||
|
public static void initInterval(byte bits, float vecStd, float vecMean, float min, float max, float[] initInterval) {
|
||||||
|
initInterval[0] = (float) clamp(MINIMUM_MSE_GRID[bits - 1][0] * vecStd + vecMean, min, max);
|
||||||
|
initInterval[1] = (float) clamp(MINIMUM_MSE_GRID[bits - 1][1] * vecStd + vecMean, min, max);
|
||||||
|
}
|
||||||
|
|
||||||
// The initial interval is set to the minimum MSE grid for each number of bits
|
// The initial interval is set to the minimum MSE grid for each number of bits
|
||||||
// these starting points are derived from the optimal MSE grid for a uniform distribution
|
// these starting points are derived from the optimal MSE grid for a uniform distribution
|
||||||
static final float[][] MINIMUM_MSE_GRID = new float[][] {
|
static final float[][] MINIMUM_MSE_GRID = new float[][] {
|
||||||
|
@ -27,19 +33,25 @@ class OptimizedScalarQuantizer {
|
||||||
{ -3.278f, 3.278f },
|
{ -3.278f, 3.278f },
|
||||||
{ -3.611f, 3.611f },
|
{ -3.611f, 3.611f },
|
||||||
{ -3.922f, 3.922f } };
|
{ -3.922f, 3.922f } };
|
||||||
private static final float DEFAULT_LAMBDA = 0.1f;
|
public static final float DEFAULT_LAMBDA = 0.1f;
|
||||||
private static final int DEFAULT_ITERS = 5;
|
private static final int DEFAULT_ITERS = 5;
|
||||||
private final VectorSimilarityFunction similarityFunction;
|
private final VectorSimilarityFunction similarityFunction;
|
||||||
private final float lambda;
|
private final float lambda;
|
||||||
private final int iters;
|
private final int iters;
|
||||||
|
private final float[] statsScratch;
|
||||||
|
private final float[] gridScratch;
|
||||||
|
private final float[] intervalScratch;
|
||||||
|
|
||||||
OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction, float lambda, int iters) {
|
public OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction, float lambda, int iters) {
|
||||||
this.similarityFunction = similarityFunction;
|
this.similarityFunction = similarityFunction;
|
||||||
this.lambda = lambda;
|
this.lambda = lambda;
|
||||||
this.iters = iters;
|
this.iters = iters;
|
||||||
|
this.statsScratch = new float[similarityFunction == EUCLIDEAN ? 5 : 6];
|
||||||
|
this.gridScratch = new float[5];
|
||||||
|
this.intervalScratch = new float[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction) {
|
public OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction) {
|
||||||
this(similarityFunction, DEFAULT_LAMBDA, DEFAULT_ITERS);
|
this(similarityFunction, DEFAULT_LAMBDA, DEFAULT_ITERS);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -49,34 +61,23 @@ class OptimizedScalarQuantizer {
|
||||||
assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector);
|
assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector);
|
||||||
assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
|
assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
|
||||||
assert bits.length == destinations.length;
|
assert bits.length == destinations.length;
|
||||||
float[] intervalScratch = new float[2];
|
if (similarityFunction == EUCLIDEAN) {
|
||||||
double vecMean = 0;
|
ESVectorUtil.centerAndCalculateOSQStatsEuclidean(vector, centroid, vector, statsScratch);
|
||||||
double vecVar = 0;
|
} else {
|
||||||
float norm2 = 0;
|
ESVectorUtil.centerAndCalculateOSQStatsDp(vector, centroid, vector, statsScratch);
|
||||||
float centroidDot = 0;
|
|
||||||
float min = Float.MAX_VALUE;
|
|
||||||
float max = -Float.MAX_VALUE;
|
|
||||||
for (int i = 0; i < vector.length; ++i) {
|
|
||||||
if (similarityFunction != EUCLIDEAN) {
|
|
||||||
centroidDot += vector[i] * centroid[i];
|
|
||||||
}
|
|
||||||
vector[i] = vector[i] - centroid[i];
|
|
||||||
min = Math.min(min, vector[i]);
|
|
||||||
max = Math.max(max, vector[i]);
|
|
||||||
norm2 += (vector[i] * vector[i]);
|
|
||||||
double delta = vector[i] - vecMean;
|
|
||||||
vecMean += delta / (i + 1);
|
|
||||||
vecVar += delta * (vector[i] - vecMean);
|
|
||||||
}
|
}
|
||||||
vecVar /= vector.length;
|
float vecMean = statsScratch[0];
|
||||||
double vecStd = Math.sqrt(vecVar);
|
float vecVar = statsScratch[1];
|
||||||
|
float norm2 = statsScratch[2];
|
||||||
|
float min = statsScratch[3];
|
||||||
|
float max = statsScratch[4];
|
||||||
|
float vecStd = (float) Math.sqrt(vecVar);
|
||||||
QuantizationResult[] results = new QuantizationResult[bits.length];
|
QuantizationResult[] results = new QuantizationResult[bits.length];
|
||||||
for (int i = 0; i < bits.length; ++i) {
|
for (int i = 0; i < bits.length; ++i) {
|
||||||
assert bits[i] > 0 && bits[i] <= 8;
|
assert bits[i] > 0 && bits[i] <= 8;
|
||||||
int points = (1 << bits[i]);
|
int points = (1 << bits[i]);
|
||||||
// Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds
|
// Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds
|
||||||
intervalScratch[0] = (float) clamp(MINIMUM_MSE_GRID[bits[i] - 1][0] * vecStd + vecMean, min, max);
|
initInterval(bits[i], vecStd, vecMean, min, max, intervalScratch);
|
||||||
intervalScratch[1] = (float) clamp(MINIMUM_MSE_GRID[bits[i] - 1][1] * vecStd + vecMean, min, max);
|
|
||||||
optimizeIntervals(intervalScratch, vector, norm2, points);
|
optimizeIntervals(intervalScratch, vector, norm2, points);
|
||||||
float nSteps = ((1 << bits[i]) - 1);
|
float nSteps = ((1 << bits[i]) - 1);
|
||||||
float a = intervalScratch[0];
|
float a = intervalScratch[0];
|
||||||
|
@ -93,7 +94,7 @@ class OptimizedScalarQuantizer {
|
||||||
results[i] = new QuantizationResult(
|
results[i] = new QuantizationResult(
|
||||||
intervalScratch[0],
|
intervalScratch[0],
|
||||||
intervalScratch[1],
|
intervalScratch[1],
|
||||||
similarityFunction == EUCLIDEAN ? norm2 : centroidDot,
|
similarityFunction == EUCLIDEAN ? norm2 : statsScratch[5],
|
||||||
sumQuery
|
sumQuery
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -105,31 +106,20 @@ class OptimizedScalarQuantizer {
|
||||||
assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
|
assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
|
||||||
assert vector.length <= destination.length;
|
assert vector.length <= destination.length;
|
||||||
assert bits > 0 && bits <= 8;
|
assert bits > 0 && bits <= 8;
|
||||||
float[] intervalScratch = new float[2];
|
|
||||||
int points = 1 << bits;
|
int points = 1 << bits;
|
||||||
double vecMean = 0;
|
if (similarityFunction == EUCLIDEAN) {
|
||||||
double vecVar = 0;
|
ESVectorUtil.centerAndCalculateOSQStatsEuclidean(vector, centroid, vector, statsScratch);
|
||||||
float norm2 = 0;
|
} else {
|
||||||
float centroidDot = 0;
|
ESVectorUtil.centerAndCalculateOSQStatsDp(vector, centroid, vector, statsScratch);
|
||||||
float min = Float.MAX_VALUE;
|
|
||||||
float max = -Float.MAX_VALUE;
|
|
||||||
for (int i = 0; i < vector.length; ++i) {
|
|
||||||
if (similarityFunction != EUCLIDEAN) {
|
|
||||||
centroidDot += vector[i] * centroid[i];
|
|
||||||
}
|
|
||||||
vector[i] = vector[i] - centroid[i];
|
|
||||||
min = Math.min(min, vector[i]);
|
|
||||||
max = Math.max(max, vector[i]);
|
|
||||||
norm2 += (vector[i] * vector[i]);
|
|
||||||
double delta = vector[i] - vecMean;
|
|
||||||
vecMean += delta / (i + 1);
|
|
||||||
vecVar += delta * (vector[i] - vecMean);
|
|
||||||
}
|
}
|
||||||
vecVar /= vector.length;
|
float vecMean = statsScratch[0];
|
||||||
double vecStd = Math.sqrt(vecVar);
|
float vecVar = statsScratch[1];
|
||||||
|
float norm2 = statsScratch[2];
|
||||||
|
float min = statsScratch[3];
|
||||||
|
float max = statsScratch[4];
|
||||||
|
float vecStd = (float) Math.sqrt(vecVar);
|
||||||
// Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds
|
// Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds
|
||||||
intervalScratch[0] = (float) clamp(MINIMUM_MSE_GRID[bits - 1][0] * vecStd + vecMean, min, max);
|
initInterval(bits, vecStd, vecMean, min, max, intervalScratch);
|
||||||
intervalScratch[1] = (float) clamp(MINIMUM_MSE_GRID[bits - 1][1] * vecStd + vecMean, min, max);
|
|
||||||
optimizeIntervals(intervalScratch, vector, norm2, points);
|
optimizeIntervals(intervalScratch, vector, norm2, points);
|
||||||
float nSteps = ((1 << bits) - 1);
|
float nSteps = ((1 << bits) - 1);
|
||||||
// Now we have the optimized intervals, quantize the vector
|
// Now we have the optimized intervals, quantize the vector
|
||||||
|
@ -146,37 +136,11 @@ class OptimizedScalarQuantizer {
|
||||||
return new QuantizationResult(
|
return new QuantizationResult(
|
||||||
intervalScratch[0],
|
intervalScratch[0],
|
||||||
intervalScratch[1],
|
intervalScratch[1],
|
||||||
similarityFunction == EUCLIDEAN ? norm2 : centroidDot,
|
similarityFunction == EUCLIDEAN ? norm2 : statsScratch[5],
|
||||||
sumQuery
|
sumQuery
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Compute the loss of the vector given the interval. Effectively, we are computing the MSE of a dequantized vector with the raw
|
|
||||||
* vector.
|
|
||||||
* @param vector raw vector
|
|
||||||
* @param interval interval to quantize the vector
|
|
||||||
* @param points number of quantization points
|
|
||||||
* @param norm2 squared norm of the vector
|
|
||||||
* @return the loss
|
|
||||||
*/
|
|
||||||
private double loss(float[] vector, float[] interval, int points, float norm2) {
|
|
||||||
double a = interval[0];
|
|
||||||
double b = interval[1];
|
|
||||||
double step = ((b - a) / (points - 1.0F));
|
|
||||||
double stepInv = 1.0 / step;
|
|
||||||
double xe = 0.0;
|
|
||||||
double e = 0.0;
|
|
||||||
for (double xi : vector) {
|
|
||||||
// this is quantizing and then dequantizing the vector
|
|
||||||
double xiq = (a + step * Math.round((clamp(xi, a, b) - a) * stepInv));
|
|
||||||
// how much does the de-quantized value differ from the original value
|
|
||||||
xe += xi * (xi - xiq);
|
|
||||||
e += (xi - xiq) * (xi - xiq);
|
|
||||||
}
|
|
||||||
return (1.0 - lambda) * xe * xe / norm2 + lambda * e;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Optimize the quantization interval for the given vector. This is done via a coordinate descent trying to minimize the quantization
|
* Optimize the quantization interval for the given vector. This is done via a coordinate descent trying to minimize the quantization
|
||||||
* loss. Note, the loss is not always guaranteed to decrease, so we have a maximum number of iterations and will exit early if the
|
* loss. Note, the loss is not always guaranteed to decrease, so we have a maximum number of iterations and will exit early if the
|
||||||
|
@ -187,30 +151,19 @@ class OptimizedScalarQuantizer {
|
||||||
* @param points number of quantization points
|
* @param points number of quantization points
|
||||||
*/
|
*/
|
||||||
private void optimizeIntervals(float[] initInterval, float[] vector, float norm2, int points) {
|
private void optimizeIntervals(float[] initInterval, float[] vector, float norm2, int points) {
|
||||||
double initialLoss = loss(vector, initInterval, points, norm2);
|
double initialLoss = ESVectorUtil.calculateOSQLoss(vector, initInterval, points, norm2, lambda);
|
||||||
final float scale = (1.0f - lambda) / norm2;
|
final float scale = (1.0f - lambda) / norm2;
|
||||||
if (Float.isFinite(scale) == false) {
|
if (Float.isFinite(scale) == false) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
for (int i = 0; i < iters; ++i) {
|
for (int i = 0; i < iters; ++i) {
|
||||||
float a = initInterval[0];
|
|
||||||
float b = initInterval[1];
|
|
||||||
float stepInv = (points - 1.0f) / (b - a);
|
|
||||||
// calculate the grid points for coordinate descent
|
// calculate the grid points for coordinate descent
|
||||||
double daa = 0;
|
ESVectorUtil.calculateOSQGridPoints(vector, initInterval, points, gridScratch);
|
||||||
double dab = 0;
|
float daa = gridScratch[0];
|
||||||
double dbb = 0;
|
float dab = gridScratch[1];
|
||||||
double dax = 0;
|
float dbb = gridScratch[2];
|
||||||
double dbx = 0;
|
float dax = gridScratch[3];
|
||||||
for (float xi : vector) {
|
float dbx = gridScratch[4];
|
||||||
float k = Math.round((clamp(xi, a, b) - a) * stepInv);
|
|
||||||
float s = k / (points - 1);
|
|
||||||
daa += (1.0 - s) * (1.0 - s);
|
|
||||||
dab += (1.0 - s) * s;
|
|
||||||
dbb += s * s;
|
|
||||||
dax += xi * (1.0 - s);
|
|
||||||
dbx += xi * s;
|
|
||||||
}
|
|
||||||
double m0 = scale * dax * dax + lambda * daa;
|
double m0 = scale * dax * dax + lambda * daa;
|
||||||
double m1 = scale * dax * dbx + lambda * dab;
|
double m1 = scale * dax * dbx + lambda * dab;
|
||||||
double m2 = scale * dbx * dbx + lambda * dbb;
|
double m2 = scale * dbx * dbx + lambda * dbb;
|
||||||
|
@ -225,7 +178,7 @@ class OptimizedScalarQuantizer {
|
||||||
if ((Math.abs(initInterval[0] - aOpt) < 1e-8 && Math.abs(initInterval[1] - bOpt) < 1e-8)) {
|
if ((Math.abs(initInterval[0] - aOpt) < 1e-8 && Math.abs(initInterval[1] - bOpt) < 1e-8)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
double newLoss = loss(vector, new float[] { aOpt, bOpt }, points, norm2);
|
double newLoss = ESVectorUtil.calculateOSQLoss(vector, new float[] { aOpt, bOpt }, points, norm2, lambda);
|
||||||
// If the new loss is worse, don't update the interval and exit
|
// If the new loss is worse, don't update the interval and exit
|
||||||
// This optimization, unlike kMeans, does not always converge to better loss
|
// This optimization, unlike kMeans, does not always converge to better loss
|
||||||
// So exit if we are getting worse
|
// So exit if we are getting worse
|
|
@ -23,6 +23,7 @@ import org.apache.lucene.index.ByteVectorValues;
|
||||||
import org.apache.lucene.search.VectorScorer;
|
import org.apache.lucene.search.VectorScorer;
|
||||||
import org.apache.lucene.util.VectorUtil;
|
import org.apache.lucene.util.VectorUtil;
|
||||||
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
|
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
|
||||||
|
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,7 @@ import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||||
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||||
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
|
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
|
||||||
import org.elasticsearch.index.codec.vectors.BQSpaceUtils;
|
import org.elasticsearch.index.codec.vectors.BQSpaceUtils;
|
||||||
|
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
|
||||||
import org.elasticsearch.simdvec.ESVectorUtil;
|
import org.elasticsearch.simdvec.ESVectorUtil;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
|
@ -26,6 +26,7 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
|
||||||
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
|
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
|
||||||
import org.apache.lucene.index.SegmentReadState;
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
import org.apache.lucene.index.SegmentWriteState;
|
import org.apache.lucene.index.SegmentWriteState;
|
||||||
|
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
||||||
|
@ -35,7 +36,7 @@ import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_
|
||||||
* Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10
|
* Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10
|
||||||
* Codec for encoding/decoding binary quantized vectors The binary quantization format used here
|
* Codec for encoding/decoding binary quantized vectors The binary quantization format used here
|
||||||
* is a per-vector optimized scalar quantization. Also see {@link
|
* is a per-vector optimized scalar quantization. Also see {@link
|
||||||
* org.elasticsearch.index.codec.vectors.es818.OptimizedScalarQuantizer}. Some of key features are:
|
* OptimizedScalarQuantizer}. Some of key features are:
|
||||||
*
|
*
|
||||||
* <ul>
|
* <ul>
|
||||||
* <li>Estimating the distance between two vectors using their centroid normalized distance. This
|
* <li>Estimating the distance between two vectors using their centroid normalized distance. This
|
||||||
|
|
|
@ -44,6 +44,7 @@ import org.apache.lucene.util.SuppressForbidden;
|
||||||
import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
|
import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
|
||||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||||
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
|
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
|
||||||
|
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
|
||||||
import org.elasticsearch.index.codec.vectors.reflect.OffHeapByteSizeUtils;
|
import org.elasticsearch.index.codec.vectors.reflect.OffHeapByteSizeUtils;
|
||||||
import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats;
|
import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats;
|
||||||
|
|
||||||
|
|
|
@ -49,6 +49,7 @@ import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
|
||||||
import org.elasticsearch.core.SuppressForbidden;
|
import org.elasticsearch.core.SuppressForbidden;
|
||||||
import org.elasticsearch.index.codec.vectors.BQSpaceUtils;
|
import org.elasticsearch.index.codec.vectors.BQSpaceUtils;
|
||||||
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
|
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
|
||||||
|
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
|
||||||
|
|
||||||
import java.io.Closeable;
|
import java.io.Closeable;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
|
@ -30,6 +30,7 @@ import org.apache.lucene.util.Bits;
|
||||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||||
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||||
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
|
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
|
||||||
|
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
|
|
|
@ -7,13 +7,13 @@
|
||||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.elasticsearch.index.codec.vectors.es818;
|
package org.elasticsearch.index.codec.vectors;
|
||||||
|
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.util.VectorUtil;
|
import org.apache.lucene.util.VectorUtil;
|
||||||
import org.elasticsearch.test.ESTestCase;
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
|
||||||
import static org.elasticsearch.index.codec.vectors.es818.OptimizedScalarQuantizer.MINIMUM_MSE_GRID;
|
import static org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer.MINIMUM_MSE_GRID;
|
||||||
|
|
||||||
public class OptimizedScalarQuantizerTests extends ESTestCase {
|
public class OptimizedScalarQuantizerTests extends ESTestCase {
|
||||||
|
|
|
@ -45,6 +45,7 @@ import org.apache.lucene.store.Directory;
|
||||||
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
|
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
|
||||||
import org.elasticsearch.common.logging.LogConfigurator;
|
import org.elasticsearch.common.logging.LogConfigurator;
|
||||||
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
|
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
|
||||||
|
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
|
||||||
import org.elasticsearch.index.codec.vectors.reflect.OffHeapByteSizeUtils;
|
import org.elasticsearch.index.codec.vectors.reflect.OffHeapByteSizeUtils;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue