mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 01:22:26 -04:00
Refactor JMH script vector distance benchmark to add panama benchmarks (#124351)
Add vector benchmarks vs scalar, and automatically pick up new implementations as they get added
This commit is contained in:
parent
f6538e86e2
commit
d7864f4af6
1 changed files with 142 additions and 418 deletions
|
@ -10,10 +10,12 @@
|
||||||
package org.elasticsearch.benchmark.vector;
|
package org.elasticsearch.benchmark.vector;
|
||||||
|
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
import org.elasticsearch.common.logging.LogConfigurator;
|
||||||
import org.elasticsearch.index.IndexVersion;
|
import org.elasticsearch.index.IndexVersion;
|
||||||
import org.elasticsearch.script.field.vectors.BinaryDenseVector;
|
import org.elasticsearch.script.field.vectors.BinaryDenseVector;
|
||||||
import org.elasticsearch.script.field.vectors.ByteBinaryDenseVector;
|
import org.elasticsearch.script.field.vectors.ByteBinaryDenseVector;
|
||||||
import org.elasticsearch.script.field.vectors.ByteKnnDenseVector;
|
import org.elasticsearch.script.field.vectors.ByteKnnDenseVector;
|
||||||
|
import org.elasticsearch.script.field.vectors.DenseVector;
|
||||||
import org.elasticsearch.script.field.vectors.KnnDenseVector;
|
import org.elasticsearch.script.field.vectors.KnnDenseVector;
|
||||||
import org.openjdk.jmh.annotations.Benchmark;
|
import org.openjdk.jmh.annotations.Benchmark;
|
||||||
import org.openjdk.jmh.annotations.BenchmarkMode;
|
import org.openjdk.jmh.annotations.BenchmarkMode;
|
||||||
|
@ -27,11 +29,12 @@ import org.openjdk.jmh.annotations.Scope;
|
||||||
import org.openjdk.jmh.annotations.Setup;
|
import org.openjdk.jmh.annotations.Setup;
|
||||||
import org.openjdk.jmh.annotations.State;
|
import org.openjdk.jmh.annotations.State;
|
||||||
import org.openjdk.jmh.annotations.Warmup;
|
import org.openjdk.jmh.annotations.Warmup;
|
||||||
|
import org.openjdk.jmh.infra.Blackhole;
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
|
import java.util.Random;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
import java.util.function.Consumer;
|
import java.util.function.DoubleSupplier;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Various benchmarks for the distance functions
|
* Various benchmarks for the distance functions
|
||||||
|
@ -51,451 +54,172 @@ import java.util.function.Consumer;
|
||||||
@State(Scope.Benchmark)
|
@State(Scope.Benchmark)
|
||||||
public class DistanceFunctionBenchmark {
|
public class DistanceFunctionBenchmark {
|
||||||
|
|
||||||
@Param({ "float", "byte" })
|
static {
|
||||||
private String element;
|
LogConfigurator.configureESLogging();
|
||||||
|
}
|
||||||
|
|
||||||
@Param({ "96" })
|
public enum VectorType {
|
||||||
|
FLOAT,
|
||||||
|
BYTE
|
||||||
|
}
|
||||||
|
|
||||||
|
public enum Function {
|
||||||
|
DOT,
|
||||||
|
COSINE,
|
||||||
|
L1,
|
||||||
|
L2,
|
||||||
|
HAMMING
|
||||||
|
}
|
||||||
|
|
||||||
|
public enum Implementation {
|
||||||
|
KNN,
|
||||||
|
BINARY
|
||||||
|
}
|
||||||
|
|
||||||
|
@Param
|
||||||
|
private VectorType docType;
|
||||||
|
|
||||||
|
@Param
|
||||||
|
private VectorType queryType;
|
||||||
|
|
||||||
|
@Param({ "1024" })
|
||||||
private int dims;
|
private int dims;
|
||||||
|
|
||||||
@Param({ "dot", "cosine", "l1", "l2", "hamming" })
|
@Param
|
||||||
private String function;
|
private Function function;
|
||||||
|
|
||||||
@Param({ "knn", "binary" })
|
@Param
|
||||||
private String type;
|
private Implementation type;
|
||||||
|
|
||||||
private abstract static class BenchmarkFunction {
|
private DoubleSupplier benchmarkImpl;
|
||||||
|
|
||||||
final int dims;
|
private static float calculateMag(float[] vector) {
|
||||||
|
float mag = 0;
|
||||||
private BenchmarkFunction(int dims) {
|
for (float f : vector) {
|
||||||
this.dims = dims;
|
mag += f * f;
|
||||||
}
|
}
|
||||||
|
return (float) Math.sqrt(mag);
|
||||||
abstract void execute(Consumer<Object> consumer);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private abstract static class KnnFloatBenchmarkFunction extends BenchmarkFunction {
|
private static float calculateMag(byte[] vector) {
|
||||||
|
float mag = 0;
|
||||||
final float[] docVector;
|
for (byte b : vector) {
|
||||||
final float[] queryVector;
|
mag += b * b;
|
||||||
|
|
||||||
private KnnFloatBenchmarkFunction(int dims, boolean normalize) {
|
|
||||||
super(dims);
|
|
||||||
|
|
||||||
docVector = new float[dims];
|
|
||||||
queryVector = new float[dims];
|
|
||||||
|
|
||||||
float docMagnitude = 0f;
|
|
||||||
float queryMagnitude = 0f;
|
|
||||||
|
|
||||||
for (int i = 0; i < dims; ++i) {
|
|
||||||
docVector[i] = (float) (dims - i);
|
|
||||||
queryVector[i] = (float) i;
|
|
||||||
|
|
||||||
docMagnitude += (float) (dims - i);
|
|
||||||
queryMagnitude += (float) i;
|
|
||||||
}
|
|
||||||
|
|
||||||
docMagnitude /= dims;
|
|
||||||
queryMagnitude /= dims;
|
|
||||||
|
|
||||||
if (normalize) {
|
|
||||||
for (int i = 0; i < dims; ++i) {
|
|
||||||
docVector[i] /= docMagnitude;
|
|
||||||
queryVector[i] /= queryMagnitude;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
return (float) Math.sqrt(mag);
|
||||||
}
|
}
|
||||||
|
|
||||||
private abstract static class BinaryFloatBenchmarkFunction extends BenchmarkFunction {
|
private static float normalizeVector(float[] vector) {
|
||||||
|
float mag = calculateMag(vector);
|
||||||
final BytesRef docVector;
|
for (int i = 0; i < vector.length; i++) {
|
||||||
final float[] docFloatVector;
|
vector[i] /= mag;
|
||||||
final float[] queryVector;
|
|
||||||
|
|
||||||
private BinaryFloatBenchmarkFunction(int dims, boolean normalize) {
|
|
||||||
super(dims);
|
|
||||||
|
|
||||||
docFloatVector = new float[dims];
|
|
||||||
queryVector = new float[dims];
|
|
||||||
|
|
||||||
float docMagnitude = 0f;
|
|
||||||
float queryMagnitude = 0f;
|
|
||||||
|
|
||||||
for (int i = 0; i < dims; ++i) {
|
|
||||||
docFloatVector[i] = (float) (dims - i);
|
|
||||||
queryVector[i] = (float) i;
|
|
||||||
|
|
||||||
docMagnitude += (float) (dims - i);
|
|
||||||
queryMagnitude += (float) i;
|
|
||||||
}
|
|
||||||
|
|
||||||
docMagnitude /= dims;
|
|
||||||
queryMagnitude /= dims;
|
|
||||||
|
|
||||||
ByteBuffer byteBuffer = ByteBuffer.allocate(dims * 4 + 4);
|
|
||||||
|
|
||||||
for (int i = 0; i < dims; ++i) {
|
|
||||||
if (normalize) {
|
|
||||||
docFloatVector[i] /= docMagnitude;
|
|
||||||
queryVector[i] /= queryMagnitude;
|
|
||||||
}
|
|
||||||
|
|
||||||
byteBuffer.putFloat(docFloatVector[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
byteBuffer.putFloat(docMagnitude);
|
|
||||||
this.docVector = new BytesRef(byteBuffer.array());
|
|
||||||
}
|
}
|
||||||
|
return mag;
|
||||||
}
|
}
|
||||||
|
|
||||||
private abstract static class KnnByteBenchmarkFunction extends BenchmarkFunction {
|
private static BytesRef generateVectorData(float[] vector) {
|
||||||
|
return generateVectorData(vector, calculateMag(vector));
|
||||||
final byte[] docVector;
|
|
||||||
final byte[] queryVector;
|
|
||||||
|
|
||||||
final float queryMagnitude;
|
|
||||||
|
|
||||||
private KnnByteBenchmarkFunction(int dims) {
|
|
||||||
super(dims);
|
|
||||||
|
|
||||||
ByteBuffer docVector = ByteBuffer.allocate(dims);
|
|
||||||
queryVector = new byte[dims];
|
|
||||||
|
|
||||||
float queryMagnitude = 0f;
|
|
||||||
|
|
||||||
for (int i = 0; i < dims; ++i) {
|
|
||||||
docVector.put((byte) (dims - i));
|
|
||||||
queryVector[i] = (byte) i;
|
|
||||||
|
|
||||||
queryMagnitude += (float) i;
|
|
||||||
}
|
|
||||||
|
|
||||||
this.docVector = docVector.array();
|
|
||||||
this.queryMagnitude = queryMagnitude / dims;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private abstract static class BinaryByteBenchmarkFunction extends BenchmarkFunction {
|
private static BytesRef generateVectorData(float[] vector, float mag) {
|
||||||
|
ByteBuffer buffer = ByteBuffer.allocate(vector.length * 4 + 4);
|
||||||
final BytesRef docVector;
|
for (float f : vector) {
|
||||||
final byte[] vectorValue;
|
buffer.putFloat(f);
|
||||||
final byte[] queryVector;
|
|
||||||
|
|
||||||
final float queryMagnitude;
|
|
||||||
|
|
||||||
private BinaryByteBenchmarkFunction(int dims) {
|
|
||||||
super(dims);
|
|
||||||
|
|
||||||
ByteBuffer docVector = ByteBuffer.allocate(dims + 4);
|
|
||||||
queryVector = new byte[dims];
|
|
||||||
vectorValue = new byte[dims];
|
|
||||||
|
|
||||||
float docMagnitude = 0f;
|
|
||||||
float queryMagnitude = 0f;
|
|
||||||
|
|
||||||
for (int i = 0; i < dims; ++i) {
|
|
||||||
docVector.put((byte) (dims - i));
|
|
||||||
vectorValue[i] = (byte) (dims - i);
|
|
||||||
queryVector[i] = (byte) i;
|
|
||||||
|
|
||||||
docMagnitude += (float) (dims - i);
|
|
||||||
queryMagnitude += (float) i;
|
|
||||||
}
|
|
||||||
|
|
||||||
docVector.putFloat(docMagnitude / dims);
|
|
||||||
this.docVector = new BytesRef(docVector.array());
|
|
||||||
this.queryMagnitude = queryMagnitude / dims;
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
buffer.putFloat(mag);
|
||||||
|
return new BytesRef(buffer.array());
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class DotKnnFloatBenchmarkFunction extends KnnFloatBenchmarkFunction {
|
private static BytesRef generateVectorData(byte[] vector) {
|
||||||
|
float mag = calculateMag(vector);
|
||||||
|
|
||||||
private DotKnnFloatBenchmarkFunction(int dims) {
|
ByteBuffer buffer = ByteBuffer.allocate(vector.length + 4);
|
||||||
super(dims, false);
|
buffer.put(vector);
|
||||||
}
|
buffer.putFloat(mag);
|
||||||
|
return new BytesRef(buffer.array());
|
||||||
@Override
|
|
||||||
public void execute(Consumer<Object> consumer) {
|
|
||||||
new KnnDenseVector(docVector).dotProduct(queryVector);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class DotKnnByteBenchmarkFunction extends KnnByteBenchmarkFunction {
|
|
||||||
|
|
||||||
private DotKnnByteBenchmarkFunction(int dims) {
|
|
||||||
super(dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void execute(Consumer<Object> consumer) {
|
|
||||||
new ByteKnnDenseVector(docVector).dotProduct(queryVector);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static class DotBinaryFloatBenchmarkFunction extends BinaryFloatBenchmarkFunction {
|
|
||||||
|
|
||||||
private DotBinaryFloatBenchmarkFunction(int dims) {
|
|
||||||
super(dims, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void execute(Consumer<Object> consumer) {
|
|
||||||
new BinaryDenseVector(docFloatVector, docVector, dims, IndexVersion.current()).dotProduct(queryVector);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static class DotBinaryByteBenchmarkFunction extends BinaryByteBenchmarkFunction {
|
|
||||||
|
|
||||||
private DotBinaryByteBenchmarkFunction(int dims) {
|
|
||||||
super(dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void execute(Consumer<Object> consumer) {
|
|
||||||
new ByteBinaryDenseVector(vectorValue, docVector, dims).dotProduct(queryVector);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static class CosineKnnFloatBenchmarkFunction extends KnnFloatBenchmarkFunction {
|
|
||||||
|
|
||||||
private CosineKnnFloatBenchmarkFunction(int dims) {
|
|
||||||
super(dims, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void execute(Consumer<Object> consumer) {
|
|
||||||
new KnnDenseVector(docVector).cosineSimilarity(queryVector, false);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static class CosineKnnByteBenchmarkFunction extends KnnByteBenchmarkFunction {
|
|
||||||
|
|
||||||
private CosineKnnByteBenchmarkFunction(int dims) {
|
|
||||||
super(dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void execute(Consumer<Object> consumer) {
|
|
||||||
new ByteKnnDenseVector(docVector).cosineSimilarity(queryVector, queryMagnitude);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static class CosineBinaryFloatBenchmarkFunction extends BinaryFloatBenchmarkFunction {
|
|
||||||
|
|
||||||
private CosineBinaryFloatBenchmarkFunction(int dims) {
|
|
||||||
super(dims, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void execute(Consumer<Object> consumer) {
|
|
||||||
new BinaryDenseVector(docFloatVector, docVector, dims, IndexVersion.current()).cosineSimilarity(queryVector, false);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static class CosineBinaryByteBenchmarkFunction extends BinaryByteBenchmarkFunction {
|
|
||||||
|
|
||||||
private CosineBinaryByteBenchmarkFunction(int dims) {
|
|
||||||
super(dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void execute(Consumer<Object> consumer) {
|
|
||||||
new ByteBinaryDenseVector(vectorValue, docVector, dims).cosineSimilarity(queryVector, queryMagnitude);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static class L1KnnFloatBenchmarkFunction extends KnnFloatBenchmarkFunction {
|
|
||||||
|
|
||||||
private L1KnnFloatBenchmarkFunction(int dims) {
|
|
||||||
super(dims, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void execute(Consumer<Object> consumer) {
|
|
||||||
new KnnDenseVector(docVector).l1Norm(queryVector);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static class L1KnnByteBenchmarkFunction extends KnnByteBenchmarkFunction {
|
|
||||||
|
|
||||||
private L1KnnByteBenchmarkFunction(int dims) {
|
|
||||||
super(dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void execute(Consumer<Object> consumer) {
|
|
||||||
new ByteKnnDenseVector(docVector).l1Norm(queryVector);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static class HammingKnnByteBenchmarkFunction extends KnnByteBenchmarkFunction {
|
|
||||||
|
|
||||||
private HammingKnnByteBenchmarkFunction(int dims) {
|
|
||||||
super(dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void execute(Consumer<Object> consumer) {
|
|
||||||
new ByteKnnDenseVector(docVector).hamming(queryVector);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static class L1BinaryFloatBenchmarkFunction extends BinaryFloatBenchmarkFunction {
|
|
||||||
|
|
||||||
private L1BinaryFloatBenchmarkFunction(int dims) {
|
|
||||||
super(dims, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void execute(Consumer<Object> consumer) {
|
|
||||||
new BinaryDenseVector(docFloatVector, docVector, dims, IndexVersion.current()).l1Norm(queryVector);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static class L1BinaryByteBenchmarkFunction extends BinaryByteBenchmarkFunction {
|
|
||||||
|
|
||||||
private L1BinaryByteBenchmarkFunction(int dims) {
|
|
||||||
super(dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void execute(Consumer<Object> consumer) {
|
|
||||||
new ByteBinaryDenseVector(vectorValue, docVector, dims).l1Norm(queryVector);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static class HammingBinaryByteBenchmarkFunction extends BinaryByteBenchmarkFunction {
|
|
||||||
|
|
||||||
private HammingBinaryByteBenchmarkFunction(int dims) {
|
|
||||||
super(dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void execute(Consumer<Object> consumer) {
|
|
||||||
new ByteBinaryDenseVector(vectorValue, docVector, dims).hamming(queryVector);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static class L2KnnFloatBenchmarkFunction extends KnnFloatBenchmarkFunction {
|
|
||||||
|
|
||||||
private L2KnnFloatBenchmarkFunction(int dims) {
|
|
||||||
super(dims, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void execute(Consumer<Object> consumer) {
|
|
||||||
new KnnDenseVector(docVector).l2Norm(queryVector);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static class L2KnnByteBenchmarkFunction extends KnnByteBenchmarkFunction {
|
|
||||||
|
|
||||||
private L2KnnByteBenchmarkFunction(int dims) {
|
|
||||||
super(dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void execute(Consumer<Object> consumer) {
|
|
||||||
new ByteKnnDenseVector(docVector).l2Norm(queryVector);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static class L2BinaryFloatBenchmarkFunction extends BinaryFloatBenchmarkFunction {
|
|
||||||
|
|
||||||
private L2BinaryFloatBenchmarkFunction(int dims) {
|
|
||||||
super(dims, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void execute(Consumer<Object> consumer) {
|
|
||||||
new BinaryDenseVector(docFloatVector, docVector, dims, IndexVersion.current()).l1Norm(queryVector);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static class L2BinaryByteBenchmarkFunction extends BinaryByteBenchmarkFunction {
|
|
||||||
|
|
||||||
private L2BinaryByteBenchmarkFunction(int dims) {
|
|
||||||
super(dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void execute(Consumer<Object> consumer) {
|
|
||||||
consumer.accept(new ByteBinaryDenseVector(vectorValue, docVector, dims).l2Norm(queryVector));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private BenchmarkFunction benchmarkFunction;
|
|
||||||
|
|
||||||
@Setup
|
@Setup
|
||||||
public void setBenchmarkFunction() {
|
public void findBenchmarkImpl() {
|
||||||
switch (element) {
|
Random r = new Random();
|
||||||
case "float" -> {
|
|
||||||
switch (function) {
|
float[] floatDocVector = new float[dims];
|
||||||
case "dot" -> benchmarkFunction = switch (type) {
|
byte[] byteDocVector = new byte[dims];
|
||||||
case "knn" -> new DotKnnFloatBenchmarkFunction(dims);
|
|
||||||
case "binary" -> new DotBinaryFloatBenchmarkFunction(dims);
|
float[] floatQueryVector = new float[dims];
|
||||||
default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
|
byte[] byteQueryVector = new byte[dims];
|
||||||
};
|
|
||||||
case "cosine" -> benchmarkFunction = switch (type) {
|
r.nextBytes(byteDocVector);
|
||||||
case "knn" -> new CosineKnnFloatBenchmarkFunction(dims);
|
r.nextBytes(byteQueryVector);
|
||||||
case "binary" -> new CosineBinaryFloatBenchmarkFunction(dims);
|
for (int i = 0; i < dims; i++) {
|
||||||
default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
|
floatDocVector[i] = r.nextFloat();
|
||||||
};
|
floatQueryVector[i] = r.nextFloat();
|
||||||
case "l1" -> benchmarkFunction = switch (type) {
|
|
||||||
case "knn" -> new L1KnnFloatBenchmarkFunction(dims);
|
|
||||||
case "binary" -> new L1BinaryFloatBenchmarkFunction(dims);
|
|
||||||
default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
|
|
||||||
};
|
|
||||||
case "l2" -> benchmarkFunction = switch (type) {
|
|
||||||
case "knn" -> new L2KnnFloatBenchmarkFunction(dims);
|
|
||||||
case "binary" -> new L2BinaryFloatBenchmarkFunction(dims);
|
|
||||||
default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
|
|
||||||
};
|
|
||||||
default -> throw new UnsupportedOperationException("unexpected function [" + function + "]");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case "byte" -> {
|
|
||||||
switch (function) {
|
|
||||||
case "dot" -> benchmarkFunction = switch (type) {
|
|
||||||
case "knn" -> new DotKnnByteBenchmarkFunction(dims);
|
|
||||||
case "binary" -> new DotBinaryByteBenchmarkFunction(dims);
|
|
||||||
default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
|
|
||||||
};
|
|
||||||
case "cosine" -> benchmarkFunction = switch (type) {
|
|
||||||
case "knn" -> new CosineKnnByteBenchmarkFunction(dims);
|
|
||||||
case "binary" -> new CosineBinaryByteBenchmarkFunction(dims);
|
|
||||||
default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
|
|
||||||
};
|
|
||||||
case "l1" -> benchmarkFunction = switch (type) {
|
|
||||||
case "knn" -> new L1KnnByteBenchmarkFunction(dims);
|
|
||||||
case "binary" -> new L1BinaryByteBenchmarkFunction(dims);
|
|
||||||
default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
|
|
||||||
};
|
|
||||||
case "l2" -> benchmarkFunction = switch (type) {
|
|
||||||
case "knn" -> new L2KnnByteBenchmarkFunction(dims);
|
|
||||||
case "binary" -> new L2BinaryByteBenchmarkFunction(dims);
|
|
||||||
default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
|
|
||||||
};
|
|
||||||
case "hamming" -> benchmarkFunction = switch (type) {
|
|
||||||
case "knn" -> new HammingKnnByteBenchmarkFunction(dims);
|
|
||||||
case "binary" -> new HammingBinaryByteBenchmarkFunction(dims);
|
|
||||||
default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
|
|
||||||
};
|
|
||||||
default -> throw new UnsupportedOperationException("unexpected function [" + function + "]");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
default -> throw new UnsupportedOperationException("unexpected element [" + element + "]");
|
|
||||||
}
|
}
|
||||||
;
|
|
||||||
|
DenseVector vectorImpl = switch (docType) {
|
||||||
|
case FLOAT -> switch (type) {
|
||||||
|
case KNN -> {
|
||||||
|
if (function == Function.COSINE) {
|
||||||
|
normalizeVector(floatDocVector);
|
||||||
|
normalizeVector(floatQueryVector);
|
||||||
|
}
|
||||||
|
yield new KnnDenseVector(floatDocVector);
|
||||||
|
}
|
||||||
|
case BINARY -> {
|
||||||
|
BytesRef vectorData;
|
||||||
|
if (function == Function.COSINE || function == Function.L1 || function == Function.L2) {
|
||||||
|
float mag = normalizeVector(floatDocVector);
|
||||||
|
vectorData = generateVectorData(floatDocVector, mag);
|
||||||
|
normalizeVector(floatQueryVector);
|
||||||
|
} else {
|
||||||
|
vectorData = generateVectorData(floatDocVector);
|
||||||
|
}
|
||||||
|
yield new BinaryDenseVector(floatDocVector, vectorData, dims, IndexVersion.current());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
case BYTE -> switch (type) {
|
||||||
|
case KNN -> new ByteKnnDenseVector(byteDocVector);
|
||||||
|
case BINARY -> {
|
||||||
|
BytesRef vectorData = generateVectorData(byteDocVector);
|
||||||
|
yield new ByteBinaryDenseVector(byteDocVector, vectorData, dims);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
benchmarkImpl = switch (queryType) {
|
||||||
|
case FLOAT -> switch (function) {
|
||||||
|
case DOT -> () -> vectorImpl.dotProduct(floatQueryVector);
|
||||||
|
case COSINE -> () -> vectorImpl.cosineSimilarity(floatQueryVector, false);
|
||||||
|
case L1 -> () -> vectorImpl.l1Norm(floatQueryVector);
|
||||||
|
case L2 -> () -> vectorImpl.l2Norm(floatQueryVector);
|
||||||
|
case HAMMING -> throw new UnsupportedOperationException("Unsupported function " + function);
|
||||||
|
};
|
||||||
|
case BYTE -> switch (function) {
|
||||||
|
case DOT -> () -> vectorImpl.dotProduct(byteQueryVector);
|
||||||
|
case COSINE -> {
|
||||||
|
float mag = calculateMag(byteQueryVector);
|
||||||
|
yield () -> vectorImpl.cosineSimilarity(byteQueryVector, mag);
|
||||||
|
}
|
||||||
|
case L1 -> () -> vectorImpl.l1Norm(byteQueryVector);
|
||||||
|
case L2 -> () -> vectorImpl.l2Norm(byteQueryVector);
|
||||||
|
case HAMMING -> () -> vectorImpl.hamming(byteQueryVector);
|
||||||
|
};
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Fork(1)
|
||||||
@Benchmark
|
@Benchmark
|
||||||
public void benchmark() throws IOException {
|
public void benchmark(Blackhole blackhole) {
|
||||||
for (int i = 0; i < 25000; ++i) {
|
for (int i = 0; i < 25000; ++i) {
|
||||||
benchmarkFunction.execute(Object::toString);
|
blackhole.consume(benchmarkImpl.getAsDouble());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Fork(value = 1, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
|
||||||
|
@Benchmark
|
||||||
|
public void vectorBenchmark(Blackhole blackhole) {
|
||||||
|
for (int i = 0; i < 25000; ++i) {
|
||||||
|
blackhole.consume(benchmarkImpl.getAsDouble());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue