mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 01:22:26 -04:00
Panama vector accelerated optimized scalar quantization (#127118)
* Adds accelerates optimized scalar quantization with vectorized functions * Adding benchmark * Update docs/changelog/127118.yaml * adjusting benchmark and delta
This commit is contained in:
parent
ad0fe78e3e
commit
059f91c90c
16 changed files with 702 additions and 99 deletions
|
@ -0,0 +1,78 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.benchmark.vector;
|
||||
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.elasticsearch.common.logging.LogConfigurator;
|
||||
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
|
||||
import org.openjdk.jmh.annotations.Benchmark;
|
||||
import org.openjdk.jmh.annotations.BenchmarkMode;
|
||||
import org.openjdk.jmh.annotations.Fork;
|
||||
import org.openjdk.jmh.annotations.Level;
|
||||
import org.openjdk.jmh.annotations.Measurement;
|
||||
import org.openjdk.jmh.annotations.Mode;
|
||||
import org.openjdk.jmh.annotations.OutputTimeUnit;
|
||||
import org.openjdk.jmh.annotations.Param;
|
||||
import org.openjdk.jmh.annotations.Scope;
|
||||
import org.openjdk.jmh.annotations.Setup;
|
||||
import org.openjdk.jmh.annotations.State;
|
||||
import org.openjdk.jmh.annotations.Warmup;
|
||||
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
@BenchmarkMode(Mode.Throughput)
|
||||
@OutputTimeUnit(TimeUnit.MILLISECONDS)
|
||||
@State(Scope.Benchmark)
|
||||
@Warmup(iterations = 3, time = 1)
|
||||
@Measurement(iterations = 5, time = 1)
|
||||
@Fork(value = 3)
|
||||
public class OptimizedScalarQuantizerBenchmark {
|
||||
static {
|
||||
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
|
||||
}
|
||||
@Param({ "384", "702", "1024" })
|
||||
int dims;
|
||||
|
||||
float[] vector;
|
||||
float[] centroid;
|
||||
byte[] destination;
|
||||
|
||||
@Param({ "1", "4", "7" })
|
||||
byte bits;
|
||||
|
||||
OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(VectorSimilarityFunction.DOT_PRODUCT);
|
||||
|
||||
@Setup(Level.Iteration)
|
||||
public void init() {
|
||||
ThreadLocalRandom random = ThreadLocalRandom.current();
|
||||
// random byte arrays for binary methods
|
||||
destination = new byte[dims];
|
||||
vector = new float[dims];
|
||||
centroid = new float[dims];
|
||||
for (int i = 0; i < dims; ++i) {
|
||||
vector[i] = random.nextFloat();
|
||||
centroid[i] = random.nextFloat();
|
||||
}
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
public byte[] scalar() {
|
||||
osq.scalarQuantize(vector, destination, bits, centroid);
|
||||
return destination;
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
|
||||
public byte[] vector() {
|
||||
osq.scalarQuantize(vector, destination, bits, centroid);
|
||||
return destination;
|
||||
}
|
||||
}
|
5
docs/changelog/127118.yaml
Normal file
5
docs/changelog/127118.yaml
Normal file
|
@ -0,0 +1,5 @@
|
|||
pr: 127118
|
||||
summary: Panama vector accelerated optimized scalar quantization
|
||||
area: Vector Search
|
||||
type: enhancement
|
||||
issues: []
|
|
@ -144,4 +144,71 @@ public class ESVectorUtil {
|
|||
}
|
||||
return distance;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate the loss for optimized-scalar quantization for the given parameteres
|
||||
* @param target The vector being quantized, assumed to be centered
|
||||
* @param interval The interval for which to calculate the loss
|
||||
* @param points the quantization points
|
||||
* @param norm2 The norm squared of the target vector
|
||||
* @param lambda The lambda parameter for controlling anisotropic loss calculation
|
||||
* @return The loss for the given parameters
|
||||
*/
|
||||
public static float calculateOSQLoss(float[] target, float[] interval, int points, float norm2, float lambda) {
|
||||
assert interval.length == 2;
|
||||
float step = ((interval[1] - interval[0]) / (points - 1.0F));
|
||||
float invStep = 1f / step;
|
||||
return IMPL.calculateOSQLoss(target, interval, step, invStep, norm2, lambda);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate the grid points for optimized-scalar quantization
|
||||
* @param target The vector being quantized, assumed to be centered
|
||||
* @param interval The interval for which to calculate the grid points
|
||||
* @param points the quantization points
|
||||
* @param pts The array to store the grid points, must be of length 5
|
||||
*/
|
||||
public static void calculateOSQGridPoints(float[] target, float[] interval, int points, float[] pts) {
|
||||
assert interval.length == 2;
|
||||
assert pts.length == 5;
|
||||
float invStep = (points - 1.0F) / (interval[1] - interval[0]);
|
||||
IMPL.calculateOSQGridPoints(target, interval, points, invStep, pts);
|
||||
}
|
||||
|
||||
/**
|
||||
* Center the target vector and calculate the optimized-scalar quantization statistics
|
||||
* @param target The vector being quantized
|
||||
* @param centroid The centroid of the target vector
|
||||
* @param centered The destination of the centered vector, will be overwritten
|
||||
* @param stats The array to store the statistics, must be of length 5
|
||||
*/
|
||||
public static void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats) {
|
||||
assert target.length == centroid.length;
|
||||
assert stats.length == 5;
|
||||
if (target.length != centroid.length) {
|
||||
throw new IllegalArgumentException("vector dimensions differ: " + target.length + "!=" + centroid.length);
|
||||
}
|
||||
if (centered.length != target.length) {
|
||||
throw new IllegalArgumentException("vector dimensions differ: " + centered.length + "!=" + target.length);
|
||||
}
|
||||
IMPL.centerAndCalculateOSQStatsEuclidean(target, centroid, centered, stats);
|
||||
}
|
||||
|
||||
/**
|
||||
* Center the target vector and calculate the optimized-scalar quantization statistics
|
||||
* @param target The vector being quantized
|
||||
* @param centroid The centroid of the target vector
|
||||
* @param centered The destination of the centered vector, will be overwritten
|
||||
* @param stats The array to store the statistics, must be of length 6
|
||||
*/
|
||||
public static void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats) {
|
||||
if (target.length != centroid.length) {
|
||||
throw new IllegalArgumentException("vector dimensions differ: " + target.length + "!=" + centroid.length);
|
||||
}
|
||||
if (centered.length != target.length) {
|
||||
throw new IllegalArgumentException("vector dimensions differ: " + centered.length + "!=" + target.length);
|
||||
}
|
||||
assert stats.length == 6;
|
||||
IMPL.centerAndCalculateOSQStatsDp(target, centroid, centered, stats);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,6 +44,100 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
|
|||
return ipFloatByteImpl(q, d);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float calculateOSQLoss(float[] target, float[] interval, float step, float invStep, float norm2, float lambda) {
|
||||
float a = interval[0];
|
||||
float b = interval[1];
|
||||
float xe = 0f;
|
||||
float e = 0f;
|
||||
for (float xi : target) {
|
||||
// this is quantizing and then dequantizing the vector
|
||||
float xiq = fma(step, Math.round((Math.min(Math.max(xi, a), b) - a) * invStep), a);
|
||||
// how much does the de-quantized value differ from the original value
|
||||
float xiiq = xi - xiq;
|
||||
e = fma(xiiq, xiiq, e);
|
||||
xe = fma(xi, xiiq, xe);
|
||||
}
|
||||
return (1f - lambda) * xe * xe / norm2 + lambda * e;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void calculateOSQGridPoints(float[] target, float[] interval, int points, float invStep, float[] pts) {
|
||||
float a = interval[0];
|
||||
float b = interval[1];
|
||||
float daa = 0;
|
||||
float dab = 0;
|
||||
float dbb = 0;
|
||||
float dax = 0;
|
||||
float dbx = 0;
|
||||
for (float v : target) {
|
||||
float k = Math.round((Math.min(Math.max(v, a), b) - a) * invStep);
|
||||
float s = k / (points - 1);
|
||||
float ms = 1f - s;
|
||||
daa = fma(ms, ms, daa);
|
||||
dab = fma(ms, s, dab);
|
||||
dbb = fma(s, s, dbb);
|
||||
dax = fma(ms, v, dax);
|
||||
dbx = fma(s, v, dbx);
|
||||
}
|
||||
pts[0] = daa;
|
||||
pts[1] = dab;
|
||||
pts[2] = dbb;
|
||||
pts[3] = dax;
|
||||
pts[4] = dbx;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats) {
|
||||
float vecMean = 0;
|
||||
float vecVar = 0;
|
||||
float norm2 = 0;
|
||||
float min = Float.MAX_VALUE;
|
||||
float max = -Float.MAX_VALUE;
|
||||
for (int i = 0; i < target.length; i++) {
|
||||
centered[i] = target[i] - centroid[i];
|
||||
min = Math.min(min, centered[i]);
|
||||
max = Math.max(max, centered[i]);
|
||||
norm2 = fma(centered[i], centered[i], norm2);
|
||||
float delta = centered[i] - vecMean;
|
||||
vecMean += delta / (i + 1);
|
||||
float delta2 = centered[i] - vecMean;
|
||||
vecVar = fma(delta, delta2, vecVar);
|
||||
}
|
||||
stats[0] = vecMean;
|
||||
stats[1] = vecVar / target.length;
|
||||
stats[2] = norm2;
|
||||
stats[3] = min;
|
||||
stats[4] = max;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats) {
|
||||
float vecMean = 0;
|
||||
float vecVar = 0;
|
||||
float norm2 = 0;
|
||||
float centroidDot = 0;
|
||||
float min = Float.MAX_VALUE;
|
||||
float max = -Float.MAX_VALUE;
|
||||
for (int i = 0; i < target.length; i++) {
|
||||
centroidDot = fma(target[i], centroid[i], centroidDot);
|
||||
centered[i] = target[i] - centroid[i];
|
||||
min = Math.min(min, centered[i]);
|
||||
max = Math.max(max, centered[i]);
|
||||
norm2 = fma(centered[i], centered[i], norm2);
|
||||
float delta = centered[i] - vecMean;
|
||||
vecMean += delta / (i + 1);
|
||||
float delta2 = centered[i] - vecMean;
|
||||
vecVar = fma(delta, delta2, vecVar);
|
||||
}
|
||||
stats[0] = vecMean;
|
||||
stats[1] = vecVar / target.length;
|
||||
stats[2] = norm2;
|
||||
stats[3] = min;
|
||||
stats[4] = max;
|
||||
stats[5] = centroidDot;
|
||||
}
|
||||
|
||||
public static int ipByteBitImpl(byte[] q, byte[] d) {
|
||||
return ipByteBitImpl(q, d, 0);
|
||||
}
|
||||
|
|
|
@ -20,4 +20,12 @@ public interface ESVectorUtilSupport {
|
|||
float ipFloatBit(float[] q, byte[] d);
|
||||
|
||||
float ipFloatByte(float[] q, byte[] d);
|
||||
|
||||
float calculateOSQLoss(float[] target, float[] interval, float step, float invStep, float norm2, float lambda);
|
||||
|
||||
void calculateOSQGridPoints(float[] target, float[] interval, int points, float invStep, float[] pts);
|
||||
|
||||
void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats);
|
||||
|
||||
void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats);
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import jdk.incubator.vector.ByteVector;
|
|||
import jdk.incubator.vector.FloatVector;
|
||||
import jdk.incubator.vector.IntVector;
|
||||
import jdk.incubator.vector.LongVector;
|
||||
import jdk.incubator.vector.Vector;
|
||||
import jdk.incubator.vector.VectorMask;
|
||||
import jdk.incubator.vector.VectorOperators;
|
||||
import jdk.incubator.vector.VectorShape;
|
||||
|
@ -21,16 +22,22 @@ import jdk.incubator.vector.VectorSpecies;
|
|||
import org.apache.lucene.util.BitUtil;
|
||||
import org.apache.lucene.util.Constants;
|
||||
|
||||
import static jdk.incubator.vector.VectorOperators.ADD;
|
||||
import static jdk.incubator.vector.VectorOperators.MAX;
|
||||
import static jdk.incubator.vector.VectorOperators.MIN;
|
||||
|
||||
public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
|
||||
|
||||
static final int VECTOR_BITSIZE;
|
||||
|
||||
private static final VectorSpecies<Float> FLOAT_SPECIES;
|
||||
/** Whether integer vectors can be trusted to actually be fast. */
|
||||
static final boolean HAS_FAST_INTEGER_VECTORS;
|
||||
|
||||
static {
|
||||
// default to platform supported bitsize
|
||||
VECTOR_BITSIZE = VectorShape.preferredShape().vectorBitSize();
|
||||
FLOAT_SPECIES = VectorSpecies.of(float.class, VectorShape.forBitSize(VECTOR_BITSIZE));
|
||||
|
||||
// hotspot misses some SSE intrinsics, workaround it
|
||||
// to be fair, they do document this thing only works well with AVX2/AVX3 and Neon
|
||||
|
@ -38,6 +45,22 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
|
|||
HAS_FAST_INTEGER_VECTORS = isAMD64withoutAVX2 == false;
|
||||
}
|
||||
|
||||
private static FloatVector fma(FloatVector a, FloatVector b, FloatVector c) {
|
||||
if (Constants.HAS_FAST_VECTOR_FMA) {
|
||||
return a.fma(b, c);
|
||||
} else {
|
||||
return a.mul(b).add(c);
|
||||
}
|
||||
}
|
||||
|
||||
private static float fma(float a, float b, float c) {
|
||||
if (Constants.HAS_FAST_SCALAR_FMA) {
|
||||
return Math.fma(a, b, c);
|
||||
} else {
|
||||
return a * b + c;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ipByteBinByte(byte[] q, byte[] d) {
|
||||
// 128 / 8 == 16
|
||||
|
@ -83,6 +106,267 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
|
|||
return DefaultESVectorUtilSupport.ipFloatByteImpl(q, d);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void centerAndCalculateOSQStatsEuclidean(float[] vector, float[] centroid, float[] centered, float[] stats) {
|
||||
assert vector.length == centroid.length;
|
||||
assert vector.length == centered.length;
|
||||
float vecMean = 0;
|
||||
float vecVar = 0;
|
||||
float norm2 = 0;
|
||||
float min = Float.MAX_VALUE;
|
||||
float max = -Float.MAX_VALUE;
|
||||
int i = 0;
|
||||
int vectCount = 0;
|
||||
if (vector.length > 2 * FLOAT_SPECIES.length()) {
|
||||
FloatVector vecMeanVec = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector m2Vec = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector norm2Vec = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector minVec = FloatVector.broadcast(FLOAT_SPECIES, Float.MAX_VALUE);
|
||||
FloatVector maxVec = FloatVector.broadcast(FLOAT_SPECIES, -Float.MAX_VALUE);
|
||||
int count = 0;
|
||||
for (; i < FLOAT_SPECIES.loopBound(vector.length); i += FLOAT_SPECIES.length()) {
|
||||
++count;
|
||||
FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, vector, i);
|
||||
FloatVector c = FloatVector.fromArray(FLOAT_SPECIES, centroid, i);
|
||||
FloatVector centeredVec = v.sub(c);
|
||||
FloatVector deltaVec = centeredVec.sub(vecMeanVec);
|
||||
norm2Vec = fma(centeredVec, centeredVec, norm2Vec);
|
||||
vecMeanVec = vecMeanVec.add(deltaVec.div(count));
|
||||
FloatVector delta2Vec = centeredVec.sub(vecMeanVec);
|
||||
m2Vec = fma(deltaVec, delta2Vec, m2Vec);
|
||||
minVec = minVec.min(centeredVec);
|
||||
maxVec = maxVec.max(centeredVec);
|
||||
centeredVec.intoArray(centered, i);
|
||||
}
|
||||
min = minVec.reduceLanes(MIN);
|
||||
max = maxVec.reduceLanes(MAX);
|
||||
norm2 = norm2Vec.reduceLanes(ADD);
|
||||
vecMean = vecMeanVec.reduceLanes(ADD) / FLOAT_SPECIES.length();
|
||||
FloatVector d2Mean = vecMeanVec.sub(vecMean);
|
||||
m2Vec = fma(d2Mean, d2Mean, m2Vec);
|
||||
vectCount = count * FLOAT_SPECIES.length();
|
||||
vecVar = m2Vec.reduceLanes(ADD);
|
||||
}
|
||||
|
||||
float tailMean = 0;
|
||||
float tailM2 = 0;
|
||||
int tailCount = 0;
|
||||
// handle the tail
|
||||
for (; i < vector.length; i++) {
|
||||
centered[i] = vector[i] - centroid[i];
|
||||
float delta = centered[i] - tailMean;
|
||||
++tailCount;
|
||||
tailMean += delta / tailCount;
|
||||
float delta2 = centered[i] - tailMean;
|
||||
tailM2 = fma(delta, delta2, tailM2);
|
||||
min = Math.min(min, centered[i]);
|
||||
max = Math.max(max, centered[i]);
|
||||
norm2 = fma(centered[i], centered[i], norm2);
|
||||
}
|
||||
if (vectCount == 0) {
|
||||
vecMean = tailMean;
|
||||
vecVar = tailM2;
|
||||
} else if (tailCount > 0) {
|
||||
int totalCount = tailCount + vectCount;
|
||||
assert totalCount == vector.length;
|
||||
float alpha = (float) vectCount / totalCount;
|
||||
float beta = 1f - alpha;
|
||||
float completeMean = alpha * vecMean + beta * tailMean;
|
||||
float dMean2Lhs = (vecMean - completeMean) * (vecMean - completeMean);
|
||||
float dMean2Rhs = (tailMean - completeMean) * (tailMean - completeMean);
|
||||
vecVar = (vecVar + dMean2Lhs) + beta * (tailM2 + dMean2Rhs);
|
||||
vecMean = completeMean;
|
||||
}
|
||||
stats[0] = vecMean;
|
||||
stats[1] = vecVar / vector.length;
|
||||
stats[2] = norm2;
|
||||
stats[3] = min;
|
||||
stats[4] = max;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void centerAndCalculateOSQStatsDp(float[] vector, float[] centroid, float[] centered, float[] stats) {
|
||||
assert vector.length == centroid.length;
|
||||
assert vector.length == centered.length;
|
||||
float vecMean = 0;
|
||||
float vecVar = 0;
|
||||
float norm2 = 0;
|
||||
float min = Float.MAX_VALUE;
|
||||
float max = -Float.MAX_VALUE;
|
||||
float centroidDot = 0;
|
||||
int i = 0;
|
||||
int vectCount = 0;
|
||||
int loopBound = FLOAT_SPECIES.loopBound(vector.length);
|
||||
if (vector.length > 2 * FLOAT_SPECIES.length()) {
|
||||
FloatVector vecMeanVec = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector m2Vec = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector norm2Vec = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector minVec = FloatVector.broadcast(FLOAT_SPECIES, Float.MAX_VALUE);
|
||||
FloatVector maxVec = FloatVector.broadcast(FLOAT_SPECIES, -Float.MAX_VALUE);
|
||||
FloatVector centroidDotVec = FloatVector.zero(FLOAT_SPECIES);
|
||||
int count = 0;
|
||||
for (; i < loopBound; i += FLOAT_SPECIES.length()) {
|
||||
++count;
|
||||
FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, vector, i);
|
||||
FloatVector c = FloatVector.fromArray(FLOAT_SPECIES, centroid, i);
|
||||
centroidDotVec = fma(v, c, centroidDotVec);
|
||||
FloatVector centeredVec = v.sub(c);
|
||||
FloatVector deltaVec = centeredVec.sub(vecMeanVec);
|
||||
norm2Vec = fma(centeredVec, centeredVec, norm2Vec);
|
||||
vecMeanVec = vecMeanVec.add(deltaVec.div(count));
|
||||
FloatVector delta2Vec = centeredVec.sub(vecMeanVec);
|
||||
m2Vec = fma(deltaVec, delta2Vec, m2Vec);
|
||||
minVec = minVec.min(centeredVec);
|
||||
maxVec = maxVec.max(centeredVec);
|
||||
centeredVec.intoArray(centered, i);
|
||||
}
|
||||
min = minVec.reduceLanes(MIN);
|
||||
max = maxVec.reduceLanes(MAX);
|
||||
norm2 = norm2Vec.reduceLanes(ADD);
|
||||
centroidDot = centroidDotVec.reduceLanes(ADD);
|
||||
vecMean = vecMeanVec.reduceLanes(ADD) / FLOAT_SPECIES.length();
|
||||
FloatVector d2Mean = vecMeanVec.sub(vecMean);
|
||||
m2Vec = fma(d2Mean, d2Mean, m2Vec);
|
||||
vectCount = count * FLOAT_SPECIES.length();
|
||||
vecVar = m2Vec.reduceLanes(ADD);
|
||||
}
|
||||
|
||||
float tailMean = 0;
|
||||
float tailM2 = 0;
|
||||
int tailCount = 0;
|
||||
// handle the tail
|
||||
for (; i < vector.length; i++) {
|
||||
centroidDot = fma(vector[i], centroid[i], centroidDot);
|
||||
centered[i] = vector[i] - centroid[i];
|
||||
float delta = centered[i] - tailMean;
|
||||
++tailCount;
|
||||
tailMean += delta / tailCount;
|
||||
float delta2 = centered[i] - tailMean;
|
||||
tailM2 = fma(delta, delta2, tailM2);
|
||||
min = Math.min(min, centered[i]);
|
||||
max = Math.max(max, centered[i]);
|
||||
norm2 = fma(centered[i], centered[i], norm2);
|
||||
}
|
||||
if (vectCount == 0) {
|
||||
vecMean = tailMean;
|
||||
vecVar = tailM2;
|
||||
} else if (tailCount > 0) {
|
||||
int totalCount = tailCount + vectCount;
|
||||
assert totalCount == vector.length;
|
||||
float alpha = (float) vectCount / totalCount;
|
||||
float beta = 1f - alpha;
|
||||
float completeMean = alpha * vecMean + beta * tailMean;
|
||||
float dMean2Lhs = (vecMean - completeMean) * (vecMean - completeMean);
|
||||
float dMean2Rhs = (tailMean - completeMean) * (tailMean - completeMean);
|
||||
vecVar = (vecVar + dMean2Lhs) + beta * (tailM2 + dMean2Rhs);
|
||||
vecMean = completeMean;
|
||||
}
|
||||
stats[0] = vecMean;
|
||||
stats[1] = vecVar / vector.length;
|
||||
stats[2] = norm2;
|
||||
stats[3] = min;
|
||||
stats[4] = max;
|
||||
stats[5] = centroidDot;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void calculateOSQGridPoints(float[] target, float[] interval, int points, float invStep, float[] pts) {
|
||||
float a = interval[0];
|
||||
float b = interval[1];
|
||||
int i = 0;
|
||||
float daa = 0;
|
||||
float dab = 0;
|
||||
float dbb = 0;
|
||||
float dax = 0;
|
||||
float dbx = 0;
|
||||
|
||||
FloatVector daaVec = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector dabVec = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector dbbVec = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector daxVec = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector dbxVec = FloatVector.zero(FLOAT_SPECIES);
|
||||
|
||||
// if the array size is large (> 2x platform vector size), it's worth the overhead to vectorize
|
||||
if (target.length > 2 * FLOAT_SPECIES.length()) {
|
||||
FloatVector ones = FloatVector.broadcast(FLOAT_SPECIES, 1f);
|
||||
FloatVector pmOnes = FloatVector.broadcast(FLOAT_SPECIES, points - 1f);
|
||||
for (; i < FLOAT_SPECIES.loopBound(target.length); i += FLOAT_SPECIES.length()) {
|
||||
FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, target, i);
|
||||
FloatVector vClamped = v.max(a).min(b);
|
||||
Vector<Integer> xiqint = vClamped.sub(a)
|
||||
.mul(invStep)
|
||||
// round
|
||||
.add(0.5f)
|
||||
.convert(VectorOperators.F2I, 0);
|
||||
FloatVector kVec = xiqint.convert(VectorOperators.I2F, 0).reinterpretAsFloats();
|
||||
FloatVector sVec = kVec.div(pmOnes);
|
||||
FloatVector smVec = ones.sub(sVec);
|
||||
daaVec = fma(smVec, smVec, daaVec);
|
||||
dabVec = fma(smVec, sVec, dabVec);
|
||||
dbbVec = fma(sVec, sVec, dbbVec);
|
||||
daxVec = fma(v, smVec, daxVec);
|
||||
dbxVec = fma(v, sVec, dbxVec);
|
||||
}
|
||||
daa = daaVec.reduceLanes(ADD);
|
||||
dab = dabVec.reduceLanes(ADD);
|
||||
dbb = dbbVec.reduceLanes(ADD);
|
||||
dax = daxVec.reduceLanes(ADD);
|
||||
dbx = dbxVec.reduceLanes(ADD);
|
||||
}
|
||||
|
||||
for (; i < target.length; i++) {
|
||||
float k = Math.round((Math.min(Math.max(target[i], a), b) - a) * invStep);
|
||||
float s = k / (points - 1);
|
||||
float ms = 1f - s;
|
||||
daa = fma(ms, ms, daa);
|
||||
dab = fma(ms, s, dab);
|
||||
dbb = fma(s, s, dbb);
|
||||
dax = fma(ms, target[i], dax);
|
||||
dbx = fma(s, target[i], dbx);
|
||||
}
|
||||
|
||||
pts[0] = daa;
|
||||
pts[1] = dab;
|
||||
pts[2] = dbb;
|
||||
pts[3] = dax;
|
||||
pts[4] = dbx;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float calculateOSQLoss(float[] target, float[] interval, float step, float invStep, float norm2, float lambda) {
|
||||
float a = interval[0];
|
||||
float b = interval[1];
|
||||
float xe = 0f;
|
||||
float e = 0f;
|
||||
FloatVector xeVec = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector eVec = FloatVector.zero(FLOAT_SPECIES);
|
||||
int i = 0;
|
||||
// if the array size is large (> 2x platform vector size), it's worth the overhead to vectorize
|
||||
if (target.length > 2 * FLOAT_SPECIES.length()) {
|
||||
for (; i < FLOAT_SPECIES.loopBound(target.length); i += FLOAT_SPECIES.length()) {
|
||||
FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, target, i);
|
||||
FloatVector vClamped = v.max(a).min(b);
|
||||
Vector<Integer> xiqint = vClamped.sub(a).mul(invStep).add(0.5f).convert(VectorOperators.F2I, 0);
|
||||
FloatVector xiq = xiqint.convert(VectorOperators.I2F, 0).reinterpretAsFloats().mul(step).add(a);
|
||||
FloatVector xiiq = v.sub(xiq);
|
||||
xeVec = fma(v, xiiq, xeVec);
|
||||
eVec = fma(xiiq, xiiq, eVec);
|
||||
}
|
||||
e = eVec.reduceLanes(ADD);
|
||||
xe = xeVec.reduceLanes(ADD);
|
||||
}
|
||||
|
||||
for (; i < target.length; i++) {
|
||||
// this is quantizing and then dequantizing the vector
|
||||
float xiq = fma(step, Math.round((Math.min(Math.max(target[i], a), b) - a) * invStep), a);
|
||||
// how much does the de-quantized value differ from the original value
|
||||
float xiiq = target[i] - xiq;
|
||||
e = fma(xiiq, xiiq, e);
|
||||
xe = fma(target[i], xiiq, xe);
|
||||
}
|
||||
return (1f - lambda) * xe * xe / norm2 + lambda * e;
|
||||
}
|
||||
|
||||
private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
|
||||
private static final VectorSpecies<Byte> BYTE_SPECIES_256 = ByteVector.SPECIES_256;
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
|
||||
package org.elasticsearch.simdvec;
|
||||
|
||||
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
|
||||
import org.elasticsearch.simdvec.internal.vectorization.BaseVectorizationTests;
|
||||
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
|
||||
|
||||
|
@ -161,6 +162,112 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
|
|||
testIpByteBinImpl(defOrPanamaProvider.getVectorUtilSupport()::ipByteBinByte);
|
||||
}
|
||||
|
||||
public void testCenterAndCalculateOSQStatsDp() {
|
||||
int size = random().nextInt(128, 512);
|
||||
float delta = 1e-3f * size;
|
||||
var vector = new float[size];
|
||||
var centroid = new float[size];
|
||||
for (int i = 0; i < size; ++i) {
|
||||
vector[i] = random().nextFloat();
|
||||
centroid[i] = random().nextFloat();
|
||||
}
|
||||
var centeredLucene = new float[size];
|
||||
var statsLucene = new float[6];
|
||||
defaultedProvider.getVectorUtilSupport().centerAndCalculateOSQStatsDp(vector, centroid, centeredLucene, statsLucene);
|
||||
var centeredPanama = new float[size];
|
||||
var statsPanama = new float[6];
|
||||
defOrPanamaProvider.getVectorUtilSupport().centerAndCalculateOSQStatsDp(vector, centroid, centeredPanama, statsPanama);
|
||||
assertArrayEquals(centeredLucene, centeredPanama, delta);
|
||||
assertArrayEquals(statsLucene, statsPanama, delta);
|
||||
}
|
||||
|
||||
public void testCenterAndCalculateOSQStatsEuclidean() {
|
||||
int size = random().nextInt(128, 512);
|
||||
float delta = 1e-3f * size;
|
||||
var vector = new float[size];
|
||||
var centroid = new float[size];
|
||||
for (int i = 0; i < size; ++i) {
|
||||
vector[i] = random().nextFloat();
|
||||
centroid[i] = random().nextFloat();
|
||||
}
|
||||
var centeredLucene = new float[size];
|
||||
var statsLucene = new float[5];
|
||||
defaultedProvider.getVectorUtilSupport().centerAndCalculateOSQStatsEuclidean(vector, centroid, centeredLucene, statsLucene);
|
||||
var centeredPanama = new float[size];
|
||||
var statsPanama = new float[5];
|
||||
defOrPanamaProvider.getVectorUtilSupport().centerAndCalculateOSQStatsEuclidean(vector, centroid, centeredPanama, statsPanama);
|
||||
assertArrayEquals(centeredLucene, centeredPanama, delta);
|
||||
assertArrayEquals(statsLucene, statsPanama, delta);
|
||||
}
|
||||
|
||||
public void testOsqLoss() {
|
||||
int size = random().nextInt(128, 512);
|
||||
float deltaEps = 1e-5f * size;
|
||||
var vector = new float[size];
|
||||
var min = Float.MAX_VALUE;
|
||||
var max = -Float.MAX_VALUE;
|
||||
float vecMean = 0;
|
||||
float vecVar = 0;
|
||||
float norm2 = 0;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
vector[i] = random().nextFloat();
|
||||
min = Math.min(min, vector[i]);
|
||||
max = Math.max(max, vector[i]);
|
||||
float delta = vector[i] - vecMean;
|
||||
vecMean += delta / (i + 1);
|
||||
float delta2 = vector[i] - vecMean;
|
||||
vecVar += delta * delta2;
|
||||
norm2 += vector[i] * vector[i];
|
||||
}
|
||||
vecVar /= size;
|
||||
float vecStd = (float) Math.sqrt(vecVar);
|
||||
|
||||
for (byte bits : new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 }) {
|
||||
int points = 1 << bits;
|
||||
float[] initInterval = new float[2];
|
||||
OptimizedScalarQuantizer.initInterval(bits, vecStd, vecMean, min, max, initInterval);
|
||||
float step = ((initInterval[1] - initInterval[0]) / (points - 1f));
|
||||
float stepInv = 1f / step;
|
||||
float expected = defaultedProvider.getVectorUtilSupport().calculateOSQLoss(vector, initInterval, step, stepInv, norm2, 0.1f);
|
||||
float result = defOrPanamaProvider.getVectorUtilSupport().calculateOSQLoss(vector, initInterval, step, stepInv, norm2, 0.1f);
|
||||
assertEquals(expected, result, deltaEps);
|
||||
}
|
||||
}
|
||||
|
||||
public void testOsqGridPoints() {
|
||||
int size = random().nextInt(128, 512);
|
||||
float deltaEps = 1e-5f * size;
|
||||
var vector = new float[size];
|
||||
var min = Float.MAX_VALUE;
|
||||
var max = -Float.MAX_VALUE;
|
||||
float vecMean = 0;
|
||||
float vecVar = 0;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
vector[i] = random().nextFloat();
|
||||
min = Math.min(min, vector[i]);
|
||||
max = Math.max(max, vector[i]);
|
||||
float delta = vector[i] - vecMean;
|
||||
vecMean += delta / (i + 1);
|
||||
float delta2 = vector[i] - vecMean;
|
||||
vecVar += delta * delta2;
|
||||
}
|
||||
vecVar /= size;
|
||||
float vecStd = (float) Math.sqrt(vecVar);
|
||||
for (byte bits : new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 }) {
|
||||
int points = 1 << bits;
|
||||
float[] initInterval = new float[2];
|
||||
OptimizedScalarQuantizer.initInterval(bits, vecStd, vecMean, min, max, initInterval);
|
||||
float step = ((initInterval[1] - initInterval[0]) / (points - 1f));
|
||||
float stepInv = 1f / step;
|
||||
float[] expected = new float[5];
|
||||
defaultedProvider.getVectorUtilSupport().calculateOSQGridPoints(vector, initInterval, points, stepInv, expected);
|
||||
|
||||
float[] result = new float[5];
|
||||
defOrPanamaProvider.getVectorUtilSupport().calculateOSQGridPoints(vector, initInterval, points, stepInv, result);
|
||||
assertArrayEquals(expected, result, deltaEps);
|
||||
}
|
||||
}
|
||||
|
||||
void testIpByteBinImpl(ToLongBiFunction<byte[], byte[]> ipByteBinFunc) {
|
||||
int iterations = atLeast(50);
|
||||
for (int i = 0; i < iterations; i++) {
|
||||
|
|
|
@ -7,15 +7,21 @@
|
|||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.codec.vectors.es818;
|
||||
package org.elasticsearch.index.codec.vectors;
|
||||
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.elasticsearch.simdvec.ESVectorUtil;
|
||||
|
||||
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
|
||||
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
|
||||
|
||||
class OptimizedScalarQuantizer {
|
||||
public class OptimizedScalarQuantizer {
|
||||
public static void initInterval(byte bits, float vecStd, float vecMean, float min, float max, float[] initInterval) {
|
||||
initInterval[0] = (float) clamp(MINIMUM_MSE_GRID[bits - 1][0] * vecStd + vecMean, min, max);
|
||||
initInterval[1] = (float) clamp(MINIMUM_MSE_GRID[bits - 1][1] * vecStd + vecMean, min, max);
|
||||
}
|
||||
|
||||
// The initial interval is set to the minimum MSE grid for each number of bits
|
||||
// these starting points are derived from the optimal MSE grid for a uniform distribution
|
||||
static final float[][] MINIMUM_MSE_GRID = new float[][] {
|
||||
|
@ -27,19 +33,25 @@ class OptimizedScalarQuantizer {
|
|||
{ -3.278f, 3.278f },
|
||||
{ -3.611f, 3.611f },
|
||||
{ -3.922f, 3.922f } };
|
||||
private static final float DEFAULT_LAMBDA = 0.1f;
|
||||
public static final float DEFAULT_LAMBDA = 0.1f;
|
||||
private static final int DEFAULT_ITERS = 5;
|
||||
private final VectorSimilarityFunction similarityFunction;
|
||||
private final float lambda;
|
||||
private final int iters;
|
||||
private final float[] statsScratch;
|
||||
private final float[] gridScratch;
|
||||
private final float[] intervalScratch;
|
||||
|
||||
OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction, float lambda, int iters) {
|
||||
public OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction, float lambda, int iters) {
|
||||
this.similarityFunction = similarityFunction;
|
||||
this.lambda = lambda;
|
||||
this.iters = iters;
|
||||
this.statsScratch = new float[similarityFunction == EUCLIDEAN ? 5 : 6];
|
||||
this.gridScratch = new float[5];
|
||||
this.intervalScratch = new float[2];
|
||||
}
|
||||
|
||||
OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction) {
|
||||
public OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction) {
|
||||
this(similarityFunction, DEFAULT_LAMBDA, DEFAULT_ITERS);
|
||||
}
|
||||
|
||||
|
@ -49,34 +61,23 @@ class OptimizedScalarQuantizer {
|
|||
assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector);
|
||||
assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
|
||||
assert bits.length == destinations.length;
|
||||
float[] intervalScratch = new float[2];
|
||||
double vecMean = 0;
|
||||
double vecVar = 0;
|
||||
float norm2 = 0;
|
||||
float centroidDot = 0;
|
||||
float min = Float.MAX_VALUE;
|
||||
float max = -Float.MAX_VALUE;
|
||||
for (int i = 0; i < vector.length; ++i) {
|
||||
if (similarityFunction != EUCLIDEAN) {
|
||||
centroidDot += vector[i] * centroid[i];
|
||||
}
|
||||
vector[i] = vector[i] - centroid[i];
|
||||
min = Math.min(min, vector[i]);
|
||||
max = Math.max(max, vector[i]);
|
||||
norm2 += (vector[i] * vector[i]);
|
||||
double delta = vector[i] - vecMean;
|
||||
vecMean += delta / (i + 1);
|
||||
vecVar += delta * (vector[i] - vecMean);
|
||||
if (similarityFunction == EUCLIDEAN) {
|
||||
ESVectorUtil.centerAndCalculateOSQStatsEuclidean(vector, centroid, vector, statsScratch);
|
||||
} else {
|
||||
ESVectorUtil.centerAndCalculateOSQStatsDp(vector, centroid, vector, statsScratch);
|
||||
}
|
||||
vecVar /= vector.length;
|
||||
double vecStd = Math.sqrt(vecVar);
|
||||
float vecMean = statsScratch[0];
|
||||
float vecVar = statsScratch[1];
|
||||
float norm2 = statsScratch[2];
|
||||
float min = statsScratch[3];
|
||||
float max = statsScratch[4];
|
||||
float vecStd = (float) Math.sqrt(vecVar);
|
||||
QuantizationResult[] results = new QuantizationResult[bits.length];
|
||||
for (int i = 0; i < bits.length; ++i) {
|
||||
assert bits[i] > 0 && bits[i] <= 8;
|
||||
int points = (1 << bits[i]);
|
||||
// Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds
|
||||
intervalScratch[0] = (float) clamp(MINIMUM_MSE_GRID[bits[i] - 1][0] * vecStd + vecMean, min, max);
|
||||
intervalScratch[1] = (float) clamp(MINIMUM_MSE_GRID[bits[i] - 1][1] * vecStd + vecMean, min, max);
|
||||
initInterval(bits[i], vecStd, vecMean, min, max, intervalScratch);
|
||||
optimizeIntervals(intervalScratch, vector, norm2, points);
|
||||
float nSteps = ((1 << bits[i]) - 1);
|
||||
float a = intervalScratch[0];
|
||||
|
@ -93,7 +94,7 @@ class OptimizedScalarQuantizer {
|
|||
results[i] = new QuantizationResult(
|
||||
intervalScratch[0],
|
||||
intervalScratch[1],
|
||||
similarityFunction == EUCLIDEAN ? norm2 : centroidDot,
|
||||
similarityFunction == EUCLIDEAN ? norm2 : statsScratch[5],
|
||||
sumQuery
|
||||
);
|
||||
}
|
||||
|
@ -105,31 +106,20 @@ class OptimizedScalarQuantizer {
|
|||
assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
|
||||
assert vector.length <= destination.length;
|
||||
assert bits > 0 && bits <= 8;
|
||||
float[] intervalScratch = new float[2];
|
||||
int points = 1 << bits;
|
||||
double vecMean = 0;
|
||||
double vecVar = 0;
|
||||
float norm2 = 0;
|
||||
float centroidDot = 0;
|
||||
float min = Float.MAX_VALUE;
|
||||
float max = -Float.MAX_VALUE;
|
||||
for (int i = 0; i < vector.length; ++i) {
|
||||
if (similarityFunction != EUCLIDEAN) {
|
||||
centroidDot += vector[i] * centroid[i];
|
||||
}
|
||||
vector[i] = vector[i] - centroid[i];
|
||||
min = Math.min(min, vector[i]);
|
||||
max = Math.max(max, vector[i]);
|
||||
norm2 += (vector[i] * vector[i]);
|
||||
double delta = vector[i] - vecMean;
|
||||
vecMean += delta / (i + 1);
|
||||
vecVar += delta * (vector[i] - vecMean);
|
||||
if (similarityFunction == EUCLIDEAN) {
|
||||
ESVectorUtil.centerAndCalculateOSQStatsEuclidean(vector, centroid, vector, statsScratch);
|
||||
} else {
|
||||
ESVectorUtil.centerAndCalculateOSQStatsDp(vector, centroid, vector, statsScratch);
|
||||
}
|
||||
vecVar /= vector.length;
|
||||
double vecStd = Math.sqrt(vecVar);
|
||||
float vecMean = statsScratch[0];
|
||||
float vecVar = statsScratch[1];
|
||||
float norm2 = statsScratch[2];
|
||||
float min = statsScratch[3];
|
||||
float max = statsScratch[4];
|
||||
float vecStd = (float) Math.sqrt(vecVar);
|
||||
// Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds
|
||||
intervalScratch[0] = (float) clamp(MINIMUM_MSE_GRID[bits - 1][0] * vecStd + vecMean, min, max);
|
||||
intervalScratch[1] = (float) clamp(MINIMUM_MSE_GRID[bits - 1][1] * vecStd + vecMean, min, max);
|
||||
initInterval(bits, vecStd, vecMean, min, max, intervalScratch);
|
||||
optimizeIntervals(intervalScratch, vector, norm2, points);
|
||||
float nSteps = ((1 << bits) - 1);
|
||||
// Now we have the optimized intervals, quantize the vector
|
||||
|
@ -146,37 +136,11 @@ class OptimizedScalarQuantizer {
|
|||
return new QuantizationResult(
|
||||
intervalScratch[0],
|
||||
intervalScratch[1],
|
||||
similarityFunction == EUCLIDEAN ? norm2 : centroidDot,
|
||||
similarityFunction == EUCLIDEAN ? norm2 : statsScratch[5],
|
||||
sumQuery
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the loss of the vector given the interval. Effectively, we are computing the MSE of a dequantized vector with the raw
|
||||
* vector.
|
||||
* @param vector raw vector
|
||||
* @param interval interval to quantize the vector
|
||||
* @param points number of quantization points
|
||||
* @param norm2 squared norm of the vector
|
||||
* @return the loss
|
||||
*/
|
||||
private double loss(float[] vector, float[] interval, int points, float norm2) {
|
||||
double a = interval[0];
|
||||
double b = interval[1];
|
||||
double step = ((b - a) / (points - 1.0F));
|
||||
double stepInv = 1.0 / step;
|
||||
double xe = 0.0;
|
||||
double e = 0.0;
|
||||
for (double xi : vector) {
|
||||
// this is quantizing and then dequantizing the vector
|
||||
double xiq = (a + step * Math.round((clamp(xi, a, b) - a) * stepInv));
|
||||
// how much does the de-quantized value differ from the original value
|
||||
xe += xi * (xi - xiq);
|
||||
e += (xi - xiq) * (xi - xiq);
|
||||
}
|
||||
return (1.0 - lambda) * xe * xe / norm2 + lambda * e;
|
||||
}
|
||||
|
||||
/**
|
||||
* Optimize the quantization interval for the given vector. This is done via a coordinate descent trying to minimize the quantization
|
||||
* loss. Note, the loss is not always guaranteed to decrease, so we have a maximum number of iterations and will exit early if the
|
||||
|
@ -187,30 +151,19 @@ class OptimizedScalarQuantizer {
|
|||
* @param points number of quantization points
|
||||
*/
|
||||
private void optimizeIntervals(float[] initInterval, float[] vector, float norm2, int points) {
|
||||
double initialLoss = loss(vector, initInterval, points, norm2);
|
||||
double initialLoss = ESVectorUtil.calculateOSQLoss(vector, initInterval, points, norm2, lambda);
|
||||
final float scale = (1.0f - lambda) / norm2;
|
||||
if (Float.isFinite(scale) == false) {
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
float a = initInterval[0];
|
||||
float b = initInterval[1];
|
||||
float stepInv = (points - 1.0f) / (b - a);
|
||||
// calculate the grid points for coordinate descent
|
||||
double daa = 0;
|
||||
double dab = 0;
|
||||
double dbb = 0;
|
||||
double dax = 0;
|
||||
double dbx = 0;
|
||||
for (float xi : vector) {
|
||||
float k = Math.round((clamp(xi, a, b) - a) * stepInv);
|
||||
float s = k / (points - 1);
|
||||
daa += (1.0 - s) * (1.0 - s);
|
||||
dab += (1.0 - s) * s;
|
||||
dbb += s * s;
|
||||
dax += xi * (1.0 - s);
|
||||
dbx += xi * s;
|
||||
}
|
||||
ESVectorUtil.calculateOSQGridPoints(vector, initInterval, points, gridScratch);
|
||||
float daa = gridScratch[0];
|
||||
float dab = gridScratch[1];
|
||||
float dbb = gridScratch[2];
|
||||
float dax = gridScratch[3];
|
||||
float dbx = gridScratch[4];
|
||||
double m0 = scale * dax * dax + lambda * daa;
|
||||
double m1 = scale * dax * dbx + lambda * dab;
|
||||
double m2 = scale * dbx * dbx + lambda * dbb;
|
||||
|
@ -225,7 +178,7 @@ class OptimizedScalarQuantizer {
|
|||
if ((Math.abs(initInterval[0] - aOpt) < 1e-8 && Math.abs(initInterval[1] - bOpt) < 1e-8)) {
|
||||
return;
|
||||
}
|
||||
double newLoss = loss(vector, new float[] { aOpt, bOpt }, points, norm2);
|
||||
double newLoss = ESVectorUtil.calculateOSQLoss(vector, new float[] { aOpt, bOpt }, points, norm2, lambda);
|
||||
// If the new loss is worse, don't update the interval and exit
|
||||
// This optimization, unlike kMeans, does not always converge to better loss
|
||||
// So exit if we are getting worse
|
|
@ -23,6 +23,7 @@ import org.apache.lucene.index.ByteVectorValues;
|
|||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
|
||||
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
|||
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
|
||||
import org.elasticsearch.index.codec.vectors.BQSpaceUtils;
|
||||
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
|
||||
import org.elasticsearch.simdvec.ESVectorUtil;
|
||||
|
||||
import java.io.IOException;
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
|
|||
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
|
@ -35,7 +36,7 @@ import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_
|
|||
* Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10
|
||||
* Codec for encoding/decoding binary quantized vectors The binary quantization format used here
|
||||
* is a per-vector optimized scalar quantization. Also see {@link
|
||||
* org.elasticsearch.index.codec.vectors.es818.OptimizedScalarQuantizer}. Some of key features are:
|
||||
* OptimizedScalarQuantizer}. Some of key features are:
|
||||
*
|
||||
* <ul>
|
||||
* <li>Estimating the distance between two vectors using their centroid normalized distance. This
|
||||
|
|
|
@ -44,6 +44,7 @@ import org.apache.lucene.util.SuppressForbidden;
|
|||
import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
|
||||
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
|
||||
import org.elasticsearch.index.codec.vectors.reflect.OffHeapByteSizeUtils;
|
||||
import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats;
|
||||
|
||||
|
|
|
@ -49,6 +49,7 @@ import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
|
|||
import org.elasticsearch.core.SuppressForbidden;
|
||||
import org.elasticsearch.index.codec.vectors.BQSpaceUtils;
|
||||
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
|
||||
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
|
||||
|
||||
import java.io.Closeable;
|
||||
import java.io.IOException;
|
||||
|
|
|
@ -30,6 +30,7 @@ import org.apache.lucene.util.Bits;
|
|||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
|
||||
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
|
|
|
@ -7,13 +7,13 @@
|
|||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.codec.vectors.es818;
|
||||
package org.elasticsearch.index.codec.vectors;
|
||||
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import static org.elasticsearch.index.codec.vectors.es818.OptimizedScalarQuantizer.MINIMUM_MSE_GRID;
|
||||
import static org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer.MINIMUM_MSE_GRID;
|
||||
|
||||
public class OptimizedScalarQuantizerTests extends ESTestCase {
|
||||
|
|
@ -45,6 +45,7 @@ import org.apache.lucene.store.Directory;
|
|||
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
|
||||
import org.elasticsearch.common.logging.LogConfigurator;
|
||||
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
|
||||
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
|
||||
import org.elasticsearch.index.codec.vectors.reflect.OffHeapByteSizeUtils;
|
||||
|
||||
import java.io.IOException;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue