mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-27 17:10:22 -04:00
JDKVectorLibrary: update low-level bounds checks and add benchmark (#130216)
This commit updates the low-level bounds checks in JDKVectorLibrary and add benchmark, so that we can more easily bench the low-level operations. Note: I added the mr-jar gradle plugin to the benchmarks so that we can compile with preview features in Java 21, namely MemorySegment.
This commit is contained in:
parent
5dcded20a9
commit
4d3b699067
6 changed files with 225 additions and 18 deletions
|
@ -13,6 +13,7 @@ import org.elasticsearch.gradle.OS
|
|||
apply plugin: org.elasticsearch.gradle.internal.ElasticsearchJavaBasePlugin
|
||||
apply plugin: 'java-library'
|
||||
apply plugin: 'application'
|
||||
apply plugin: 'elasticsearch.mrjar'
|
||||
|
||||
var os = org.gradle.internal.os.OperatingSystem.current()
|
||||
|
||||
|
@ -46,6 +47,7 @@ dependencies {
|
|||
api(project(':x-pack:plugin:core'))
|
||||
api(project(':x-pack:plugin:esql'))
|
||||
api(project(':x-pack:plugin:esql:compute'))
|
||||
implementation project(path: ':libs:native')
|
||||
implementation project(path: ':libs:simdvec')
|
||||
expression(project(path: ':modules:lang-expression', configuration: 'zip'))
|
||||
painless(project(path: ':modules:lang-painless', configuration: 'zip'))
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
package org.elasticsearch.benchmark.vector;
|
||||
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.elasticsearch.common.logging.LogConfigurator;
|
||||
import org.elasticsearch.common.logging.NodeNamePatternConverter;
|
||||
import org.elasticsearch.nativeaccess.NativeAccess;
|
||||
import org.elasticsearch.nativeaccess.VectorSimilarityFunctions;
|
||||
import org.openjdk.jmh.annotations.Benchmark;
|
||||
import org.openjdk.jmh.annotations.BenchmarkMode;
|
||||
import org.openjdk.jmh.annotations.Fork;
|
||||
import org.openjdk.jmh.annotations.Level;
|
||||
import org.openjdk.jmh.annotations.Measurement;
|
||||
import org.openjdk.jmh.annotations.Mode;
|
||||
import org.openjdk.jmh.annotations.OutputTimeUnit;
|
||||
import org.openjdk.jmh.annotations.Param;
|
||||
import org.openjdk.jmh.annotations.Scope;
|
||||
import org.openjdk.jmh.annotations.Setup;
|
||||
import org.openjdk.jmh.annotations.State;
|
||||
import org.openjdk.jmh.annotations.TearDown;
|
||||
import org.openjdk.jmh.annotations.Warmup;
|
||||
|
||||
import java.lang.foreign.Arena;
|
||||
import java.lang.foreign.MemorySegment;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
@BenchmarkMode(Mode.AverageTime)
|
||||
@OutputTimeUnit(TimeUnit.NANOSECONDS)
|
||||
@State(Scope.Benchmark)
|
||||
@Warmup(iterations = 3, time = 1)
|
||||
@Measurement(iterations = 5, time = 1)
|
||||
public class JDKVectorInt7uBenchmark {
|
||||
|
||||
static {
|
||||
NodeNamePatternConverter.setGlobalNodeName("foo");
|
||||
LogConfigurator.loadLog4jPlugins();
|
||||
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
|
||||
}
|
||||
|
||||
byte[] byteArrayA;
|
||||
byte[] byteArrayB;
|
||||
MemorySegment heapSegA, heapSegB;
|
||||
MemorySegment nativeSegA, nativeSegB;
|
||||
|
||||
Arena arena;
|
||||
|
||||
@Param({ "1", "128", "207", "256", "300", "512", "702", "1024" })
|
||||
public int size;
|
||||
|
||||
@Setup(Level.Iteration)
|
||||
public void init() {
|
||||
byteArrayA = new byte[size];
|
||||
byteArrayB = new byte[size];
|
||||
for (int i = 0; i < size; ++i) {
|
||||
randomInt7BytesBetween(byteArrayA);
|
||||
randomInt7BytesBetween(byteArrayB);
|
||||
}
|
||||
heapSegA = MemorySegment.ofArray(byteArrayA);
|
||||
heapSegB = MemorySegment.ofArray(byteArrayB);
|
||||
|
||||
arena = Arena.ofConfined();
|
||||
nativeSegA = arena.allocate((long) byteArrayA.length);
|
||||
MemorySegment.copy(MemorySegment.ofArray(byteArrayA), 0L, nativeSegA, 0L, byteArrayA.length);
|
||||
nativeSegB = arena.allocate((long) byteArrayB.length);
|
||||
MemorySegment.copy(MemorySegment.ofArray(byteArrayB), 0L, nativeSegB, 0L, byteArrayB.length);
|
||||
}
|
||||
|
||||
@TearDown
|
||||
public void teardown() {
|
||||
arena.close();
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
|
||||
public int dotProductLucene() {
|
||||
return VectorUtil.dotProduct(byteArrayA, byteArrayB);
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
|
||||
public int dotProductNativeWithNativeSeg() {
|
||||
return dotProduct7u(nativeSegA, nativeSegB, size);
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
|
||||
public int dotProductNativeWithHeapSeg() {
|
||||
return dotProduct7u(heapSegA, heapSegB, size);
|
||||
}
|
||||
|
||||
static final VectorSimilarityFunctions vectorSimilarityFunctions = vectorSimilarityFunctions();
|
||||
|
||||
static VectorSimilarityFunctions vectorSimilarityFunctions() {
|
||||
return NativeAccess.instance().getVectorSimilarityFunctions().get();
|
||||
}
|
||||
|
||||
int dotProduct7u(MemorySegment a, MemorySegment b, int length) {
|
||||
try {
|
||||
return (int) vectorSimilarityFunctions.dotProductHandle7u().invokeExact(a, b, length);
|
||||
} catch (Throwable e) {
|
||||
if (e instanceof Error err) {
|
||||
throw err;
|
||||
} else if (e instanceof RuntimeException re) {
|
||||
throw re;
|
||||
} else {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Unsigned int7 byte vectors have values in the range of 0 to 127 (inclusive).
|
||||
static final byte MIN_INT7_VALUE = 0;
|
||||
static final byte MAX_INT7_VALUE = 127;
|
||||
|
||||
static void randomInt7BytesBetween(byte[] bytes) {
|
||||
var random = ThreadLocalRandom.current();
|
||||
for (int i = 0, len = bytes.length; i < len;) {
|
||||
bytes[i++] = (byte) random.nextInt(MIN_INT7_VALUE, MAX_INT7_VALUE + 1);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.benchmark.vector;
|
||||
|
||||
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
|
||||
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.openjdk.jmh.annotations.Param;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
public class JDKVectorInt7uBenchmarkTests extends ESTestCase {
|
||||
|
||||
final double delta = 1e-3;
|
||||
final int size;
|
||||
|
||||
public JDKVectorInt7uBenchmarkTests(int size) {
|
||||
this.size = size;
|
||||
}
|
||||
|
||||
public void testDotProduct() {
|
||||
for (int i = 0; i < 100; i++) {
|
||||
var bench = new JDKVectorInt7uBenchmark();
|
||||
bench.size = size;
|
||||
bench.init();
|
||||
try {
|
||||
float expected = dotProductScalar(bench.byteArrayA, bench.byteArrayB);
|
||||
assertEquals(expected, bench.dotProductLucene(), delta);
|
||||
assertEquals(expected, bench.dotProductNativeWithHeapSeg(), delta);
|
||||
assertEquals(expected, bench.dotProductNativeWithNativeSeg(), delta);
|
||||
} finally {
|
||||
bench.teardown();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ParametersFactory
|
||||
public static Iterable<Object[]> parametersFactory() {
|
||||
try {
|
||||
var params = JDKVectorInt7uBenchmark.class.getField("size").getAnnotationsByType(Param.class)[0].value();
|
||||
return () -> Arrays.stream(params).map(Integer::parseInt).map(i -> new Object[] { i }).iterator();
|
||||
} catch (NoSuchFieldException e) {
|
||||
throw new AssertionError(e);
|
||||
}
|
||||
}
|
||||
|
||||
/** Computes the dot product of the given vectors a and b. */
|
||||
static int dotProductScalar(byte[] a, byte[] b) {
|
||||
int res = 0;
|
||||
for (int i = 0; i < a.length; i++) {
|
||||
res += a[i] * b[i];
|
||||
}
|
||||
return res;
|
||||
}
|
||||
}
|
|
@ -20,6 +20,7 @@ import java.lang.foreign.MemorySegment;
|
|||
import java.lang.invoke.MethodHandle;
|
||||
import java.lang.invoke.MethodHandles;
|
||||
import java.lang.invoke.MethodType;
|
||||
import java.util.Objects;
|
||||
|
||||
import static java.lang.foreign.ValueLayout.ADDRESS;
|
||||
import static java.lang.foreign.ValueLayout.JAVA_INT;
|
||||
|
@ -99,13 +100,8 @@ public final class JdkVectorLibrary implements VectorLibrary {
|
|||
* @param length the vector dimensions
|
||||
*/
|
||||
static int dotProduct7u(MemorySegment a, MemorySegment b, int length) {
|
||||
assert length >= 0;
|
||||
if (a.byteSize() != b.byteSize()) {
|
||||
throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize());
|
||||
}
|
||||
if (length > a.byteSize()) {
|
||||
throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize());
|
||||
}
|
||||
checkByteSize(a, b);
|
||||
Objects.checkFromIndexSize(0, length, (int) a.byteSize());
|
||||
return dot7u(a, b, length);
|
||||
}
|
||||
|
||||
|
@ -119,14 +115,15 @@ public final class JdkVectorLibrary implements VectorLibrary {
|
|||
* @param length the vector dimensions
|
||||
*/
|
||||
static int squareDistance7u(MemorySegment a, MemorySegment b, int length) {
|
||||
assert length >= 0;
|
||||
checkByteSize(a, b);
|
||||
Objects.checkFromIndexSize(0, length, (int) a.byteSize());
|
||||
return sqr7u(a, b, length);
|
||||
}
|
||||
|
||||
static void checkByteSize(MemorySegment a, MemorySegment b) {
|
||||
if (a.byteSize() != b.byteSize()) {
|
||||
throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize());
|
||||
}
|
||||
if (length > a.byteSize()) {
|
||||
throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize());
|
||||
}
|
||||
return sqr7u(a, b, length);
|
||||
}
|
||||
|
||||
private static int dot7u(MemorySegment a, MemorySegment b, int length) {
|
||||
|
|
|
@ -28,6 +28,7 @@ public class JDKVectorLibraryTests extends VectorSimilarityFunctionsTests {
|
|||
static final byte MAX_INT7_VALUE = 127;
|
||||
|
||||
static final Class<IllegalArgumentException> IAE = IllegalArgumentException.class;
|
||||
static final Class<IndexOutOfBoundsException> IOOBE = IndexOutOfBoundsException.class;
|
||||
|
||||
static final int[] VECTOR_DIMS = { 1, 4, 6, 8, 13, 16, 25, 31, 32, 33, 64, 100, 128, 207, 256, 300, 512, 702, 1023, 1024, 1025 };
|
||||
|
||||
|
@ -35,8 +36,11 @@ public class JDKVectorLibraryTests extends VectorSimilarityFunctionsTests {
|
|||
|
||||
static Arena arena;
|
||||
|
||||
final double delta;
|
||||
|
||||
public JDKVectorLibraryTests(int size) {
|
||||
this.size = size;
|
||||
this.delta = 1e-5 * size; // scale the delta with the size
|
||||
}
|
||||
|
||||
@BeforeClass
|
||||
|
@ -103,11 +107,24 @@ public class JDKVectorLibraryTests extends VectorSimilarityFunctionsTests {
|
|||
public void testIllegalDims() {
|
||||
assumeTrue(notSupportedMsg(), supported());
|
||||
var segment = arena.allocate((long) size * 3);
|
||||
var e = expectThrows(IAE, () -> dotProduct7u(segment.asSlice(0L, size), segment.asSlice(size, size + 1), size));
|
||||
assertThat(e.getMessage(), containsString("dimensions differ"));
|
||||
|
||||
e = expectThrows(IAE, () -> dotProduct7u(segment.asSlice(0L, size), segment.asSlice(size, size), size + 1));
|
||||
assertThat(e.getMessage(), containsString("greater than vector dimensions"));
|
||||
var e1 = expectThrows(IAE, () -> dotProduct7u(segment.asSlice(0L, size), segment.asSlice(size, size + 1), size));
|
||||
assertThat(e1.getMessage(), containsString("dimensions differ"));
|
||||
|
||||
var e2 = expectThrows(IOOBE, () -> dotProduct7u(segment.asSlice(0L, size), segment.asSlice(size, size), size + 1));
|
||||
assertThat(e2.getMessage(), containsString("out of bounds for length"));
|
||||
|
||||
var e3 = expectThrows(IOOBE, () -> dotProduct7u(segment.asSlice(0L, size), segment.asSlice(size, size), -1));
|
||||
assertThat(e3.getMessage(), containsString("out of bounds for length"));
|
||||
|
||||
var e4 = expectThrows(IAE, () -> squareDistance7u(segment.asSlice(0L, size), segment.asSlice(size, size + 1), size));
|
||||
assertThat(e4.getMessage(), containsString("dimensions differ"));
|
||||
|
||||
var e5 = expectThrows(IOOBE, () -> squareDistance7u(segment.asSlice(0L, size), segment.asSlice(size, size), size + 1));
|
||||
assertThat(e5.getMessage(), containsString("out of bounds for length"));
|
||||
|
||||
var e6 = expectThrows(IOOBE, () -> squareDistance7u(segment.asSlice(0L, size), segment.asSlice(size, size), -1));
|
||||
assertThat(e6.getMessage(), containsString("out of bounds for length"));
|
||||
}
|
||||
|
||||
int dotProduct7u(MemorySegment a, MemorySegment b, int length) {
|
||||
|
|
|
@ -23,10 +23,10 @@ var os = org.gradle.internal.os.OperatingSystem.current()
|
|||
// 1. Temporarily comment out the download in libs/native/library/build.gradle
|
||||
// libs "org.elasticsearch:vec:${vecVersion}@zip"
|
||||
// 2. Copy your locally built libvec binary, e.g.
|
||||
// cp libs/vec/native/build/libs/vec/shared/libvec.dylib libs/native/libraries/build/platform/darwin-aarch64/libvec.dylib
|
||||
// cp libs/simdvec/native/build/libs/vec/shared/aarch64/libvec.dylib libs/native/libraries/build/platform/darwin-aarch64/libvec.dylib
|
||||
//
|
||||
// Look at the disassemble:
|
||||
// objdump --disassemble-symbols=_dot8s build/libs/vec/shared/libvec.dylib
|
||||
// objdump --disassemble-symbols=_dot7u build/libs/vec/shared/aarch64/libvec.dylib
|
||||
// Note: symbol decoration may differ on Linux, i.e. the leading underscore is not present
|
||||
//
|
||||
// gcc -shared -fpic -o libvec.so -I src/vec/headers/ src/vec/c/vec.c -O3
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue