mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 09:28:55 -04:00
Add functionality to test if the host CPU supports native SIMD instructions (#107429)
This commit is contained in:
parent
8594290263
commit
6ef4865195
8 changed files with 200 additions and 153 deletions
|
@ -18,7 +18,7 @@ configurations {
|
|||
}
|
||||
|
||||
var zstdVersion = "1.5.5"
|
||||
var vecVersion = "1.0.1"
|
||||
var vecVersion = "1.0.3"
|
||||
|
||||
repositories {
|
||||
exclusiveContent {
|
||||
|
|
|
@ -27,7 +27,7 @@ abstract class PosixNativeAccess extends AbstractNativeAccess {
|
|||
|
||||
static VectorSimilarityFunctions vectorSimilarityFunctionsOrNull(NativeLibraryProvider libraryProvider) {
|
||||
if (isNativeVectorLibSupported()) {
|
||||
var lib = new VectorSimilarityFunctions(libraryProvider.getLibrary(VectorLibrary.class));
|
||||
var lib = libraryProvider.getLibrary(VectorLibrary.class).getVectorSimilarityFunctions();
|
||||
logger.info("Using native vector library; to disable start with -D" + ENABLE_JDK_VECTOR_LIBRARY + "=false");
|
||||
return lib;
|
||||
}
|
||||
|
|
|
@ -8,25 +8,16 @@
|
|||
|
||||
package org.elasticsearch.nativeaccess;
|
||||
|
||||
import org.elasticsearch.nativeaccess.lib.VectorLibrary;
|
||||
|
||||
import java.lang.invoke.MethodHandle;
|
||||
|
||||
/**
|
||||
* Utility class providing vector similarity functions.
|
||||
* Utility interface providing vector similarity functions.
|
||||
*
|
||||
* <p> MethodHandles are returned to avoid a static reference to MemorySegment,
|
||||
* which is not in the currently lowest compile version, JDK 17. Code consuming
|
||||
* the method handles will, by definition, require access to MemorySegment.
|
||||
*/
|
||||
public final class VectorSimilarityFunctions implements VectorLibrary {
|
||||
|
||||
private final VectorLibrary vectorLibrary;
|
||||
|
||||
VectorSimilarityFunctions(VectorLibrary vectorLibrary) {
|
||||
this.vectorLibrary = vectorLibrary;
|
||||
}
|
||||
|
||||
public interface VectorSimilarityFunctions {
|
||||
/**
|
||||
* Produces a method handle returning the dot product of byte (signed int8) vectors.
|
||||
*
|
||||
|
@ -34,9 +25,7 @@ public final class VectorSimilarityFunctions implements VectorLibrary {
|
|||
* its first and second arguments will be {@code MemorySegment}, whose contents is the
|
||||
* vector data bytes. The third argument is the length of the vector data.
|
||||
*/
|
||||
public MethodHandle dotProductHandle() {
|
||||
return vectorLibrary.dotProductHandle();
|
||||
}
|
||||
MethodHandle dotProductHandle();
|
||||
|
||||
/**
|
||||
* Produces a method handle returning the square distance of byte (signed int8) vectors.
|
||||
|
@ -45,7 +34,5 @@ public final class VectorSimilarityFunctions implements VectorLibrary {
|
|||
* its first and second arguments will be {@code MemorySegment}, whose contents is the
|
||||
* vector data bytes. The third argument is the length of the vector data.
|
||||
*/
|
||||
public MethodHandle squareDistanceHandle() {
|
||||
return vectorLibrary.squareDistanceHandle();
|
||||
}
|
||||
MethodHandle squareDistanceHandle();
|
||||
}
|
||||
|
|
|
@ -8,7 +8,8 @@
|
|||
|
||||
package org.elasticsearch.nativeaccess.lib;
|
||||
|
||||
import java.lang.invoke.MethodHandle;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.nativeaccess.VectorSimilarityFunctions;
|
||||
|
||||
/**
|
||||
* A VectorLibrary is just an adaptation of the factory for a NativeLibrary.
|
||||
|
@ -16,8 +17,6 @@ import java.lang.invoke.MethodHandle;
|
|||
* for native implementations.
|
||||
*/
|
||||
public non-sealed interface VectorLibrary extends NativeLibrary {
|
||||
|
||||
MethodHandle dotProductHandle();
|
||||
|
||||
MethodHandle squareDistanceHandle();
|
||||
@Nullable
|
||||
VectorSimilarityFunctions getVectorSimilarityFunctions();
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
package org.elasticsearch.nativeaccess.jdk;
|
||||
|
||||
import org.elasticsearch.nativeaccess.VectorSimilarityFunctions;
|
||||
import org.elasticsearch.nativeaccess.lib.VectorLibrary;
|
||||
|
||||
import java.lang.foreign.FunctionDescriptor;
|
||||
|
@ -23,142 +24,166 @@ import static org.elasticsearch.nativeaccess.jdk.LinkerHelper.downcallHandle;
|
|||
|
||||
public final class JdkVectorLibrary implements VectorLibrary {
|
||||
|
||||
static final VectorSimilarityFunctions INSTANCE;
|
||||
|
||||
static {
|
||||
System.loadLibrary("vec");
|
||||
final MethodHandle vecCaps$mh = downcallHandle("vec_caps", FunctionDescriptor.of(JAVA_INT));
|
||||
|
||||
try {
|
||||
int caps = (int) vecCaps$mh.invokeExact();
|
||||
if (caps != 0) {
|
||||
INSTANCE = new JdkVectorSimilarityFunctions();
|
||||
} else {
|
||||
INSTANCE = null;
|
||||
}
|
||||
} catch (Throwable t) {
|
||||
throw new AssertionError(t);
|
||||
}
|
||||
}
|
||||
|
||||
public JdkVectorLibrary() {}
|
||||
|
||||
static final MethodHandle dot8stride$mh = downcallHandle("dot8s_stride", FunctionDescriptor.of(JAVA_INT));
|
||||
static final MethodHandle sqr8stride$mh = downcallHandle("sqr8s_stride", FunctionDescriptor.of(JAVA_INT));
|
||||
|
||||
static final MethodHandle dot8s$mh = downcallHandle("dot8s", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT));
|
||||
static final MethodHandle sqr8s$mh = downcallHandle("sqr8s", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT));
|
||||
|
||||
// Stride of the native implementation - consumes this number of bytes per loop invocation.
|
||||
// There must be at least this number of bytes/elements available when going native
|
||||
static final int DOT_STRIDE = 32;
|
||||
static final int SQR_STRIDE = 16;
|
||||
|
||||
static {
|
||||
assert DOT_STRIDE > 0 && (DOT_STRIDE & (DOT_STRIDE - 1)) == 0 : "Not a power of two";
|
||||
assert dot8Stride() == DOT_STRIDE : dot8Stride() + " != " + DOT_STRIDE;
|
||||
assert SQR_STRIDE > 0 && (SQR_STRIDE & (SQR_STRIDE - 1)) == 0 : "Not a power of two";
|
||||
assert sqr8Stride() == SQR_STRIDE : sqr8Stride() + " != " + SQR_STRIDE;
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the dot product of given byte vectors.
|
||||
* @param a address of the first vector
|
||||
* @param b address of the second vector
|
||||
* @param length the vector dimensions
|
||||
*/
|
||||
static int dotProduct(MemorySegment a, MemorySegment b, int length) {
|
||||
assert length >= 0;
|
||||
if (a.byteSize() != b.byteSize()) {
|
||||
throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize());
|
||||
}
|
||||
if (length > a.byteSize()) {
|
||||
throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize());
|
||||
}
|
||||
int i = 0;
|
||||
int res = 0;
|
||||
if (length >= DOT_STRIDE) {
|
||||
i += length & ~(DOT_STRIDE - 1);
|
||||
res = dot8s(a, b, i);
|
||||
}
|
||||
|
||||
// tail
|
||||
for (; i < length; i++) {
|
||||
res += a.get(JAVA_BYTE, i) * b.get(JAVA_BYTE, i);
|
||||
}
|
||||
assert i == length;
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the square distance of given byte vectors.
|
||||
* @param a address of the first vector
|
||||
* @param b address of the second vector
|
||||
* @param length the vector dimensions
|
||||
*/
|
||||
static int squareDistance(MemorySegment a, MemorySegment b, int length) {
|
||||
assert length >= 0;
|
||||
if (a.byteSize() != b.byteSize()) {
|
||||
throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize());
|
||||
}
|
||||
if (length > a.byteSize()) {
|
||||
throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize());
|
||||
}
|
||||
int i = 0;
|
||||
int res = 0;
|
||||
if (length >= SQR_STRIDE) {
|
||||
i += length & ~(SQR_STRIDE - 1);
|
||||
res = sqr8s(a, b, i);
|
||||
}
|
||||
|
||||
// tail
|
||||
for (; i < length; i++) {
|
||||
int dist = a.get(JAVA_BYTE, i) - b.get(JAVA_BYTE, i);
|
||||
res += dist * dist;
|
||||
}
|
||||
assert i == length;
|
||||
return res;
|
||||
}
|
||||
|
||||
private static int dot8Stride() {
|
||||
try {
|
||||
return (int) dot8stride$mh.invokeExact();
|
||||
} catch (Throwable t) {
|
||||
throw new AssertionError(t);
|
||||
}
|
||||
}
|
||||
|
||||
private static int sqr8Stride() {
|
||||
try {
|
||||
return (int) sqr8stride$mh.invokeExact();
|
||||
} catch (Throwable t) {
|
||||
throw new AssertionError(t);
|
||||
}
|
||||
}
|
||||
|
||||
private static int dot8s(MemorySegment a, MemorySegment b, int length) {
|
||||
try {
|
||||
return (int) dot8s$mh.invokeExact(a, b, length);
|
||||
} catch (Throwable t) {
|
||||
throw new AssertionError(t);
|
||||
}
|
||||
}
|
||||
|
||||
private static int sqr8s(MemorySegment a, MemorySegment b, int length) {
|
||||
try {
|
||||
return (int) sqr8s$mh.invokeExact(a, b, length);
|
||||
} catch (Throwable t) {
|
||||
throw new AssertionError(t);
|
||||
}
|
||||
}
|
||||
|
||||
static final MethodHandle DOT_HANDLE;
|
||||
static final MethodHandle SQR_HANDLE;
|
||||
|
||||
static {
|
||||
try {
|
||||
var lookup = MethodHandles.lookup();
|
||||
var mt = MethodType.methodType(int.class, MemorySegment.class, MemorySegment.class, int.class);
|
||||
DOT_HANDLE = lookup.findStatic(JdkVectorLibrary.class, "dotProduct", mt);
|
||||
SQR_HANDLE = lookup.findStatic(JdkVectorLibrary.class, "squareDistance", mt);
|
||||
} catch (NoSuchMethodException | IllegalAccessException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public MethodHandle dotProductHandle() {
|
||||
return DOT_HANDLE;
|
||||
public VectorSimilarityFunctions getVectorSimilarityFunctions() {
|
||||
return INSTANCE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MethodHandle squareDistanceHandle() {
|
||||
return SQR_HANDLE;
|
||||
private static final class JdkVectorSimilarityFunctions implements VectorSimilarityFunctions {
|
||||
|
||||
static final MethodHandle dot8stride$mh = downcallHandle("dot8s_stride", FunctionDescriptor.of(JAVA_INT));
|
||||
static final MethodHandle sqr8stride$mh = downcallHandle("sqr8s_stride", FunctionDescriptor.of(JAVA_INT));
|
||||
|
||||
static final MethodHandle dot8s$mh = downcallHandle("dot8s", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT));
|
||||
static final MethodHandle sqr8s$mh = downcallHandle("sqr8s", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT));
|
||||
|
||||
// Stride of the native implementation - consumes this number of bytes per loop invocation.
|
||||
// There must be at least this number of bytes/elements available when going native
|
||||
static final int DOT_STRIDE = 32;
|
||||
static final int SQR_STRIDE = 16;
|
||||
|
||||
static {
|
||||
assert DOT_STRIDE > 0 && (DOT_STRIDE & (DOT_STRIDE - 1)) == 0 : "Not a power of two";
|
||||
assert dot8Stride() == DOT_STRIDE : dot8Stride() + " != " + DOT_STRIDE;
|
||||
assert SQR_STRIDE > 0 && (SQR_STRIDE & (SQR_STRIDE - 1)) == 0 : "Not a power of two";
|
||||
assert sqr8Stride() == SQR_STRIDE : sqr8Stride() + " != " + SQR_STRIDE;
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the dot product of given byte vectors.
|
||||
*
|
||||
* @param a address of the first vector
|
||||
* @param b address of the second vector
|
||||
* @param length the vector dimensions
|
||||
*/
|
||||
static int dotProduct(MemorySegment a, MemorySegment b, int length) {
|
||||
assert length >= 0;
|
||||
if (a.byteSize() != b.byteSize()) {
|
||||
throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize());
|
||||
}
|
||||
if (length > a.byteSize()) {
|
||||
throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize());
|
||||
}
|
||||
int i = 0;
|
||||
int res = 0;
|
||||
if (length >= DOT_STRIDE) {
|
||||
i += length & ~(DOT_STRIDE - 1);
|
||||
res = dot8s(a, b, i);
|
||||
}
|
||||
|
||||
// tail
|
||||
for (; i < length; i++) {
|
||||
res += a.get(JAVA_BYTE, i) * b.get(JAVA_BYTE, i);
|
||||
}
|
||||
assert i == length;
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the square distance of given byte vectors.
|
||||
*
|
||||
* @param a address of the first vector
|
||||
* @param b address of the second vector
|
||||
* @param length the vector dimensions
|
||||
*/
|
||||
static int squareDistance(MemorySegment a, MemorySegment b, int length) {
|
||||
assert length >= 0;
|
||||
if (a.byteSize() != b.byteSize()) {
|
||||
throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize());
|
||||
}
|
||||
if (length > a.byteSize()) {
|
||||
throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize());
|
||||
}
|
||||
int i = 0;
|
||||
int res = 0;
|
||||
if (length >= SQR_STRIDE) {
|
||||
i += length & ~(SQR_STRIDE - 1);
|
||||
res = sqr8s(a, b, i);
|
||||
}
|
||||
|
||||
// tail
|
||||
for (; i < length; i++) {
|
||||
int dist = a.get(JAVA_BYTE, i) - b.get(JAVA_BYTE, i);
|
||||
res += dist * dist;
|
||||
}
|
||||
assert i == length;
|
||||
return res;
|
||||
}
|
||||
|
||||
private static int dot8Stride() {
|
||||
try {
|
||||
return (int) dot8stride$mh.invokeExact();
|
||||
} catch (Throwable t) {
|
||||
throw new AssertionError(t);
|
||||
}
|
||||
}
|
||||
|
||||
private static int sqr8Stride() {
|
||||
try {
|
||||
return (int) sqr8stride$mh.invokeExact();
|
||||
} catch (Throwable t) {
|
||||
throw new AssertionError(t);
|
||||
}
|
||||
}
|
||||
|
||||
private static int dot8s(MemorySegment a, MemorySegment b, int length) {
|
||||
try {
|
||||
return (int) dot8s$mh.invokeExact(a, b, length);
|
||||
} catch (Throwable t) {
|
||||
throw new AssertionError(t);
|
||||
}
|
||||
}
|
||||
|
||||
private static int sqr8s(MemorySegment a, MemorySegment b, int length) {
|
||||
try {
|
||||
return (int) sqr8s$mh.invokeExact(a, b, length);
|
||||
} catch (Throwable t) {
|
||||
throw new AssertionError(t);
|
||||
}
|
||||
}
|
||||
|
||||
static final MethodHandle DOT_HANDLE;
|
||||
static final MethodHandle SQR_HANDLE;
|
||||
|
||||
static {
|
||||
try {
|
||||
var lookup = MethodHandles.lookup();
|
||||
var mt = MethodType.methodType(int.class, MemorySegment.class, MemorySegment.class, int.class);
|
||||
DOT_HANDLE = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProduct", mt);
|
||||
SQR_HANDLE = lookup.findStatic(JdkVectorSimilarityFunctions.class, "squareDistance", mt);
|
||||
} catch (NoSuchMethodException | IllegalAccessException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public MethodHandle dotProductHandle() {
|
||||
return DOT_HANDLE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MethodHandle squareDistanceHandle() {
|
||||
return SQR_HANDLE;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,7 +19,7 @@ if [ -z "$ARTIFACTORY_API_KEY" ]; then
|
|||
exit 1;
|
||||
fi
|
||||
|
||||
VERSION="1.0.1"
|
||||
VERSION="1.0.3"
|
||||
ARTIFACTORY_REPOSITORY="${ARTIFACTORY_REPOSITORY:-https://artifactory.elastic.dev/artifactory/elasticsearch-native/}"
|
||||
TEMP=$(mktemp -d)
|
||||
|
||||
|
|
|
@ -18,6 +18,34 @@
|
|||
#define SQR8S_STRIDE_BYTES_LEN 16
|
||||
#endif
|
||||
|
||||
#ifdef __linux__
|
||||
#include <sys/auxv.h>
|
||||
#include <asm/hwcap.h>
|
||||
#ifndef HWCAP_NEON
|
||||
#define HWCAP_NEON 0x1000
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef __APPLE__
|
||||
#include <TargetConditionals.h>
|
||||
#endif
|
||||
|
||||
EXPORT int vec_caps() {
|
||||
#ifdef __APPLE__
|
||||
#ifdef TARGET_OS_OSX
|
||||
// All M series Apple silicon support Neon instructions
|
||||
return 1;
|
||||
#else
|
||||
#error "Unsupported Apple platform"
|
||||
#endif
|
||||
#elif __linux__
|
||||
int hwcap = getauxval(AT_HWCAP);
|
||||
return (hwcap & HWCAP_NEON) != 0;
|
||||
#else
|
||||
#error "Unsupported aarch64 platform"
|
||||
#endif
|
||||
}
|
||||
|
||||
EXPORT int dot8s_stride() {
|
||||
return DOT8_STRIDE_BYTES_LEN;
|
||||
}
|
||||
|
|
|
@ -6,7 +6,15 @@
|
|||
* Side Public License, v 1.
|
||||
*/
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#define EXPORT extern "C" __declspec(dllexport)
|
||||
#elif defined(__GNUC__) && !defined(__clang__)
|
||||
#define EXPORT __attribute__((externally_visible,visibility("default")))
|
||||
#elif __clang__
|
||||
#define EXPORT __attribute__((visibility("default")))
|
||||
#endif
|
||||
|
||||
EXPORT int vec_caps();
|
||||
|
||||
EXPORT int dot8s_stride();
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue