Add functionality to test if the host CPU supports native SIMD instructions (#107429)

This commit is contained in:
Lorenzo Dematté 2024-04-29 12:01:41 +02:00 committed by GitHub
parent 8594290263
commit 6ef4865195
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 200 additions and 153 deletions

View file

@ -18,7 +18,7 @@ configurations {
} }
var zstdVersion = "1.5.5" var zstdVersion = "1.5.5"
var vecVersion = "1.0.1" var vecVersion = "1.0.3"
repositories { repositories {
exclusiveContent { exclusiveContent {

View file

@ -27,7 +27,7 @@ abstract class PosixNativeAccess extends AbstractNativeAccess {
static VectorSimilarityFunctions vectorSimilarityFunctionsOrNull(NativeLibraryProvider libraryProvider) { static VectorSimilarityFunctions vectorSimilarityFunctionsOrNull(NativeLibraryProvider libraryProvider) {
if (isNativeVectorLibSupported()) { if (isNativeVectorLibSupported()) {
var lib = new VectorSimilarityFunctions(libraryProvider.getLibrary(VectorLibrary.class)); var lib = libraryProvider.getLibrary(VectorLibrary.class).getVectorSimilarityFunctions();
logger.info("Using native vector library; to disable start with -D" + ENABLE_JDK_VECTOR_LIBRARY + "=false"); logger.info("Using native vector library; to disable start with -D" + ENABLE_JDK_VECTOR_LIBRARY + "=false");
return lib; return lib;
} }

View file

@ -8,25 +8,16 @@
package org.elasticsearch.nativeaccess; package org.elasticsearch.nativeaccess;
import org.elasticsearch.nativeaccess.lib.VectorLibrary;
import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandle;
/** /**
* Utility class providing vector similarity functions. * Utility interface providing vector similarity functions.
* *
* <p> MethodHandles are returned to avoid a static reference to MemorySegment, * <p> MethodHandles are returned to avoid a static reference to MemorySegment,
* which is not in the currently lowest compile version, JDK 17. Code consuming * which is not in the currently lowest compile version, JDK 17. Code consuming
* the method handles will, by definition, require access to MemorySegment. * the method handles will, by definition, require access to MemorySegment.
*/ */
public final class VectorSimilarityFunctions implements VectorLibrary { public interface VectorSimilarityFunctions {
private final VectorLibrary vectorLibrary;
VectorSimilarityFunctions(VectorLibrary vectorLibrary) {
this.vectorLibrary = vectorLibrary;
}
/** /**
* Produces a method handle returning the dot product of byte (signed int8) vectors. * Produces a method handle returning the dot product of byte (signed int8) vectors.
* *
@ -34,9 +25,7 @@ public final class VectorSimilarityFunctions implements VectorLibrary {
* its first and second arguments will be {@code MemorySegment}, whose contents is the * its first and second arguments will be {@code MemorySegment}, whose contents is the
* vector data bytes. The third argument is the length of the vector data. * vector data bytes. The third argument is the length of the vector data.
*/ */
public MethodHandle dotProductHandle() { MethodHandle dotProductHandle();
return vectorLibrary.dotProductHandle();
}
/** /**
* Produces a method handle returning the square distance of byte (signed int8) vectors. * Produces a method handle returning the square distance of byte (signed int8) vectors.
@ -45,7 +34,5 @@ public final class VectorSimilarityFunctions implements VectorLibrary {
* its first and second arguments will be {@code MemorySegment}, whose contents is the * its first and second arguments will be {@code MemorySegment}, whose contents is the
* vector data bytes. The third argument is the length of the vector data. * vector data bytes. The third argument is the length of the vector data.
*/ */
public MethodHandle squareDistanceHandle() { MethodHandle squareDistanceHandle();
return vectorLibrary.squareDistanceHandle();
}
} }

View file

@ -8,7 +8,8 @@
package org.elasticsearch.nativeaccess.lib; package org.elasticsearch.nativeaccess.lib;
import java.lang.invoke.MethodHandle; import org.elasticsearch.core.Nullable;
import org.elasticsearch.nativeaccess.VectorSimilarityFunctions;
/** /**
* A VectorLibrary is just an adaptation of the factory for a NativeLibrary. * A VectorLibrary is just an adaptation of the factory for a NativeLibrary.
@ -16,8 +17,6 @@ import java.lang.invoke.MethodHandle;
* for native implementations. * for native implementations.
*/ */
public non-sealed interface VectorLibrary extends NativeLibrary { public non-sealed interface VectorLibrary extends NativeLibrary {
@Nullable
MethodHandle dotProductHandle(); VectorSimilarityFunctions getVectorSimilarityFunctions();
MethodHandle squareDistanceHandle();
} }

View file

@ -8,6 +8,7 @@
package org.elasticsearch.nativeaccess.jdk; package org.elasticsearch.nativeaccess.jdk;
import org.elasticsearch.nativeaccess.VectorSimilarityFunctions;
import org.elasticsearch.nativeaccess.lib.VectorLibrary; import org.elasticsearch.nativeaccess.lib.VectorLibrary;
import java.lang.foreign.FunctionDescriptor; import java.lang.foreign.FunctionDescriptor;
@ -23,142 +24,166 @@ import static org.elasticsearch.nativeaccess.jdk.LinkerHelper.downcallHandle;
public final class JdkVectorLibrary implements VectorLibrary { public final class JdkVectorLibrary implements VectorLibrary {
static final VectorSimilarityFunctions INSTANCE;
static { static {
System.loadLibrary("vec"); System.loadLibrary("vec");
final MethodHandle vecCaps$mh = downcallHandle("vec_caps", FunctionDescriptor.of(JAVA_INT));
try {
int caps = (int) vecCaps$mh.invokeExact();
if (caps != 0) {
INSTANCE = new JdkVectorSimilarityFunctions();
} else {
INSTANCE = null;
}
} catch (Throwable t) {
throw new AssertionError(t);
}
} }
public JdkVectorLibrary() {} public JdkVectorLibrary() {}
static final MethodHandle dot8stride$mh = downcallHandle("dot8s_stride", FunctionDescriptor.of(JAVA_INT));
static final MethodHandle sqr8stride$mh = downcallHandle("sqr8s_stride", FunctionDescriptor.of(JAVA_INT));
static final MethodHandle dot8s$mh = downcallHandle("dot8s", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT));
static final MethodHandle sqr8s$mh = downcallHandle("sqr8s", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT));
// Stride of the native implementation - consumes this number of bytes per loop invocation.
// There must be at least this number of bytes/elements available when going native
static final int DOT_STRIDE = 32;
static final int SQR_STRIDE = 16;
static {
assert DOT_STRIDE > 0 && (DOT_STRIDE & (DOT_STRIDE - 1)) == 0 : "Not a power of two";
assert dot8Stride() == DOT_STRIDE : dot8Stride() + " != " + DOT_STRIDE;
assert SQR_STRIDE > 0 && (SQR_STRIDE & (SQR_STRIDE - 1)) == 0 : "Not a power of two";
assert sqr8Stride() == SQR_STRIDE : sqr8Stride() + " != " + SQR_STRIDE;
}
/**
* Computes the dot product of given byte vectors.
* @param a address of the first vector
* @param b address of the second vector
* @param length the vector dimensions
*/
static int dotProduct(MemorySegment a, MemorySegment b, int length) {
assert length >= 0;
if (a.byteSize() != b.byteSize()) {
throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize());
}
if (length > a.byteSize()) {
throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize());
}
int i = 0;
int res = 0;
if (length >= DOT_STRIDE) {
i += length & ~(DOT_STRIDE - 1);
res = dot8s(a, b, i);
}
// tail
for (; i < length; i++) {
res += a.get(JAVA_BYTE, i) * b.get(JAVA_BYTE, i);
}
assert i == length;
return res;
}
/**
* Computes the square distance of given byte vectors.
* @param a address of the first vector
* @param b address of the second vector
* @param length the vector dimensions
*/
static int squareDistance(MemorySegment a, MemorySegment b, int length) {
assert length >= 0;
if (a.byteSize() != b.byteSize()) {
throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize());
}
if (length > a.byteSize()) {
throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize());
}
int i = 0;
int res = 0;
if (length >= SQR_STRIDE) {
i += length & ~(SQR_STRIDE - 1);
res = sqr8s(a, b, i);
}
// tail
for (; i < length; i++) {
int dist = a.get(JAVA_BYTE, i) - b.get(JAVA_BYTE, i);
res += dist * dist;
}
assert i == length;
return res;
}
private static int dot8Stride() {
try {
return (int) dot8stride$mh.invokeExact();
} catch (Throwable t) {
throw new AssertionError(t);
}
}
private static int sqr8Stride() {
try {
return (int) sqr8stride$mh.invokeExact();
} catch (Throwable t) {
throw new AssertionError(t);
}
}
private static int dot8s(MemorySegment a, MemorySegment b, int length) {
try {
return (int) dot8s$mh.invokeExact(a, b, length);
} catch (Throwable t) {
throw new AssertionError(t);
}
}
private static int sqr8s(MemorySegment a, MemorySegment b, int length) {
try {
return (int) sqr8s$mh.invokeExact(a, b, length);
} catch (Throwable t) {
throw new AssertionError(t);
}
}
static final MethodHandle DOT_HANDLE;
static final MethodHandle SQR_HANDLE;
static {
try {
var lookup = MethodHandles.lookup();
var mt = MethodType.methodType(int.class, MemorySegment.class, MemorySegment.class, int.class);
DOT_HANDLE = lookup.findStatic(JdkVectorLibrary.class, "dotProduct", mt);
SQR_HANDLE = lookup.findStatic(JdkVectorLibrary.class, "squareDistance", mt);
} catch (NoSuchMethodException | IllegalAccessException e) {
throw new RuntimeException(e);
}
}
@Override @Override
public MethodHandle dotProductHandle() { public VectorSimilarityFunctions getVectorSimilarityFunctions() {
return DOT_HANDLE; return INSTANCE;
} }
@Override private static final class JdkVectorSimilarityFunctions implements VectorSimilarityFunctions {
public MethodHandle squareDistanceHandle() {
return SQR_HANDLE; static final MethodHandle dot8stride$mh = downcallHandle("dot8s_stride", FunctionDescriptor.of(JAVA_INT));
static final MethodHandle sqr8stride$mh = downcallHandle("sqr8s_stride", FunctionDescriptor.of(JAVA_INT));
static final MethodHandle dot8s$mh = downcallHandle("dot8s", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT));
static final MethodHandle sqr8s$mh = downcallHandle("sqr8s", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT));
// Stride of the native implementation - consumes this number of bytes per loop invocation.
// There must be at least this number of bytes/elements available when going native
static final int DOT_STRIDE = 32;
static final int SQR_STRIDE = 16;
static {
assert DOT_STRIDE > 0 && (DOT_STRIDE & (DOT_STRIDE - 1)) == 0 : "Not a power of two";
assert dot8Stride() == DOT_STRIDE : dot8Stride() + " != " + DOT_STRIDE;
assert SQR_STRIDE > 0 && (SQR_STRIDE & (SQR_STRIDE - 1)) == 0 : "Not a power of two";
assert sqr8Stride() == SQR_STRIDE : sqr8Stride() + " != " + SQR_STRIDE;
}
/**
* Computes the dot product of given byte vectors.
*
* @param a address of the first vector
* @param b address of the second vector
* @param length the vector dimensions
*/
static int dotProduct(MemorySegment a, MemorySegment b, int length) {
assert length >= 0;
if (a.byteSize() != b.byteSize()) {
throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize());
}
if (length > a.byteSize()) {
throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize());
}
int i = 0;
int res = 0;
if (length >= DOT_STRIDE) {
i += length & ~(DOT_STRIDE - 1);
res = dot8s(a, b, i);
}
// tail
for (; i < length; i++) {
res += a.get(JAVA_BYTE, i) * b.get(JAVA_BYTE, i);
}
assert i == length;
return res;
}
/**
* Computes the square distance of given byte vectors.
*
* @param a address of the first vector
* @param b address of the second vector
* @param length the vector dimensions
*/
static int squareDistance(MemorySegment a, MemorySegment b, int length) {
assert length >= 0;
if (a.byteSize() != b.byteSize()) {
throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize());
}
if (length > a.byteSize()) {
throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize());
}
int i = 0;
int res = 0;
if (length >= SQR_STRIDE) {
i += length & ~(SQR_STRIDE - 1);
res = sqr8s(a, b, i);
}
// tail
for (; i < length; i++) {
int dist = a.get(JAVA_BYTE, i) - b.get(JAVA_BYTE, i);
res += dist * dist;
}
assert i == length;
return res;
}
private static int dot8Stride() {
try {
return (int) dot8stride$mh.invokeExact();
} catch (Throwable t) {
throw new AssertionError(t);
}
}
private static int sqr8Stride() {
try {
return (int) sqr8stride$mh.invokeExact();
} catch (Throwable t) {
throw new AssertionError(t);
}
}
private static int dot8s(MemorySegment a, MemorySegment b, int length) {
try {
return (int) dot8s$mh.invokeExact(a, b, length);
} catch (Throwable t) {
throw new AssertionError(t);
}
}
private static int sqr8s(MemorySegment a, MemorySegment b, int length) {
try {
return (int) sqr8s$mh.invokeExact(a, b, length);
} catch (Throwable t) {
throw new AssertionError(t);
}
}
static final MethodHandle DOT_HANDLE;
static final MethodHandle SQR_HANDLE;
static {
try {
var lookup = MethodHandles.lookup();
var mt = MethodType.methodType(int.class, MemorySegment.class, MemorySegment.class, int.class);
DOT_HANDLE = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProduct", mt);
SQR_HANDLE = lookup.findStatic(JdkVectorSimilarityFunctions.class, "squareDistance", mt);
} catch (NoSuchMethodException | IllegalAccessException e) {
throw new RuntimeException(e);
}
}
@Override
public MethodHandle dotProductHandle() {
return DOT_HANDLE;
}
@Override
public MethodHandle squareDistanceHandle() {
return SQR_HANDLE;
}
} }
} }

