From 6ef486519549f37be893b39b7248b49b9518a9d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lorenzo=20Dematt=C3=A9?= Date: Mon, 29 Apr 2024 12:01:41 +0200 Subject: [PATCH] Add functionality to test if the host CPU supports native SIMD instructions (#107429) --- libs/native/libraries/build.gradle | 2 +- .../nativeaccess/PosixNativeAccess.java | 2 +- .../VectorSimilarityFunctions.java | 21 +- .../nativeaccess/lib/VectorLibrary.java | 9 +- .../nativeaccess/jdk/JdkVectorLibrary.java | 281 ++++++++++-------- libs/vec/native/publish_vec_binaries.sh | 2 +- libs/vec/native/src/vec/c/vec.c | 28 ++ libs/vec/native/src/vec/headers/vec.h | 8 + 8 files changed, 200 insertions(+), 153 deletions(-) diff --git a/libs/native/libraries/build.gradle b/libs/native/libraries/build.gradle index 73c2c6fe14ba..e07235962074 100644 --- a/libs/native/libraries/build.gradle +++ b/libs/native/libraries/build.gradle @@ -18,7 +18,7 @@ configurations { } var zstdVersion = "1.5.5" -var vecVersion = "1.0.1" +var vecVersion = "1.0.3" repositories { exclusiveContent { diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/PosixNativeAccess.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/PosixNativeAccess.java index 993c9d2a874b..56017d3a8a20 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/PosixNativeAccess.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/PosixNativeAccess.java @@ -27,7 +27,7 @@ abstract class PosixNativeAccess extends AbstractNativeAccess { static VectorSimilarityFunctions vectorSimilarityFunctionsOrNull(NativeLibraryProvider libraryProvider) { if (isNativeVectorLibSupported()) { - var lib = new VectorSimilarityFunctions(libraryProvider.getLibrary(VectorLibrary.class)); + var lib = libraryProvider.getLibrary(VectorLibrary.class).getVectorSimilarityFunctions(); logger.info("Using native vector library; to disable start with -D" + ENABLE_JDK_VECTOR_LIBRARY + "=false"); return lib; } diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java index 7cb852ccf787..6b8f6048fe05 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java @@ -8,25 +8,16 @@ package org.elasticsearch.nativeaccess; -import org.elasticsearch.nativeaccess.lib.VectorLibrary; - import java.lang.invoke.MethodHandle; /** - * Utility class providing vector similarity functions. + * Utility interface providing vector similarity functions. * *

MethodHandles are returned to avoid a static reference to MemorySegment, * which is not in the currently lowest compile version, JDK 17. Code consuming * the method handles will, by definition, require access to MemorySegment. */ -public final class VectorSimilarityFunctions implements VectorLibrary { - - private final VectorLibrary vectorLibrary; - - VectorSimilarityFunctions(VectorLibrary vectorLibrary) { - this.vectorLibrary = vectorLibrary; - } - +public interface VectorSimilarityFunctions { /** * Produces a method handle returning the dot product of byte (signed int8) vectors. * @@ -34,9 +25,7 @@ public final class VectorSimilarityFunctions implements VectorLibrary { * its first and second arguments will be {@code MemorySegment}, whose contents is the * vector data bytes. The third argument is the length of the vector data. */ - public MethodHandle dotProductHandle() { - return vectorLibrary.dotProductHandle(); - } + MethodHandle dotProductHandle(); /** * Produces a method handle returning the square distance of byte (signed int8) vectors. @@ -45,7 +34,5 @@ public final class VectorSimilarityFunctions implements VectorLibrary { * its first and second arguments will be {@code MemorySegment}, whose contents is the * vector data bytes. The third argument is the length of the vector data. */ - public MethodHandle squareDistanceHandle() { - return vectorLibrary.squareDistanceHandle(); - } + MethodHandle squareDistanceHandle(); } diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/VectorLibrary.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/VectorLibrary.java index a11533c29beb..86d1a82b2bdc 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/VectorLibrary.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/VectorLibrary.java @@ -8,7 +8,8 @@ package org.elasticsearch.nativeaccess.lib; -import java.lang.invoke.MethodHandle; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.nativeaccess.VectorSimilarityFunctions; /** * A VectorLibrary is just an adaptation of the factory for a NativeLibrary. @@ -16,8 +17,6 @@ import java.lang.invoke.MethodHandle; * for native implementations. */ public non-sealed interface VectorLibrary extends NativeLibrary { - - MethodHandle dotProductHandle(); - - MethodHandle squareDistanceHandle(); + @Nullable + VectorSimilarityFunctions getVectorSimilarityFunctions(); } diff --git a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java index d4ab57396e29..b988c9730fd1 100644 --- a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java +++ b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java @@ -8,6 +8,7 @@ package org.elasticsearch.nativeaccess.jdk; +import org.elasticsearch.nativeaccess.VectorSimilarityFunctions; import org.elasticsearch.nativeaccess.lib.VectorLibrary; import java.lang.foreign.FunctionDescriptor; @@ -23,142 +24,166 @@ import static org.elasticsearch.nativeaccess.jdk.LinkerHelper.downcallHandle; public final class JdkVectorLibrary implements VectorLibrary { + static final VectorSimilarityFunctions INSTANCE; + static { System.loadLibrary("vec"); + final MethodHandle vecCaps$mh = downcallHandle("vec_caps", FunctionDescriptor.of(JAVA_INT)); + + try { + int caps = (int) vecCaps$mh.invokeExact(); + if (caps != 0) { + INSTANCE = new JdkVectorSimilarityFunctions(); + } else { + INSTANCE = null; + } + } catch (Throwable t) { + throw new AssertionError(t); + } } public JdkVectorLibrary() {} - static final MethodHandle dot8stride$mh = downcallHandle("dot8s_stride", FunctionDescriptor.of(JAVA_INT)); - static final MethodHandle sqr8stride$mh = downcallHandle("sqr8s_stride", FunctionDescriptor.of(JAVA_INT)); - - static final MethodHandle dot8s$mh = downcallHandle("dot8s", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT)); - static final MethodHandle sqr8s$mh = downcallHandle("sqr8s", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT)); - - // Stride of the native implementation - consumes this number of bytes per loop invocation. - // There must be at least this number of bytes/elements available when going native - static final int DOT_STRIDE = 32; - static final int SQR_STRIDE = 16; - - static { - assert DOT_STRIDE > 0 && (DOT_STRIDE & (DOT_STRIDE - 1)) == 0 : "Not a power of two"; - assert dot8Stride() == DOT_STRIDE : dot8Stride() + " != " + DOT_STRIDE; - assert SQR_STRIDE > 0 && (SQR_STRIDE & (SQR_STRIDE - 1)) == 0 : "Not a power of two"; - assert sqr8Stride() == SQR_STRIDE : sqr8Stride() + " != " + SQR_STRIDE; - } - - /** - * Computes the dot product of given byte vectors. - * @param a address of the first vector - * @param b address of the second vector - * @param length the vector dimensions - */ - static int dotProduct(MemorySegment a, MemorySegment b, int length) { - assert length >= 0; - if (a.byteSize() != b.byteSize()) { - throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize()); - } - if (length > a.byteSize()) { - throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize()); - } - int i = 0; - int res = 0; - if (length >= DOT_STRIDE) { - i += length & ~(DOT_STRIDE - 1); - res = dot8s(a, b, i); - } - - // tail - for (; i < length; i++) { - res += a.get(JAVA_BYTE, i) * b.get(JAVA_BYTE, i); - } - assert i == length; - return res; - } - - /** - * Computes the square distance of given byte vectors. - * @param a address of the first vector - * @param b address of the second vector - * @param length the vector dimensions - */ - static int squareDistance(MemorySegment a, MemorySegment b, int length) { - assert length >= 0; - if (a.byteSize() != b.byteSize()) { - throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize()); - } - if (length > a.byteSize()) { - throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize()); - } - int i = 0; - int res = 0; - if (length >= SQR_STRIDE) { - i += length & ~(SQR_STRIDE - 1); - res = sqr8s(a, b, i); - } - - // tail - for (; i < length; i++) { - int dist = a.get(JAVA_BYTE, i) - b.get(JAVA_BYTE, i); - res += dist * dist; - } - assert i == length; - return res; - } - - private static int dot8Stride() { - try { - return (int) dot8stride$mh.invokeExact(); - } catch (Throwable t) { - throw new AssertionError(t); - } - } - - private static int sqr8Stride() { - try { - return (int) sqr8stride$mh.invokeExact(); - } catch (Throwable t) { - throw new AssertionError(t); - } - } - - private static int dot8s(MemorySegment a, MemorySegment b, int length) { - try { - return (int) dot8s$mh.invokeExact(a, b, length); - } catch (Throwable t) { - throw new AssertionError(t); - } - } - - private static int sqr8s(MemorySegment a, MemorySegment b, int length) { - try { - return (int) sqr8s$mh.invokeExact(a, b, length); - } catch (Throwable t) { - throw new AssertionError(t); - } - } - - static final MethodHandle DOT_HANDLE; - static final MethodHandle SQR_HANDLE; - - static { - try { - var lookup = MethodHandles.lookup(); - var mt = MethodType.methodType(int.class, MemorySegment.class, MemorySegment.class, int.class); - DOT_HANDLE = lookup.findStatic(JdkVectorLibrary.class, "dotProduct", mt); - SQR_HANDLE = lookup.findStatic(JdkVectorLibrary.class, "squareDistance", mt); - } catch (NoSuchMethodException | IllegalAccessException e) { - throw new RuntimeException(e); - } - } - @Override - public MethodHandle dotProductHandle() { - return DOT_HANDLE; + public VectorSimilarityFunctions getVectorSimilarityFunctions() { + return INSTANCE; } - @Override - public MethodHandle squareDistanceHandle() { - return SQR_HANDLE; + private static final class JdkVectorSimilarityFunctions implements VectorSimilarityFunctions { + + static final MethodHandle dot8stride$mh = downcallHandle("dot8s_stride", FunctionDescriptor.of(JAVA_INT)); + static final MethodHandle sqr8stride$mh = downcallHandle("sqr8s_stride", FunctionDescriptor.of(JAVA_INT)); + + static final MethodHandle dot8s$mh = downcallHandle("dot8s", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT)); + static final MethodHandle sqr8s$mh = downcallHandle("sqr8s", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT)); + + // Stride of the native implementation - consumes this number of bytes per loop invocation. + // There must be at least this number of bytes/elements available when going native + static final int DOT_STRIDE = 32; + static final int SQR_STRIDE = 16; + + static { + assert DOT_STRIDE > 0 && (DOT_STRIDE & (DOT_STRIDE - 1)) == 0 : "Not a power of two"; + assert dot8Stride() == DOT_STRIDE : dot8Stride() + " != " + DOT_STRIDE; + assert SQR_STRIDE > 0 && (SQR_STRIDE & (SQR_STRIDE - 1)) == 0 : "Not a power of two"; + assert sqr8Stride() == SQR_STRIDE : sqr8Stride() + " != " + SQR_STRIDE; + } + + /** + * Computes the dot product of given byte vectors. + * + * @param a address of the first vector + * @param b address of the second vector + * @param length the vector dimensions + */ + static int dotProduct(MemorySegment a, MemorySegment b, int length) { + assert length >= 0; + if (a.byteSize() != b.byteSize()) { + throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize()); + } + if (length > a.byteSize()) { + throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize()); + } + int i = 0; + int res = 0; + if (length >= DOT_STRIDE) { + i += length & ~(DOT_STRIDE - 1); + res = dot8s(a, b, i); + } + + // tail + for (; i < length; i++) { + res += a.get(JAVA_BYTE, i) * b.get(JAVA_BYTE, i); + } + assert i == length; + return res; + } + + /** + * Computes the square distance of given byte vectors. + * + * @param a address of the first vector + * @param b address of the second vector + * @param length the vector dimensions + */ + static int squareDistance(MemorySegment a, MemorySegment b, int length) { + assert length >= 0; + if (a.byteSize() != b.byteSize()) { + throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize()); + } + if (length > a.byteSize()) { + throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize()); + } + int i = 0; + int res = 0; + if (length >= SQR_STRIDE) { + i += length & ~(SQR_STRIDE - 1); + res = sqr8s(a, b, i); + } + + // tail + for (; i < length; i++) { + int dist = a.get(JAVA_BYTE, i) - b.get(JAVA_BYTE, i); + res += dist * dist; + } + assert i == length; + return res; + } + + private static int dot8Stride() { + try { + return (int) dot8stride$mh.invokeExact(); + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private static int sqr8Stride() { + try { + return (int) sqr8stride$mh.invokeExact(); + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private static int dot8s(MemorySegment a, MemorySegment b, int length) { + try { + return (int) dot8s$mh.invokeExact(a, b, length); + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private static int sqr8s(MemorySegment a, MemorySegment b, int length) { + try { + return (int) sqr8s$mh.invokeExact(a, b, length); + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + static final MethodHandle DOT_HANDLE; + static final MethodHandle SQR_HANDLE; + + static { + try { + var lookup = MethodHandles.lookup(); + var mt = MethodType.methodType(int.class, MemorySegment.class, MemorySegment.class, int.class); + DOT_HANDLE = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProduct", mt); + SQR_HANDLE = lookup.findStatic(JdkVectorSimilarityFunctions.class, "squareDistance", mt); + } catch (NoSuchMethodException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + @Override + public MethodHandle dotProductHandle() { + return DOT_HANDLE; + } + + @Override + public MethodHandle squareDistanceHandle() { + return SQR_HANDLE; + } } } diff --git a/libs/vec/native/publish_vec_binaries.sh b/libs/vec/native/publish_vec_binaries.sh index 6cdea109c2eb..7c460eb0321c 100755 --- a/libs/vec/native/publish_vec_binaries.sh +++ b/libs/vec/native/publish_vec_binaries.sh @@ -19,7 +19,7 @@ if [ -z "$ARTIFACTORY_API_KEY" ]; then exit 1; fi -VERSION="1.0.1" +VERSION="1.0.3" ARTIFACTORY_REPOSITORY="${ARTIFACTORY_REPOSITORY:-https://artifactory.elastic.dev/artifactory/elasticsearch-native/}" TEMP=$(mktemp -d) diff --git a/libs/vec/native/src/vec/c/vec.c b/libs/vec/native/src/vec/c/vec.c index 008129b665d0..46cc6722d01d 100644 --- a/libs/vec/native/src/vec/c/vec.c +++ b/libs/vec/native/src/vec/c/vec.c @@ -18,6 +18,34 @@ #define SQR8S_STRIDE_BYTES_LEN 16 #endif +#ifdef __linux__ + #include + #include + #ifndef HWCAP_NEON + #define HWCAP_NEON 0x1000 + #endif +#endif + +#ifdef __APPLE__ +#include +#endif + +EXPORT int vec_caps() { +#ifdef __APPLE__ + #ifdef TARGET_OS_OSX + // All M series Apple silicon support Neon instructions + return 1; + #else + #error "Unsupported Apple platform" + #endif +#elif __linux__ + int hwcap = getauxval(AT_HWCAP); + return (hwcap & HWCAP_NEON) != 0; +#else + #error "Unsupported aarch64 platform" +#endif +} + EXPORT int dot8s_stride() { return DOT8_STRIDE_BYTES_LEN; } diff --git a/libs/vec/native/src/vec/headers/vec.h b/libs/vec/native/src/vec/headers/vec.h index a717ad2712e1..380111107f38 100644 --- a/libs/vec/native/src/vec/headers/vec.h +++ b/libs/vec/native/src/vec/headers/vec.h @@ -6,7 +6,15 @@ * Side Public License, v 1. */ +#ifdef _MSC_VER +#define EXPORT extern "C" __declspec(dllexport) +#elif defined(__GNUC__) && !defined(__clang__) #define EXPORT __attribute__((externally_visible,visibility("default"))) +#elif __clang__ +#define EXPORT __attribute__((visibility("default"))) +#endif + +EXPORT int vec_caps(); EXPORT int dot8s_stride();