diff --git a/libs/native/libraries/build.gradle b/libs/native/libraries/build.gradle index 73c2c6fe14ba..e07235962074 100644 --- a/libs/native/libraries/build.gradle +++ b/libs/native/libraries/build.gradle @@ -18,7 +18,7 @@ configurations { } var zstdVersion = "1.5.5" -var vecVersion = "1.0.1" +var vecVersion = "1.0.3" repositories { exclusiveContent { diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/PosixNativeAccess.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/PosixNativeAccess.java index 993c9d2a874b..56017d3a8a20 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/PosixNativeAccess.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/PosixNativeAccess.java @@ -27,7 +27,7 @@ abstract class PosixNativeAccess extends AbstractNativeAccess { static VectorSimilarityFunctions vectorSimilarityFunctionsOrNull(NativeLibraryProvider libraryProvider) { if (isNativeVectorLibSupported()) { - var lib = new VectorSimilarityFunctions(libraryProvider.getLibrary(VectorLibrary.class)); + var lib = libraryProvider.getLibrary(VectorLibrary.class).getVectorSimilarityFunctions(); logger.info("Using native vector library; to disable start with -D" + ENABLE_JDK_VECTOR_LIBRARY + "=false"); return lib; } diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java index 7cb852ccf787..6b8f6048fe05 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java @@ -8,25 +8,16 @@ package org.elasticsearch.nativeaccess; -import org.elasticsearch.nativeaccess.lib.VectorLibrary; - import java.lang.invoke.MethodHandle; /** - * Utility class providing vector similarity functions. + * Utility interface providing vector similarity functions. * *
MethodHandles are returned to avoid a static reference to MemorySegment,
* which is not in the currently lowest compile version, JDK 17. Code consuming
* the method handles will, by definition, require access to MemorySegment.
*/
-public final class VectorSimilarityFunctions implements VectorLibrary {
-
- private final VectorLibrary vectorLibrary;
-
- VectorSimilarityFunctions(VectorLibrary vectorLibrary) {
- this.vectorLibrary = vectorLibrary;
- }
-
+public interface VectorSimilarityFunctions {
/**
* Produces a method handle returning the dot product of byte (signed int8) vectors.
*
@@ -34,9 +25,7 @@ public final class VectorSimilarityFunctions implements VectorLibrary {
* its first and second arguments will be {@code MemorySegment}, whose contents is the
* vector data bytes. The third argument is the length of the vector data.
*/
- public MethodHandle dotProductHandle() {
- return vectorLibrary.dotProductHandle();
- }
+ MethodHandle dotProductHandle();
/**
* Produces a method handle returning the square distance of byte (signed int8) vectors.
@@ -45,7 +34,5 @@ public final class VectorSimilarityFunctions implements VectorLibrary {
* its first and second arguments will be {@code MemorySegment}, whose contents is the
* vector data bytes. The third argument is the length of the vector data.
*/
- public MethodHandle squareDistanceHandle() {
- return vectorLibrary.squareDistanceHandle();
- }
+ MethodHandle squareDistanceHandle();
}
diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/VectorLibrary.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/VectorLibrary.java
index a11533c29beb..86d1a82b2bdc 100644
--- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/VectorLibrary.java
+++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/VectorLibrary.java
@@ -8,7 +8,8 @@
package org.elasticsearch.nativeaccess.lib;
-import java.lang.invoke.MethodHandle;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.nativeaccess.VectorSimilarityFunctions;
/**
* A VectorLibrary is just an adaptation of the factory for a NativeLibrary.
@@ -16,8 +17,6 @@ import java.lang.invoke.MethodHandle;
* for native implementations.
*/
public non-sealed interface VectorLibrary extends NativeLibrary {
-
- MethodHandle dotProductHandle();
-
- MethodHandle squareDistanceHandle();
+ @Nullable
+ VectorSimilarityFunctions getVectorSimilarityFunctions();
}
diff --git a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java
index d4ab57396e29..b988c9730fd1 100644
--- a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java
+++ b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java
@@ -8,6 +8,7 @@
package org.elasticsearch.nativeaccess.jdk;
+import org.elasticsearch.nativeaccess.VectorSimilarityFunctions;
import org.elasticsearch.nativeaccess.lib.VectorLibrary;
import java.lang.foreign.FunctionDescriptor;
@@ -23,142 +24,166 @@ import static org.elasticsearch.nativeaccess.jdk.LinkerHelper.downcallHandle;
public final class JdkVectorLibrary implements VectorLibrary {
+ static final VectorSimilarityFunctions INSTANCE;
+
static {
System.loadLibrary("vec");
+ final MethodHandle vecCaps$mh = downcallHandle("vec_caps", FunctionDescriptor.of(JAVA_INT));
+
+ try {
+ int caps = (int) vecCaps$mh.invokeExact();
+ if (caps != 0) {
+ INSTANCE = new JdkVectorSimilarityFunctions();
+ } else {
+ INSTANCE = null;
+ }
+ } catch (Throwable t) {
+ throw new AssertionError(t);
+ }
}
public JdkVectorLibrary() {}
- static final MethodHandle dot8stride$mh = downcallHandle("dot8s_stride", FunctionDescriptor.of(JAVA_INT));
- static final MethodHandle sqr8stride$mh = downcallHandle("sqr8s_stride", FunctionDescriptor.of(JAVA_INT));
-
- static final MethodHandle dot8s$mh = downcallHandle("dot8s", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT));
- static final MethodHandle sqr8s$mh = downcallHandle("sqr8s", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT));
-
- // Stride of the native implementation - consumes this number of bytes per loop invocation.
- // There must be at least this number of bytes/elements available when going native
- static final int DOT_STRIDE = 32;
- static final int SQR_STRIDE = 16;
-
- static {
- assert DOT_STRIDE > 0 && (DOT_STRIDE & (DOT_STRIDE - 1)) == 0 : "Not a power of two";
- assert dot8Stride() == DOT_STRIDE : dot8Stride() + " != " + DOT_STRIDE;
- assert SQR_STRIDE > 0 && (SQR_STRIDE & (SQR_STRIDE - 1)) == 0 : "Not a power of two";
- assert sqr8Stride() == SQR_STRIDE : sqr8Stride() + " != " + SQR_STRIDE;
- }
-
- /**
- * Computes the dot product of given byte vectors.
- * @param a address of the first vector
- * @param b address of the second vector
- * @param length the vector dimensions
- */
- static int dotProduct(MemorySegment a, MemorySegment b, int length) {
- assert length >= 0;
- if (a.byteSize() != b.byteSize()) {
- throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize());
- }
- if (length > a.byteSize()) {
- throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize());
- }
- int i = 0;
- int res = 0;
- if (length >= DOT_STRIDE) {
- i += length & ~(DOT_STRIDE - 1);
- res = dot8s(a, b, i);
- }
-
- // tail
- for (; i < length; i++) {
- res += a.get(JAVA_BYTE, i) * b.get(JAVA_BYTE, i);
- }
- assert i == length;
- return res;
- }
-
- /**
- * Computes the square distance of given byte vectors.
- * @param a address of the first vector
- * @param b address of the second vector
- * @param length the vector dimensions
- */
- static int squareDistance(MemorySegment a, MemorySegment b, int length) {
- assert length >= 0;
- if (a.byteSize() != b.byteSize()) {
- throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize());
- }
- if (length > a.byteSize()) {
- throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize());
- }
- int i = 0;
- int res = 0;
- if (length >= SQR_STRIDE) {
- i += length & ~(SQR_STRIDE - 1);
- res = sqr8s(a, b, i);
- }
-
- // tail
- for (; i < length; i++) {
- int dist = a.get(JAVA_BYTE, i) - b.get(JAVA_BYTE, i);
- res += dist * dist;
- }
- assert i == length;
- return res;
- }
-
- private static int dot8Stride() {
- try {
- return (int) dot8stride$mh.invokeExact();
- } catch (Throwable t) {
- throw new AssertionError(t);
- }
- }
-
- private static int sqr8Stride() {
- try {
- return (int) sqr8stride$mh.invokeExact();
- } catch (Throwable t) {
- throw new AssertionError(t);
- }
- }
-
- private static int dot8s(MemorySegment a, MemorySegment b, int length) {
- try {
- return (int) dot8s$mh.invokeExact(a, b, length);
- } catch (Throwable t) {
- throw new AssertionError(t);
- }
- }
-
- private static int sqr8s(MemorySegment a, MemorySegment b, int length) {
- try {
- return (int) sqr8s$mh.invokeExact(a, b, length);
- } catch (Throwable t) {
- throw new AssertionError(t);
- }
- }
-
- static final MethodHandle DOT_HANDLE;
- static final MethodHandle SQR_HANDLE;
-
- static {
- try {
- var lookup = MethodHandles.lookup();
- var mt = MethodType.methodType(int.class, MemorySegment.class, MemorySegment.class, int.class);
- DOT_HANDLE = lookup.findStatic(JdkVectorLibrary.class, "dotProduct", mt);
- SQR_HANDLE = lookup.findStatic(JdkVectorLibrary.class, "squareDistance", mt);
- } catch (NoSuchMethodException | IllegalAccessException e) {
- throw new RuntimeException(e);
- }
- }
-
@Override
- public MethodHandle dotProductHandle() {
- return DOT_HANDLE;
+ public VectorSimilarityFunctions getVectorSimilarityFunctions() {
+ return INSTANCE;
}
- @Override
- public MethodHandle squareDistanceHandle() {
- return SQR_HANDLE;
+ private static final class JdkVectorSimilarityFunctions implements VectorSimilarityFunctions {
+
+ static final MethodHandle dot8stride$mh = downcallHandle("dot8s_stride", FunctionDescriptor.of(JAVA_INT));
+ static final MethodHandle sqr8stride$mh = downcallHandle("sqr8s_stride", FunctionDescriptor.of(JAVA_INT));
+
+ static final MethodHandle dot8s$mh = downcallHandle("dot8s", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT));
+ static final MethodHandle sqr8s$mh = downcallHandle("sqr8s", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT));
+
+ // Stride of the native implementation - consumes this number of bytes per loop invocation.
+ // There must be at least this number of bytes/elements available when going native
+ static final int DOT_STRIDE = 32;
+ static final int SQR_STRIDE = 16;
+
+ static {
+ assert DOT_STRIDE > 0 && (DOT_STRIDE & (DOT_STRIDE - 1)) == 0 : "Not a power of two";
+ assert dot8Stride() == DOT_STRIDE : dot8Stride() + " != " + DOT_STRIDE;
+ assert SQR_STRIDE > 0 && (SQR_STRIDE & (SQR_STRIDE - 1)) == 0 : "Not a power of two";
+ assert sqr8Stride() == SQR_STRIDE : sqr8Stride() + " != " + SQR_STRIDE;
+ }
+
+ /**
+ * Computes the dot product of given byte vectors.
+ *
+ * @param a address of the first vector
+ * @param b address of the second vector
+ * @param length the vector dimensions
+ */
+ static int dotProduct(MemorySegment a, MemorySegment b, int length) {
+ assert length >= 0;
+ if (a.byteSize() != b.byteSize()) {
+ throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize());
+ }
+ if (length > a.byteSize()) {
+ throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize());
+ }
+ int i = 0;
+ int res = 0;
+ if (length >= DOT_STRIDE) {
+ i += length & ~(DOT_STRIDE - 1);
+ res = dot8s(a, b, i);
+ }
+
+ // tail
+ for (; i < length; i++) {
+ res += a.get(JAVA_BYTE, i) * b.get(JAVA_BYTE, i);
+ }
+ assert i == length;
+ return res;
+ }
+
+ /**
+ * Computes the square distance of given byte vectors.
+ *
+ * @param a address of the first vector
+ * @param b address of the second vector
+ * @param length the vector dimensions
+ */
+ static int squareDistance(MemorySegment a, MemorySegment b, int length) {
+ assert length >= 0;
+ if (a.byteSize() != b.byteSize()) {
+ throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize());
+ }
+ if (length > a.byteSize()) {
+ throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize());
+ }
+ int i = 0;
+ int res = 0;
+ if (length >= SQR_STRIDE) {
+ i += length & ~(SQR_STRIDE - 1);
+ res = sqr8s(a, b, i);
+ }
+
+ // tail
+ for (; i < length; i++) {
+ int dist = a.get(JAVA_BYTE, i) - b.get(JAVA_BYTE, i);
+ res += dist * dist;
+ }
+ assert i == length;
+ return res;
+ }
+
+ private static int dot8Stride() {
+ try {
+ return (int) dot8stride$mh.invokeExact();
+ } catch (Throwable t) {
+ throw new AssertionError(t);
+ }
+ }
+
+ private static int sqr8Stride() {
+ try {
+ return (int) sqr8stride$mh.invokeExact();
+ } catch (Throwable t) {
+ throw new AssertionError(t);
+ }
+ }
+
+ private static int dot8s(MemorySegment a, MemorySegment b, int length) {
+ try {
+ return (int) dot8s$mh.invokeExact(a, b, length);
+ } catch (Throwable t) {
+ throw new AssertionError(t);
+ }
+ }
+
+ private static int sqr8s(MemorySegment a, MemorySegment b, int length) {
+ try {
+ return (int) sqr8s$mh.invokeExact(a, b, length);
+ } catch (Throwable t) {
+ throw new AssertionError(t);
+ }
+ }
+
+ static final MethodHandle DOT_HANDLE;
+ static final MethodHandle SQR_HANDLE;
+
+ static {
+ try {
+ var lookup = MethodHandles.lookup();
+ var mt = MethodType.methodType(int.class, MemorySegment.class, MemorySegment.class, int.class);
+ DOT_HANDLE = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProduct", mt);
+ SQR_HANDLE = lookup.findStatic(JdkVectorSimilarityFunctions.class, "squareDistance", mt);
+ } catch (NoSuchMethodException | IllegalAccessException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public MethodHandle dotProductHandle() {
+ return DOT_HANDLE;
+ }
+
+ @Override
+ public MethodHandle squareDistanceHandle() {
+ return SQR_HANDLE;
+ }
}
}
diff --git a/libs/vec/native/publish_vec_binaries.sh b/libs/vec/native/publish_vec_binaries.sh
index 6cdea109c2eb..7c460eb0321c 100755
--- a/libs/vec/native/publish_vec_binaries.sh
+++ b/libs/vec/native/publish_vec_binaries.sh
@@ -19,7 +19,7 @@ if [ -z "$ARTIFACTORY_API_KEY" ]; then
exit 1;
fi
-VERSION="1.0.1"
+VERSION="1.0.3"
ARTIFACTORY_REPOSITORY="${ARTIFACTORY_REPOSITORY:-https://artifactory.elastic.dev/artifactory/elasticsearch-native/}"
TEMP=$(mktemp -d)
diff --git a/libs/vec/native/src/vec/c/vec.c b/libs/vec/native/src/vec/c/vec.c
index 008129b665d0..46cc6722d01d 100644
--- a/libs/vec/native/src/vec/c/vec.c
+++ b/libs/vec/native/src/vec/c/vec.c
@@ -18,6 +18,34 @@
#define SQR8S_STRIDE_BYTES_LEN 16
#endif
+#ifdef __linux__
+ #include