From ffea6ca2bf3325f9dbfe5b48ecd57c9325860ca9 Mon Sep 17 00:00:00 2001 From: Ignacio Vera Date: Mon, 23 Jun 2025 18:44:12 +0200 Subject: [PATCH] Introduce an int4 off-heap vector scorer (#129824) * Introduce an int4 off-heap vector scorer * iter * Update server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java Co-authored-by: Benjamin Trent --------- Co-authored-by: Benjamin Trent --- .../benchmark/vector/Int4ScorerBenchmark.java | 123 +++++++++++ .../simdvec/ES91Int4VectorsScorer.java | 43 ++++ .../elasticsearch/simdvec/ESVectorUtil.java | 4 + .../DefaultESVectorizationProvider.java | 6 + .../ESVectorizationProvider.java | 4 + .../ESVectorizationProvider.java | 4 + .../MemorySegmentES91Int4VectorsScorer.java | 191 ++++++++++++++++++ .../PanamaESVectorizationProvider.java | 12 ++ .../ES91Int4VectorScorerTests.java | 60 ++++++ .../vectors/DefaultIVFVectorsReader.java | 129 ++++++------ .../index/codec/vectors/IVFVectorsReader.java | 2 +- 11 files changed, 506 insertions(+), 72 deletions(-) create mode 100644 benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Int4ScorerBenchmark.java create mode 100644 libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91Int4VectorsScorer.java create mode 100644 libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91Int4VectorsScorer.java create mode 100644 libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91Int4VectorScorerTests.java diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Int4ScorerBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Int4ScorerBenchmark.java new file mode 100644 index 000000000000..e104aa85cccb --- /dev/null +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Int4ScorerBenchmark.java @@ -0,0 +1,123 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.benchmark.vector; + +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.store.MMapDirectory; +import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.core.IOUtils; +import org.elasticsearch.simdvec.ES91Int4VectorsScorer; +import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +import java.io.IOException; +import java.nio.file.Files; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; + +@BenchmarkMode(Mode.Throughput) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Benchmark) +// first iteration is complete garbage, so make sure we really warmup +@Warmup(iterations = 4, time = 1) +// real iterations. not useful to spend tons of time here, better to fork more +@Measurement(iterations = 5, time = 1) +// engage some noise reduction +@Fork(value = 1) +public class Int4ScorerBenchmark { + + static { + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + @Param({ "384", "702", "1024" }) + int dims; + + int numVectors = 200; + int numQueries = 10; + + byte[] scratch; + byte[][] binaryVectors; + byte[][] binaryQueries; + + ES91Int4VectorsScorer scorer; + Directory dir; + IndexInput in; + + @Setup + public void setup() throws IOException { + binaryVectors = new byte[numVectors][dims]; + dir = new MMapDirectory(Files.createTempDirectory("vectorData")); + try (IndexOutput out = dir.createOutput("vectors", IOContext.DEFAULT)) { + for (byte[] binaryVector : binaryVectors) { + for (int i = 0; i < dims; i++) { + // 4-bit quantization + binaryVector[i] = (byte) ThreadLocalRandom.current().nextInt(16); + } + out.writeBytes(binaryVector, 0, binaryVector.length); + } + } + + in = dir.openInput("vectors", IOContext.DEFAULT); + binaryQueries = new byte[numVectors][dims]; + for (byte[] binaryVector : binaryVectors) { + for (int i = 0; i < dims; i++) { + // 4-bit quantization + binaryVector[i] = (byte) ThreadLocalRandom.current().nextInt(16); + } + } + + scratch = new byte[dims]; + scorer = ESVectorizationProvider.getInstance().newES91Int4VectorsScorer(in, dims); + } + + @TearDown + public void teardown() throws IOException { + IOUtils.close(dir, in); + } + + @Benchmark + @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public void scoreFromArray(Blackhole bh) throws IOException { + for (int j = 0; j < numQueries; j++) { + in.seek(0); + for (int i = 0; i < numVectors; i++) { + in.readBytes(scratch, 0, dims); + bh.consume(VectorUtil.int4DotProduct(binaryQueries[j], scratch)); + } + } + } + + @Benchmark + @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public void scoreFromMemorySegmentOnlyVector(Blackhole bh) throws IOException { + for (int j = 0; j < numQueries; j++) { + in.seek(0); + for (int i = 0; i < numVectors; i++) { + bh.consume(scorer.int4DotProduct(binaryQueries[j])); + } + } + } +} diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91Int4VectorsScorer.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91Int4VectorsScorer.java new file mode 100644 index 000000000000..803bdd523a6b --- /dev/null +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91Int4VectorsScorer.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.simdvec; + +import org.apache.lucene.store.IndexInput; + +import java.io.IOException; + +/** Scorer for quantized vectors stored as an {@link IndexInput}. + *

+ * Similar to {@link org.apache.lucene.util.VectorUtil#int4DotProduct(byte[], byte[])} but + * one value is read directly from an {@link IndexInput}. + * + * */ +public class ES91Int4VectorsScorer { + + /** The wrapper {@link IndexInput}. */ + protected final IndexInput in; + protected final int dimensions; + protected byte[] scratch; + + /** Sole constructor, called by sub-classes. */ + public ES91Int4VectorsScorer(IndexInput in, int dimensions) { + this.in = in; + this.dimensions = dimensions; + scratch = new byte[dimensions]; + } + + public long int4DotProduct(byte[] b) throws IOException { + in.readBytes(scratch, 0, dimensions); + int total = 0; + for (int i = 0; i < dimensions; i++) { + total += scratch[i] * b[i]; + } + return total; + } +} diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java index 5778c26e16e5..6671ed5084a8 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java @@ -47,6 +47,10 @@ public class ESVectorUtil { return ESVectorizationProvider.getInstance().newES91OSQVectorsScorer(input, dimension); } + public static ES91Int4VectorsScorer getES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException { + return ESVectorizationProvider.getInstance().newES91Int4VectorsScorer(input, dimension); + } + public static long ipByteBinByte(byte[] q, byte[] d) { if (q.length != d.length * B_QUERY) { throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + B_QUERY + " x " + d.length); diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java index 51a78d3cd6c3..5bdd7a724ced 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java @@ -10,6 +10,7 @@ package org.elasticsearch.simdvec.internal.vectorization; import org.apache.lucene.store.IndexInput; +import org.elasticsearch.simdvec.ES91Int4VectorsScorer; import org.elasticsearch.simdvec.ES91OSQVectorsScorer; import java.io.IOException; @@ -30,4 +31,9 @@ final class DefaultESVectorizationProvider extends ESVectorizationProvider { public ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException { return new ES91OSQVectorsScorer(input, dimension); } + + @Override + public ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException { + return new ES91Int4VectorsScorer(input, dimension); + } } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java index 8c040484c7c0..719284f48471 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java @@ -10,6 +10,7 @@ package org.elasticsearch.simdvec.internal.vectorization; import org.apache.lucene.store.IndexInput; +import org.elasticsearch.simdvec.ES91Int4VectorsScorer; import org.elasticsearch.simdvec.ES91OSQVectorsScorer; import java.io.IOException; @@ -31,6 +32,9 @@ public abstract class ESVectorizationProvider { /** Create a new {@link ES91OSQVectorsScorer} for the given {@link IndexInput}. */ public abstract ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException; + /** Create a new {@link ES91Int4VectorsScorer} for the given {@link IndexInput}. */ + public abstract ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException; + // visible for tests static ESVectorizationProvider lookup(boolean testMode) { return new DefaultESVectorizationProvider(); diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java index ea4180b59565..4708a052b05d 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java @@ -13,6 +13,7 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.Constants; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; +import org.elasticsearch.simdvec.ES91Int4VectorsScorer; import org.elasticsearch.simdvec.ES91OSQVectorsScorer; import java.io.IOException; @@ -38,6 +39,9 @@ public abstract class ESVectorizationProvider { /** Create a new {@link ES91OSQVectorsScorer} for the given {@link IndexInput}. */ public abstract ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException; + /** Create a new {@link ES91Int4VectorsScorer} for the given {@link IndexInput}. */ + public abstract ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException; + // visible for tests static ESVectorizationProvider lookup(boolean testMode) { final int runtimeVersion = Runtime.version().feature(); diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91Int4VectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91Int4VectorsScorer.java new file mode 100644 index 000000000000..9a314fc4c18e --- /dev/null +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91Int4VectorsScorer.java @@ -0,0 +1,191 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.simdvec.internal.vectorization; + +import jdk.incubator.vector.ByteVector; +import jdk.incubator.vector.IntVector; +import jdk.incubator.vector.ShortVector; +import jdk.incubator.vector.Vector; +import jdk.incubator.vector.VectorSpecies; + +import org.apache.lucene.store.IndexInput; +import org.elasticsearch.simdvec.ES91Int4VectorsScorer; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; + +import static java.nio.ByteOrder.LITTLE_ENDIAN; +import static jdk.incubator.vector.VectorOperators.ADD; +import static jdk.incubator.vector.VectorOperators.B2I; +import static jdk.incubator.vector.VectorOperators.B2S; +import static jdk.incubator.vector.VectorOperators.S2I; + +/** Panamized scorer for quantized vectors stored as an {@link IndexInput}. + *

+ * Similar to {@link org.apache.lucene.util.VectorUtil#int4DotProduct(byte[], byte[])} but + * one value is read directly from a {@link MemorySegment}. + * */ +public final class MemorySegmentES91Int4VectorsScorer extends ES91Int4VectorsScorer { + + private static final VectorSpecies BYTE_SPECIES_64 = ByteVector.SPECIES_64; + private static final VectorSpecies BYTE_SPECIES_128 = ByteVector.SPECIES_128; + + private static final VectorSpecies SHORT_SPECIES_128 = ShortVector.SPECIES_128; + private static final VectorSpecies SHORT_SPECIES_256 = ShortVector.SPECIES_256; + + private static final VectorSpecies INT_SPECIES_128 = IntVector.SPECIES_128; + private static final VectorSpecies INT_SPECIES_256 = IntVector.SPECIES_256; + private static final VectorSpecies INT_SPECIES_512 = IntVector.SPECIES_512; + + private final MemorySegment memorySegment; + + public MemorySegmentES91Int4VectorsScorer(IndexInput in, int dimensions, MemorySegment memorySegment) { + super(in, dimensions); + this.memorySegment = memorySegment; + } + + @Override + public long int4DotProduct(byte[] q) throws IOException { + if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 512 || PanamaESVectorUtilSupport.VECTOR_BITSIZE == 256) { + return dotProduct(q); + } + int i = 0; + int res = 0; + if (dimensions >= 32 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) { + i += BYTE_SPECIES_128.loopBound(dimensions); + res += int4DotProductBody128(q, i); + } + in.readBytes(scratch, i, dimensions - i); + while (i < dimensions) { + res += scratch[i] * q[i++]; + } + return res; + } + + private int int4DotProductBody128(byte[] q, int limit) throws IOException { + int sum = 0; + long offset = in.getFilePointer(); + for (int i = 0; i < limit; i += 1024) { + ShortVector acc0 = ShortVector.zero(SHORT_SPECIES_128); + ShortVector acc1 = ShortVector.zero(SHORT_SPECIES_128); + int innerLimit = Math.min(limit - i, 1024); + for (int j = 0; j < innerLimit; j += BYTE_SPECIES_128.length()) { + ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i + j); + ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i + j, LITTLE_ENDIAN); + ByteVector prod8 = va8.mul(vb8); + ShortVector prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts(); + acc0 = acc0.add(prod16.and((short) 255)); + va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i + j + 8); + vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i + j + 8, LITTLE_ENDIAN); + prod8 = va8.mul(vb8); + prod16 = prod8.convertShape(B2S, SHORT_SPECIES_128, 0).reinterpretAsShorts(); + acc1 = acc1.add(prod16.and((short) 255)); + } + + IntVector intAcc0 = acc0.convertShape(S2I, INT_SPECIES_128, 0).reinterpretAsInts(); + IntVector intAcc1 = acc0.convertShape(S2I, INT_SPECIES_128, 1).reinterpretAsInts(); + IntVector intAcc2 = acc1.convertShape(S2I, INT_SPECIES_128, 0).reinterpretAsInts(); + IntVector intAcc3 = acc1.convertShape(S2I, INT_SPECIES_128, 1).reinterpretAsInts(); + sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD); + } + in.seek(offset + limit); + return sum; + } + + private long dotProduct(byte[] q) throws IOException { + int i = 0; + int res = 0; + + // only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit + // vectors (256-bit on intel to dodge performance landmines) + if (dimensions >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) { + // compute vectorized dot product consistent with VPDPBUSD instruction + if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 512) { + i += BYTE_SPECIES_128.loopBound(dimensions); + res += dotProductBody512(q, i); + } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 256) { + i += BYTE_SPECIES_64.loopBound(dimensions); + res += dotProductBody256(q, i); + } else { + // tricky: we don't have SPECIES_32, so we workaround with "overlapping read" + i += BYTE_SPECIES_64.loopBound(dimensions - BYTE_SPECIES_64.length()); + res += dotProductBody128(q, i); + } + } + // scalar tail + for (; i < q.length; i++) { + res += in.readByte() * q[i]; + } + return res; + } + + /** vectorized dot product body (512 bit vectors) */ + private int dotProductBody512(byte[] q, int limit) throws IOException { + IntVector acc = IntVector.zero(INT_SPECIES_512); + long offset = in.getFilePointer(); + for (int i = 0; i < limit; i += BYTE_SPECIES_128.length()) { + ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_128, q, i); + ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_128, memorySegment, offset + i, LITTLE_ENDIAN); + + // 16-bit multiply: avoid AVX-512 heavy multiply on zmm + Vector va16 = va8.convertShape(B2S, SHORT_SPECIES_256, 0); + Vector vb16 = vb8.convertShape(B2S, SHORT_SPECIES_256, 0); + Vector prod16 = va16.mul(vb16); + + // 32-bit add + Vector prod32 = prod16.convertShape(S2I, INT_SPECIES_512, 0); + acc = acc.add(prod32); + } + + in.seek(offset + limit); // advance the input stream + // reduce + return acc.reduceLanes(ADD); + } + + /** vectorized dot product body (256 bit vectors) */ + private int dotProductBody256(byte[] q, int limit) throws IOException { + IntVector acc = IntVector.zero(INT_SPECIES_256); + long offset = in.getFilePointer(); + for (int i = 0; i < limit; i += BYTE_SPECIES_64.length()) { + ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i); + ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN); + + // 32-bit multiply and add into accumulator + Vector va32 = va8.convertShape(B2I, INT_SPECIES_256, 0); + Vector vb32 = vb8.convertShape(B2I, INT_SPECIES_256, 0); + acc = acc.add(va32.mul(vb32)); + } + in.seek(offset + limit); + // reduce + return acc.reduceLanes(ADD); + } + + /** vectorized dot product body (128 bit vectors) */ + private int dotProductBody128(byte[] q, int limit) throws IOException { + IntVector acc = IntVector.zero(INT_SPECIES_128); + long offset = in.getFilePointer(); + // 4 bytes at a time (re-loading half the vector each time!) + for (int i = 0; i < limit; i += BYTE_SPECIES_64.length() >> 1) { + // load 8 bytes + ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i); + ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN); + + // process first "half" only: 16-bit multiply + Vector va16 = va8.convert(B2S, 0); + Vector vb16 = vb8.convert(B2S, 0); + Vector prod16 = va16.mul(vb16); + + // 32-bit add + acc = acc.add(prod16.convertShape(S2I, INT_SPECIES_128, 0)); + } + in.seek(offset + limit); + // reduce + return acc.reduceLanes(ADD); + } +} diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java index 5ff8c19c90a5..abb75352da2f 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java @@ -11,6 +11,7 @@ package org.elasticsearch.simdvec.internal.vectorization; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.MemorySegmentAccessInput; +import org.elasticsearch.simdvec.ES91Int4VectorsScorer; import org.elasticsearch.simdvec.ES91OSQVectorsScorer; import java.io.IOException; @@ -39,4 +40,15 @@ final class PanamaESVectorizationProvider extends ESVectorizationProvider { } return new ES91OSQVectorsScorer(input, dimension); } + + @Override + public ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException { + if (PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS && input instanceof MemorySegmentAccessInput msai) { + MemorySegment ms = msai.segmentSliceOrNull(0, input.length()); + if (ms != null) { + return new MemorySegmentES91Int4VectorsScorer(input, dimension, ms); + } + } + return new ES91Int4VectorsScorer(input, dimension); + } } diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91Int4VectorScorerTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91Int4VectorScorerTests.java new file mode 100644 index 000000000000..c19211585a76 --- /dev/null +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91Int4VectorScorerTests.java @@ -0,0 +1,60 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec.internal.vectorization; + +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.store.MMapDirectory; +import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.simdvec.ES91Int4VectorsScorer; + +public class ES91Int4VectorScorerTests extends BaseVectorizationTests { + + public void testInt4DotProduct() throws Exception { + // only even dimensions are supported + final int dimensions = random().nextInt(1, 1000) * 2; + final int numVectors = random().nextInt(1, 100); + final byte[] vector = new byte[dimensions]; + try (Directory dir = new MMapDirectory(createTempDir())) { + try (IndexOutput out = dir.createOutput("tests.bin", IOContext.DEFAULT)) { + for (int i = 0; i < numVectors; i++) { + for (int j = 0; j < dimensions; j++) { + vector[j] = (byte) random().nextInt(16); // 4-bit quantization + } + out.writeBytes(vector, 0, dimensions); + } + } + final byte[] query = new byte[dimensions]; + for (int j = 0; j < dimensions; j++) { + query[j] = (byte) random().nextInt(16); // 4-bit quantization + } + try (IndexInput in = dir.openInput("tests.bin", IOContext.DEFAULT)) { + // Work on a slice that has just the right number of bytes to make the test fail with an + // index-out-of-bounds in case the implementation reads more than the allowed number of + // padding bytes. + final IndexInput slice = in.slice("test", 0, (long) dimensions * numVectors); + final IndexInput slice2 = in.slice("test2", 0, (long) dimensions * numVectors); + final ES91Int4VectorsScorer defaultScorer = defaultProvider().newES91Int4VectorsScorer(slice, dimensions); + final ES91Int4VectorsScorer panamaScorer = maybePanamaProvider().newES91Int4VectorsScorer(slice2, dimensions); + for (int i = 0; i < numVectors; i++) { + in.readBytes(vector, 0, dimensions); + long val = VectorUtil.int4DotProduct(vector, query); + assertEquals(val, defaultScorer.int4DotProduct(query)); + assertEquals(val, panamaScorer.int4DotProduct(query)); + assertEquals(in.getFilePointer(), slice.getFilePointer()); + assertEquals(in.getFilePointer(), slice2.getFilePointer()); + } + assertEquals((long) dimensions * numVectors, in.getFilePointer()); + } + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java index 5c86f602a654..e7b41d005d7e 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java @@ -19,6 +19,7 @@ import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.NeighborQueue; import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats; +import org.elasticsearch.simdvec.ES91Int4VectorsScorer; import org.elasticsearch.simdvec.ES91OSQVectorsScorer; import org.elasticsearch.simdvec.ESVectorUtil; @@ -48,25 +49,23 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap @Override CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery) throws IOException { - FieldEntry fieldEntry = fields.get(fieldInfo.number); - float[] globalCentroid = fieldEntry.globalCentroid(); - float globalCentroidDp = fieldEntry.globalCentroidDp(); - OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); - byte[] quantized = new byte[targetQuery.length]; - float[] targetScratch = ArrayUtil.copyArray(targetQuery); - OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize( - targetScratch, + final FieldEntry fieldEntry = fields.get(fieldInfo.number); + final float globalCentroidDp = fieldEntry.globalCentroidDp(); + final OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + final byte[] quantized = new byte[targetQuery.length]; + final OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize( + ArrayUtil.copyArray(targetQuery), quantized, (byte) 4, - globalCentroid + fieldEntry.globalCentroid() ); + final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension()); return new CentroidQueryScorer() { int currentCentroid = -1; - private final byte[] quantizedCentroid = new byte[fieldInfo.getVectorDimension()]; private final float[] centroid = new float[fieldInfo.getVectorDimension()]; private final float[] centroidCorrectiveValues = new float[3]; - private int quantizedCentroidComponentSum; - private final long centroidByteSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES; + private final long rawCentroidsOffset = (long) numCentroids * (fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES); + private final long rawCentroidsByteSize = (long) Float.BYTES * fieldInfo.getVectorDimension(); @Override public int size() { @@ -75,35 +74,67 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap @Override public float[] centroid(int centroidOrdinal) throws IOException { - readQuantizedAndRawCentroid(centroidOrdinal); + if (centroidOrdinal != currentCentroid) { + centroids.seek(rawCentroidsOffset + rawCentroidsByteSize * centroidOrdinal); + centroids.readFloats(centroid, 0, centroid.length); + currentCentroid = centroidOrdinal; + } return centroid; } - private void readQuantizedAndRawCentroid(int centroidOrdinal) throws IOException { - if (centroidOrdinal == currentCentroid) { - return; + public void bulkScore(NeighborQueue queue) throws IOException { + // TODO: bulk score centroids like we do with posting lists + centroids.seek(0L); + for (int i = 0; i < numCentroids; i++) { + queue.add(i, score()); } - centroids.seek(centroidOrdinal * centroidByteSize); - quantizedCentroidComponentSum = readQuantizedValue(centroids, quantizedCentroid, centroidCorrectiveValues); - centroids.seek(numCentroids * centroidByteSize + (long) Float.BYTES * quantizedCentroid.length * centroidOrdinal); - centroids.readFloats(centroid, 0, centroid.length); - currentCentroid = centroidOrdinal; } - @Override - public float score(int centroidOrdinal) throws IOException { - readQuantizedAndRawCentroid(centroidOrdinal); + private float score() throws IOException { + final float qcDist = scorer.int4DotProduct(quantized); + centroids.readFloats(centroidCorrectiveValues, 0, 3); + final int quantizedCentroidComponentSum = Short.toUnsignedInt(centroids.readShort()); return int4QuantizedScore( - quantized, + qcDist, queryParams, fieldInfo.getVectorDimension(), - quantizedCentroid, centroidCorrectiveValues, quantizedCentroidComponentSum, globalCentroidDp, fieldInfo.getVectorSimilarityFunction() ); } + + // TODO can we do this in off-heap blocks? + private float int4QuantizedScore( + float qcDist, + OptimizedScalarQuantizer.QuantizationResult queryCorrections, + int dims, + float[] targetCorrections, + int targetComponentSum, + float centroidDp, + VectorSimilarityFunction similarityFunction + ) { + float ax = targetCorrections[0]; + // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary + float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE; + float ay = queryCorrections.lowerInterval(); + float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE; + float y1 = queryCorrections.quantizedComponentSum(); + float score = ax * ay * dims + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist; + if (similarityFunction == EUCLIDEAN) { + score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2 * score; + return Math.max(1 / (1f + score), 0); + } else { + // For cosine and max inner product, we need to apply the additional correction, which is + // assumed to be the non-centered dot-product between the vector and the centroid + score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp; + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + return VectorUtil.scaleMaxInnerProductScore(score); + } + return Math.max((1f + score) / 2f, 0); + } + } }; } @@ -111,10 +142,7 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap NeighborQueue scorePostingLists(FieldInfo fieldInfo, KnnCollector knnCollector, CentroidQueryScorer centroidQueryScorer, int nProbe) throws IOException { NeighborQueue neighborQueue = new NeighborQueue(centroidQueryScorer.size(), true); - // TODO Off heap scoring for quantized centroids? - for (int centroid = 0; centroid < centroidQueryScorer.size(); centroid++) { - neighborQueue.add(centroid, centroidQueryScorer.score(centroid)); - } + centroidQueryScorer.bulkScore(neighborQueue); return neighborQueue; } @@ -125,39 +153,6 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap return new MemorySegmentPostingsVisitor(target, indexInput.clone(), entry, fieldInfo, needsScoring); } - // TODO can we do this in off-heap blocks? - static float int4QuantizedScore( - byte[] quantizedQuery, - OptimizedScalarQuantizer.QuantizationResult queryCorrections, - int dims, - byte[] binaryCode, - float[] targetCorrections, - int targetComponentSum, - float centroidDp, - VectorSimilarityFunction similarityFunction - ) { - float qcDist = VectorUtil.int4DotProduct(quantizedQuery, binaryCode); - float ax = targetCorrections[0]; - // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary - float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE; - float ay = queryCorrections.lowerInterval(); - float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE; - float y1 = queryCorrections.quantizedComponentSum(); - float score = ax * ay * dims + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist; - if (similarityFunction == EUCLIDEAN) { - score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2 * score; - return Math.max(1 / (1f + score), 0); - } else { - // For cosine and max inner product, we need to apply the additional correction, which is - // assumed to be the non-centered dot-product between the vector and the centroid - score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp; - if (similarityFunction == MAXIMUM_INNER_PRODUCT) { - return VectorUtil.scaleMaxInnerProductScore(score); - } - return Math.max((1f + score) / 2f, 0); - } - } - @Override public Map getOffHeapByteSize(FieldInfo fieldInfo) { return Map.of(); @@ -356,12 +351,4 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap } } - static int readQuantizedValue(IndexInput indexInput, byte[] binaryValue, float[] corrections) throws IOException { - assert corrections.length == 3; - indexInput.readBytes(binaryValue, 0, binaryValue.length); - corrections[0] = Float.intBitsToFloat(indexInput.readInt()); - corrections[1] = Float.intBitsToFloat(indexInput.readInt()); - corrections[2] = Float.intBitsToFloat(indexInput.readInt()); - return Short.toUnsignedInt(indexInput.readShort()); - } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java index 453780466478..dbcdfd451df9 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java @@ -332,7 +332,7 @@ public abstract class IVFVectorsReader extends KnnVectorsReader { float[] centroid(int centroidOrdinal) throws IOException; - float score(int centroidOrdinal) throws IOException; + void bulkScore(NeighborQueue queue) throws IOException; } interface PostingVisitor {