diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/DistanceFunctionBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/DistanceFunctionBenchmark.java index 187d30097947..ed6ae4e46459 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/DistanceFunctionBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/DistanceFunctionBenchmark.java @@ -10,10 +10,12 @@ package org.elasticsearch.benchmark.vector; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.logging.LogConfigurator; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.script.field.vectors.BinaryDenseVector; import org.elasticsearch.script.field.vectors.ByteBinaryDenseVector; import org.elasticsearch.script.field.vectors.ByteKnnDenseVector; +import org.elasticsearch.script.field.vectors.DenseVector; import org.elasticsearch.script.field.vectors.KnnDenseVector; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; @@ -27,11 +29,12 @@ import org.openjdk.jmh.annotations.Scope; import org.openjdk.jmh.annotations.Setup; import org.openjdk.jmh.annotations.State; import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; -import java.io.IOException; import java.nio.ByteBuffer; +import java.util.Random; import java.util.concurrent.TimeUnit; -import java.util.function.Consumer; +import java.util.function.DoubleSupplier; /** * Various benchmarks for the distance functions @@ -51,451 +54,172 @@ import java.util.function.Consumer; @State(Scope.Benchmark) public class DistanceFunctionBenchmark { - @Param({ "float", "byte" }) - private String element; + static { + LogConfigurator.configureESLogging(); + } - @Param({ "96" }) + public enum VectorType { + FLOAT, + BYTE + } + + public enum Function { + DOT, + COSINE, + L1, + L2, + HAMMING + } + + public enum Implementation { + KNN, + BINARY + } + + @Param + private VectorType docType; + + @Param + private VectorType queryType; + + @Param({ "1024" }) private int dims; - @Param({ "dot", "cosine", "l1", "l2", "hamming" }) - private String function; + @Param + private Function function; - @Param({ "knn", "binary" }) - private String type; + @Param + private Implementation type; - private abstract static class BenchmarkFunction { + private DoubleSupplier benchmarkImpl; - final int dims; - - private BenchmarkFunction(int dims) { - this.dims = dims; + private static float calculateMag(float[] vector) { + float mag = 0; + for (float f : vector) { + mag += f * f; } - - abstract void execute(Consumer consumer); + return (float) Math.sqrt(mag); } - private abstract static class KnnFloatBenchmarkFunction extends BenchmarkFunction { - - final float[] docVector; - final float[] queryVector; - - private KnnFloatBenchmarkFunction(int dims, boolean normalize) { - super(dims); - - docVector = new float[dims]; - queryVector = new float[dims]; - - float docMagnitude = 0f; - float queryMagnitude = 0f; - - for (int i = 0; i < dims; ++i) { - docVector[i] = (float) (dims - i); - queryVector[i] = (float) i; - - docMagnitude += (float) (dims - i); - queryMagnitude += (float) i; - } - - docMagnitude /= dims; - queryMagnitude /= dims; - - if (normalize) { - for (int i = 0; i < dims; ++i) { - docVector[i] /= docMagnitude; - queryVector[i] /= queryMagnitude; - } - } + private static float calculateMag(byte[] vector) { + float mag = 0; + for (byte b : vector) { + mag += b * b; } + return (float) Math.sqrt(mag); } - private abstract static class BinaryFloatBenchmarkFunction extends BenchmarkFunction { - - final BytesRef docVector; - final float[] docFloatVector; - final float[] queryVector; - - private BinaryFloatBenchmarkFunction(int dims, boolean normalize) { - super(dims); - - docFloatVector = new float[dims]; - queryVector = new float[dims]; - - float docMagnitude = 0f; - float queryMagnitude = 0f; - - for (int i = 0; i < dims; ++i) { - docFloatVector[i] = (float) (dims - i); - queryVector[i] = (float) i; - - docMagnitude += (float) (dims - i); - queryMagnitude += (float) i; - } - - docMagnitude /= dims; - queryMagnitude /= dims; - - ByteBuffer byteBuffer = ByteBuffer.allocate(dims * 4 + 4); - - for (int i = 0; i < dims; ++i) { - if (normalize) { - docFloatVector[i] /= docMagnitude; - queryVector[i] /= queryMagnitude; - } - - byteBuffer.putFloat(docFloatVector[i]); - } - - byteBuffer.putFloat(docMagnitude); - this.docVector = new BytesRef(byteBuffer.array()); + private static float normalizeVector(float[] vector) { + float mag = calculateMag(vector); + for (int i = 0; i < vector.length; i++) { + vector[i] /= mag; } + return mag; } - private abstract static class KnnByteBenchmarkFunction extends BenchmarkFunction { - - final byte[] docVector; - final byte[] queryVector; - - final float queryMagnitude; - - private KnnByteBenchmarkFunction(int dims) { - super(dims); - - ByteBuffer docVector = ByteBuffer.allocate(dims); - queryVector = new byte[dims]; - - float queryMagnitude = 0f; - - for (int i = 0; i < dims; ++i) { - docVector.put((byte) (dims - i)); - queryVector[i] = (byte) i; - - queryMagnitude += (float) i; - } - - this.docVector = docVector.array(); - this.queryMagnitude = queryMagnitude / dims; - } + private static BytesRef generateVectorData(float[] vector) { + return generateVectorData(vector, calculateMag(vector)); } - private abstract static class BinaryByteBenchmarkFunction extends BenchmarkFunction { - - final BytesRef docVector; - final byte[] vectorValue; - final byte[] queryVector; - - final float queryMagnitude; - - private BinaryByteBenchmarkFunction(int dims) { - super(dims); - - ByteBuffer docVector = ByteBuffer.allocate(dims + 4); - queryVector = new byte[dims]; - vectorValue = new byte[dims]; - - float docMagnitude = 0f; - float queryMagnitude = 0f; - - for (int i = 0; i < dims; ++i) { - docVector.put((byte) (dims - i)); - vectorValue[i] = (byte) (dims - i); - queryVector[i] = (byte) i; - - docMagnitude += (float) (dims - i); - queryMagnitude += (float) i; - } - - docVector.putFloat(docMagnitude / dims); - this.docVector = new BytesRef(docVector.array()); - this.queryMagnitude = queryMagnitude / dims; - + private static BytesRef generateVectorData(float[] vector, float mag) { + ByteBuffer buffer = ByteBuffer.allocate(vector.length * 4 + 4); + for (float f : vector) { + buffer.putFloat(f); } + buffer.putFloat(mag); + return new BytesRef(buffer.array()); } - private static class DotKnnFloatBenchmarkFunction extends KnnFloatBenchmarkFunction { + private static BytesRef generateVectorData(byte[] vector) { + float mag = calculateMag(vector); - private DotKnnFloatBenchmarkFunction(int dims) { - super(dims, false); - } - - @Override - public void execute(Consumer consumer) { - new KnnDenseVector(docVector).dotProduct(queryVector); - } + ByteBuffer buffer = ByteBuffer.allocate(vector.length + 4); + buffer.put(vector); + buffer.putFloat(mag); + return new BytesRef(buffer.array()); } - private static class DotKnnByteBenchmarkFunction extends KnnByteBenchmarkFunction { - - private DotKnnByteBenchmarkFunction(int dims) { - super(dims); - } - - @Override - public void execute(Consumer consumer) { - new ByteKnnDenseVector(docVector).dotProduct(queryVector); - } - } - - private static class DotBinaryFloatBenchmarkFunction extends BinaryFloatBenchmarkFunction { - - private DotBinaryFloatBenchmarkFunction(int dims) { - super(dims, false); - } - - @Override - public void execute(Consumer consumer) { - new BinaryDenseVector(docFloatVector, docVector, dims, IndexVersion.current()).dotProduct(queryVector); - } - } - - private static class DotBinaryByteBenchmarkFunction extends BinaryByteBenchmarkFunction { - - private DotBinaryByteBenchmarkFunction(int dims) { - super(dims); - } - - @Override - public void execute(Consumer consumer) { - new ByteBinaryDenseVector(vectorValue, docVector, dims).dotProduct(queryVector); - } - } - - private static class CosineKnnFloatBenchmarkFunction extends KnnFloatBenchmarkFunction { - - private CosineKnnFloatBenchmarkFunction(int dims) { - super(dims, true); - } - - @Override - public void execute(Consumer consumer) { - new KnnDenseVector(docVector).cosineSimilarity(queryVector, false); - } - } - - private static class CosineKnnByteBenchmarkFunction extends KnnByteBenchmarkFunction { - - private CosineKnnByteBenchmarkFunction(int dims) { - super(dims); - } - - @Override - public void execute(Consumer consumer) { - new ByteKnnDenseVector(docVector).cosineSimilarity(queryVector, queryMagnitude); - } - } - - private static class CosineBinaryFloatBenchmarkFunction extends BinaryFloatBenchmarkFunction { - - private CosineBinaryFloatBenchmarkFunction(int dims) { - super(dims, true); - } - - @Override - public void execute(Consumer consumer) { - new BinaryDenseVector(docFloatVector, docVector, dims, IndexVersion.current()).cosineSimilarity(queryVector, false); - } - } - - private static class CosineBinaryByteBenchmarkFunction extends BinaryByteBenchmarkFunction { - - private CosineBinaryByteBenchmarkFunction(int dims) { - super(dims); - } - - @Override - public void execute(Consumer consumer) { - new ByteBinaryDenseVector(vectorValue, docVector, dims).cosineSimilarity(queryVector, queryMagnitude); - } - } - - private static class L1KnnFloatBenchmarkFunction extends KnnFloatBenchmarkFunction { - - private L1KnnFloatBenchmarkFunction(int dims) { - super(dims, false); - } - - @Override - public void execute(Consumer consumer) { - new KnnDenseVector(docVector).l1Norm(queryVector); - } - } - - private static class L1KnnByteBenchmarkFunction extends KnnByteBenchmarkFunction { - - private L1KnnByteBenchmarkFunction(int dims) { - super(dims); - } - - @Override - public void execute(Consumer consumer) { - new ByteKnnDenseVector(docVector).l1Norm(queryVector); - } - } - - private static class HammingKnnByteBenchmarkFunction extends KnnByteBenchmarkFunction { - - private HammingKnnByteBenchmarkFunction(int dims) { - super(dims); - } - - @Override - public void execute(Consumer consumer) { - new ByteKnnDenseVector(docVector).hamming(queryVector); - } - } - - private static class L1BinaryFloatBenchmarkFunction extends BinaryFloatBenchmarkFunction { - - private L1BinaryFloatBenchmarkFunction(int dims) { - super(dims, true); - } - - @Override - public void execute(Consumer consumer) { - new BinaryDenseVector(docFloatVector, docVector, dims, IndexVersion.current()).l1Norm(queryVector); - } - } - - private static class L1BinaryByteBenchmarkFunction extends BinaryByteBenchmarkFunction { - - private L1BinaryByteBenchmarkFunction(int dims) { - super(dims); - } - - @Override - public void execute(Consumer consumer) { - new ByteBinaryDenseVector(vectorValue, docVector, dims).l1Norm(queryVector); - } - } - - private static class HammingBinaryByteBenchmarkFunction extends BinaryByteBenchmarkFunction { - - private HammingBinaryByteBenchmarkFunction(int dims) { - super(dims); - } - - @Override - public void execute(Consumer consumer) { - new ByteBinaryDenseVector(vectorValue, docVector, dims).hamming(queryVector); - } - } - - private static class L2KnnFloatBenchmarkFunction extends KnnFloatBenchmarkFunction { - - private L2KnnFloatBenchmarkFunction(int dims) { - super(dims, false); - } - - @Override - public void execute(Consumer consumer) { - new KnnDenseVector(docVector).l2Norm(queryVector); - } - } - - private static class L2KnnByteBenchmarkFunction extends KnnByteBenchmarkFunction { - - private L2KnnByteBenchmarkFunction(int dims) { - super(dims); - } - - @Override - public void execute(Consumer consumer) { - new ByteKnnDenseVector(docVector).l2Norm(queryVector); - } - } - - private static class L2BinaryFloatBenchmarkFunction extends BinaryFloatBenchmarkFunction { - - private L2BinaryFloatBenchmarkFunction(int dims) { - super(dims, true); - } - - @Override - public void execute(Consumer consumer) { - new BinaryDenseVector(docFloatVector, docVector, dims, IndexVersion.current()).l1Norm(queryVector); - } - } - - private static class L2BinaryByteBenchmarkFunction extends BinaryByteBenchmarkFunction { - - private L2BinaryByteBenchmarkFunction(int dims) { - super(dims); - } - - @Override - public void execute(Consumer consumer) { - consumer.accept(new ByteBinaryDenseVector(vectorValue, docVector, dims).l2Norm(queryVector)); - } - } - - private BenchmarkFunction benchmarkFunction; - @Setup - public void setBenchmarkFunction() { - switch (element) { - case "float" -> { - switch (function) { - case "dot" -> benchmarkFunction = switch (type) { - case "knn" -> new DotKnnFloatBenchmarkFunction(dims); - case "binary" -> new DotBinaryFloatBenchmarkFunction(dims); - default -> throw new UnsupportedOperationException("unexpected type [" + type + "]"); - }; - case "cosine" -> benchmarkFunction = switch (type) { - case "knn" -> new CosineKnnFloatBenchmarkFunction(dims); - case "binary" -> new CosineBinaryFloatBenchmarkFunction(dims); - default -> throw new UnsupportedOperationException("unexpected type [" + type + "]"); - }; - case "l1" -> benchmarkFunction = switch (type) { - case "knn" -> new L1KnnFloatBenchmarkFunction(dims); - case "binary" -> new L1BinaryFloatBenchmarkFunction(dims); - default -> throw new UnsupportedOperationException("unexpected type [" + type + "]"); - }; - case "l2" -> benchmarkFunction = switch (type) { - case "knn" -> new L2KnnFloatBenchmarkFunction(dims); - case "binary" -> new L2BinaryFloatBenchmarkFunction(dims); - default -> throw new UnsupportedOperationException("unexpected type [" + type + "]"); - }; - default -> throw new UnsupportedOperationException("unexpected function [" + function + "]"); - } - } - case "byte" -> { - switch (function) { - case "dot" -> benchmarkFunction = switch (type) { - case "knn" -> new DotKnnByteBenchmarkFunction(dims); - case "binary" -> new DotBinaryByteBenchmarkFunction(dims); - default -> throw new UnsupportedOperationException("unexpected type [" + type + "]"); - }; - case "cosine" -> benchmarkFunction = switch (type) { - case "knn" -> new CosineKnnByteBenchmarkFunction(dims); - case "binary" -> new CosineBinaryByteBenchmarkFunction(dims); - default -> throw new UnsupportedOperationException("unexpected type [" + type + "]"); - }; - case "l1" -> benchmarkFunction = switch (type) { - case "knn" -> new L1KnnByteBenchmarkFunction(dims); - case "binary" -> new L1BinaryByteBenchmarkFunction(dims); - default -> throw new UnsupportedOperationException("unexpected type [" + type + "]"); - }; - case "l2" -> benchmarkFunction = switch (type) { - case "knn" -> new L2KnnByteBenchmarkFunction(dims); - case "binary" -> new L2BinaryByteBenchmarkFunction(dims); - default -> throw new UnsupportedOperationException("unexpected type [" + type + "]"); - }; - case "hamming" -> benchmarkFunction = switch (type) { - case "knn" -> new HammingKnnByteBenchmarkFunction(dims); - case "binary" -> new HammingBinaryByteBenchmarkFunction(dims); - default -> throw new UnsupportedOperationException("unexpected type [" + type + "]"); - }; - default -> throw new UnsupportedOperationException("unexpected function [" + function + "]"); - } - } - default -> throw new UnsupportedOperationException("unexpected element [" + element + "]"); + public void findBenchmarkImpl() { + Random r = new Random(); + + float[] floatDocVector = new float[dims]; + byte[] byteDocVector = new byte[dims]; + + float[] floatQueryVector = new float[dims]; + byte[] byteQueryVector = new byte[dims]; + + r.nextBytes(byteDocVector); + r.nextBytes(byteQueryVector); + for (int i = 0; i < dims; i++) { + floatDocVector[i] = r.nextFloat(); + floatQueryVector[i] = r.nextFloat(); } - ; + + DenseVector vectorImpl = switch (docType) { + case FLOAT -> switch (type) { + case KNN -> { + if (function == Function.COSINE) { + normalizeVector(floatDocVector); + normalizeVector(floatQueryVector); + } + yield new KnnDenseVector(floatDocVector); + } + case BINARY -> { + BytesRef vectorData; + if (function == Function.COSINE || function == Function.L1 || function == Function.L2) { + float mag = normalizeVector(floatDocVector); + vectorData = generateVectorData(floatDocVector, mag); + normalizeVector(floatQueryVector); + } else { + vectorData = generateVectorData(floatDocVector); + } + yield new BinaryDenseVector(floatDocVector, vectorData, dims, IndexVersion.current()); + } + }; + case BYTE -> switch (type) { + case KNN -> new ByteKnnDenseVector(byteDocVector); + case BINARY -> { + BytesRef vectorData = generateVectorData(byteDocVector); + yield new ByteBinaryDenseVector(byteDocVector, vectorData, dims); + } + }; + }; + + benchmarkImpl = switch (queryType) { + case FLOAT -> switch (function) { + case DOT -> () -> vectorImpl.dotProduct(floatQueryVector); + case COSINE -> () -> vectorImpl.cosineSimilarity(floatQueryVector, false); + case L1 -> () -> vectorImpl.l1Norm(floatQueryVector); + case L2 -> () -> vectorImpl.l2Norm(floatQueryVector); + case HAMMING -> throw new UnsupportedOperationException("Unsupported function " + function); + }; + case BYTE -> switch (function) { + case DOT -> () -> vectorImpl.dotProduct(byteQueryVector); + case COSINE -> { + float mag = calculateMag(byteQueryVector); + yield () -> vectorImpl.cosineSimilarity(byteQueryVector, mag); + } + case L1 -> () -> vectorImpl.l1Norm(byteQueryVector); + case L2 -> () -> vectorImpl.l2Norm(byteQueryVector); + case HAMMING -> () -> vectorImpl.hamming(byteQueryVector); + }; + }; } + @Fork(1) @Benchmark - public void benchmark() throws IOException { + public void benchmark(Blackhole blackhole) { for (int i = 0; i < 25000; ++i) { - benchmarkFunction.execute(Object::toString); + blackhole.consume(benchmarkImpl.getAsDouble()); + } + } + + @Fork(value = 1, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + @Benchmark + public void vectorBenchmark(Blackhole blackhole) { + for (int i = 0; i < 25000; ++i) { + blackhole.consume(benchmarkImpl.getAsDouble()); } } }