JDKVectorLibrary: update low-level bounds checks and add benchmark (#130216)

This commit updates the low-level bounds checks in JDKVectorLibrary and add benchmark, so that we can more easily bench the low-level operations.

Note: I added the mr-jar gradle plugin to the benchmarks so that we can compile with preview features in Java 21, namely MemorySegment.
This commit is contained in:
Chris Hegarty 2025-06-27 19:21:04 +01:00 committed by GitHub
parent 5dcded20a9
commit 4d3b699067
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 225 additions and 18 deletions

View file

@ -13,6 +13,7 @@ import org.elasticsearch.gradle.OS
apply plugin: org.elasticsearch.gradle.internal.ElasticsearchJavaBasePlugin
apply plugin: 'java-library'
apply plugin: 'application'
apply plugin: 'elasticsearch.mrjar'
var os = org.gradle.internal.os.OperatingSystem.current()
@ -46,6 +47,7 @@ dependencies {
api(project(':x-pack:plugin:core'))
api(project(':x-pack:plugin:esql'))
api(project(':x-pack:plugin:esql:compute'))
implementation project(path: ':libs:native')
implementation project(path: ':libs:simdvec')
expression(project(path: ':modules:lang-expression', configuration: 'zip'))
painless(project(path: ':modules:lang-painless', configuration: 'zip'))

View file

@ -0,0 +1,129 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.benchmark.vector;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.common.logging.LogConfigurator;
import org.elasticsearch.common.logging.NodeNamePatternConverter;
import org.elasticsearch.nativeaccess.NativeAccess;
import org.elasticsearch.nativeaccess.VectorSimilarityFunctions;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.annotations.Warmup;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@State(Scope.Benchmark)
@Warmup(iterations = 3, time = 1)
@Measurement(iterations = 5, time = 1)
public class JDKVectorInt7uBenchmark {
static {
NodeNamePatternConverter.setGlobalNodeName("foo");
LogConfigurator.loadLog4jPlugins();
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
}
byte[] byteArrayA;
byte[] byteArrayB;
MemorySegment heapSegA, heapSegB;
MemorySegment nativeSegA, nativeSegB;
Arena arena;
@Param({ "1", "128", "207", "256", "300", "512", "702", "1024" })
public int size;
@Setup(Level.Iteration)
public void init() {
byteArrayA = new byte[size];
byteArrayB = new byte[size];
for (int i = 0; i < size; ++i) {
randomInt7BytesBetween(byteArrayA);
randomInt7BytesBetween(byteArrayB);
}
heapSegA = MemorySegment.ofArray(byteArrayA);
heapSegB = MemorySegment.ofArray(byteArrayB);
arena = Arena.ofConfined();
nativeSegA = arena.allocate((long) byteArrayA.length);
MemorySegment.copy(MemorySegment.ofArray(byteArrayA), 0L, nativeSegA, 0L, byteArrayA.length);
nativeSegB = arena.allocate((long) byteArrayB.length);
MemorySegment.copy(MemorySegment.ofArray(byteArrayB), 0L, nativeSegB, 0L, byteArrayB.length);
}
@TearDown
public void teardown() {
arena.close();
}
@Benchmark
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public int dotProductLucene() {
return VectorUtil.dotProduct(byteArrayA, byteArrayB);
}
@Benchmark
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public int dotProductNativeWithNativeSeg() {
return dotProduct7u(nativeSegA, nativeSegB, size);
}
@Benchmark
@Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public int dotProductNativeWithHeapSeg() {
return dotProduct7u(heapSegA, heapSegB, size);
}
static final VectorSimilarityFunctions vectorSimilarityFunctions = vectorSimilarityFunctions();
static VectorSimilarityFunctions vectorSimilarityFunctions() {
return NativeAccess.instance().getVectorSimilarityFunctions().get();
}
int dotProduct7u(MemorySegment a, MemorySegment b, int length) {
try {
return (int) vectorSimilarityFunctions.dotProductHandle7u().invokeExact(a, b, length);
} catch (Throwable e) {
if (e instanceof Error err) {
throw err;
} else if (e instanceof RuntimeException re) {
throw re;
} else {
throw new RuntimeException(e);
}
}
}
// Unsigned int7 byte vectors have values in the range of 0 to 127 (inclusive).
static final byte MIN_INT7_VALUE = 0;
static final byte MAX_INT7_VALUE = 127;
static void randomInt7BytesBetween(byte[] bytes) {
var random = ThreadLocalRandom.current();
for (int i = 0, len = bytes.length; i < len;) {
bytes[i++] = (byte) random.nextInt(MIN_INT7_VALUE, MAX_INT7_VALUE + 1);
}
}
}

View file

@ -0,0 +1,62 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.benchmark.vector;
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
import org.elasticsearch.test.ESTestCase;
import org.openjdk.jmh.annotations.Param;
import java.util.Arrays;
public class JDKVectorInt7uBenchmarkTests extends ESTestCase {
final double delta = 1e-3;
final int size;
public JDKVectorInt7uBenchmarkTests(int size) {
this.size = size;
}
public void testDotProduct() {
for (int i = 0; i < 100; i++) {
var bench = new JDKVectorInt7uBenchmark();
bench.size = size;
bench.init();
try {
float expected = dotProductScalar(bench.byteArrayA, bench.byteArrayB);
assertEquals(expected, bench.dotProductLucene(), delta);
assertEquals(expected, bench.dotProductNativeWithHeapSeg(), delta);
assertEquals(expected, bench.dotProductNativeWithNativeSeg(), delta);
} finally {
bench.teardown();
}
}
}
@ParametersFactory
public static Iterable<Object[]> parametersFactory() {
try {
var params = JDKVectorInt7uBenchmark.class.getField("size").getAnnotationsByType(Param.class)[0].value();
return () -> Arrays.stream(params).map(Integer::parseInt).map(i -> new Object[] { i }).iterator();
} catch (NoSuchFieldException e) {
throw new AssertionError(e);
}
}
/** Computes the dot product of the given vectors a and b. */
static int dotProductScalar(byte[] a, byte[] b) {
int res = 0;
for (int i = 0; i < a.length; i++) {
res += a[i] * b[i];
}
return res;
}
}

View file

@ -20,6 +20,7 @@ import java.lang.foreign.MemorySegment;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.util.Objects;
import static java.lang.foreign.ValueLayout.ADDRESS;
import static java.lang.foreign.ValueLayout.JAVA_INT;
@ -99,13 +100,8 @@ public final class JdkVectorLibrary implements VectorLibrary {
* @param length the vector dimensions
*/
static int dotProduct7u(MemorySegment a, MemorySegment b, int length) {
assert length >= 0;
if (a.byteSize() != b.byteSize()) {
throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize());
}
if (length > a.byteSize()) {
throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize());
}
checkByteSize(a, b);
Objects.checkFromIndexSize(0, length, (int) a.byteSize());
return dot7u(a, b, length);
}
@ -119,14 +115,15 @@ public final class JdkVectorLibrary implements VectorLibrary {
* @param length the vector dimensions
*/
static int squareDistance7u(MemorySegment a, MemorySegment b, int length) {
assert length >= 0;
checkByteSize(a, b);
Objects.checkFromIndexSize(0, length, (int) a.byteSize());
return sqr7u(a, b, length);
}
static void checkByteSize(MemorySegment a, MemorySegment b) {
if (a.byteSize() != b.byteSize()) {
throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize());
}
if (length > a.byteSize()) {
throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize());
}
return sqr7u(a, b, length);
}
private static int dot7u(MemorySegment a, MemorySegment b, int length) {

View file

@ -28,6 +28,7 @@ public class JDKVectorLibraryTests extends VectorSimilarityFunctionsTests {
static final byte MAX_INT7_VALUE = 127;
static final Class<IllegalArgumentException> IAE = IllegalArgumentException.class;
static final Class<IndexOutOfBoundsException> IOOBE = IndexOutOfBoundsException.class;
static final int[] VECTOR_DIMS = { 1, 4, 6, 8, 13, 16, 25, 31, 32, 33, 64, 100, 128, 207, 256, 300, 512, 702, 1023, 1024, 1025 };
@ -35,8 +36,11 @@ public class JDKVectorLibraryTests extends VectorSimilarityFunctionsTests {
static Arena arena;
final double delta;
public JDKVectorLibraryTests(int size) {
this.size = size;
this.delta = 1e-5 * size; // scale the delta with the size
}
@BeforeClass
@ -103,11 +107,24 @@ public class JDKVectorLibraryTests extends VectorSimilarityFunctionsTests {
public void testIllegalDims() {
assumeTrue(notSupportedMsg(), supported());
var segment = arena.allocate((long) size * 3);
var e = expectThrows(IAE, () -> dotProduct7u(segment.asSlice(0L, size), segment.asSlice(size, size + 1), size));
assertThat(e.getMessage(), containsString("dimensions differ"));
e = expectThrows(IAE, () -> dotProduct7u(segment.asSlice(0L, size), segment.asSlice(size, size), size + 1));
assertThat(e.getMessage(), containsString("greater than vector dimensions"));
var e1 = expectThrows(IAE, () -> dotProduct7u(segment.asSlice(0L, size), segment.asSlice(size, size + 1), size));
assertThat(e1.getMessage(), containsString("dimensions differ"));
var e2 = expectThrows(IOOBE, () -> dotProduct7u(segment.asSlice(0L, size), segment.asSlice(size, size), size + 1));
assertThat(e2.getMessage(), containsString("out of bounds for length"));
var e3 = expectThrows(IOOBE, () -> dotProduct7u(segment.asSlice(0L, size), segment.asSlice(size, size), -1));
assertThat(e3.getMessage(), containsString("out of bounds for length"));
var e4 = expectThrows(IAE, () -> squareDistance7u(segment.asSlice(0L, size), segment.asSlice(size, size + 1), size));
assertThat(e4.getMessage(), containsString("dimensions differ"));
var e5 = expectThrows(IOOBE, () -> squareDistance7u(segment.asSlice(0L, size), segment.asSlice(size, size), size + 1));
assertThat(e5.getMessage(), containsString("out of bounds for length"));
var e6 = expectThrows(IOOBE, () -> squareDistance7u(segment.asSlice(0L, size), segment.asSlice(size, size), -1));
assertThat(e6.getMessage(), containsString("out of bounds for length"));
}
int dotProduct7u(MemorySegment a, MemorySegment b, int length) {

View file

@ -23,10 +23,10 @@ var os = org.gradle.internal.os.OperatingSystem.current()
// 1. Temporarily comment out the download in libs/native/library/build.gradle
// libs "org.elasticsearch:vec:${vecVersion}@zip"
// 2. Copy your locally built libvec binary, e.g.
// cp libs/vec/native/build/libs/vec/shared/libvec.dylib libs/native/libraries/build/platform/darwin-aarch64/libvec.dylib
// cp libs/simdvec/native/build/libs/vec/shared/aarch64/libvec.dylib libs/native/libraries/build/platform/darwin-aarch64/libvec.dylib
//
// Look at the disassemble:
// objdump --disassemble-symbols=_dot8s build/libs/vec/shared/libvec.dylib
// objdump --disassemble-symbols=_dot7u build/libs/vec/shared/aarch64/libvec.dylib
// Note: symbol decoration may differ on Linux, i.e. the leading underscore is not present
//
// gcc -shared -fpic -o libvec.so -I src/vec/headers/ src/vec/c/vec.c -O3