From 7f1203e472d196008570faeb75fd6ee06bc9da39 Mon Sep 17 00:00:00 2001 From: Simon Cooper Date: Tue, 25 Mar 2025 13:59:11 +0000 Subject: [PATCH] Add panama implementations of byte-bit and float-bit script operations (#124722) --- .../vector/DistanceFunctionBenchmark.java | 47 ++-- docs/changelog/124722.yaml | 6 + .../DefaultESVectorUtilSupport.java | 12 +- .../PanamaESVectorUtilSupport.java | 252 +++++++++++++++++- .../simdvec/ESVectorUtilTests.java | 47 ++-- 5 files changed, 329 insertions(+), 35 deletions(-) create mode 100644 docs/changelog/124722.yaml diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/DistanceFunctionBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/DistanceFunctionBenchmark.java index ed6ae4e46459..546b72cf446a 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/DistanceFunctionBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/DistanceFunctionBenchmark.java @@ -13,6 +13,8 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.logging.LogConfigurator; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.script.field.vectors.BinaryDenseVector; +import org.elasticsearch.script.field.vectors.BitBinaryDenseVector; +import org.elasticsearch.script.field.vectors.BitKnnDenseVector; import org.elasticsearch.script.field.vectors.ByteBinaryDenseVector; import org.elasticsearch.script.field.vectors.ByteKnnDenseVector; import org.elasticsearch.script.field.vectors.DenseVector; @@ -37,30 +39,30 @@ import java.util.concurrent.TimeUnit; import java.util.function.DoubleSupplier; /** - * Various benchmarks for the distance functions - * used by indexed and non-indexed vectors. - * Parameters include element, dims, function, and type. + * Various benchmarks for the distance functions used by indexed and non-indexed vectors. + * Parameters include doc and query type, dims, function, and implementation. * For individual local tests it may be useful to increase - * fork, measurement, and operations per invocation. (Note - * to also update the benchmark loop if operations per invocation - * is increased.) + * fork, measurement, and operations per invocation. */ @Fork(1) @Warmup(iterations = 1) @Measurement(iterations = 2) @BenchmarkMode(Mode.AverageTime) @OutputTimeUnit(TimeUnit.NANOSECONDS) -@OperationsPerInvocation(25000) +@OperationsPerInvocation(DistanceFunctionBenchmark.OPERATIONS) @State(Scope.Benchmark) public class DistanceFunctionBenchmark { + public static final int OPERATIONS = 25000; + static { LogConfigurator.configureESLogging(); } public enum VectorType { FLOAT, - BYTE + BYTE, + BIT } public enum Function { @@ -122,7 +124,7 @@ public class DistanceFunctionBenchmark { } private static BytesRef generateVectorData(float[] vector, float mag) { - ByteBuffer buffer = ByteBuffer.allocate(vector.length * 4 + 4); + ByteBuffer buffer = ByteBuffer.allocate(vector.length * Float.BYTES + Float.BYTES); for (float f : vector) { buffer.putFloat(f); } @@ -133,7 +135,7 @@ public class DistanceFunctionBenchmark { private static BytesRef generateVectorData(byte[] vector) { float mag = calculateMag(vector); - ByteBuffer buffer = ByteBuffer.allocate(vector.length + 4); + ByteBuffer buffer = ByteBuffer.allocate(vector.length + Float.BYTES); buffer.put(vector); buffer.putFloat(mag); return new BytesRef(buffer.array()); @@ -141,16 +143,21 @@ public class DistanceFunctionBenchmark { @Setup public void findBenchmarkImpl() { + if (dims % 8 != 0) throw new IllegalArgumentException("Dims must be a multiple of 8"); Random r = new Random(); float[] floatDocVector = new float[dims]; byte[] byteDocVector = new byte[dims]; + byte[] bitDocVector = new byte[dims / 8]; float[] floatQueryVector = new float[dims]; byte[] byteQueryVector = new byte[dims]; + byte[] bitQueryVector = new byte[dims / 8]; r.nextBytes(byteDocVector); + r.nextBytes(bitDocVector); r.nextBytes(byteQueryVector); + r.nextBytes(bitQueryVector); for (int i = 0; i < dims; i++) { floatDocVector[i] = r.nextFloat(); floatQueryVector[i] = r.nextFloat(); @@ -179,10 +186,11 @@ public class DistanceFunctionBenchmark { }; case BYTE -> switch (type) { case KNN -> new ByteKnnDenseVector(byteDocVector); - case BINARY -> { - BytesRef vectorData = generateVectorData(byteDocVector); - yield new ByteBinaryDenseVector(byteDocVector, vectorData, dims); - } + case BINARY -> new ByteBinaryDenseVector(byteDocVector, generateVectorData(byteDocVector), dims); + }; + case BIT -> switch (type) { + case KNN -> new BitKnnDenseVector(bitDocVector); + case BINARY -> new BitBinaryDenseVector(bitDocVector, new BytesRef(bitDocVector), bitDocVector.length); }; }; @@ -204,13 +212,20 @@ public class DistanceFunctionBenchmark { case L2 -> () -> vectorImpl.l2Norm(byteQueryVector); case HAMMING -> () -> vectorImpl.hamming(byteQueryVector); }; + case BIT -> switch (function) { + case DOT -> () -> vectorImpl.dotProduct(bitQueryVector); + case COSINE -> throw new UnsupportedOperationException("Unsupported function " + function); + case L1 -> () -> vectorImpl.l1Norm(bitQueryVector); + case L2 -> () -> vectorImpl.l2Norm(bitQueryVector); + case HAMMING -> () -> vectorImpl.hamming(bitQueryVector); + }; }; } @Fork(1) @Benchmark public void benchmark(Blackhole blackhole) { - for (int i = 0; i < 25000; ++i) { + for (int i = 0; i < OPERATIONS; ++i) { blackhole.consume(benchmarkImpl.getAsDouble()); } } @@ -218,7 +233,7 @@ public class DistanceFunctionBenchmark { @Fork(value = 1, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) @Benchmark public void vectorBenchmark(Blackhole blackhole) { - for (int i = 0; i < 25000; ++i) { + for (int i = 0; i < OPERATIONS; ++i) { blackhole.consume(benchmarkImpl.getAsDouble()); } } diff --git a/docs/changelog/124722.yaml b/docs/changelog/124722.yaml new file mode 100644 index 000000000000..9bbfd846ef2e --- /dev/null +++ b/docs/changelog/124722.yaml @@ -0,0 +1,6 @@ +pr: 124722 +summary: Add panama implementations of byte-bit and float-bit script operations +area: Vector Search +type: enhancement +issues: + - 117096 diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java index 058982aaf551..90226c3b0e94 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java @@ -45,13 +45,17 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport { } public static int ipByteBitImpl(byte[] q, byte[] d) { + return ipByteBitImpl(q, d, 0); + } + + public static int ipByteBitImpl(byte[] q, byte[] d, int start) { assert q.length == d.length * Byte.SIZE; int acc0 = 0; int acc1 = 0; int acc2 = 0; int acc3 = 0; // now combine the two vectors, summing the byte dimensions where the bit in d is `1` - for (int i = 0; i < d.length; i++) { + for (int i = start; i < d.length; i++) { byte mask = d[i]; // Make sure its just 1 or 0 @@ -69,13 +73,17 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport { } public static float ipFloatBitImpl(float[] q, byte[] d) { + return ipFloatBitImpl(q, d, 0); + } + + static float ipFloatBitImpl(float[] q, byte[] d, int start) { assert q.length == d.length * Byte.SIZE; float acc0 = 0; float acc1 = 0; float acc2 = 0; float acc3 = 0; // now combine the two vectors, summing the byte dimensions where the bit in d is `1` - for (int i = 0; i < d.length; i++) { + for (int i = start; i < d.length; i++) { byte mask = d[i]; acc0 = fma(q[i * Byte.SIZE + 0], (mask >> 7) & 1, acc0); acc1 = fma(q[i * Byte.SIZE + 1], (mask >> 6) & 1, acc1); diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java index 8ef3f2a7f988..53f5924cbe72 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java @@ -13,10 +13,12 @@ import jdk.incubator.vector.ByteVector; import jdk.incubator.vector.FloatVector; import jdk.incubator.vector.IntVector; import jdk.incubator.vector.LongVector; +import jdk.incubator.vector.VectorMask; import jdk.incubator.vector.VectorOperators; import jdk.incubator.vector.VectorShape; import jdk.incubator.vector.VectorSpecies; +import org.apache.lucene.util.BitUtil; import org.apache.lucene.util.Constants; public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport { @@ -51,11 +53,25 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport { @Override public int ipByteBit(byte[] q, byte[] d) { + if (d.length >= 16 && HAS_FAST_INTEGER_VECTORS) { + if (VECTOR_BITSIZE >= 512) { + return ipByteBit512(q, d); + } else if (VECTOR_BITSIZE == 256) { + return ipByteBit256(q, d); + } + } return DefaultESVectorUtilSupport.ipByteBitImpl(q, d); } @Override public float ipFloatBit(float[] q, byte[] d) { + if (q.length >= 16) { + if (VECTOR_BITSIZE >= 512) { + return ipFloatBit512(q, d); + } else if (VECTOR_BITSIZE == 256) { + return ipFloatBit256(q, d); + } + } return DefaultESVectorUtilSupport.ipFloatBitImpl(q, d); } @@ -170,6 +186,240 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport { return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); } + private static final VectorSpecies INT_SPECIES_512 = IntVector.SPECIES_512; + private static final VectorSpecies BYTE_SPECIES_FOR_INT_512 = VectorSpecies.of( + byte.class, + VectorShape.forBitSize(INT_SPECIES_512.vectorBitSize() / Integer.BYTES) + ); + private static final VectorSpecies INT_SPECIES_256 = IntVector.SPECIES_256; + private static final VectorSpecies BYTE_SPECIES_FOR_INT_256 = VectorSpecies.of( + byte.class, + VectorShape.forBitSize(INT_SPECIES_256.vectorBitSize() / Integer.BYTES) + ); + + private static int limit(int length, int sectionSize) { + return length - (length % sectionSize); + } + + static int ipByteBit512(byte[] q, byte[] d) { + assert q.length == d.length * Byte.SIZE; + int i = 0; + int sum = 0; + + int sectionLength = INT_SPECIES_512.length() * 4; + if (q.length >= sectionLength) { + IntVector acc0 = IntVector.zero(INT_SPECIES_512); + IntVector acc1 = IntVector.zero(INT_SPECIES_512); + IntVector acc2 = IntVector.zero(INT_SPECIES_512); + IntVector acc3 = IntVector.zero(INT_SPECIES_512); + int limit = limit(q.length, sectionLength); + for (; i < limit; i += sectionLength) { + var vals0 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i).castShape(INT_SPECIES_512, 0); + var vals1 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i + INT_SPECIES_512.length()).castShape(INT_SPECIES_512, 0); + var vals2 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i + INT_SPECIES_512.length() * 2) + .castShape(INT_SPECIES_512, 0); + var vals3 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i + INT_SPECIES_512.length() * 3) + .castShape(INT_SPECIES_512, 0); + + long maskBits = Long.reverse((long) BitUtil.VH_BE_LONG.get(d, i / 8)); + var mask0 = VectorMask.fromLong(INT_SPECIES_512, maskBits); + var mask1 = VectorMask.fromLong(INT_SPECIES_512, maskBits >> 16); + var mask2 = VectorMask.fromLong(INT_SPECIES_512, maskBits >> 32); + var mask3 = VectorMask.fromLong(INT_SPECIES_512, maskBits >> 48); + + acc0 = acc0.add(vals0, mask0); + acc1 = acc1.add(vals1, mask1); + acc2 = acc2.add(vals2, mask2); + acc3 = acc3.add(vals3, mask3); + } + sum += acc0.reduceLanes(VectorOperators.ADD) + acc1.reduceLanes(VectorOperators.ADD) + acc2.reduceLanes(VectorOperators.ADD) + + acc3.reduceLanes(VectorOperators.ADD); + } + + sectionLength = INT_SPECIES_256.length(); + if (q.length - i >= sectionLength) { + IntVector acc = IntVector.zero(INT_SPECIES_256); + int limit = limit(q.length, sectionLength); + for (; i < limit; i += sectionLength) { + var vals = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i).castShape(INT_SPECIES_256, 0); + + long maskBits = Integer.reverse(d[i / 8]) >> 24; + var mask = VectorMask.fromLong(INT_SPECIES_256, maskBits); + + acc = acc.add(vals, mask); + } + sum += acc.reduceLanes(VectorOperators.ADD); + } + + // that should have got them all (q.length is a multiple of 8, which fits in a 256-bit vector) + assert i == q.length; + return sum; + } + + static int ipByteBit256(byte[] q, byte[] d) { + assert q.length == d.length * Byte.SIZE; + int i = 0; + int sum = 0; + + int sectionLength = INT_SPECIES_256.length() * 4; + if (q.length >= sectionLength) { + IntVector acc0 = IntVector.zero(INT_SPECIES_256); + IntVector acc1 = IntVector.zero(INT_SPECIES_256); + IntVector acc2 = IntVector.zero(INT_SPECIES_256); + IntVector acc3 = IntVector.zero(INT_SPECIES_256); + int limit = limit(q.length, sectionLength); + for (; i < limit; i += sectionLength) { + var vals0 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i).castShape(INT_SPECIES_256, 0); + var vals1 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i + INT_SPECIES_256.length()).castShape(INT_SPECIES_256, 0); + var vals2 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i + INT_SPECIES_256.length() * 2) + .castShape(INT_SPECIES_256, 0); + var vals3 = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i + INT_SPECIES_256.length() * 3) + .castShape(INT_SPECIES_256, 0); + + long maskBits = Integer.reverse((int) BitUtil.VH_BE_INT.get(d, i / 8)); + var mask0 = VectorMask.fromLong(INT_SPECIES_256, maskBits); + var mask1 = VectorMask.fromLong(INT_SPECIES_256, maskBits >> 8); + var mask2 = VectorMask.fromLong(INT_SPECIES_256, maskBits >> 16); + var mask3 = VectorMask.fromLong(INT_SPECIES_256, maskBits >> 24); + + acc0 = acc0.add(vals0, mask0); + acc1 = acc1.add(vals1, mask1); + acc2 = acc2.add(vals2, mask2); + acc3 = acc3.add(vals3, mask3); + } + sum += acc0.reduceLanes(VectorOperators.ADD) + acc1.reduceLanes(VectorOperators.ADD) + acc2.reduceLanes(VectorOperators.ADD) + + acc3.reduceLanes(VectorOperators.ADD); + } + + sectionLength = INT_SPECIES_256.length(); + if (q.length - i >= sectionLength) { + IntVector acc = IntVector.zero(INT_SPECIES_256); + int limit = limit(q.length, sectionLength); + for (; i < limit; i += sectionLength) { + var vals = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i).castShape(INT_SPECIES_256, 0); + + long maskBits = Integer.reverse(d[i / 8]) >> 24; + var mask = VectorMask.fromLong(INT_SPECIES_256, maskBits); + + acc = acc.add(vals, mask); + } + sum += acc.reduceLanes(VectorOperators.ADD); + } + + // that should have got them all (q.length is a multiple of 8, which fits in a 256-bit vector) + assert i == q.length; + return sum; + } + + private static final VectorSpecies FLOAT_SPECIES_512 = FloatVector.SPECIES_512; + private static final VectorSpecies FLOAT_SPECIES_256 = FloatVector.SPECIES_256; + + static float ipFloatBit512(float[] q, byte[] d) { + assert q.length == d.length * Byte.SIZE; + int i = 0; + float sum = 0; + + int sectionLength = FLOAT_SPECIES_512.length() * 4; + if (q.length >= sectionLength) { + FloatVector acc0 = FloatVector.zero(FLOAT_SPECIES_512); + FloatVector acc1 = FloatVector.zero(FLOAT_SPECIES_512); + FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES_512); + FloatVector acc3 = FloatVector.zero(FLOAT_SPECIES_512); + int limit = limit(q.length, sectionLength); + for (; i < limit; i += sectionLength) { + var floats0 = FloatVector.fromArray(FLOAT_SPECIES_512, q, i); + var floats1 = FloatVector.fromArray(FLOAT_SPECIES_512, q, i + FLOAT_SPECIES_512.length()); + var floats2 = FloatVector.fromArray(FLOAT_SPECIES_512, q, i + FLOAT_SPECIES_512.length() * 2); + var floats3 = FloatVector.fromArray(FLOAT_SPECIES_512, q, i + FLOAT_SPECIES_512.length() * 3); + + long maskBits = Long.reverse((long) BitUtil.VH_BE_LONG.get(d, i / 8)); + var mask0 = VectorMask.fromLong(FLOAT_SPECIES_512, maskBits); + var mask1 = VectorMask.fromLong(FLOAT_SPECIES_512, maskBits >> 16); + var mask2 = VectorMask.fromLong(FLOAT_SPECIES_512, maskBits >> 32); + var mask3 = VectorMask.fromLong(FLOAT_SPECIES_512, maskBits >> 48); + + acc0 = acc0.add(floats0, mask0); + acc1 = acc1.add(floats1, mask1); + acc2 = acc2.add(floats2, mask2); + acc3 = acc3.add(floats3, mask3); + } + sum += acc0.reduceLanes(VectorOperators.ADD) + acc1.reduceLanes(VectorOperators.ADD) + acc2.reduceLanes(VectorOperators.ADD) + + acc3.reduceLanes(VectorOperators.ADD); + } + + sectionLength = FLOAT_SPECIES_256.length(); + if (q.length - i >= sectionLength) { + FloatVector acc = FloatVector.zero(FLOAT_SPECIES_256); + int limit = limit(q.length, sectionLength); + for (; i < limit; i += sectionLength) { + var floats = FloatVector.fromArray(FLOAT_SPECIES_256, q, i); + + long maskBits = Integer.reverse(d[i / 8]) >> 24; + var mask = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits); + + acc = acc.add(floats, mask); + } + sum += acc.reduceLanes(VectorOperators.ADD); + } + + // that should have got them all (q.length is a multiple of 8, which fits in a 256-bit vector) + assert i == q.length; + return sum; + } + + static float ipFloatBit256(float[] q, byte[] d) { + assert q.length == d.length * Byte.SIZE; + int i = 0; + float sum = 0; + + int sectionLength = FLOAT_SPECIES_256.length() * 4; + if (q.length >= sectionLength) { + FloatVector acc0 = FloatVector.zero(FLOAT_SPECIES_256); + FloatVector acc1 = FloatVector.zero(FLOAT_SPECIES_256); + FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES_256); + FloatVector acc3 = FloatVector.zero(FLOAT_SPECIES_256); + int limit = limit(q.length, sectionLength); + for (; i < limit; i += sectionLength) { + var floats0 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i); + var floats1 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i + FLOAT_SPECIES_256.length()); + var floats2 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i + FLOAT_SPECIES_256.length() * 2); + var floats3 = FloatVector.fromArray(FLOAT_SPECIES_256, q, i + FLOAT_SPECIES_256.length() * 3); + + long maskBits = Integer.reverse((int) BitUtil.VH_BE_INT.get(d, i / 8)); + var mask0 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits); + var mask1 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits >> 8); + var mask2 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits >> 16); + var mask3 = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits >> 24); + + acc0 = acc0.add(floats0, mask0); + acc1 = acc1.add(floats1, mask1); + acc2 = acc2.add(floats2, mask2); + acc3 = acc3.add(floats3, mask3); + } + sum += acc0.reduceLanes(VectorOperators.ADD) + acc1.reduceLanes(VectorOperators.ADD) + acc2.reduceLanes(VectorOperators.ADD) + + acc3.reduceLanes(VectorOperators.ADD); + } + + sectionLength = FLOAT_SPECIES_256.length(); + if (q.length - i >= sectionLength) { + FloatVector acc = FloatVector.zero(FLOAT_SPECIES_256); + int limit = limit(q.length, sectionLength); + for (; i < limit; i += sectionLength) { + var floats = FloatVector.fromArray(FLOAT_SPECIES_256, q, i); + + long maskBits = Integer.reverse(d[i / 8]) >> 24; + var mask = VectorMask.fromLong(FLOAT_SPECIES_256, maskBits); + + acc = acc.add(floats, mask); + } + sum += acc.reduceLanes(VectorOperators.ADD); + } + + // that should have got them all (q.length is a multiple of 8, which fits in a 256-bit vector) + assert i == q.length; + return sum; + } + private static final VectorSpecies PREFERRED_FLOAT_SPECIES = FloatVector.SPECIES_PREFERRED; private static final VectorSpecies BYTE_SPECIES_FOR_PREFFERED_FLOATS; @@ -177,7 +427,7 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport { VectorSpecies byteForFloat; try { // calculate vector size to convert from single bytes to 4-byte floats - byteForFloat = VectorSpecies.of(byte.class, VectorShape.forBitSize(PREFERRED_FLOAT_SPECIES.vectorBitSize() / Integer.BYTES)); + byteForFloat = VectorSpecies.of(byte.class, VectorShape.forBitSize(PREFERRED_FLOAT_SPECIES.vectorBitSize() / Float.BYTES)); } catch (IllegalArgumentException e) { // can't get a byte vector size small enough, just use default impl byteForFloat = null; diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java index 173cb0455a29..3b60df19109b 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java @@ -13,7 +13,6 @@ import org.elasticsearch.simdvec.internal.vectorization.BaseVectorizationTests; import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider; import java.util.Arrays; -import java.util.function.ToDoubleBiFunction; import java.util.function.ToLongBiFunction; import static org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport.B_QUERY; @@ -25,30 +24,44 @@ public class ESVectorUtilTests extends BaseVectorizationTests { static final ESVectorizationProvider defOrPanamaProvider = BaseVectorizationTests.maybePanamaProvider(); public void testIpByteBit() { - byte[] q = new byte[16]; - byte[] d = new byte[] { (byte) Integer.parseInt("01100010", 2), (byte) Integer.parseInt("10100111", 2) }; + byte[] d = new byte[random().nextInt(128)]; + byte[] q = new byte[d.length * 8]; + random().nextBytes(d); random().nextBytes(q); - int expected = q[1] + q[2] + q[6] + q[8] + q[10] + q[13] + q[14] + q[15]; - assertEquals(expected, ESVectorUtil.ipByteBit(q, d)); + + int sum = 0; + for (int i = 0; i < q.length; i++) { + if (((d[i / 8] << (i % 8)) & 0x80) == 0x80) { + sum += q[i]; + } + } + + assertEquals(sum, ESVectorUtil.ipByteBit(q, d)); + assertEquals(sum, defaultedProvider.getVectorUtilSupport().ipByteBit(q, d)); + assertEquals(sum, defOrPanamaProvider.getVectorUtilSupport().ipByteBit(q, d)); } public void testIpFloatBit() { - float[] q = new float[16]; - byte[] d = new byte[] { (byte) Integer.parseInt("01100010", 2), (byte) Integer.parseInt("10100111", 2) }; + byte[] d = new byte[random().nextInt(128)]; + float[] q = new float[d.length * 8]; + random().nextBytes(d); + + float sum = 0; for (int i = 0; i < q.length; i++) { q[i] = random().nextFloat(); + if (((d[i / 8] << (i % 8)) & 0x80) == 0x80) { + sum += q[i]; + } } - float expected = q[1] + q[2] + q[6] + q[8] + q[10] + q[13] + q[14] + q[15]; - assertEquals(expected, ESVectorUtil.ipFloatBit(q, d), 1e-6); + + double delta = 1e-5 * q.length; + + assertEquals(sum, ESVectorUtil.ipFloatBit(q, d), delta); + assertEquals(sum, defaultedProvider.getVectorUtilSupport().ipFloatBit(q, d), delta); + assertEquals(sum, defOrPanamaProvider.getVectorUtilSupport().ipFloatBit(q, d), delta); } public void testIpFloatByte() { - testIpFloatByteImpl(ESVectorUtil::ipFloatByte); - testIpFloatByteImpl(defaultedProvider.getVectorUtilSupport()::ipFloatByte); - testIpFloatByteImpl(defOrPanamaProvider.getVectorUtilSupport()::ipFloatByte); - } - - private void testIpFloatByteImpl(ToDoubleBiFunction impl) { int vectorSize = randomIntBetween(1, 1024); // scale the delta according to the vector size double delta = 1e-5 * vectorSize; @@ -64,7 +77,9 @@ public class ESVectorUtilTests extends BaseVectorizationTests { for (int i = 0; i < q.length; i++) { expected += q[i] * d[i]; } - assertThat(impl.applyAsDouble(q, d), closeTo(expected, delta)); + assertThat((double) ESVectorUtil.ipFloatByte(q, d), closeTo(expected, delta)); + assertThat((double) defaultedProvider.getVectorUtilSupport().ipFloatByte(q, d), closeTo(expected, delta)); + assertThat((double) defOrPanamaProvider.getVectorUtilSupport().ipFloatByte(q, d), closeTo(expected, delta)); } public void testBitAndCount() {