View file

@ -19,7 +19,7 @@ if [ -z "$ARTIFACTORY_API_KEY" ]; then
exit 1; exit 1;
fi fi
VERSION="1.0.1" VERSION="1.0.3"
ARTIFACTORY_REPOSITORY="${ARTIFACTORY_REPOSITORY:-https://artifactory.elastic.dev/artifactory/elasticsearch-native/}" ARTIFACTORY_REPOSITORY="${ARTIFACTORY_REPOSITORY:-https://artifactory.elastic.dev/artifactory/elasticsearch-native/}"
TEMP=$(mktemp -d) TEMP=$(mktemp -d)

View file

@ -18,6 +18,34 @@
#define SQR8S_STRIDE_BYTES_LEN 16 #define SQR8S_STRIDE_BYTES_LEN 16
#endif #endif
#ifdef __linux__
#include <sys/auxv.h>
#include <asm/hwcap.h>
#ifndef HWCAP_NEON
#define HWCAP_NEON 0x1000
#endif
#endif
#ifdef __APPLE__
#include <TargetConditionals.h>
#endif
EXPORT int vec_caps() {
#ifdef __APPLE__
#ifdef TARGET_OS_OSX
// All M series Apple silicon support Neon instructions
return 1;
#else
#error "Unsupported Apple platform"
#endif
#elif __linux__
int hwcap = getauxval(AT_HWCAP);
return (hwcap & HWCAP_NEON) != 0;
#else
#error "Unsupported aarch64 platform"
#endif
}
EXPORT int dot8s_stride() { EXPORT int dot8s_stride() {
return DOT8_STRIDE_BYTES_LEN; return DOT8_STRIDE_BYTES_LEN;
} }

View file

@ -6,7 +6,15 @@
* Side Public License, v 1. * Side Public License, v 1.
*/ */
#ifdef _MSC_VER
#define EXPORT extern "C" __declspec(dllexport)
#elif defined(__GNUC__) && !defined(__clang__)
#define EXPORT __attribute__((externally_visible,visibility("default"))) #define EXPORT __attribute__((externally_visible,visibility("default")))
#elif __clang__
#define EXPORT __attribute__((visibility("default")))
#endif
EXPORT int vec_caps();
EXPORT int dot8s_stride(); EXPORT int dot8s_stride();