mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 09:28:55 -04:00
Add panama implementations of byte-bit and float-bit script operations (#124722)
This commit is contained in:
parent
11fed4502c
commit
7f1203e472
5 changed files with 329 additions and 35 deletions
|
@ -13,6 +13,8 @@ import org.apache.lucene.util.BytesRef;
|
||||||
import org.elasticsearch.common.logging.LogConfigurator;
|
import org.elasticsearch.common.logging.LogConfigurator;
|
||||||
import org.elasticsearch.index.IndexVersion;
|
import org.elasticsearch.index.IndexVersion;
|
||||||
import org.elasticsearch.script.field.vectors.BinaryDenseVector;
|
import org.elasticsearch.script.field.vectors.BinaryDenseVector;
|
||||||
|
import org.elasticsearch.script.field.vectors.BitBinaryDenseVector;
|
||||||
|
import org.elasticsearch.script.field.vectors.BitKnnDenseVector;
|
||||||
import org.elasticsearch.script.field.vectors.ByteBinaryDenseVector;
|
import org.elasticsearch.script.field.vectors.ByteBinaryDenseVector;
|
||||||
import org.elasticsearch.script.field.vectors.ByteKnnDenseVector;
|
import org.elasticsearch.script.field.vectors.ByteKnnDenseVector;
|
||||||
import org.elasticsearch.script.field.vectors.DenseVector;
|
import org.elasticsearch.script.field.vectors.DenseVector;
|
||||||
|
@ -37,30 +39,30 @@ import java.util.concurrent.TimeUnit;
|
||||||
import java.util.function.DoubleSupplier;
|
import java.util.function.DoubleSupplier;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Various benchmarks for the distance functions
|
* Various benchmarks for the distance functions used by indexed and non-indexed vectors.
|
||||||
* used by indexed and non-indexed vectors.
|
* Parameters include doc and query type, dims, function, and implementation.
|
||||||
* Parameters include element, dims, function, and type.
|
|
||||||
* For individual local tests it may be useful to increase
|
* For individual local tests it may be useful to increase
|
||||||
* fork, measurement, and operations per invocation. (Note
|
* fork, measurement, and operations per invocation.
|
||||||
* to also update the benchmark loop if operations per invocation
|
|
||||||
* is increased.)
|
|
||||||
*/
|
*/
|
||||||
@Fork(1)
|
@Fork(1)
|
||||||
@Warmup(iterations = 1)
|
@Warmup(iterations = 1)
|
||||||
@Measurement(iterations = 2)
|
@Measurement(iterations = 2)
|
||||||
@BenchmarkMode(Mode.AverageTime)
|
@BenchmarkMode(Mode.AverageTime)
|
||||||
@OutputTimeUnit(TimeUnit.NANOSECONDS)
|
@OutputTimeUnit(TimeUnit.NANOSECONDS)
|
||||||
@OperationsPerInvocation(25000)
|
@OperationsPerInvocation(DistanceFunctionBenchmark.OPERATIONS)
|
||||||
@State(Scope.Benchmark)
|
@State(Scope.Benchmark)
|
||||||
public class DistanceFunctionBenchmark {
|
public class DistanceFunctionBenchmark {
|
||||||
|
|
||||||
|
public static final int OPERATIONS = 25000;
|
||||||
|
|
||||||
static {
|
static {
|
||||||
LogConfigurator.configureESLogging();
|
LogConfigurator.configureESLogging();
|
||||||
}
|
}
|
||||||
|
|
||||||
public enum VectorType {
|
public enum VectorType {
|
||||||
FLOAT,
|
FLOAT,
|
||||||
BYTE
|
BYTE,
|
||||||
|
BIT
|
||||||
}
|
}
|
||||||
|
|
||||||
public enum Function {
|
public enum Function {
|
||||||
|
@ -122,7 +124,7 @@ public class DistanceFunctionBenchmark {
|
||||||
}
|
}
|
||||||
|
|
||||||
private static BytesRef generateVectorData(float[] vector, float mag) {
|
private static BytesRef generateVectorData(float[] vector, float mag) {
|
||||||
ByteBuffer buffer = ByteBuffer.allocate(vector.length * 4 + 4);
|
ByteBuffer buffer = ByteBuffer.allocate(vector.length * Float.BYTES + Float.BYTES);
|
||||||
for (float f : vector) {
|
for (float f : vector) {
|
||||||
buffer.putFloat(f);
|
buffer.putFloat(f);
|
||||||
}
|
}
|
||||||
|
@ -133,7 +135,7 @@ public class DistanceFunctionBenchmark {
|
||||||
private static BytesRef generateVectorData(byte[] vector) {
|
private static BytesRef generateVectorData(byte[] vector) {
|
||||||
float mag = calculateMag(vector);
|
float mag = calculateMag(vector);
|
||||||
|
|
||||||
ByteBuffer buffer = ByteBuffer.allocate(vector.length + 4);
|
ByteBuffer buffer = ByteBuffer.allocate(vector.length + Float.BYTES);
|
||||||
buffer.put(vector);
|
buffer.put(vector);
|
||||||
buffer.putFloat(mag);
|
buffer.putFloat(mag);
|
||||||
return new BytesRef(buffer.array());
|
return new BytesRef(buffer.array());
|
||||||
|
@ -141,16 +143,21 @@ public class DistanceFunctionBenchmark {
|
||||||
|
|
||||||
@Setup
|
@Setup
|
||||||
public void findBenchmarkImpl() {
|
public void findBenchmarkImpl() {
|
||||||
|
if (dims % 8 != 0) throw new IllegalArgumentException("Dims must be a multiple of 8");
|
||||||
Random r = new Random();
|
Random r = new Random();
|
||||||
|
|
||||||
float[] floatDocVector = new float[dims];
|
float[] floatDocVector = new float[dims];
|
||||||
byte[] byteDocVector = new byte[dims];
|
byte[] byteDocVector = new byte[dims];
|
||||||
|
byte[] bitDocVector = new byte[dims / 8];
|
||||||
|
|
||||||
float[] floatQueryVector = new float[dims];
|
float[] floatQueryVector = new float[dims];
|
||||||
byte[] byteQueryVector = new byte[dims];
|
byte[] byteQueryVector = new byte[dims];
|
||||||
|
byte[] bitQueryVector = new byte[dims / 8];
|
||||||
|
|
||||||
r.nextBytes(byteDocVector);
|
r.nextBytes(byteDocVector);
|
||||||
|
r.nextBytes(bitDocVector);
|
||||||
r.nextBytes(byteQueryVector);
|
r.nextBytes(byteQueryVector);
|
||||||
|
r.nextBytes(bitQueryVector);
|
||||||
for (int i = 0; i < dims; i++) {
|
for (int i = 0; i < dims; i++) {
|
||||||
floatDocVector[i] = r.nextFloat();
|
floatDocVector[i] = r.nextFloat();
|
||||||
floatQueryVector[i] = r.nextFloat();
|
floatQueryVector[i] = r.nextFloat();
|
||||||
|
@ -179,10 +186,11 @@ public class DistanceFunctionBenchmark {
|
||||||
};
|
};
|
||||||
case BYTE -> switch (type) {
|
case BYTE -> switch (type) {
|
||||||
case KNN -> new ByteKnnDenseVector(byteDocVector);
|
case KNN -> new ByteKnnDenseVector(byteDocVector);
|
||||||
case BINARY -> {
|
case BINARY -> new ByteBinaryDenseVector(byteDocVector, generateVectorData(byteDocVector), dims);
|
||||||
BytesRef vectorData = generateVectorData(byteDocVector);
|
};
|
||||||
yield new ByteBinaryDenseVector(byteDocVector, vectorData, dims);
|
case BIT -> switch (type) {
|
||||||
}
|
case KNN -> new BitKnnDenseVector(bitDocVector);
|
||||||
|
case BINARY -> new BitBinaryDenseVector(bitDocVector, new BytesRef(bitDocVector), bitDocVector.length);
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -204,13 +212,20 @@ public class DistanceFunctionBenchmark {
|
||||||
case L2 -> () -> vectorImpl.l2Norm(byteQueryVector);
|
case L2 -> () -> vectorImpl.l2Norm(byteQueryVector);
|
||||||
case HAMMING -> () -> vectorImpl.hamming(byteQueryVector);
|
case HAMMING -> () -> vectorImpl.hamming(byteQueryVector);
|
||||||
};
|
};
|
||||||
|
case BIT -> switch (function) {
|
||||||
|
case DOT -> () -> vectorImpl.dotProduct(bitQueryVector);
|
||||||
|
case COSINE -> throw new UnsupportedOperationException("Unsupported function " + function);
|
||||||
|
case L1 -> () -> vectorImpl.l1Norm(bitQueryVector);
|
||||||
|
case L2 -> () -> vectorImpl.l2Norm(bitQueryVector);
|
||||||
|
case HAMMING -> () -> vectorImpl.hamming(bitQueryVector);
|
||||||
|
};
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@Fork(1)
|
@Fork(1)
|
||||||
@Benchmark
|
@Benchmark
|
||||||
public void benchmark(Blackhole blackhole) {
|
public void benchmark(Blackhole blackhole) {
|
||||||
for (int i = 0; i < 25000; ++i) {
|
for (int i = 0; i < OPERATIONS; ++i) {
|
||||||
blackhole.consume(benchmarkImpl.getAsDouble());
|
blackhole.consume(benchmarkImpl.getAsDouble());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -218,7 +233,7 @@ public class DistanceFunctionBenchmark {
|
||||||
@Fork(value = 1, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
|
@Fork(value = 1, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
|
||||||
@Benchmark
|
@Benchmark
|
||||||
public void vectorBenchmark(Blackhole blackhole) {
|
public void vectorBenchmark(Blackhole blackhole) {
|
||||||
for (int i = 0; i < 25000; ++i) {
|
for (int i = 0; i < OPERATIONS; ++i) {
|
||||||
blackhole.consume(benchmarkImpl.getAsDouble());
|
blackhole.consume(benchmarkImpl.getAsDouble());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
6
docs/changelog/124722.yaml
Normal file
6
docs/changelog/124722.yaml
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
pr: 124722
|
||||||
|
summary: Add panama implementations of byte-bit and float-bit script operations
|
||||||
|
area: Vector Search
|
||||||
|
type: enhancement
|
||||||
|
issues:
|
||||||
|
- 117096
|
|
@ -45,13 +45,17 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
|
||||||
}
|
}
|
||||||
|
|
||||||
public static int ipByteBitImpl(byte[] q, byte[] d) {
|
public static int ipByteBitImpl(byte[] q, byte[] d) {
|
||||||
|
return ipByteBitImpl(q, d, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static int ipByteBitImpl(byte[] q, byte[] d, int start) {
|
||||||
assert q.length == d.length * Byte.SIZE;
|
assert q.length == d.length * Byte.SIZE;
|
||||||
int acc0 = 0;
|
int acc0 = 0;
|
||||||
int acc1 = 0;
|
int acc1 = 0;
|
||||||
int acc2 = 0;
|
int acc2 = 0;
|
||||||
int acc3 = 0;
|
int acc3 = 0;
|
||||||
// now combine the two vectors, summing the byte dimensions where the bit in d is `1`
|
// now combine the two vectors, summing the byte dimensions where the bit in d is `1`
|
||||||
for (int i = 0; i < d.length; i++) {
|
for (int i = start; i < d.length; i++) {
|
||||||
byte mask = d[i];
|
byte mask = d[i];
|
||||||
// Make sure its just 1 or 0
|
// Make sure its just 1 or 0
|
||||||
|
|
||||||
|
@ -69,13 +73,17 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
|
||||||
}
|
}
|
||||||
|
|
||||||
public static float ipFloatBitImpl(float[] q, byte[] d) {
|
public static float ipFloatBitImpl(float[] q, byte[] d) {
|
||||||
|
return ipFloatBitImpl(q, d, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static float ipFloatBitImpl(float[] q, byte[] d, int start) {
|
||||||
assert q.length == d.length * Byte.SIZE;
|
assert q.length == d.length * Byte.SIZE;
|
||||||
float acc0 = 0;
|
float acc0 = 0;
|
||||||
float acc1 = 0;
|
float acc1 = 0;
|
||||||
float acc2 = 0;
|
float acc2 = 0;
|
||||||
float acc3 = 0;
|
float acc3 = 0;
|
||||||
// now combine the two vectors, summing the byte dimensions where the bit in d is `1`
|
// now combine the two vectors, summing the byte dimensions where the bit in d is `1`
|
||||||
for (int i = 0; i < d.length; i++) {
|
for (int i = start; i < d.length; i++) {
|
||||||
byte mask = d[i];
|
byte mask = d[i];
|
||||||
acc0 = fma(q[i * Byte.SIZE + 0], (mask >> 7) & 1, acc0);
|
acc0 = fma(q[i * Byte.SIZE + 0], (mask >> 7) & 1, acc0);
|
||||||
acc1 = fma(q[i * Byte.SIZE + 1], (mask >> 6) & 1, acc1);
|
acc1 = fma(q[i * Byte.SIZE + 1], (mask >> 6) & 1, acc1);
|
||||||
|
|
|
@ -13,10 +13,12 @@ import jdk.incubator.vector.ByteVector;
|
||||||
import jdk.incubator.vector.FloatVector;
|
import jdk.incubator.vector.FloatVector;
|
||||||
import jdk.incubator.vector.IntVector;
|
import jdk.incubator.vector.IntVector;
|
||||||
import jdk.incubator.vector.LongVector;
|
import jdk.incubator.vector.LongVector;
|
||||||
|
import jdk.incubator.vector.VectorMask;
|
||||||
import jdk.incubator.vector.VectorOperators;
|
import jdk.incubator.vector.VectorOperators;
|
||||||
import jdk.incubator.vector.VectorShape;
|
import jdk.incubator.vector.VectorShape;
|
||||||
import jdk.incubator.vector.VectorSpecies;
|
import jdk.incubator.vector.VectorSpecies;
|
||||||
|
|
||||||
|
import org.apache.lucene.util.BitUtil;
|
||||||
import org.apache.lucene.util.Constants;
|
import org.apache.lucene.util.Constants;
|
||||||
|
|
||||||
public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
|
public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
|
||||||
|
@ -51,11 +53,25 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int ipByteBit(byte[] q, byte[] d) {
|
public int ipByteBit(byte[] q, byte[] d) {
|
||||||
|
if (d.length >= 16 && HAS_FAST_INTEGER_VECTORS) {
|
||||||
|
if (VECTOR_BITSIZE >= 512) {
|
||||||
|
return ipByteBit512(q, d);
|
||||||
|
} else if (VECTOR_BITSIZE == 256) {
|
||||||
|
return ipByteBit256(q, d);
|
||||||
|
}
|
||||||
|
}
|
||||||
return DefaultESVectorUtilSupport.ipByteBitImpl(q, d);
|
return DefaultESVectorUtilSupport.ipByteBitImpl(q, d);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public float ipFloatBit(float[] q, byte[] d) {
|
public float ipFloatBit(float[] q, byte[] d) {
|
||||||
|
if (q.length >= 16) {
|
||||||
|
if (VECTOR_BITSIZE >= 512) {
|
||||||
|
return ipFloatBit512(q, d);
|
||||||
|
} else if (VECTOR_BITSIZE == 256) {
|
||||||
|
return ipFloatBit256(q, d);
|
||||||
|
}
|
||||||
|
}
|
||||||
return DefaultESVectorUtilSupport.ipFloatBitImpl(q, d);
|
return DefaultESVectorUtilSupport.ipFloatBitImpl(q, d);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -170,6 +186,240 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
|
||||||
return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
|
return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static final VectorSpecies<Integer> INT_SPECIES_512 = IntVector.SPECIES_512;
|
||||||
|
private static final VectorSpecies<Byte> BYTE_SPECIES_FOR_INT_512 = VectorSpecies.of(
|
||||||
|
byte.class,
|
||||||
|
VectorShape.forBitSize(INT_SPECIES_512.vectorBitSize() / Integer.BYTES)
|
||||||
|
);
|
||||||
|
private static final VectorSpecies<Integer> INT_SPECIES_256 = IntVector.SPECIES_256;
|
||||||
|
private static final VectorSpecies<Byte> BYTE_SPECIES_FOR_INT_256 = VectorSpecies.of(
|
||||||
|
byte.class,
|
||||||
|
VectorShape.forBitSize(INT_SPECIES_256.vectorBitSize() / Integer.BYTES)
|
||||||
|
);
|
||||||
|
|
||||||
|
private static int limit(int length, int sectionSize) {
|
||||||
|
return length - (length % sectionSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
static int ipByteBit512(byte[] q, byte[] d) {
|
||||||
|
assert q.length == d.length * Byte.SIZE;
|
||||||
|
int i = 0;
|
||||||
|
int sum = 0;
|
||||||
|
|
||||||
|
int sectionLength = INT_SPECIES_512.length() * 4;
|
||||||
|
if (q.length >= sectionLength) {
|
||||||
|
IntVector acc0 = IntVector.zero(INT_SPECIES_512);
|
||||||
|
IntVector acc1 = IntVector.zero(INT_SPECIES_512);
|
||||||
|
IntVector acc2 = IntVector.zero(INT_SPECIES_512);
|
||||||
|
IntVector acc3 = IntVector.zero(INT_SPECIES_512);
|
||||||
|
int limit = limit(q.length, sectionLength);
|
||||||
|
for (; i < limit; i += sectionLength) {
|
||||||
|
var vals0 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i).castShape(INT_SPECIES_512, 0);
|
||||||
|
var vals1 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i + INT_SPECIES_512.length()).castShape(INT_SPECIES_512, 0);
|
||||||
|
var vals2 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i + INT_SPECIES_512.length() * 2)
|
||||||
|
.castShape(INT_SPECIES_512, 0);
|
||||||
|
var vals3 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i + INT_SPECIES_512.length() * 3)
|
||||||
|
.castShape(INT_SPECIES_512, 0);
|
||||||
|
|
||||||
|
long maskBits = Long.reverse((long) BitUtil.VH_BE_LONG.get(d, i / 8));
|
||||||
|
var mask0 = VectorMask.fromLong(INT_SPECIES_512, maskBits);
|
||||||
|
var mask1 = VectorMask.fromLong(INT_SPECIES_512, maskBits >> 16);
|
||||||
|
var mask2 = VectorMask.fromLong(INT_SPECIES_512, maskBits >> 32);
|
||||||
|
var mask3 = VectorMask.fromLong(INT_SPECIES_512, maskBits >> 48);
|
||||||
|
|
||||||
|
acc0 = acc0.add(vals0, mask0);
|
||||||
|
acc1 = acc1.add(vals1, mask1);
|
||||||
|
acc2 = acc2.add(vals2, mask2);
|
||||||
|
acc3 = acc3.add(vals3, mask3);
|
||||||
|
}
|
||||||
|
sum += acc0.reduceLanes(VectorOperators.ADD) + acc1.reduceLanes(VectorOperators.ADD) + acc2.reduceLanes(VectorOperators.ADD)
|
||||||
|
+ acc3.reduceLanes(VectorOperators.ADD);
|
||||||
|
}
|
||||||
|
|
||||||
|
sectionLength = INT_SPECIES_256.length();
|
||||||
|
if (q.length - i >= sectionLength) {
|
||||||
|
IntVector acc = IntVector.zero(INT_SPECIES_256);
|
||||||
|
int limit = limit(q.length, sectionLength);
|
||||||
|
for (; i < limit; i += sectionLength) {
|
||||||
|
var vals = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i).castShape(INT_SPECIES_256, 0);
|
||||||
|
|
||||||
|
long maskBits = Integer.reverse(d[i / 8]) >> 24;
|
||||||
|
var mask = VectorMask.fromLong(INT_SPECIES_256, maskBits);
|
||||||
|
|
||||||
|
acc = acc.add(vals, mask);
|
||||||
|
}
|
||||||
|
sum += acc.reduceLanes(VectorOperators.ADD);
|
||||||
|
}
|
||||||
|
|
||||||
|
// that should have got them all (q.length is a multiple of 8, which fits in a 256-bit vector)
|
||||||
|
assert i == q.length;
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int ipByteBit256(byte[] q, byte[] d) {
|
||||||
|
assert q.length == d.length * Byte.SIZE;
|
||||||
|
int i = 0;
|
||||||
|
int sum = 0;
|
||||||
|
|
||||||
|
int sectionLength = INT_SPECIES_256.length() * 4;
|
||||||
|
if (q.length >= sectionLength) {
|
||||||
|
IntVector acc0 = IntVector.zero(INT_SPECIES_256);
|
||||||
|
IntVector acc1 = IntVector.zero(INT_SPECIES_256);
|
||||||
|
IntVector acc2 = IntVector.zero(INT_SPECIES_256);
|
||||||
|
IntVector acc3 = IntVector.zero(INT_SPECIES_256);
|
||||||
|
int limit = limit(q.length, sectionLength);
|
||||||
|
for (; i < limit; i += sectionLength) {
|
||||||
|
var vals0 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i).castShape(INT_SPECIES_256, 0);
|
||||||
|
var vals1 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i + INT_SPECIES_256.length()).castShape(INT_SPECIES_256, 0);
|
||||||
|
var vals2 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i + INT_SPECIES_256.length() * 2)
|
||||||
|
.castShape(INT_SPECIES_256, 0);
|
||||||
|
var vals3 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i + INT_SPECIES_256.length() * 3)
|
||||||
|
.castShape(INT_SPECIES_256, 0);
|
||||||
|
|
||||||
|
long maskBits = Integer.reverse((int) BitUtil.VH_BE_INT.get(d, i / 8));
|
||||||
|
var mask0 = VectorMask.fromLong(INT_SPECIES_256, maskBits);
|
||||||
|
var mask1 = VectorMask.fromLong(INT_SPECIES_256, maskBits >> 8);
|
||||||
|
var mask2 = VectorMask.fromLong(INT_SPECIES_256, maskBits >> 16);
|
||||||
|
var mask3 = VectorMask.fromLong(INT_SPECIES_256, maskBits >> 24);
|
||||||
|
|
||||||
|
acc0 = acc0.add(vals0, mask0);
|
||||||
|
acc1 = acc1.add(vals1, mask1);
|
||||||
|
acc2 = acc2.add(vals2, mask2);
|
||||||
|
acc3 = acc3.add(vals3, mask3);
|
||||||
|
}
|
||||||
|
sum += acc0.reduceLanes(VectorOperators.ADD) + acc1.reduceLanes(VectorOperators.ADD) + acc2.reduceLanes(VectorOperators.ADD)
|
||||||
|
+ acc3.reduceLanes(VectorOperators.ADD);
|
||||||
|
}
|
||||||
|
|
||||||
|
sectionLength = INT_SPECIES_256.length();
|
||||||
|
if (q.length - i >= sectionLength) {
|
||||||
|
IntVector acc = IntVector.zero(INT_SPECIES_256);
|
||||||
|
int limit = limit(q.length, sectionLength);
|
||||||
|
for (; i < limit; i += sectionLength) {
|
||||||
|
var vals = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i).castShape(INT_SPECIES_256, 0);
|
||||||
|
|
||||||
|
long maskBits = Integer.reverse(d[i / 8]) >> 24;
|
||||||
|
var mask = VectorMask.fromLong(INT_SPECIES_256, maskBits);
|
||||||
|
|
||||||
|
acc = acc.add(vals, mask);
|
||||||
|
}
|
||||||
|
sum += acc.reduceLanes(VectorOperators.ADD);
|
||||||
|
}
|
||||||
|
|
||||||
|
// that should have got them all (q.length is a multiple of 8, which fits in a 256-bit vector)
|
||||||
|
assert i == q.length;
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final VectorSpecies<Float> FLOAT_SPECIES_512 = FloatVector.SPECIES_512;
|
||||||
|
private static final VectorSpecies<Float> FLOAT_SPECIES_256 = FloatVector.SPECIES_256;
|
||||||
|
|
||||||
|
static float ipFloatBit512(float[] q, byte[] d) {
|
||||||
|
assert q.length == d.length * Byte.SIZE;
|
||||||
|
int i = 0;
|
||||||
|
float sum = 0;
|
||||||
|
|
||||||
|
int sectionLength = FLOAT_SPECIES_512.length() * 4;
|
||||||
|
if (q.length >= sectionLength) {
|
||||||
|
FloatVector acc0 = FloatVector.zero(FLOAT_SPECIES_512);
|
||||||
|
FloatVector acc1 = FloatVector.zero(FLOAT_SPECIES_512);
|
||||||
|
FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES_512);
|
||||||
|
FloatVector acc3 = FloatVector.zero(FLOAT_SPECIES_512);
|
||||||
|
int limit = limit(q.length, sectionLength);
|
||||||
|
for (; i < limit; i += sectionLength) {
|
||||||
|
var floats0 = FloatVector.fromArray(FLOAT_SPECIES_512, q, i);
|
||||||
|
var floats1 = FloatVector.fromArray(FLOAT_SPECIES_512, q, i + FLOAT_SPECIES_512.length());
|
||||||
|
var floats2 = FloatVector.fromArray(FLOAT_SPECIES_512, q, i + FLOAT_SPECIES_512.length() * 2);
|
||||||
|
var floats3 = FloatVector.fromArray(FLOAT_SPECIES_512, q, i + FLOAT_SPECIES_512.length() * 3);
|
||||||
|
|
||||||
|
long maskBits = Long.reverse((long) BitUtil.VH_BE_LONG.get(d, i / 8));
|
||||||
|
var mask0 = VectorMask.fromLong(FLOAT_SPECIES_512, maskBits);
|
||||||
|
var mask1 = VectorMask.fromLong(FLOAT_SPECIES_512, maskBits >> 16);
|
||||||
|
var mask2 = VectorMask.fromLong(FLOAT_SPECIES_512, maskBits >> 32);
|
||||||
|
var mask3 = VectorMask.fromLong(FLOAT_SPECIES_512, maskBits >> 48);
|
||||||
|
|
||||||
|
acc0 = acc0.add(floats0, mask0);
|
||||||
|
acc1 = acc1.add(floats1, mask1);
|
||||||
|
acc2 = acc2.add(floats2, mask2);
|
||||||
|
acc3 = acc3.add(floats3, mask3);
|
||||||
|
}
|
||||||
|
sum += acc0.reduceLanes(VectorOperators.ADD) + acc1.reduceLanes(VectorOperators.ADD) + acc2.reduceLanes(VectorOperators.ADD)
|
||||||
|
+ acc3.reduceLanes(VectorOperators.ADD);
|
||||||
|
}
|
||||||
|
|
||||||
|
sectionLength = FLOAT_SPECIES_256.length();
|
||||||
|
if (q.length - i >= sectionLength) {
|
||||||
|
FloatVector acc = FloatVector.zero(FLOAT_SPECIES_256);
|
||||||
|
int limit = limit(q.length, sectionLength);
|
||||||
|
for (; i < limit; i += sectionLength) {
|
||||||
|
var floats = FloatVector.fromArray(FLOAT_SPECIES_256, q, i);
|
||||||
|
|
||||||
|
long maskBits = Integer.reverse(d[i / 8]) >> 24;
|
||||||
|
var mask = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits);
|
||||||
|
|
||||||
|
acc = acc.add(floats, mask);
|
||||||
|
}
|
||||||
|
sum += acc.reduceLanes(VectorOperators.ADD);
|
||||||
|
}
|
||||||
|
|
||||||
|
// that should have got them all (q.length is a multiple of 8, which fits in a 256-bit vector)
|
||||||
|
assert i == q.length;
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
static float ipFloatBit256(float[] q, byte[] d) {
|
||||||
|
assert q.length == d.length * Byte.SIZE;
|
||||||
|
int i = 0;
|
||||||
|
float sum = 0;
|
||||||
|
|
||||||
|
int sectionLength = FLOAT_SPECIES_256.length() * 4;
|
||||||
|
if (q.length >= sectionLength) {
|
||||||
|
FloatVector acc0 = FloatVector.zero(FLOAT_SPECIES_256);
|
||||||
|
FloatVector acc1 = FloatVector.zero(FLOAT_SPECIES_256);
|
||||||
|
FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES_256);
|
||||||
|
FloatVector acc3 = FloatVector.zero(FLOAT_SPECIES_256);
|
||||||
|
int limit = limit(q.length, sectionLength);
|
||||||
|
for (; i < limit; i += sectionLength) {
|
||||||
|
var floats0 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i);
|
||||||
|
var floats1 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i + FLOAT_SPECIES_256.length());
|
||||||
|
var floats2 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i + FLOAT_SPECIES_256.length() * 2);
|
||||||
|
var floats3 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i + FLOAT_SPECIES_256.length() * 3);
|
||||||
|
|
||||||
|
long maskBits = Integer.reverse((int) BitUtil.VH_BE_INT.get(d, i / 8));
|
||||||
|
var mask0 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits);
|
||||||
|
var mask1 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits >> 8);
|
||||||
|
var mask2 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits >> 16);
|
||||||
|
var mask3 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits >> 24);
|
||||||
|
|
||||||
|
acc0 = acc0.add(floats0, mask0);
|
||||||
|
acc1 = acc1.add(floats1, mask1);
|
||||||
|
acc2 = acc2.add(floats2, mask2);
|
||||||
|
acc3 = acc3.add(floats3, mask3);
|
||||||
|
}
|
||||||
|
sum += acc0.reduceLanes(VectorOperators.ADD) + acc1.reduceLanes(VectorOperators.ADD) + acc2.reduceLanes(VectorOperators.ADD)
|
||||||
|
+ acc3.reduceLanes(VectorOperators.ADD);
|
||||||
|
}
|
||||||
|
|
||||||
|
sectionLength = FLOAT_SPECIES_256.length();
|
||||||
|
if (q.length - i >= sectionLength) {
|
||||||
|
FloatVector acc = FloatVector.zero(FLOAT_SPECIES_256);
|
||||||
|
int limit = limit(q.length, sectionLength);
|
||||||
|
for (; i < limit; i += sectionLength) {
|
||||||
|
var floats = FloatVector.fromArray(FLOAT_SPECIES_256, q, i);
|
||||||
|
|
||||||
|
long maskBits = Integer.reverse(d[i / 8]) >> 24;
|
||||||
|
var mask = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits);
|
||||||
|
|
||||||
|
acc = acc.add(floats, mask);
|
||||||
|
}
|
||||||
|
sum += acc.reduceLanes(VectorOperators.ADD);
|
||||||
|
}
|
||||||
|
|
||||||
|
// that should have got them all (q.length is a multiple of 8, which fits in a 256-bit vector)
|
||||||
|
assert i == q.length;
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
private static final VectorSpecies<Float> PREFERRED_FLOAT_SPECIES = FloatVector.SPECIES_PREFERRED;
|
private static final VectorSpecies<Float> PREFERRED_FLOAT_SPECIES = FloatVector.SPECIES_PREFERRED;
|
||||||
private static final VectorSpecies<Byte> BYTE_SPECIES_FOR_PREFFERED_FLOATS;
|
private static final VectorSpecies<Byte> BYTE_SPECIES_FOR_PREFFERED_FLOATS;
|
||||||
|
|
||||||
|
@ -177,7 +427,7 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
|
||||||
VectorSpecies<Byte> byteForFloat;
|
VectorSpecies<Byte> byteForFloat;
|
||||||
try {
|
try {
|
||||||
// calculate vector size to convert from single bytes to 4-byte floats
|
// calculate vector size to convert from single bytes to 4-byte floats
|
||||||
byteForFloat = VectorSpecies.of(byte.class, VectorShape.forBitSize(PREFERRED_FLOAT_SPECIES.vectorBitSize() / Integer.BYTES));
|
byteForFloat = VectorSpecies.of(byte.class, VectorShape.forBitSize(PREFERRED_FLOAT_SPECIES.vectorBitSize() / Float.BYTES));
|
||||||
} catch (IllegalArgumentException e) {
|
} catch (IllegalArgumentException e) {
|
||||||
// can't get a byte vector size small enough, just use default impl
|
// can't get a byte vector size small enough, just use default impl
|
||||||
byteForFloat = null;
|
byteForFloat = null;
|
||||||
|
|
|
@ -13,7 +13,6 @@ import org.elasticsearch.simdvec.internal.vectorization.BaseVectorizationTests;
|
||||||
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
|
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.function.ToDoubleBiFunction;
|
|
||||||
import java.util.function.ToLongBiFunction;
|
import java.util.function.ToLongBiFunction;
|
||||||
|
|
||||||
import static org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport.B_QUERY;
|
import static org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport.B_QUERY;
|
||||||
|
@ -25,30 +24,44 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
|
||||||
static final ESVectorizationProvider defOrPanamaProvider = BaseVectorizationTests.maybePanamaProvider();
|
static final ESVectorizationProvider defOrPanamaProvider = BaseVectorizationTests.maybePanamaProvider();
|
||||||
|
|
||||||
public void testIpByteBit() {
|
public void testIpByteBit() {
|
||||||
byte[] q = new byte[16];
|
byte[] d = new byte[random().nextInt(128)];
|
||||||
byte[] d = new byte[] { (byte) Integer.parseInt("01100010", 2), (byte) Integer.parseInt("10100111", 2) };
|
byte[] q = new byte[d.length * 8];
|
||||||
|
random().nextBytes(d);
|
||||||
random().nextBytes(q);
|
random().nextBytes(q);
|
||||||
int expected = q[1] + q[2] + q[6] + q[8] + q[10] + q[13] + q[14] + q[15];
|
|
||||||
assertEquals(expected, ESVectorUtil.ipByteBit(q, d));
|
int sum = 0;
|
||||||
|
for (int i = 0; i < q.length; i++) {
|
||||||
|
if (((d[i / 8] << (i % 8)) & 0x80) == 0x80) {
|
||||||
|
sum += q[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(sum, ESVectorUtil.ipByteBit(q, d));
|
||||||
|
assertEquals(sum, defaultedProvider.getVectorUtilSupport().ipByteBit(q, d));
|
||||||
|
assertEquals(sum, defOrPanamaProvider.getVectorUtilSupport().ipByteBit(q, d));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testIpFloatBit() {
|
public void testIpFloatBit() {
|
||||||
float[] q = new float[16];
|
byte[] d = new byte[random().nextInt(128)];
|
||||||
byte[] d = new byte[] { (byte) Integer.parseInt("01100010", 2), (byte) Integer.parseInt("10100111", 2) };
|
float[] q = new float[d.length * 8];
|
||||||
|
random().nextBytes(d);
|
||||||
|
|
||||||
|
float sum = 0;
|
||||||
for (int i = 0; i < q.length; i++) {
|
for (int i = 0; i < q.length; i++) {
|
||||||
q[i] = random().nextFloat();
|
q[i] = random().nextFloat();
|
||||||
|
if (((d[i / 8] << (i % 8)) & 0x80) == 0x80) {
|
||||||
|
sum += q[i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
float expected = q[1] + q[2] + q[6] + q[8] + q[10] + q[13] + q[14] + q[15];
|
|
||||||
assertEquals(expected, ESVectorUtil.ipFloatBit(q, d), 1e-6);
|
double delta = 1e-5 * q.length;
|
||||||
|
|
||||||
|
assertEquals(sum, ESVectorUtil.ipFloatBit(q, d), delta);
|
||||||
|
assertEquals(sum, defaultedProvider.getVectorUtilSupport().ipFloatBit(q, d), delta);
|
||||||
|
assertEquals(sum, defOrPanamaProvider.getVectorUtilSupport().ipFloatBit(q, d), delta);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testIpFloatByte() {
|
public void testIpFloatByte() {
|
||||||
testIpFloatByteImpl(ESVectorUtil::ipFloatByte);
|
|
||||||
testIpFloatByteImpl(defaultedProvider.getVectorUtilSupport()::ipFloatByte);
|
|
||||||
testIpFloatByteImpl(defOrPanamaProvider.getVectorUtilSupport()::ipFloatByte);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void testIpFloatByteImpl(ToDoubleBiFunction<float[], byte[]> impl) {
|
|
||||||
int vectorSize = randomIntBetween(1, 1024);
|
int vectorSize = randomIntBetween(1, 1024);
|
||||||
// scale the delta according to the vector size
|
// scale the delta according to the vector size
|
||||||
double delta = 1e-5 * vectorSize;
|
double delta = 1e-5 * vectorSize;
|
||||||
|
@ -64,7 +77,9 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
|
||||||
for (int i = 0; i < q.length; i++) {
|
for (int i = 0; i < q.length; i++) {
|
||||||
expected += q[i] * d[i];
|
expected += q[i] * d[i];
|
||||||
}
|
}
|
||||||
assertThat(impl.applyAsDouble(q, d), closeTo(expected, delta));
|
assertThat((double) ESVectorUtil.ipFloatByte(q, d), closeTo(expected, delta));
|
||||||
|
assertThat((double) defaultedProvider.getVectorUtilSupport().ipFloatByte(q, d), closeTo(expected, delta));
|
||||||
|
assertThat((double) defOrPanamaProvider.getVectorUtilSupport().ipFloatByte(q, d), closeTo(expected, delta));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testBitAndCount() {
|
public void testBitAndCount() {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue