Refactor JMH script vector distance benchmark to add panama benchmarks (#124351)

Add vector benchmarks vs scalar, and automatically pick up new implementations as they get added
This commit is contained in:
Simon Cooper 2025-03-12 13:15:16 +00:00 committed by GitHub
parent f6538e86e2
commit d7864f4af6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -10,10 +10,12 @@
package org.elasticsearch.benchmark.vector;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.logging.LogConfigurator;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.script.field.vectors.BinaryDenseVector;
import org.elasticsearch.script.field.vectors.ByteBinaryDenseVector;
import org.elasticsearch.script.field.vectors.ByteKnnDenseVector;
import org.elasticsearch.script.field.vectors.DenseVector;
import org.elasticsearch.script.field.vectors.KnnDenseVector;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
@ -27,11 +29,12 @@ import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.function.DoubleSupplier;
/**
* Various benchmarks for the distance functions
@ -51,451 +54,172 @@ import java.util.function.Consumer;
@State(Scope.Benchmark)
public class DistanceFunctionBenchmark {
@Param({ "float", "byte" })
private String element;
static {
LogConfigurator.configureESLogging();
}
@Param({ "96" })
public enum VectorType {
FLOAT,
BYTE
}
public enum Function {
DOT,
COSINE,
L1,
L2,
HAMMING
}
public enum Implementation {
KNN,
BINARY
}
@Param
private VectorType docType;
@Param
private VectorType queryType;
@Param({ "1024" })
private int dims;
@Param({ "dot", "cosine", "l1", "l2", "hamming" })
private String function;
@Param
private Function function;
@Param({ "knn", "binary" })
private String type;
@Param
private Implementation type;
private abstract static class BenchmarkFunction {
private DoubleSupplier benchmarkImpl;
final int dims;
private BenchmarkFunction(int dims) {
this.dims = dims;
}
abstract void execute(Consumer<Object> consumer);
private static float calculateMag(float[] vector) {
float mag = 0;
for (float f : vector) {
mag += f * f;
}
private abstract static class KnnFloatBenchmarkFunction extends BenchmarkFunction {
final float[] docVector;
final float[] queryVector;
private KnnFloatBenchmarkFunction(int dims, boolean normalize) {
super(dims);
docVector = new float[dims];
queryVector = new float[dims];
float docMagnitude = 0f;
float queryMagnitude = 0f;
for (int i = 0; i < dims; ++i) {
docVector[i] = (float) (dims - i);
queryVector[i] = (float) i;
docMagnitude += (float) (dims - i);
queryMagnitude += (float) i;
return (float) Math.sqrt(mag);
}
docMagnitude /= dims;
queryMagnitude /= dims;
if (normalize) {
for (int i = 0; i < dims; ++i) {
docVector[i] /= docMagnitude;
queryVector[i] /= queryMagnitude;
}
}
private static float calculateMag(byte[] vector) {
float mag = 0;
for (byte b : vector) {
mag += b * b;
}
return (float) Math.sqrt(mag);
}
private abstract static class BinaryFloatBenchmarkFunction extends BenchmarkFunction {
final BytesRef docVector;
final float[] docFloatVector;
final float[] queryVector;
private BinaryFloatBenchmarkFunction(int dims, boolean normalize) {
super(dims);
docFloatVector = new float[dims];
queryVector = new float[dims];
float docMagnitude = 0f;
float queryMagnitude = 0f;
for (int i = 0; i < dims; ++i) {
docFloatVector[i] = (float) (dims - i);
queryVector[i] = (float) i;
docMagnitude += (float) (dims - i);
queryMagnitude += (float) i;
}
docMagnitude /= dims;
queryMagnitude /= dims;
ByteBuffer byteBuffer = ByteBuffer.allocate(dims * 4 + 4);
for (int i = 0; i < dims; ++i) {
if (normalize) {
docFloatVector[i] /= docMagnitude;
queryVector[i] /= queryMagnitude;
}
byteBuffer.putFloat(docFloatVector[i]);
}
byteBuffer.putFloat(docMagnitude);
this.docVector = new BytesRef(byteBuffer.array());
}
}
private abstract static class KnnByteBenchmarkFunction extends BenchmarkFunction {
final byte[] docVector;
final byte[] queryVector;
final float queryMagnitude;
private KnnByteBenchmarkFunction(int dims) {
super(dims);
ByteBuffer docVector = ByteBuffer.allocate(dims);
queryVector = new byte[dims];
float queryMagnitude = 0f;
for (int i = 0; i < dims; ++i) {
docVector.put((byte) (dims - i));
queryVector[i] = (byte) i;
queryMagnitude += (float) i;
}
this.docVector = docVector.array();
this.queryMagnitude = queryMagnitude / dims;
private static float normalizeVector(float[] vector) {
float mag = calculateMag(vector);
for (int i = 0; i < vector.length; i++) {
vector[i] /= mag;
}
return mag;
}
private abstract static class BinaryByteBenchmarkFunction extends BenchmarkFunction {
final BytesRef docVector;
final byte[] vectorValue;
final byte[] queryVector;
final float queryMagnitude;
private BinaryByteBenchmarkFunction(int dims) {
super(dims);
ByteBuffer docVector = ByteBuffer.allocate(dims + 4);
queryVector = new byte[dims];
vectorValue = new byte[dims];
float docMagnitude = 0f;
float queryMagnitude = 0f;
for (int i = 0; i < dims; ++i) {
docVector.put((byte) (dims - i));
vectorValue[i] = (byte) (dims - i);
queryVector[i] = (byte) i;
docMagnitude += (float) (dims - i);
queryMagnitude += (float) i;
}
docVector.putFloat(docMagnitude / dims);
this.docVector = new BytesRef(docVector.array());
this.queryMagnitude = queryMagnitude / dims;
}
}
private static class DotKnnFloatBenchmarkFunction extends KnnFloatBenchmarkFunction {
private DotKnnFloatBenchmarkFunction(int dims) {
super(dims, false);
}
@Override
public void execute(Consumer<Object> consumer) {
new KnnDenseVector(docVector).dotProduct(queryVector);
}
}
private static class DotKnnByteBenchmarkFunction extends KnnByteBenchmarkFunction {
private DotKnnByteBenchmarkFunction(int dims) {
super(dims);
}
@Override
public void execute(Consumer<Object> consumer) {
new ByteKnnDenseVector(docVector).dotProduct(queryVector);
}
}
private static class DotBinaryFloatBenchmarkFunction extends BinaryFloatBenchmarkFunction {
private DotBinaryFloatBenchmarkFunction(int dims) {
super(dims, false);
}
@Override
public void execute(Consumer<Object> consumer) {
new BinaryDenseVector(docFloatVector, docVector, dims, IndexVersion.current()).dotProduct(queryVector);
}
}
private static class DotBinaryByteBenchmarkFunction extends BinaryByteBenchmarkFunction {
private DotBinaryByteBenchmarkFunction(int dims) {
super(dims);
}
@Override
public void execute(Consumer<Object> consumer) {
new ByteBinaryDenseVector(vectorValue, docVector, dims).dotProduct(queryVector);
}
}
private static class CosineKnnFloatBenchmarkFunction extends KnnFloatBenchmarkFunction {
private CosineKnnFloatBenchmarkFunction(int dims) {
super(dims, true);
private static BytesRef generateVectorData(float[] vector) {
return generateVectorData(vector, calculateMag(vector));
}
@Override
public void execute(Consumer<Object> consumer) {
new KnnDenseVector(docVector).cosineSimilarity(queryVector, false);
private static BytesRef generateVectorData(float[] vector, float mag) {
ByteBuffer buffer = ByteBuffer.allocate(vector.length * 4 + 4);
for (float f : vector) {
buffer.putFloat(f);
}
buffer.putFloat(mag);
return new BytesRef(buffer.array());
}
private static class CosineKnnByteBenchmarkFunction extends KnnByteBenchmarkFunction {
private static BytesRef generateVectorData(byte[] vector) {
float mag = calculateMag(vector);
private CosineKnnByteBenchmarkFunction(int dims) {
super(dims);
ByteBuffer buffer = ByteBuffer.allocate(vector.length + 4);
buffer.put(vector);
buffer.putFloat(mag);
return new BytesRef(buffer.array());
}
@Override
public void execute(Consumer<Object> consumer) {
new ByteKnnDenseVector(docVector).cosineSimilarity(queryVector, queryMagnitude);
}
}
private static class CosineBinaryFloatBenchmarkFunction extends BinaryFloatBenchmarkFunction {
private CosineBinaryFloatBenchmarkFunction(int dims) {
super(dims, true);
}
@Override
public void execute(Consumer<Object> consumer) {
new BinaryDenseVector(docFloatVector, docVector, dims, IndexVersion.current()).cosineSimilarity(queryVector, false);
}
}
private static class CosineBinaryByteBenchmarkFunction extends BinaryByteBenchmarkFunction {
private CosineBinaryByteBenchmarkFunction(int dims) {
super(dims);
}
@Override
public void execute(Consumer<Object> consumer) {
new ByteBinaryDenseVector(vectorValue, docVector, dims).cosineSimilarity(queryVector, queryMagnitude);
}
}
private static class L1KnnFloatBenchmarkFunction extends KnnFloatBenchmarkFunction {
private L1KnnFloatBenchmarkFunction(int dims) {
super(dims, false);
}
@Override
public void execute(Consumer<Object> consumer) {
new KnnDenseVector(docVector).l1Norm(queryVector);
}
}
private static class L1KnnByteBenchmarkFunction extends KnnByteBenchmarkFunction {
private L1KnnByteBenchmarkFunction(int dims) {
super(dims);
}
@Override
public void execute(Consumer<Object> consumer) {
new ByteKnnDenseVector(docVector).l1Norm(queryVector);
}
}
private static class HammingKnnByteBenchmarkFunction extends KnnByteBenchmarkFunction {
private HammingKnnByteBenchmarkFunction(int dims) {
super(dims);
}
@Override
public void execute(Consumer<Object> consumer) {
new ByteKnnDenseVector(docVector).hamming(queryVector);
}
}
private static class L1BinaryFloatBenchmarkFunction extends BinaryFloatBenchmarkFunction {
private L1BinaryFloatBenchmarkFunction(int dims) {
super(dims, true);
}
@Override
public void execute(Consumer<Object> consumer) {
new BinaryDenseVector(docFloatVector, docVector, dims, IndexVersion.current()).l1Norm(queryVector);
}
}
private static class L1BinaryByteBenchmarkFunction extends BinaryByteBenchmarkFunction {
private L1BinaryByteBenchmarkFunction(int dims) {
super(dims);
}
@Override
public void execute(Consumer<Object> consumer) {
new ByteBinaryDenseVector(vectorValue, docVector, dims).l1Norm(queryVector);
}
}
private static class HammingBinaryByteBenchmarkFunction extends BinaryByteBenchmarkFunction {
private HammingBinaryByteBenchmarkFunction(int dims) {
super(dims);
}
@Override
public void execute(Consumer<Object> consumer) {
new ByteBinaryDenseVector(vectorValue, docVector, dims).hamming(queryVector);
}
}
private static class L2KnnFloatBenchmarkFunction extends KnnFloatBenchmarkFunction {
private L2KnnFloatBenchmarkFunction(int dims) {
super(dims, false);
}
@Override
public void execute(Consumer<Object> consumer) {
new KnnDenseVector(docVector).l2Norm(queryVector);
}
}
private static class L2KnnByteBenchmarkFunction extends KnnByteBenchmarkFunction {
private L2KnnByteBenchmarkFunction(int dims) {
super(dims);
}
@Override
public void execute(Consumer<Object> consumer) {
new ByteKnnDenseVector(docVector).l2Norm(queryVector);
}
}
private static class L2BinaryFloatBenchmarkFunction extends BinaryFloatBenchmarkFunction {
private L2BinaryFloatBenchmarkFunction(int dims) {
super(dims, true);
}
@Override
public void execute(Consumer<Object> consumer) {
new BinaryDenseVector(docFloatVector, docVector, dims, IndexVersion.current()).l1Norm(queryVector);
}
}
private static class L2BinaryByteBenchmarkFunction extends BinaryByteBenchmarkFunction {
private L2BinaryByteBenchmarkFunction(int dims) {
super(dims);
}
@Override
public void execute(Consumer<Object> consumer) {
consumer.accept(new ByteBinaryDenseVector(vectorValue, docVector, dims).l2Norm(queryVector));
}
}
private BenchmarkFunction benchmarkFunction;
@Setup
public void setBenchmarkFunction() {
switch (element) {
case "float" -> {
switch (function) {
case "dot" -> benchmarkFunction = switch (type) {
case "knn" -> new DotKnnFloatBenchmarkFunction(dims);
case "binary" -> new DotBinaryFloatBenchmarkFunction(dims);
default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
};
case "cosine" -> benchmarkFunction = switch (type) {
case "knn" -> new CosineKnnFloatBenchmarkFunction(dims);
case "binary" -> new CosineBinaryFloatBenchmarkFunction(dims);
default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
};
case "l1" -> benchmarkFunction = switch (type) {
case "knn" -> new L1KnnFloatBenchmarkFunction(dims);
case "binary" -> new L1BinaryFloatBenchmarkFunction(dims);
default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
};
case "l2" -> benchmarkFunction = switch (type) {
case "knn" -> new L2KnnFloatBenchmarkFunction(dims);
case "binary" -> new L2BinaryFloatBenchmarkFunction(dims);
default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
};
default -> throw new UnsupportedOperationException("unexpected function [" + function + "]");
}
}
case "byte" -> {
switch (function) {
case "dot" -> benchmarkFunction = switch (type) {
case "knn" -> new DotKnnByteBenchmarkFunction(dims);
case "binary" -> new DotBinaryByteBenchmarkFunction(dims);
default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
};
case "cosine" -> benchmarkFunction = switch (type) {
case "knn" -> new CosineKnnByteBenchmarkFunction(dims);
case "binary" -> new CosineBinaryByteBenchmarkFunction(dims);
default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
};
case "l1" -> benchmarkFunction = switch (type) {
case "knn" -> new L1KnnByteBenchmarkFunction(dims);
case "binary" -> new L1BinaryByteBenchmarkFunction(dims);
default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
};
case "l2" -> benchmarkFunction = switch (type) {
case "knn" -> new L2KnnByteBenchmarkFunction(dims);
case "binary" -> new L2BinaryByteBenchmarkFunction(dims);
default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
};
case "hamming" -> benchmarkFunction = switch (type) {
case "knn" -> new HammingKnnByteBenchmarkFunction(dims);
case "binary" -> new HammingBinaryByteBenchmarkFunction(dims);
default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
};
default -> throw new UnsupportedOperationException("unexpected function [" + function + "]");
}
}
default -> throw new UnsupportedOperationException("unexpected element [" + element + "]");
}
;
public void findBenchmarkImpl() {
Random r = new Random();
float[] floatDocVector = new float[dims];
byte[] byteDocVector = new byte[dims];
float[] floatQueryVector = new float[dims];
byte[] byteQueryVector = new byte[dims];
r.nextBytes(byteDocVector);
r.nextBytes(byteQueryVector);
for (int i = 0; i < dims; i++) {
floatDocVector[i] = r.nextFloat();
floatQueryVector[i] = r.nextFloat();
}
DenseVector vectorImpl = switch (docType) {
case FLOAT -> switch (type) {
case KNN -> {
if (function == Function.COSINE) {
normalizeVector(floatDocVector);
normalizeVector(floatQueryVector);
}
yield new KnnDenseVector(floatDocVector);
}
case BINARY -> {
BytesRef vectorData;
if (function == Function.COSINE || function == Function.L1 || function == Function.L2) {
float mag = normalizeVector(floatDocVector);
vectorData = generateVectorData(floatDocVector, mag);
normalizeVector(floatQueryVector);
} else {
vectorData = generateVectorData(floatDocVector);
}
yield new BinaryDenseVector(floatDocVector, vectorData, dims, IndexVersion.current());
}
};
case BYTE -> switch (type) {
case KNN -> new ByteKnnDenseVector(byteDocVector);
case BINARY -> {
BytesRef vectorData = generateVectorData(byteDocVector);
yield new ByteBinaryDenseVector(byteDocVector, vectorData, dims);
}
};
};
benchmarkImpl = switch (queryType) {
case FLOAT -> switch (function) {
case DOT -> () -> vectorImpl.dotProduct(floatQueryVector);
case COSINE -> () -> vectorImpl.cosineSimilarity(floatQueryVector, false);
case L1 -> () -> vectorImpl.l1Norm(floatQueryVector);
case L2 -> () -> vectorImpl.l2Norm(floatQueryVector);
case HAMMING -> throw new UnsupportedOperationException("Unsupported function " + function);
};
case BYTE -> switch (function) {
case DOT -> () -> vectorImpl.dotProduct(byteQueryVector);
case COSINE -> {
float mag = calculateMag(byteQueryVector);
yield () -> vectorImpl.cosineSimilarity(byteQueryVector, mag);
}
case L1 -> () -> vectorImpl.l1Norm(byteQueryVector);
case L2 -> () -> vectorImpl.l2Norm(byteQueryVector);
case HAMMING -> () -> vectorImpl.hamming(byteQueryVector);
};
};
}
@Fork(1)
@Benchmark
public void benchmark() throws IOException {
public void benchmark(Blackhole blackhole) {
for (int i = 0; i < 25000; ++i) {
benchmarkFunction.execute(Object::toString);
blackhole.consume(benchmarkImpl.getAsDouble());
}
}
@Fork(value = 1, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
@Benchmark
public void vectorBenchmark(Blackhole blackhole) {
for (int i = 0; i < 25000; ++i) {
blackhole.consume(benchmarkImpl.getAsDouble());
}
}
}