mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 09:28:55 -04:00
Introduce an int4 off-heap vector scorer (#129824)
* Introduce an int4 off-heap vector scorer * iter * Update server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java Co-authored-by: Benjamin Trent <ben.w.trent@gmail.com> --------- Co-authored-by: Benjamin Trent <ben.w.trent@gmail.com>
This commit is contained in:
parent
321a39738a
commit
ffea6ca2bf
11 changed files with 506 additions and 72 deletions
|
@ -0,0 +1,123 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the "Elastic License
|
||||||
|
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||||
|
* Public License v 1"; you may not use this file except in compliance with, at
|
||||||
|
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||||
|
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.benchmark.vector;
|
||||||
|
|
||||||
|
import org.apache.lucene.store.Directory;
|
||||||
|
import org.apache.lucene.store.IOContext;
|
||||||
|
import org.apache.lucene.store.IndexInput;
|
||||||
|
import org.apache.lucene.store.IndexOutput;
|
||||||
|
import org.apache.lucene.store.MMapDirectory;
|
||||||
|
import org.apache.lucene.util.VectorUtil;
|
||||||
|
import org.elasticsearch.common.logging.LogConfigurator;
|
||||||
|
import org.elasticsearch.core.IOUtils;
|
||||||
|
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
|
||||||
|
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
|
||||||
|
import org.openjdk.jmh.annotations.Benchmark;
|
||||||
|
import org.openjdk.jmh.annotations.BenchmarkMode;
|
||||||
|
import org.openjdk.jmh.annotations.Fork;
|
||||||
|
import org.openjdk.jmh.annotations.Measurement;
|
||||||
|
import org.openjdk.jmh.annotations.Mode;
|
||||||
|
import org.openjdk.jmh.annotations.OutputTimeUnit;
|
||||||
|
import org.openjdk.jmh.annotations.Param;
|
||||||
|
import org.openjdk.jmh.annotations.Scope;
|
||||||
|
import org.openjdk.jmh.annotations.Setup;
|
||||||
|
import org.openjdk.jmh.annotations.State;
|
||||||
|
import org.openjdk.jmh.annotations.TearDown;
|
||||||
|
import org.openjdk.jmh.annotations.Warmup;
|
||||||
|
import org.openjdk.jmh.infra.Blackhole;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.file.Files;
|
||||||
|
import java.util.concurrent.ThreadLocalRandom;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
|
@BenchmarkMode(Mode.Throughput)
|
||||||
|
@OutputTimeUnit(TimeUnit.MILLISECONDS)
|
||||||
|
@State(Scope.Benchmark)
|
||||||
|
// first iteration is complete garbage, so make sure we really warmup
|
||||||
|
@Warmup(iterations = 4, time = 1)
|
||||||
|
// real iterations. not useful to spend tons of time here, better to fork more
|
||||||
|
@Measurement(iterations = 5, time = 1)
|
||||||
|
// engage some noise reduction
|
||||||
|
@Fork(value = 1)
|
||||||
|
public class Int4ScorerBenchmark {
|
||||||
|
|
||||||
|
static {
|
||||||
|
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
|
||||||
|
}
|
||||||
|
|
||||||
|
@Param({ "384", "702", "1024" })
|
||||||
|
int dims;
|
||||||
|
|
||||||
|
int numVectors = 200;
|
||||||
|
int numQueries = 10;
|
||||||
|
|
||||||
|
byte[] scratch;
|
||||||
|
byte[][] binaryVectors;
|
||||||
|
byte[][] binaryQueries;
|
||||||
|
|
||||||
|
ES91Int4VectorsScorer scorer;
|
||||||
|
Directory dir;
|
||||||
|
IndexInput in;
|
||||||
|
|
||||||
|
@Setup
|
||||||
|
public void setup() throws IOException {
|
||||||
|
binaryVectors = new byte[numVectors][dims];
|
||||||
|
dir = new MMapDirectory(Files.createTempDirectory("vectorData"));
|
||||||
|
try (IndexOutput out = dir.createOutput("vectors", IOContext.DEFAULT)) {
|
||||||
|
for (byte[] binaryVector : binaryVectors) {
|
||||||
|
for (int i = 0; i < dims; i++) {
|
||||||
|
// 4-bit quantization
|
||||||
|
binaryVector[i] = (byte) ThreadLocalRandom.current().nextInt(16);
|
||||||
|
}
|
||||||
|
out.writeBytes(binaryVector, 0, binaryVector.length);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
in = dir.openInput("vectors", IOContext.DEFAULT);
|
||||||
|
binaryQueries = new byte[numVectors][dims];
|
||||||
|
for (byte[] binaryVector : binaryVectors) {
|
||||||
|
for (int i = 0; i < dims; i++) {
|
||||||
|
// 4-bit quantization
|
||||||
|
binaryVector[i] = (byte) ThreadLocalRandom.current().nextInt(16);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
scratch = new byte[dims];
|
||||||
|
scorer = ESVectorizationProvider.getInstance().newES91Int4VectorsScorer(in, dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
@TearDown
|
||||||
|
public void teardown() throws IOException {
|
||||||
|
IOUtils.close(dir, in);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Benchmark
|
||||||
|
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
|
||||||
|
public void scoreFromArray(Blackhole bh) throws IOException {
|
||||||
|
for (int j = 0; j < numQueries; j++) {
|
||||||
|
in.seek(0);
|
||||||
|
for (int i = 0; i < numVectors; i++) {
|
||||||
|
in.readBytes(scratch, 0, dims);
|
||||||
|
bh.consume(VectorUtil.int4DotProduct(binaryQueries[j], scratch));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Benchmark
|
||||||
|
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
|
||||||
|
public void scoreFromMemorySegmentOnlyVector(Blackhole bh) throws IOException {
|
||||||
|
for (int j = 0; j < numQueries; j++) {
|
||||||
|
in.seek(0);
|
||||||
|
for (int i = 0; i < numVectors; i++) {
|
||||||
|
bh.consume(scorer.int4DotProduct(binaryQueries[j]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,43 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the "Elastic License
|
||||||
|
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||||
|
* Public License v 1"; you may not use this file except in compliance with, at
|
||||||
|
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||||
|
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.simdvec;
|
||||||
|
|
||||||
|
import org.apache.lucene.store.IndexInput;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
/** Scorer for quantized vectors stored as an {@link IndexInput}.
|
||||||
|
* <p>
|
||||||
|
* Similar to {@link org.apache.lucene.util.VectorUtil#int4DotProduct(byte[], byte[])} but
|
||||||
|
* one value is read directly from an {@link IndexInput}.
|
||||||
|
*
|
||||||
|
* */
|
||||||
|
public class ES91Int4VectorsScorer {
|
||||||
|
|
||||||
|
/** The wrapper {@link IndexInput}. */
|
||||||
|
protected final IndexInput in;
|
||||||
|
protected final int dimensions;
|
||||||
|
protected byte[] scratch;
|
||||||
|
|
||||||
|
/** Sole constructor, called by sub-classes. */
|
||||||
|
public ES91Int4VectorsScorer(IndexInput in, int dimensions) {
|
||||||
|
this.in = in;
|
||||||
|
this.dimensions = dimensions;
|
||||||
|
scratch = new byte[dimensions];
|
||||||
|
}
|
||||||
|
|
||||||
|
public long int4DotProduct(byte[] b) throws IOException {
|
||||||
|
in.readBytes(scratch, 0, dimensions);
|
||||||
|
int total = 0;
|
||||||
|
for (int i = 0; i < dimensions; i++) {
|
||||||
|
total += scratch[i] * b[i];
|
||||||
|
}
|
||||||
|
return total;
|
||||||
|
}
|
||||||
|
}
|
|
@ -47,6 +47,10 @@ public class ESVectorUtil {
|
||||||
return ESVectorizationProvider.getInstance().newES91OSQVectorsScorer(input, dimension);
|
return ESVectorizationProvider.getInstance().newES91OSQVectorsScorer(input, dimension);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static ES91Int4VectorsScorer getES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException {
|
||||||
|
return ESVectorizationProvider.getInstance().newES91Int4VectorsScorer(input, dimension);
|
||||||
|
}
|
||||||
|
|
||||||
public static long ipByteBinByte(byte[] q, byte[] d) {
|
public static long ipByteBinByte(byte[] q, byte[] d) {
|
||||||
if (q.length != d.length * B_QUERY) {
|
if (q.length != d.length * B_QUERY) {
|
||||||
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + B_QUERY + " x " + d.length);
|
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + B_QUERY + " x " + d.length);
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
package org.elasticsearch.simdvec.internal.vectorization;
|
package org.elasticsearch.simdvec.internal.vectorization;
|
||||||
|
|
||||||
import org.apache.lucene.store.IndexInput;
|
import org.apache.lucene.store.IndexInput;
|
||||||
|
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
|
||||||
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -30,4 +31,9 @@ final class DefaultESVectorizationProvider extends ESVectorizationProvider {
|
||||||
public ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException {
|
public ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException {
|
||||||
return new ES91OSQVectorsScorer(input, dimension);
|
return new ES91OSQVectorsScorer(input, dimension);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException {
|
||||||
|
return new ES91Int4VectorsScorer(input, dimension);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
package org.elasticsearch.simdvec.internal.vectorization;
|
package org.elasticsearch.simdvec.internal.vectorization;
|
||||||
|
|
||||||
import org.apache.lucene.store.IndexInput;
|
import org.apache.lucene.store.IndexInput;
|
||||||
|
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
|
||||||
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -31,6 +32,9 @@ public abstract class ESVectorizationProvider {
|
||||||
/** Create a new {@link ES91OSQVectorsScorer} for the given {@link IndexInput}. */
|
/** Create a new {@link ES91OSQVectorsScorer} for the given {@link IndexInput}. */
|
||||||
public abstract ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException;
|
public abstract ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException;
|
||||||
|
|
||||||
|
/** Create a new {@link ES91Int4VectorsScorer} for the given {@link IndexInput}. */
|
||||||
|
public abstract ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException;
|
||||||
|
|
||||||
// visible for tests
|
// visible for tests
|
||||||
static ESVectorizationProvider lookup(boolean testMode) {
|
static ESVectorizationProvider lookup(boolean testMode) {
|
||||||
return new DefaultESVectorizationProvider();
|
return new DefaultESVectorizationProvider();
|
||||||
|
|
|
@ -13,6 +13,7 @@ import org.apache.lucene.store.IndexInput;
|
||||||
import org.apache.lucene.util.Constants;
|
import org.apache.lucene.util.Constants;
|
||||||
import org.elasticsearch.logging.LogManager;
|
import org.elasticsearch.logging.LogManager;
|
||||||
import org.elasticsearch.logging.Logger;
|
import org.elasticsearch.logging.Logger;
|
||||||
|
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
|
||||||
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -38,6 +39,9 @@ public abstract class ESVectorizationProvider {
|
||||||
/** Create a new {@link ES91OSQVectorsScorer} for the given {@link IndexInput}. */
|
/** Create a new {@link ES91OSQVectorsScorer} for the given {@link IndexInput}. */
|
||||||
public abstract ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException;
|
public abstract ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException;
|
||||||
|
|
||||||
|
/** Create a new {@link ES91Int4VectorsScorer} for the given {@link IndexInput}. */
|
||||||
|
public abstract ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException;
|
||||||
|
|
||||||
// visible for tests
|
// visible for tests
|
||||||
static ESVectorizationProvider lookup(boolean testMode) {
|
static ESVectorizationProvider lookup(boolean testMode) {
|
||||||
final int runtimeVersion = Runtime.version().feature();
|
final int runtimeVersion = Runtime.version().feature();
|
||||||
|
|
|
@ -0,0 +1,191 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the "Elastic License
|
||||||
|
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||||
|
* Public License v 1"; you may not use this file except in compliance with, at
|
||||||
|
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||||
|
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.simdvec.internal.vectorization;
|
||||||
|
|
||||||
|
import jdk.incubator.vector.ByteVector;
|
||||||
|
import jdk.incubator.vector.IntVector;
|
||||||
|
import jdk.incubator.vector.ShortVector;
|
||||||
|
import jdk.incubator.vector.Vector;
|
||||||
|
import jdk.incubator.vector.VectorSpecies;
|
||||||
|
|
||||||
|
import org.apache.lucene.store.IndexInput;
|
||||||
|
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.lang.foreign.MemorySegment;
|
||||||
|
|
||||||
|
import static java.nio.ByteOrder.LITTLE_ENDIAN;
|
||||||
|
import static jdk.incubator.vector.VectorOperators.ADD;
|
||||||
|
import static jdk.incubator.vector.VectorOperators.B2I;
|
||||||
|
import static jdk.incubator.vector.VectorOperators.B2S;
|
||||||
|
import static jdk.incubator.vector.VectorOperators.S2I;
|
||||||
|
|
||||||
|
/** Panamized scorer for quantized vectors stored as an {@link IndexInput}.
|
||||||
|
* <p>
|
||||||
|
* Similar to {@link org.apache.lucene.util.VectorUtil#int4DotProduct(byte[], byte[])} but
|
||||||
|
* one value is read directly from a {@link MemorySegment}.
|
||||||
|
* */
|
||||||
|
public final class MemorySegmentES91Int4VectorsScorer extends ES91Int4VectorsScorer {
|
||||||
|
|
||||||
|
private static final VectorSpecies<Byte> BYTE_SPECIES_64 = ByteVector.SPECIES_64;
|
||||||
|
private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
|
||||||
|
|
||||||
|
private static final VectorSpecies<Short> SHORT_SPECIES_128 = ShortVector.SPECIES_128;
|
||||||
|
private static final VectorSpecies<Short> SHORT_SPECIES_256 = ShortVector.SPECIES_256;
|
||||||
|
|
||||||
|
private static final VectorSpecies<Integer> INT_SPECIES_128 = IntVector.SPECIES_128;
|
||||||
|
private static final VectorSpecies<Integer> INT_SPECIES_256 = IntVector.SPECIES_256;
|
||||||
|
private static final VectorSpecies<Integer> INT_SPECIES_512 = IntVector.SPECIES_512;
|
||||||
|
|
||||||
|
private final MemorySegment memorySegment;
|
||||||
|
|
||||||
|
public MemorySegmentES91Int4VectorsScorer(IndexInput in, int dimensions, MemorySegment memorySegment) {
|
||||||
|
super(in, dimensions);
|
||||||
|
this.memorySegment = memorySegment;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long int4DotProduct(byte[] q) throws IOException {
|
||||||
|
if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 512 || PanamaESVectorUtilSupport.VECTOR_BITSIZE == 256) {
|
||||||
|
return dotProduct(q);
|
||||||
|
}
|
||||||
|
int i = 0;
|
||||||
|
int res = 0;
|
||||||
|
if (dimensions >= 32 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
|
||||||
|
i += BYTE_SPECIES_128.loopBound(dimensions);
|
||||||
|
res += int4DotProductBody128(q, i);
|
||||||
|
}
|
||||||
|
in.readBytes(scratch, i, dimensions - i);
|
||||||
|
while (i < dimensions) {
|
||||||
|
res += scratch[i] * q[i++];
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
private int int4DotProductBody128(byte[] q, int limit) throws IOException {
|
||||||
|
int sum = 0;
|
||||||
|
long offset = in.getFilePointer();
|
||||||
|
for (int i = 0; i < limit; i += 1024) {
|
||||||
|
ShortVector acc0 = ShortVector.zero(SHORT_SPECIES_128);
|
||||||
|
ShortVector acc1 = ShortVector.zero(SHORT_SPECIES_128);
|
||||||
|
int innerLimit = Math.min(limit - i, 1024);
|
||||||
|
for (int j = 0; j < innerLimit; j += BYTE_SPECIES_128.length()) {
|
||||||
|
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i + j);
|
||||||
|
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i + j, LITTLE_ENDIAN);
|
||||||
|
ByteVector prod8 = va8.mul(vb8);
|
||||||
|
ShortVector prod16 = prod8.convertShape(B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
|
||||||
|
acc0 = acc0.add(prod16.and((short) 255));
|
||||||
|
va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i + j + 8);
|
||||||
|
vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i + j + 8, LITTLE_ENDIAN);
|
||||||
|
prod8 = va8.mul(vb8);
|
||||||
|
prod16 = prod8.convertShape(B2S, SHORT_SPECIES_128, 0).reinterpretAsShorts();
|
||||||
|
acc1 = acc1.add(prod16.and((short) 255));
|
||||||
|
}
|
||||||
|
|
||||||
|
IntVector intAcc0 = acc0.convertShape(S2I, INT_SPECIES_128, 0).reinterpretAsInts();
|
||||||
|
IntVector intAcc1 = acc0.convertShape(S2I, INT_SPECIES_128, 1).reinterpretAsInts();
|
||||||
|
IntVector intAcc2 = acc1.convertShape(S2I, INT_SPECIES_128, 0).reinterpretAsInts();
|
||||||
|
IntVector intAcc3 = acc1.convertShape(S2I, INT_SPECIES_128, 1).reinterpretAsInts();
|
||||||
|
sum += intAcc0.add(intAcc1).add(intAcc2).add(intAcc3).reduceLanes(ADD);
|
||||||
|
}
|
||||||
|
in.seek(offset + limit);
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
private long dotProduct(byte[] q) throws IOException {
|
||||||
|
int i = 0;
|
||||||
|
int res = 0;
|
||||||
|
|
||||||
|
// only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit
|
||||||
|
// vectors (256-bit on intel to dodge performance landmines)
|
||||||
|
if (dimensions >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
|
||||||
|
// compute vectorized dot product consistent with VPDPBUSD instruction
|
||||||
|
if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 512) {
|
||||||
|
i += BYTE_SPECIES_128.loopBound(dimensions);
|
||||||
|
res += dotProductBody512(q, i);
|
||||||
|
} else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 256) {
|
||||||
|
i += BYTE_SPECIES_64.loopBound(dimensions);
|
||||||
|
res += dotProductBody256(q, i);
|
||||||
|
} else {
|
||||||
|
// tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
|
||||||
|
i += BYTE_SPECIES_64.loopBound(dimensions - BYTE_SPECIES_64.length());
|
||||||
|
res += dotProductBody128(q, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// scalar tail
|
||||||
|
for (; i < q.length; i++) {
|
||||||
|
res += in.readByte() * q[i];
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** vectorized dot product body (512 bit vectors) */
|
||||||
|
private int dotProductBody512(byte[] q, int limit) throws IOException {
|
||||||
|
IntVector acc = IntVector.zero(INT_SPECIES_512);
|
||||||
|
long offset = in.getFilePointer();
|
||||||
|
for (int i = 0; i < limit; i += BYTE_SPECIES_128.length()) {
|
||||||
|
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_128, q, i);
|
||||||
|
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_128, memorySegment, offset + i, LITTLE_ENDIAN);
|
||||||
|
|
||||||
|
// 16-bit multiply: avoid AVX-512 heavy multiply on zmm
|
||||||
|
Vector<Short> va16 = va8.convertShape(B2S, SHORT_SPECIES_256, 0);
|
||||||
|
Vector<Short> vb16 = vb8.convertShape(B2S, SHORT_SPECIES_256, 0);
|
||||||
|
Vector<Short> prod16 = va16.mul(vb16);
|
||||||
|
|
||||||
|
// 32-bit add
|
||||||
|
Vector<Integer> prod32 = prod16.convertShape(S2I, INT_SPECIES_512, 0);
|
||||||
|
acc = acc.add(prod32);
|
||||||
|
}
|
||||||
|
|
||||||
|
in.seek(offset + limit); // advance the input stream
|
||||||
|
// reduce
|
||||||
|
return acc.reduceLanes(ADD);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** vectorized dot product body (256 bit vectors) */
|
||||||
|
private int dotProductBody256(byte[] q, int limit) throws IOException {
|
||||||
|
IntVector acc = IntVector.zero(INT_SPECIES_256);
|
||||||
|
long offset = in.getFilePointer();
|
||||||
|
for (int i = 0; i < limit; i += BYTE_SPECIES_64.length()) {
|
||||||
|
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i);
|
||||||
|
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN);
|
||||||
|
|
||||||
|
// 32-bit multiply and add into accumulator
|
||||||
|
Vector<Integer> va32 = va8.convertShape(B2I, INT_SPECIES_256, 0);
|
||||||
|
Vector<Integer> vb32 = vb8.convertShape(B2I, INT_SPECIES_256, 0);
|
||||||
|
acc = acc.add(va32.mul(vb32));
|
||||||
|
}
|
||||||
|
in.seek(offset + limit);
|
||||||
|
// reduce
|
||||||
|
return acc.reduceLanes(ADD);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** vectorized dot product body (128 bit vectors) */
|
||||||
|
private int dotProductBody128(byte[] q, int limit) throws IOException {
|
||||||
|
IntVector acc = IntVector.zero(INT_SPECIES_128);
|
||||||
|
long offset = in.getFilePointer();
|
||||||
|
// 4 bytes at a time (re-loading half the vector each time!)
|
||||||
|
for (int i = 0; i < limit; i += BYTE_SPECIES_64.length() >> 1) {
|
||||||
|
// load 8 bytes
|
||||||
|
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i);
|
||||||
|
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN);
|
||||||
|
|
||||||
|
// process first "half" only: 16-bit multiply
|
||||||
|
Vector<Short> va16 = va8.convert(B2S, 0);
|
||||||
|
Vector<Short> vb16 = vb8.convert(B2S, 0);
|
||||||
|
Vector<Short> prod16 = va16.mul(vb16);
|
||||||
|
|
||||||
|
// 32-bit add
|
||||||
|
acc = acc.add(prod16.convertShape(S2I, INT_SPECIES_128, 0));
|
||||||
|
}
|
||||||
|
in.seek(offset + limit);
|
||||||
|
// reduce
|
||||||
|
return acc.reduceLanes(ADD);
|
||||||
|
}
|
||||||
|
}
|
|
@ -11,6 +11,7 @@ package org.elasticsearch.simdvec.internal.vectorization;
|
||||||
|
|
||||||
import org.apache.lucene.store.IndexInput;
|
import org.apache.lucene.store.IndexInput;
|
||||||
import org.apache.lucene.store.MemorySegmentAccessInput;
|
import org.apache.lucene.store.MemorySegmentAccessInput;
|
||||||
|
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
|
||||||
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -39,4 +40,15 @@ final class PanamaESVectorizationProvider extends ESVectorizationProvider {
|
||||||
}
|
}
|
||||||
return new ES91OSQVectorsScorer(input, dimension);
|
return new ES91OSQVectorsScorer(input, dimension);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ES91Int4VectorsScorer newES91Int4VectorsScorer(IndexInput input, int dimension) throws IOException {
|
||||||
|
if (PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS && input instanceof MemorySegmentAccessInput msai) {
|
||||||
|
MemorySegment ms = msai.segmentSliceOrNull(0, input.length());
|
||||||
|
if (ms != null) {
|
||||||
|
return new MemorySegmentES91Int4VectorsScorer(input, dimension, ms);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return new ES91Int4VectorsScorer(input, dimension);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,60 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the "Elastic License
|
||||||
|
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||||
|
* Public License v 1"; you may not use this file except in compliance with, at
|
||||||
|
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||||
|
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.simdvec.internal.vectorization;
|
||||||
|
|
||||||
|
import org.apache.lucene.store.Directory;
|
||||||
|
import org.apache.lucene.store.IOContext;
|
||||||
|
import org.apache.lucene.store.IndexInput;
|
||||||
|
import org.apache.lucene.store.IndexOutput;
|
||||||
|
import org.apache.lucene.store.MMapDirectory;
|
||||||
|
import org.apache.lucene.util.VectorUtil;
|
||||||
|
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
|
||||||
|
|
||||||
|
public class ES91Int4VectorScorerTests extends BaseVectorizationTests {
|
||||||
|
|
||||||
|
public void testInt4DotProduct() throws Exception {
|
||||||
|
// only even dimensions are supported
|
||||||
|
final int dimensions = random().nextInt(1, 1000) * 2;
|
||||||
|
final int numVectors = random().nextInt(1, 100);
|
||||||
|
final byte[] vector = new byte[dimensions];
|
||||||
|
try (Directory dir = new MMapDirectory(createTempDir())) {
|
||||||
|
try (IndexOutput out = dir.createOutput("tests.bin", IOContext.DEFAULT)) {
|
||||||
|
for (int i = 0; i < numVectors; i++) {
|
||||||
|
for (int j = 0; j < dimensions; j++) {
|
||||||
|
vector[j] = (byte) random().nextInt(16); // 4-bit quantization
|
||||||
|
}
|
||||||
|
out.writeBytes(vector, 0, dimensions);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
final byte[] query = new byte[dimensions];
|
||||||
|
for (int j = 0; j < dimensions; j++) {
|
||||||
|
query[j] = (byte) random().nextInt(16); // 4-bit quantization
|
||||||
|
}
|
||||||
|
try (IndexInput in = dir.openInput("tests.bin", IOContext.DEFAULT)) {
|
||||||
|
// Work on a slice that has just the right number of bytes to make the test fail with an
|
||||||
|
// index-out-of-bounds in case the implementation reads more than the allowed number of
|
||||||
|
// padding bytes.
|
||||||
|
final IndexInput slice = in.slice("test", 0, (long) dimensions * numVectors);
|
||||||
|
final IndexInput slice2 = in.slice("test2", 0, (long) dimensions * numVectors);
|
||||||
|
final ES91Int4VectorsScorer defaultScorer = defaultProvider().newES91Int4VectorsScorer(slice, dimensions);
|
||||||
|
final ES91Int4VectorsScorer panamaScorer = maybePanamaProvider().newES91Int4VectorsScorer(slice2, dimensions);
|
||||||
|
for (int i = 0; i < numVectors; i++) {
|
||||||
|
in.readBytes(vector, 0, dimensions);
|
||||||
|
long val = VectorUtil.int4DotProduct(vector, query);
|
||||||
|
assertEquals(val, defaultScorer.int4DotProduct(query));
|
||||||
|
assertEquals(val, panamaScorer.int4DotProduct(query));
|
||||||
|
assertEquals(in.getFilePointer(), slice.getFilePointer());
|
||||||
|
assertEquals(in.getFilePointer(), slice2.getFilePointer());
|
||||||
|
}
|
||||||
|
assertEquals((long) dimensions * numVectors, in.getFilePointer());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -19,6 +19,7 @@ import org.apache.lucene.util.ArrayUtil;
|
||||||
import org.apache.lucene.util.VectorUtil;
|
import org.apache.lucene.util.VectorUtil;
|
||||||
import org.apache.lucene.util.hnsw.NeighborQueue;
|
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||||
import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats;
|
import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats;
|
||||||
|
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
|
||||||
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
|
||||||
import org.elasticsearch.simdvec.ESVectorUtil;
|
import org.elasticsearch.simdvec.ESVectorUtil;
|
||||||
|
|
||||||
|
@ -48,25 +49,23 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap
|
||||||
@Override
|
@Override
|
||||||
CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery)
|
CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(fieldInfo.number);
|
final FieldEntry fieldEntry = fields.get(fieldInfo.number);
|
||||||
float[] globalCentroid = fieldEntry.globalCentroid();
|
final float globalCentroidDp = fieldEntry.globalCentroidDp();
|
||||||
float globalCentroidDp = fieldEntry.globalCentroidDp();
|
final OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
|
||||||
OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
|
final byte[] quantized = new byte[targetQuery.length];
|
||||||
byte[] quantized = new byte[targetQuery.length];
|
final OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize(
|
||||||
float[] targetScratch = ArrayUtil.copyArray(targetQuery);
|
ArrayUtil.copyArray(targetQuery),
|
||||||
OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize(
|
|
||||||
targetScratch,
|
|
||||||
quantized,
|
quantized,
|
||||||
(byte) 4,
|
(byte) 4,
|
||||||
globalCentroid
|
fieldEntry.globalCentroid()
|
||||||
);
|
);
|
||||||
|
final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension());
|
||||||
return new CentroidQueryScorer() {
|
return new CentroidQueryScorer() {
|
||||||
int currentCentroid = -1;
|
int currentCentroid = -1;
|
||||||
private final byte[] quantizedCentroid = new byte[fieldInfo.getVectorDimension()];
|
|
||||||
private final float[] centroid = new float[fieldInfo.getVectorDimension()];
|
private final float[] centroid = new float[fieldInfo.getVectorDimension()];
|
||||||
private final float[] centroidCorrectiveValues = new float[3];
|
private final float[] centroidCorrectiveValues = new float[3];
|
||||||
private int quantizedCentroidComponentSum;
|
private final long rawCentroidsOffset = (long) numCentroids * (fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES);
|
||||||
private final long centroidByteSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES;
|
private final long rawCentroidsByteSize = (long) Float.BYTES * fieldInfo.getVectorDimension();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int size() {
|
public int size() {
|
||||||
|
@ -75,68 +74,47 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public float[] centroid(int centroidOrdinal) throws IOException {
|
public float[] centroid(int centroidOrdinal) throws IOException {
|
||||||
readQuantizedAndRawCentroid(centroidOrdinal);
|
if (centroidOrdinal != currentCentroid) {
|
||||||
return centroid;
|
centroids.seek(rawCentroidsOffset + rawCentroidsByteSize * centroidOrdinal);
|
||||||
}
|
|
||||||
|
|
||||||
private void readQuantizedAndRawCentroid(int centroidOrdinal) throws IOException {
|
|
||||||
if (centroidOrdinal == currentCentroid) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
centroids.seek(centroidOrdinal * centroidByteSize);
|
|
||||||
quantizedCentroidComponentSum = readQuantizedValue(centroids, quantizedCentroid, centroidCorrectiveValues);
|
|
||||||
centroids.seek(numCentroids * centroidByteSize + (long) Float.BYTES * quantizedCentroid.length * centroidOrdinal);
|
|
||||||
centroids.readFloats(centroid, 0, centroid.length);
|
centroids.readFloats(centroid, 0, centroid.length);
|
||||||
currentCentroid = centroidOrdinal;
|
currentCentroid = centroidOrdinal;
|
||||||
}
|
}
|
||||||
|
return centroid;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
public void bulkScore(NeighborQueue queue) throws IOException {
|
||||||
public float score(int centroidOrdinal) throws IOException {
|
// TODO: bulk score centroids like we do with posting lists
|
||||||
readQuantizedAndRawCentroid(centroidOrdinal);
|
centroids.seek(0L);
|
||||||
|
for (int i = 0; i < numCentroids; i++) {
|
||||||
|
queue.add(i, score());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private float score() throws IOException {
|
||||||
|
final float qcDist = scorer.int4DotProduct(quantized);
|
||||||
|
centroids.readFloats(centroidCorrectiveValues, 0, 3);
|
||||||
|
final int quantizedCentroidComponentSum = Short.toUnsignedInt(centroids.readShort());
|
||||||
return int4QuantizedScore(
|
return int4QuantizedScore(
|
||||||
quantized,
|
qcDist,
|
||||||
queryParams,
|
queryParams,
|
||||||
fieldInfo.getVectorDimension(),
|
fieldInfo.getVectorDimension(),
|
||||||
quantizedCentroid,
|
|
||||||
centroidCorrectiveValues,
|
centroidCorrectiveValues,
|
||||||
quantizedCentroidComponentSum,
|
quantizedCentroidComponentSum,
|
||||||
globalCentroidDp,
|
globalCentroidDp,
|
||||||
fieldInfo.getVectorSimilarityFunction()
|
fieldInfo.getVectorSimilarityFunction()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
NeighborQueue scorePostingLists(FieldInfo fieldInfo, KnnCollector knnCollector, CentroidQueryScorer centroidQueryScorer, int nProbe)
|
|
||||||
throws IOException {
|
|
||||||
NeighborQueue neighborQueue = new NeighborQueue(centroidQueryScorer.size(), true);
|
|
||||||
// TODO Off heap scoring for quantized centroids?
|
|
||||||
for (int centroid = 0; centroid < centroidQueryScorer.size(); centroid++) {
|
|
||||||
neighborQueue.add(centroid, centroidQueryScorer.score(centroid));
|
|
||||||
}
|
|
||||||
return neighborQueue;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput indexInput, float[] target, IntPredicate needsScoring)
|
|
||||||
throws IOException {
|
|
||||||
FieldEntry entry = fields.get(fieldInfo.number);
|
|
||||||
return new MemorySegmentPostingsVisitor(target, indexInput.clone(), entry, fieldInfo, needsScoring);
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO can we do this in off-heap blocks?
|
// TODO can we do this in off-heap blocks?
|
||||||
static float int4QuantizedScore(
|
private float int4QuantizedScore(
|
||||||
byte[] quantizedQuery,
|
float qcDist,
|
||||||
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
|
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
|
||||||
int dims,
|
int dims,
|
||||||
byte[] binaryCode,
|
|
||||||
float[] targetCorrections,
|
float[] targetCorrections,
|
||||||
int targetComponentSum,
|
int targetComponentSum,
|
||||||
float centroidDp,
|
float centroidDp,
|
||||||
VectorSimilarityFunction similarityFunction
|
VectorSimilarityFunction similarityFunction
|
||||||
) {
|
) {
|
||||||
float qcDist = VectorUtil.int4DotProduct(quantizedQuery, binaryCode);
|
|
||||||
float ax = targetCorrections[0];
|
float ax = targetCorrections[0];
|
||||||
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
|
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
|
||||||
float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE;
|
float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE;
|
||||||
|
@ -157,6 +135,23 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap
|
||||||
return Math.max((1f + score) / 2f, 0);
|
return Math.max((1f + score) / 2f, 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
NeighborQueue scorePostingLists(FieldInfo fieldInfo, KnnCollector knnCollector, CentroidQueryScorer centroidQueryScorer, int nProbe)
|
||||||
|
throws IOException {
|
||||||
|
NeighborQueue neighborQueue = new NeighborQueue(centroidQueryScorer.size(), true);
|
||||||
|
centroidQueryScorer.bulkScore(neighborQueue);
|
||||||
|
return neighborQueue;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput indexInput, float[] target, IntPredicate needsScoring)
|
||||||
|
throws IOException {
|
||||||
|
FieldEntry entry = fields.get(fieldInfo.number);
|
||||||
|
return new MemorySegmentPostingsVisitor(target, indexInput.clone(), entry, fieldInfo, needsScoring);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<String, Long> getOffHeapByteSize(FieldInfo fieldInfo) {
|
public Map<String, Long> getOffHeapByteSize(FieldInfo fieldInfo) {
|
||||||
|
@ -356,12 +351,4 @@ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeap
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static int readQuantizedValue(IndexInput indexInput, byte[] binaryValue, float[] corrections) throws IOException {
|
|
||||||
assert corrections.length == 3;
|
|
||||||
indexInput.readBytes(binaryValue, 0, binaryValue.length);
|
|
||||||
corrections[0] = Float.intBitsToFloat(indexInput.readInt());
|
|
||||||
corrections[1] = Float.intBitsToFloat(indexInput.readInt());
|
|
||||||
corrections[2] = Float.intBitsToFloat(indexInput.readInt());
|
|
||||||
return Short.toUnsignedInt(indexInput.readShort());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -332,7 +332,7 @@ public abstract class IVFVectorsReader extends KnnVectorsReader {
|
||||||
|
|
||||||
float[] centroid(int centroidOrdinal) throws IOException;
|
float[] centroid(int centroidOrdinal) throws IOException;
|
||||||
|
|
||||||
float score(int centroidOrdinal) throws IOException;
|
void bulkScore(NeighborQueue queue) throws IOException;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface PostingVisitor {
|
interface PostingVisitor {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue