From e68f31754c0bbe2c69b933eaf950e5f9462dedea Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Wed, 20 Nov 2024 16:41:55 -0500 Subject: [PATCH] Adds `maxSim` functions for multi_dense_vector fields (#116993) This adds `maxSim` functions, specifically dotProduct and InvHamming. Why these two you might ask? Well, they are the best approximations of whats possible with Col* late interaction type models. Effectively, you want a similarity metric where "greater == better". Regular `hamming` isn't exactly that, but inverting that (just like our `element_type: bit` index for dense_vectors), is a nice approximation with bit vectors and multi-vector scoring. Then, of course, dotProduct is another usage. We will allow dot-product between like elements (bytes -> bytes, floats -> floats) and of course, allow `floats -> bit`, where the stored `bit` elements are applied as a "mask" over the float queries. This allows for some nice asymmetric interactions. This is all behind a feature flag, and I need to write a mountain of docs in a separate PR. --- .../org.elasticsearch.script.score.txt | 2 + .../141_multi_dense_vector_max_sim.yml | 206 ++++++++++ .../action/search/SearchCapabilities.java | 3 + .../script/MultiVectorScoreScriptUtils.java | 372 ++++++++++++++++++ .../field/vectors/BitMultiDenseVector.java | 70 +++- .../field/vectors/ByteMultiDenseVector.java | 54 ++- .../ByteMultiDenseVectorDocValuesField.java | 14 +- .../field/vectors/FloatMultiDenseVector.java | 38 +- .../FloatMultiDenseVectorDocValuesField.java | 15 +- .../field/vectors/MultiDenseVector.java | 21 + .../script/field/vectors/VectorIterator.java | 70 ++++ .../MultiVectorScoreScriptUtilsTests.java | 342 ++++++++++++++++ .../field/vectors/MultiDenseVectorTests.java | 83 ++++ 13 files changed, 1274 insertions(+), 16 deletions(-) create mode 100644 modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/141_multi_dense_vector_max_sim.yml create mode 100644 server/src/main/java/org/elasticsearch/script/MultiVectorScoreScriptUtils.java create mode 100644 server/src/main/java/org/elasticsearch/script/field/vectors/VectorIterator.java create mode 100644 server/src/test/java/org/elasticsearch/script/MultiVectorScoreScriptUtilsTests.java create mode 100644 server/src/test/java/org/elasticsearch/script/field/vectors/MultiDenseVectorTests.java diff --git a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.score.txt b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.score.txt index e76db7cfb1d2..5a1d8c002aa1 100644 --- a/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.score.txt +++ b/modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.score.txt @@ -50,5 +50,7 @@ static_import { double cosineSimilarity(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$CosineSimilarity double dotProduct(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$DotProduct double hamming(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$Hamming + double maxSimDotProduct(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.MultiVectorScoreScriptUtils$MaxSimDotProduct + double maxSimInvHamming(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.MultiVectorScoreScriptUtils$MaxSimInvHamming } diff --git a/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/141_multi_dense_vector_max_sim.yml b/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/141_multi_dense_vector_max_sim.yml new file mode 100644 index 000000000000..caa7c59ab4c4 --- /dev/null +++ b/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/141_multi_dense_vector_max_sim.yml @@ -0,0 +1,206 @@ +setup: + - requires: + capabilities: + - method: POST + path: /_search + capabilities: [ multi_dense_vector_script_max_sim ] + test_runner_features: capabilities + reason: "Support for multi dense vector max-sim functions capability required" + - skip: + features: headers + + - do: + indices.create: + index: test-index + body: + settings: + number_of_shards: 1 + mappings: + properties: + vector: + type: multi_dense_vector + dims: 5 + byte_vector: + type: multi_dense_vector + dims: 5 + element_type: byte + bit_vector: + type: multi_dense_vector + dims: 40 + element_type: bit + - do: + index: + index: test-index + id: "1" + body: + vector: [[230.0, 300.33, -34.8988, 15.555, -200.0], [-0.5, 100.0, -13, 14.8, -156.0]] + byte_vector: [[8, 5, -15, 1, -7], [-1, 115, -3, 4, -128]] + bit_vector: [[8, 5, -15, 1, -7], [-1, 115, -3, 4, -128]] + + - do: + index: + index: test-index + id: "3" + body: + vector: [[0.5, 111.3, -13.0, 14.8, -156.0]] + byte_vector: [[2, 18, -5, 0, -124]] + bit_vector: [[2, 18, -5, 0, -124]] + + - do: + indices.refresh: {} +--- +"Test max-sim dot product scoring": + - skip: + features: close_to + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "maxSimDotProduct(params.query_vector, 'vector')" + params: + query_vector: [[1, 2, 1, 1, 1]] + + - match: {hits.total: 2} + + - match: {hits.hits.0._id: "1"} + - close_to: {hits.hits.0._score: {value: 611.316, error: 0.01}} + + - match: {hits.hits.1._id: "3"} + - close_to: {hits.hits.1._score: {value: 68.90001, error: 0.01}} + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "maxSimDotProduct(params.query_vector, 'byte_vector')" + params: + query_vector: [[1, 2, 1, 1, 0]] + + - match: {hits.total: 2} + + - match: {hits.hits.0._id: "1"} + - close_to: {hits.hits.0._score: {value: 230, error: 0.01}} + + - match: {hits.hits.1._id: "3"} + - close_to: {hits.hits.1._score: {value: 33, error: 0.01}} + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "maxSimDotProduct(params.query_vector, 'bit_vector')" + params: + query_vector: [[1, 2, 1, 1, 0]] + + - match: {hits.total: 2} + + - match: {hits.hits.0._id: "1"} + - close_to: {hits.hits.0._score: {value: 3, error: 0.01}} + + - match: {hits.hits.1._id: "3"} + - close_to: {hits.hits.1._score: {value: 2, error: 0.01}} + +# doing max-sim dot product with a vector where the stored bit vectors are used as masks + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "maxSimDotProduct(params.query_vector, 'bit_vector')" + params: + query_vector: [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]] + - match: {hits.total: 2} + + - match: {hits.hits.0._id: "1"} + - close_to: {hits.hits.0._score: {value: 190, error: 0.01}} + + - match: {hits.hits.1._id: "3"} + - close_to: {hits.hits.1._score: {value: 125, error: 0.01}} +--- +"Test max-sim inv hamming scoring": + - skip: + features: close_to + + # inv hamming doesn't apply to float vectors + - do: + catch: bad_request + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "maxSimInvHamming(params.query_vector, 'vector')" + params: + query_vector: [[1, 2, 1, 1, 1]] + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "maxSimInvHamming(params.query_vector, 'byte_vector')" + params: + query_vector: [[1, 2, 1, 1, 1]] + + - match: {hits.total: 2} + + - match: {hits.hits.0._id: "3"} + - close_to: {hits.hits.0._score: {value: 0.675, error: 0.01}} + + - match: {hits.hits.1._id: "1"} + - close_to: {hits.hits.1._score: {value: 0.65, error: 0.01}} + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "maxSimInvHamming(params.query_vector, 'bit_vector')" + params: + query_vector: [[1, 2, 1, 1, 1]] + + - match: {hits.total: 2} + + - match: {hits.hits.0._id: "3"} + - close_to: {hits.hits.0._score: {value: 0.675, error: 0.01}} + + - match: {hits.hits.1._id: "1"} + - close_to: {hits.hits.1._score: {value: 0.65, error: 0.01}} diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java index 241f30b36778..e5c4826bfce9 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java +++ b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java @@ -40,6 +40,8 @@ public final class SearchCapabilities { private static final String NESTED_RETRIEVER_INNER_HITS_SUPPORT = "nested_retriever_inner_hits_support"; /** Support multi-dense-vector script field access. */ private static final String MULTI_DENSE_VECTOR_SCRIPT_ACCESS = "multi_dense_vector_script_access"; + /** Initial support for multi-dense-vector maxSim functions access. */ + private static final String MULTI_DENSE_VECTOR_SCRIPT_MAX_SIM = "multi_dense_vector_script_max_sim"; private static final String RANDOM_SAMPLER_WITH_SCORED_SUBAGGS = "random_sampler_with_scored_subaggs"; @@ -56,6 +58,7 @@ public final class SearchCapabilities { if (MultiDenseVectorFieldMapper.FEATURE_FLAG.isEnabled()) { capabilities.add(MULTI_DENSE_VECTOR_FIELD_MAPPER); capabilities.add(MULTI_DENSE_VECTOR_SCRIPT_ACCESS); + capabilities.add(MULTI_DENSE_VECTOR_SCRIPT_MAX_SIM); } if (Build.current().isSnapshot()) { capabilities.add(KQL_QUERY_SUPPORTED); diff --git a/server/src/main/java/org/elasticsearch/script/MultiVectorScoreScriptUtils.java b/server/src/main/java/org/elasticsearch/script/MultiVectorScoreScriptUtils.java new file mode 100644 index 000000000000..136c5e7b57d4 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/script/MultiVectorScoreScriptUtils.java @@ -0,0 +1,372 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.script; + +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.script.field.vectors.DenseVector; +import org.elasticsearch.script.field.vectors.MultiDenseVectorDocValuesField; + +import java.io.IOException; +import java.util.HexFormat; +import java.util.List; + +public class MultiVectorScoreScriptUtils { + + public static class MultiDenseVectorFunction { + protected final ScoreScript scoreScript; + protected final MultiDenseVectorDocValuesField field; + + public MultiDenseVectorFunction(ScoreScript scoreScript, MultiDenseVectorDocValuesField field) { + this.scoreScript = scoreScript; + this.field = field; + } + + void setNextVector() { + try { + field.setNextDocId(scoreScript._getDocId()); + } catch (IOException e) { + throw ExceptionsHelper.convertToElastic(e); + } + if (field.isEmpty()) { + throw new IllegalArgumentException("A document doesn't have a value for a multi-vector field!"); + } + } + } + + public static class ByteMultiDenseVectorFunction extends MultiDenseVectorFunction { + protected final byte[][] queryVector; + + /** + * Constructs a dense vector function used for byte-sized vectors. + * + * @param scoreScript The script in which this function was referenced. + * @param field The vector field. + * @param queryVector The query vector. + */ + public ByteMultiDenseVectorFunction(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, List> queryVector) { + super(scoreScript, field); + if (queryVector.isEmpty()) { + throw new IllegalArgumentException("The query vector is empty."); + } + field.getElementType().checkDimensions(field.get().getDims(), queryVector.get(0).size()); + this.queryVector = new byte[queryVector.size()][queryVector.get(0).size()]; + float[] validateValues = new float[queryVector.size()]; + int lastSize = -1; + for (int i = 0; i < queryVector.size(); i++) { + if (lastSize != -1 && lastSize != queryVector.get(i).size()) { + throw new IllegalArgumentException( + "The query vector contains inner vectors which have inconsistent number of dimensions." + ); + } + lastSize = queryVector.get(i).size(); + for (int j = 0; j < queryVector.get(i).size(); j++) { + final Number number = queryVector.get(i).get(j); + byte value = number.byteValue(); + this.queryVector[i][j] = value; + validateValues[i] = number.floatValue(); + } + field.getElementType().checkVectorBounds(validateValues); + } + } + + /** + * Constructs a dense vector function used for byte-sized vectors. + * + * @param scoreScript The script in which this function was referenced. + * @param field The vector field. + * @param queryVector The query vector. + */ + public ByteMultiDenseVectorFunction(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, byte[][] queryVector) { + super(scoreScript, field); + this.queryVector = queryVector; + } + } + + public static class FloatMultiDenseVectorFunction extends MultiDenseVectorFunction { + protected final float[][] queryVector; + + /** + * Constructs a dense vector function used for float vectors. + * + * @param scoreScript The script in which this function was referenced. + * @param field The vector field. + * @param queryVector The query vector. + */ + public FloatMultiDenseVectorFunction( + ScoreScript scoreScript, + MultiDenseVectorDocValuesField field, + List> queryVector + ) { + super(scoreScript, field); + if (queryVector.isEmpty()) { + throw new IllegalArgumentException("The query vector is empty."); + } + DenseVector.checkDimensions(field.get().getDims(), queryVector.get(0).size()); + + this.queryVector = new float[queryVector.size()][queryVector.get(0).size()]; + int lastSize = -1; + for (int i = 0; i < queryVector.size(); i++) { + if (lastSize != -1 && lastSize != queryVector.get(i).size()) { + throw new IllegalArgumentException( + "The query vector contains inner vectors which have inconsistent number of dimensions." + ); + } + lastSize = queryVector.get(i).size(); + for (int j = 0; j < queryVector.get(i).size(); j++) { + this.queryVector[i][j] = queryVector.get(i).get(j).floatValue(); + } + field.getElementType().checkVectorBounds(this.queryVector[i]); + } + } + } + + // Calculate Hamming distances between a query's dense vector and documents' dense vectors + public interface MaxSimInvHammingDistanceInterface { + float maxSimInvHamming(); + } + + public static class ByteMaxSimInvHammingDistance extends ByteMultiDenseVectorFunction implements MaxSimInvHammingDistanceInterface { + + public ByteMaxSimInvHammingDistance(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, List> queryVector) { + super(scoreScript, field, queryVector); + } + + public ByteMaxSimInvHammingDistance(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, byte[][] queryVector) { + super(scoreScript, field, queryVector); + } + + public float maxSimInvHamming() { + setNextVector(); + return field.get().maxSimInvHamming(queryVector); + } + } + + private record BytesOrList(byte[][] bytes, List> list) {} + + @SuppressWarnings("unchecked") + private static BytesOrList parseBytes(Object queryVector) { + if (queryVector instanceof List) { + // check if its a list of strings or list of lists + if (((List) queryVector).get(0) instanceof List) { + return new BytesOrList(null, ((List>) queryVector)); + } else if (((List) queryVector).get(0) instanceof String) { + byte[][] parsedQueryVector = new byte[((List) queryVector).size()][]; + int lastSize = -1; + for (int i = 0; i < ((List) queryVector).size(); i++) { + parsedQueryVector[i] = HexFormat.of().parseHex((String) ((List) queryVector).get(i)); + if (lastSize != -1 && lastSize != parsedQueryVector[i].length) { + throw new IllegalArgumentException( + "The query vector contains inner vectors which have inconsistent number of dimensions." + ); + } + lastSize = parsedQueryVector[i].length; + } + return new BytesOrList(parsedQueryVector, null); + } else { + throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName()); + } + } else { + throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName()); + } + } + + public static final class MaxSimInvHamming { + + private final MaxSimInvHammingDistanceInterface function; + + public MaxSimInvHamming(ScoreScript scoreScript, Object queryVector, String fieldName) { + MultiDenseVectorDocValuesField field = (MultiDenseVectorDocValuesField) scoreScript.field(fieldName); + if (field.getElementType() == DenseVectorFieldMapper.ElementType.FLOAT) { + throw new IllegalArgumentException("hamming distance is only supported for byte or bit vectors"); + } + BytesOrList bytesOrList = parseBytes(queryVector); + if (bytesOrList.bytes != null) { + this.function = new ByteMaxSimInvHammingDistance(scoreScript, field, bytesOrList.bytes); + } else { + this.function = new ByteMaxSimInvHammingDistance(scoreScript, field, bytesOrList.list); + } + } + + public double maxSimInvHamming() { + return function.maxSimInvHamming(); + } + } + + // Calculate a dot product between a query's dense vector and documents' dense vectors + public interface MaxSimDotProductInterface { + double maxSimDotProduct(); + } + + public static class MaxSimBitDotProduct extends MultiDenseVectorFunction implements MaxSimDotProductInterface { + private final byte[][] byteQueryVector; + private final float[][] floatQueryVector; + + public MaxSimBitDotProduct(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, byte[][] queryVector) { + super(scoreScript, field); + if (field.getElementType() != DenseVectorFieldMapper.ElementType.BIT) { + throw new IllegalArgumentException("Cannot calculate bit dot product for non-bit vectors"); + } + int fieldDims = field.get().getDims(); + if (fieldDims != queryVector.length * Byte.SIZE && fieldDims != queryVector.length) { + throw new IllegalArgumentException( + "The query vector has an incorrect number of dimensions. Must be [" + + fieldDims / 8 + + "] for bitwise operations, or [" + + fieldDims + + "] for byte wise operations: provided [" + + queryVector.length + + "]." + ); + } + this.byteQueryVector = queryVector; + this.floatQueryVector = null; + } + + public MaxSimBitDotProduct(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, List> queryVector) { + super(scoreScript, field); + if (queryVector.isEmpty()) { + throw new IllegalArgumentException("The query vector is empty."); + } + if (field.getElementType() != DenseVectorFieldMapper.ElementType.BIT) { + throw new IllegalArgumentException("cannot calculate bit dot product for non-bit vectors"); + } + float[][] floatQueryVector = new float[queryVector.size()][]; + byte[][] byteQueryVector = new byte[queryVector.size()][]; + boolean isFloat = false; + int lastSize = -1; + for (int i = 0; i < queryVector.size(); i++) { + if (lastSize != -1 && lastSize != queryVector.get(i).size()) { + throw new IllegalArgumentException( + "The query vector contains inner vectors which have inconsistent number of dimensions." + ); + } + lastSize = queryVector.get(i).size(); + floatQueryVector[i] = new float[queryVector.get(i).size()]; + if (isFloat == false) { + byteQueryVector[i] = new byte[queryVector.get(i).size()]; + } + for (int j = 0; j < queryVector.get(i).size(); j++) { + Number number = queryVector.get(i).get(j); + floatQueryVector[i][j] = number.floatValue(); + if (isFloat == false) { + byteQueryVector[i][j] = number.byteValue(); + } + if (isFloat + || floatQueryVector[i][j] % 1.0f != 0.0f + || floatQueryVector[i][j] < Byte.MIN_VALUE + || floatQueryVector[i][j] > Byte.MAX_VALUE) { + isFloat = true; + } + } + } + int fieldDims = field.get().getDims(); + if (isFloat) { + this.floatQueryVector = floatQueryVector; + this.byteQueryVector = null; + if (fieldDims != floatQueryVector[0].length) { + throw new IllegalArgumentException( + "The query vector contains inner vectors which have incorrect number of dimensions. Must be [" + + fieldDims + + "] for float wise operations: provided [" + + floatQueryVector[0].length + + "]." + ); + } + } else { + this.floatQueryVector = null; + this.byteQueryVector = byteQueryVector; + if (fieldDims != byteQueryVector[0].length * Byte.SIZE && fieldDims != byteQueryVector[0].length) { + throw new IllegalArgumentException( + "The query vector contains inner vectors which have incorrect number of dimensions. Must be [" + + fieldDims / 8 + + "] for bitwise operations, or [" + + fieldDims + + "] for byte wise operations: provided [" + + byteQueryVector[0].length + + "]." + ); + } + } + } + + @Override + public double maxSimDotProduct() { + setNextVector(); + return byteQueryVector != null ? field.get().maxSimDotProduct(byteQueryVector) : field.get().maxSimDotProduct(floatQueryVector); + } + } + + public static class MaxSimByteDotProduct extends ByteMultiDenseVectorFunction implements MaxSimDotProductInterface { + + public MaxSimByteDotProduct(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, List> queryVector) { + super(scoreScript, field, queryVector); + } + + public MaxSimByteDotProduct(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, byte[][] queryVector) { + super(scoreScript, field, queryVector); + } + + public double maxSimDotProduct() { + setNextVector(); + return field.get().maxSimDotProduct(queryVector); + } + } + + public static class MaxSimFloatDotProduct extends FloatMultiDenseVectorFunction implements MaxSimDotProductInterface { + + public MaxSimFloatDotProduct(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, List> queryVector) { + super(scoreScript, field, queryVector); + } + + public double maxSimDotProduct() { + setNextVector(); + return field.get().maxSimDotProduct(queryVector); + } + } + + public static final class MaxSimDotProduct { + + private final MaxSimDotProductInterface function; + + @SuppressWarnings("unchecked") + public MaxSimDotProduct(ScoreScript scoreScript, Object queryVector, String fieldName) { + MultiDenseVectorDocValuesField field = (MultiDenseVectorDocValuesField) scoreScript.field(fieldName); + function = switch (field.getElementType()) { + case BIT -> { + BytesOrList bytesOrList = parseBytes(queryVector); + if (bytesOrList.bytes != null) { + yield new MaxSimBitDotProduct(scoreScript, field, bytesOrList.bytes); + } else { + yield new MaxSimBitDotProduct(scoreScript, field, bytesOrList.list); + } + } + case BYTE -> { + BytesOrList bytesOrList = parseBytes(queryVector); + if (bytesOrList.bytes != null) { + yield new MaxSimByteDotProduct(scoreScript, field, bytesOrList.bytes); + } else { + yield new MaxSimByteDotProduct(scoreScript, field, bytesOrList.list); + } + } + case FLOAT -> { + if (queryVector instanceof List) { + yield new MaxSimFloatDotProduct(scoreScript, field, (List>) queryVector); + } + throw new IllegalArgumentException("Unsupported input object for float vectors: " + queryVector.getClass().getName()); + } + }; + } + + public double maxSimDotProduct() { + return function.maxSimDotProduct(); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/BitMultiDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/BitMultiDenseVector.java index 24e19a803ff3..7805816090d5 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/BitMultiDenseVector.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/BitMultiDenseVector.java @@ -10,11 +10,13 @@ package org.elasticsearch.script.field.vectors; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.simdvec.ESVectorUtil; -import java.util.Iterator; +import java.util.Arrays; public class BitMultiDenseVector extends ByteMultiDenseVector { - public BitMultiDenseVector(Iterator vectorValues, BytesRef magnitudesBytes, int numVecs, int dims) { + public BitMultiDenseVector(VectorIterator vectorValues, BytesRef magnitudesBytes, int numVecs, int dims) { super(vectorValues, magnitudesBytes, numVecs, dims); } @@ -31,6 +33,70 @@ public class BitMultiDenseVector extends ByteMultiDenseVector { } } + @Override + public float maxSimDotProduct(float[][] query) { + vectorValues.reset(); + float[] maxes = new float[query.length]; + Arrays.fill(maxes, Float.NEGATIVE_INFINITY); + while (vectorValues.hasNext()) { + byte[] vv = vectorValues.next(); + for (int i = 0; i < query.length; i++) { + maxes[i] = Math.max(maxes[i], ESVectorUtil.ipFloatBit(query[i], vv)); + } + } + float sums = 0; + for (float m : maxes) { + sums += m; + } + return sums; + } + + @Override + public float maxSimDotProduct(byte[][] query) { + vectorValues.reset(); + float[] maxes = new float[query.length]; + Arrays.fill(maxes, Float.NEGATIVE_INFINITY); + if (query[0].length == dims) { + while (vectorValues.hasNext()) { + byte[] vv = vectorValues.next(); + for (int i = 0; i < query.length; i++) { + maxes[i] = Math.max(maxes[i], ESVectorUtil.andBitCount(query[i], vv)); + } + } + } else { + while (vectorValues.hasNext()) { + byte[] vv = vectorValues.next(); + for (int i = 0; i < query.length; i++) { + maxes[i] = Math.max(maxes[i], ESVectorUtil.ipByteBit(query[i], vv)); + } + } + } + float sum = 0; + for (float m : maxes) { + sum += m; + } + return sum; + } + + @Override + public float maxSimInvHamming(byte[][] query) { + vectorValues.reset(); + int bitCount = this.getDims(); + float[] maxes = new float[query.length]; + Arrays.fill(maxes, Float.NEGATIVE_INFINITY); + while (vectorValues.hasNext()) { + byte[] vv = vectorValues.next(); + for (int i = 0; i < query.length; i++) { + maxes[i] = Math.max(maxes[i], ((bitCount - VectorUtil.xorBitCount(vv, query[i])) / (float) bitCount)); + } + } + float sum = 0; + for (float m : maxes) { + sum += m; + } + return sum; + } + @Override public int getDims() { return dims * Byte.SIZE; diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVector.java index e610d10146b2..5e9d3e05746c 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVector.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVector.java @@ -10,21 +10,22 @@ package org.elasticsearch.script.field.vectors; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.VectorUtil; import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder; +import java.util.Arrays; import java.util.Iterator; public class ByteMultiDenseVector implements MultiDenseVector { - protected final Iterator vectorValues; + protected final VectorIterator vectorValues; protected final int numVecs; protected final int dims; - private Iterator floatDocVectors; private float[] magnitudes; private final BytesRef magnitudesBytes; - public ByteMultiDenseVector(Iterator vectorValues, BytesRef magnitudesBytes, int numVecs, int dims) { + public ByteMultiDenseVector(VectorIterator vectorValues, BytesRef magnitudesBytes, int numVecs, int dims) { assert magnitudesBytes.length == numVecs * Float.BYTES; this.vectorValues = vectorValues; this.numVecs = numVecs; @@ -33,11 +34,50 @@ public class ByteMultiDenseVector implements MultiDenseVector { } @Override - public Iterator getVectors() { - if (floatDocVectors == null) { - floatDocVectors = new ByteToFloatIteratorWrapper(vectorValues, dims); + public float maxSimDotProduct(float[][] query) { + throw new UnsupportedOperationException("use [float maxSimDotProduct(byte[][] queryVector)] instead"); + } + + @Override + public float maxSimDotProduct(byte[][] query) { + vectorValues.reset(); + float[] maxes = new float[query.length]; + Arrays.fill(maxes, Float.NEGATIVE_INFINITY); + while (vectorValues.hasNext()) { + byte[] vv = vectorValues.next(); + for (int i = 0; i < query.length; i++) { + maxes[i] = Math.max(maxes[i], VectorUtil.dotProduct(query[i], vv)); + } } - return floatDocVectors; + float sum = 0; + for (float m : maxes) { + sum += m; + } + return sum; + } + + @Override + public float maxSimInvHamming(byte[][] query) { + vectorValues.reset(); + int bitCount = dims * Byte.SIZE; + float[] maxes = new float[query.length]; + Arrays.fill(maxes, Float.NEGATIVE_INFINITY); + while (vectorValues.hasNext()) { + byte[] vv = vectorValues.next(); + for (int i = 0; i < query.length; i++) { + maxes[i] = Math.max(maxes[i], ((bitCount - VectorUtil.xorBitCount(vv, query[i])) / (float) bitCount)); + } + } + float sum = 0; + for (float m : maxes) { + sum += m; + } + return sum; + } + + @Override + public Iterator getVectors() { + return new ByteToFloatIteratorWrapper(vectorValues.copy(), dims); } @Override diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVectorDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVectorDocValuesField.java index d1e062e0a3de..d45c5b85137f 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVectorDocValuesField.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteMultiDenseVectorDocValuesField.java @@ -23,7 +23,7 @@ public class ByteMultiDenseVectorDocValuesField extends MultiDenseVectorDocValue private final BinaryDocValues magnitudes; protected final int dims; protected int numVecs; - protected Iterator vectorValue; + protected VectorIterator vectorValue; protected boolean decoded; protected BytesRef value; protected BytesRef magnitudesValue; @@ -111,7 +111,7 @@ public class ByteMultiDenseVectorDocValuesField extends MultiDenseVectorDocValue return value == null; } - static class ByteVectorIterator implements Iterator { + static class ByteVectorIterator implements VectorIterator { private final byte[] buffer; private final BytesRef vectorValues; private final int size; @@ -138,5 +138,15 @@ public class ByteMultiDenseVectorDocValuesField extends MultiDenseVectorDocValue idx++; return buffer; } + + @Override + public Iterator copy() { + return new ByteVectorIterator(vectorValues, new byte[buffer.length], size); + } + + @Override + public void reset() { + idx = 0; + } } } diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVector.java index 9ffe8b3b970c..9c2f7eb6a86d 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVector.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVector.java @@ -10,7 +10,9 @@ package org.elasticsearch.script.field.vectors; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.VectorUtil; +import java.util.Arrays; import java.util.Iterator; import static org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder.getMultiMagnitudes; @@ -21,19 +23,47 @@ public class FloatMultiDenseVector implements MultiDenseVector { private float[] magnitudesArray = null; private final int dims; private final int numVectors; - private final Iterator decodedDocVector; + private final VectorIterator vectorValues; - public FloatMultiDenseVector(Iterator decodedDocVector, BytesRef magnitudes, int numVectors, int dims) { + public FloatMultiDenseVector(VectorIterator decodedDocVector, BytesRef magnitudes, int numVectors, int dims) { assert magnitudes.length == numVectors * Float.BYTES; - this.decodedDocVector = decodedDocVector; + this.vectorValues = decodedDocVector; this.magnitudes = magnitudes; this.numVectors = numVectors; this.dims = dims; } + @Override + public float maxSimDotProduct(float[][] query) { + vectorValues.reset(); + float[] maxes = new float[query.length]; + Arrays.fill(maxes, Float.NEGATIVE_INFINITY); + while (vectorValues.hasNext()) { + float[] vv = vectorValues.next(); + for (int i = 0; i < query.length; i++) { + maxes[i] = Math.max(maxes[i], VectorUtil.dotProduct(query[i], vv)); + } + } + float sum = 0; + for (float m : maxes) { + sum += m; + } + return sum; + } + + @Override + public float maxSimDotProduct(byte[][] query) { + throw new UnsupportedOperationException("use [float maxSimDotProduct(float[][] queryVector)] instead"); + } + + @Override + public float maxSimInvHamming(byte[][] query) { + throw new UnsupportedOperationException("hamming distance is not supported for float vectors"); + } + @Override public Iterator getVectors() { - return decodedDocVector; + return vectorValues.copy(); } @Override diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVectorDocValuesField.java b/server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVectorDocValuesField.java index 356db58d989c..c7ac7842afd9 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVectorDocValuesField.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/FloatMultiDenseVectorDocValuesField.java @@ -110,14 +110,16 @@ public class FloatMultiDenseVectorDocValuesField extends MultiDenseVectorDocValu } } - static class FloatVectorIterator implements Iterator { + static class FloatVectorIterator implements VectorIterator { private final float[] buffer; private final FloatBuffer vectorValues; + private final BytesRef vectorValueBytesRef; private final int size; private int idx = 0; FloatVectorIterator(BytesRef vectorValues, float[] buffer, int size) { assert vectorValues.length == (buffer.length * Float.BYTES * size); + this.vectorValueBytesRef = vectorValues; this.vectorValues = ByteBuffer.wrap(vectorValues.bytes, vectorValues.offset, vectorValues.length) .order(ByteOrder.LITTLE_ENDIAN) .asFloatBuffer(); @@ -139,5 +141,16 @@ public class FloatMultiDenseVectorDocValuesField extends MultiDenseVectorDocValu idx++; return buffer; } + + @Override + public Iterator copy() { + return new FloatVectorIterator(vectorValueBytesRef, new float[buffer.length], size); + } + + @Override + public void reset() { + idx = 0; + vectorValues.rewind(); + } } } diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/MultiDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/MultiDenseVector.java index 85c851dbe545..7d948cf5a74f 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/MultiDenseVector.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/MultiDenseVector.java @@ -17,6 +17,12 @@ public interface MultiDenseVector { checkDimensions(getDims(), qvDims); } + float maxSimDotProduct(float[][] query); + + float maxSimDotProduct(byte[][] query); + + float maxSimInvHamming(byte[][] query); + Iterator getVectors(); float[] getMagnitudes(); @@ -63,6 +69,21 @@ public interface MultiDenseVector { throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); } + @Override + public float maxSimDotProduct(float[][] query) { + throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); + } + + @Override + public float maxSimDotProduct(byte[][] query) { + throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); + } + + @Override + public float maxSimInvHamming(byte[][] query) { + throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE); + } + @Override public int size() { return 0; diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/VectorIterator.java b/server/src/main/java/org/elasticsearch/script/field/vectors/VectorIterator.java new file mode 100644 index 000000000000..b8615ac87725 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/VectorIterator.java @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.script.field.vectors; + +import java.util.Iterator; + +public interface VectorIterator extends Iterator { + Iterator copy(); + + void reset(); + + static VectorIterator from(float[][] vectors) { + return new VectorIterator<>() { + private int i = 0; + + @Override + public boolean hasNext() { + return i < vectors.length; + } + + @Override + public float[] next() { + return vectors[i++]; + } + + @Override + public Iterator copy() { + return from(vectors); + } + + @Override + public void reset() { + i = 0; + } + }; + } + + static VectorIterator from(byte[][] vectors) { + return new VectorIterator<>() { + private int i = 0; + + @Override + public boolean hasNext() { + return i < vectors.length; + } + + @Override + public byte[] next() { + return vectors[i++]; + } + + @Override + public Iterator copy() { + return from(vectors); + } + + @Override + public void reset() { + i = 0; + } + }; + } +} diff --git a/server/src/test/java/org/elasticsearch/script/MultiVectorScoreScriptUtilsTests.java b/server/src/test/java/org/elasticsearch/script/MultiVectorScoreScriptUtilsTests.java new file mode 100644 index 000000000000..c4a1699181ef --- /dev/null +++ b/server/src/test/java/org/elasticsearch/script/MultiVectorScoreScriptUtilsTests.java @@ -0,0 +1,342 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.script; + +import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; +import org.elasticsearch.index.mapper.vectors.MultiDenseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.MultiDenseVectorScriptDocValuesTests; +import org.elasticsearch.script.MultiVectorScoreScriptUtils.MaxSimDotProduct; +import org.elasticsearch.script.MultiVectorScoreScriptUtils.MaxSimInvHamming; +import org.elasticsearch.script.field.vectors.BitMultiDenseVectorDocValuesField; +import org.elasticsearch.script.field.vectors.ByteMultiDenseVectorDocValuesField; +import org.elasticsearch.script.field.vectors.FloatMultiDenseVectorDocValuesField; +import org.elasticsearch.script.field.vectors.MultiDenseVectorDocValuesField; +import org.elasticsearch.test.ESTestCase; +import org.junit.BeforeClass; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HexFormat; +import java.util.List; + +import static org.hamcrest.Matchers.containsString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class MultiVectorScoreScriptUtilsTests extends ESTestCase { + + @BeforeClass + public static void setup() { + assumeTrue("Requires multi-dense vector support", MultiDenseVectorFieldMapper.FEATURE_FLAG.isEnabled()); + } + + public void testFloatMultiVectorClassBindings() throws IOException { + String fieldName = "vector"; + int dims = 5; + float[][][] docVectors = new float[][][] { + { { 230.0f, 300.33f, -34.8988f, 15.555f, -200.0f }, { 100.0f, 200.0f, -50.0f, 10.0f, -150.0f } } }; + float[][] docMagnitudes = new float[][] { { 0.0f, 0.0f } }; + for (int i = 0; i < docVectors.length; i++) { + for (int j = 0; j < docVectors[i].length; j++) { + docMagnitudes[i][j] = (float) Math.sqrt(VectorUtil.dotProduct(docVectors[i][j], docVectors[i][j])); + } + } + + List> queryVector = List.of(Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f)); + List> invalidQueryVector = List.of(Arrays.asList(0.5, 111.3)); + + List fields = List.of( + new FloatMultiDenseVectorDocValuesField( + MultiDenseVectorScriptDocValuesTests.wrap(docVectors, ElementType.FLOAT), + MultiDenseVectorScriptDocValuesTests.wrap(docMagnitudes), + "test", + ElementType.FLOAT, + dims + ), + new FloatMultiDenseVectorDocValuesField( + MultiDenseVectorScriptDocValuesTests.wrap(docVectors, ElementType.FLOAT), + MultiDenseVectorScriptDocValuesTests.wrap(docMagnitudes), + "test", + ElementType.FLOAT, + dims + ) + ); + for (MultiDenseVectorDocValuesField field : fields) { + field.setNextDocId(0); + + ScoreScript scoreScript = mock(ScoreScript.class); + when(scoreScript.field("vector")).thenAnswer(mock -> field); + + // Test max similarity dot product + MaxSimDotProduct maxSimDotProduct = new MaxSimDotProduct(scoreScript, queryVector, fieldName); + float maxSimDotProductExpected = 65425.625f; // Adjust this value based on expected max similarity + assertEquals( + "maxSimDotProduct result is not equal to the expected value!", + maxSimDotProductExpected, + maxSimDotProduct.maxSimDotProduct(), + 0.001 + ); + + // Check each function rejects query vectors with the wrong dimension + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> new MultiVectorScoreScriptUtils.MaxSimDotProduct(scoreScript, invalidQueryVector, fieldName) + ); + assertThat( + e.getMessage(), + containsString("query vector has a different number of dimensions [2] than the document vectors [5]") + ); + e = expectThrows(IllegalArgumentException.class, () -> new MaxSimInvHamming(scoreScript, invalidQueryVector, fieldName)); + assertThat(e.getMessage(), containsString("hamming distance is only supported for byte or bit vectors")); + + // Check scripting infrastructure integration + assertEquals(65425.6249, new MaxSimDotProduct(scoreScript, queryVector, fieldName).maxSimDotProduct(), 0.001); + when(scoreScript._getDocId()).thenReturn(1); + e = expectThrows( + IllegalArgumentException.class, + () -> new MaxSimDotProduct(scoreScript, queryVector, fieldName).maxSimDotProduct() + ); + assertEquals("A document doesn't have a value for a multi-vector field!", e.getMessage()); + } + } + + public void testByteMultiVectorClassBindings() throws IOException { + String fieldName = "vector"; + int dims = 5; + float[][] docVector = new float[][] { { 1, 127, -128, 5, -10 } }; + float[][] magnitudes = new float[][] { { 0 } }; + for (int i = 0; i < docVector.length; i++) { + magnitudes[i][0] = (float) Math.sqrt(VectorUtil.dotProduct(docVector[i], docVector[i])); + } + List> queryVector = List.of(Arrays.asList((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4)); + List> invalidQueryVector = List.of(Arrays.asList((byte) 1, (byte) 1)); + List hexidecimalString = List.of(HexFormat.of().formatHex(new byte[] { 1, 125, -12, 2, 4 })); + + List fields = List.of( + new ByteMultiDenseVectorDocValuesField( + MultiDenseVectorScriptDocValuesTests.wrap(new float[][][] { docVector }, ElementType.BYTE), + MultiDenseVectorScriptDocValuesTests.wrap(magnitudes), + "test", + ElementType.BYTE, + dims + ) + ); + for (MultiDenseVectorDocValuesField field : fields) { + field.setNextDocId(0); + + ScoreScript scoreScript = mock(ScoreScript.class); + when(scoreScript.field(fieldName)).thenAnswer(mock -> field); + + // Check each function rejects query vectors with the wrong dimension + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> new MaxSimDotProduct(scoreScript, invalidQueryVector, fieldName) + ); + assertThat( + e.getMessage(), + containsString("query vector has a different number of dimensions [2] than the document vectors [5]") + ); + e = expectThrows(IllegalArgumentException.class, () -> new MaxSimInvHamming(scoreScript, invalidQueryVector, fieldName)); + assertThat( + e.getMessage(), + containsString("query vector has a different number of dimensions [2] than the document vectors [5]") + ); + + // Check scripting infrastructure integration + assertEquals(17382.0, new MaxSimDotProduct(scoreScript, queryVector, fieldName).maxSimDotProduct(), 0.001); + assertEquals(17382.0, new MaxSimDotProduct(scoreScript, hexidecimalString, fieldName).maxSimDotProduct(), 0.001); + assertEquals(0.675, new MaxSimInvHamming(scoreScript, queryVector, fieldName).maxSimInvHamming(), 0.001); + assertEquals(0.675, new MaxSimInvHamming(scoreScript, hexidecimalString, fieldName).maxSimInvHamming(), 0.001); + MaxSimDotProduct maxSimDotProduct = new MaxSimDotProduct(scoreScript, queryVector, fieldName); + when(scoreScript._getDocId()).thenReturn(1); + e = expectThrows(IllegalArgumentException.class, maxSimDotProduct::maxSimDotProduct); + assertEquals("A document doesn't have a value for a multi-vector field!", e.getMessage()); + } + } + + public void testBitMultiVectorClassBindingsDotProduct() throws IOException { + String fieldName = "vector"; + int dims = 8; + float[][] docVector = new float[][] { { 124 } }; + // 124 in binary is b01111100 + List> queryVector = List.of( + Arrays.asList((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4, (byte) 1, (byte) 125, (byte) -12) + ); + List> floatQueryVector = List.of(Arrays.asList(1.4f, -1.4f, 0.42f, 0.0f, 1f, -1f, -0.42f, 1.2f)); + List> invalidQueryVector = List.of(Arrays.asList((byte) 1, (byte) 1)); + List hexidecimalString = List.of(HexFormat.of().formatHex(new byte[] { 124 })); + + List fields = List.of( + new BitMultiDenseVectorDocValuesField( + MultiDenseVectorScriptDocValuesTests.wrap(new float[][][] { docVector }, ElementType.BIT), + MultiDenseVectorScriptDocValuesTests.wrap(new float[][] { { 5 } }), + "test", + ElementType.BIT, + dims + ) + ); + for (MultiDenseVectorDocValuesField field : fields) { + field.setNextDocId(0); + + ScoreScript scoreScript = mock(ScoreScript.class); + when(scoreScript.field(fieldName)).thenAnswer(mock -> field); + + MaxSimDotProduct function = new MaxSimDotProduct(scoreScript, queryVector, fieldName); + assertEquals( + "maxSimDotProduct result is not equal to the expected value!", + -12 + 2 + 4 + 1 + 125, + function.maxSimDotProduct(), + 0.001 + ); + + function = new MaxSimDotProduct(scoreScript, floatQueryVector, fieldName); + assertEquals( + "maxSimDotProduct result is not equal to the expected value!", + 0.42f + 0f + 1f - 1f - 0.42f, + function.maxSimDotProduct(), + 0.001 + ); + + function = new MaxSimDotProduct(scoreScript, hexidecimalString, fieldName); + assertEquals( + "maxSimDotProduct result is not equal to the expected value!", + Integer.bitCount(124), + function.maxSimDotProduct(), + 0.0 + ); + + // Check each function rejects query vectors with the wrong dimension + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> new MaxSimDotProduct(scoreScript, invalidQueryVector, fieldName) + ); + assertThat( + e.getMessage(), + containsString( + "query vector contains inner vectors which have incorrect number of dimensions. " + + "Must be [1] for bitwise operations, or [8] for byte wise operations: provided [2]." + ) + ); + } + } + + public void testByteVsFloatSimilarity() throws IOException { + int dims = 5; + float[][] docVector = new float[][] { { 1f, 127f, -128f, 5f, -10f } }; + float[][] magnitudes = new float[][] { { 0 } }; + for (int i = 0; i < docVector.length; i++) { + magnitudes[i][0] = (float) Math.sqrt(VectorUtil.dotProduct(docVector[i], docVector[i])); + } + List> listFloatVector = List.of(Arrays.asList(1f, 125f, -12f, 2f, 4f)); + List> listByteVector = List.of(Arrays.asList((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4)); + float[][] floatVector = new float[][] { { 1f, 125f, -12f, 2f, 4f } }; + byte[][] byteVector = new byte[][] { { (byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4 } }; + + List fields = List.of( + new FloatMultiDenseVectorDocValuesField( + MultiDenseVectorScriptDocValuesTests.wrap(new float[][][] { docVector }, ElementType.FLOAT), + MultiDenseVectorScriptDocValuesTests.wrap(magnitudes), + "field1", + ElementType.FLOAT, + dims + ), + new ByteMultiDenseVectorDocValuesField( + MultiDenseVectorScriptDocValuesTests.wrap(new float[][][] { docVector }, ElementType.BYTE), + MultiDenseVectorScriptDocValuesTests.wrap(magnitudes), + "field3", + ElementType.BYTE, + dims + ) + ); + for (MultiDenseVectorDocValuesField field : fields) { + field.setNextDocId(0); + + ScoreScript scoreScript = mock(ScoreScript.class); + when(scoreScript.field("vector")).thenAnswer(mock -> field); + + int dotProductExpected = 17382; + MaxSimDotProduct maxSimDotProduct = new MaxSimDotProduct(scoreScript, listFloatVector, "vector"); + assertEquals(field.getName(), dotProductExpected, maxSimDotProduct.maxSimDotProduct(), 0.001); + maxSimDotProduct = new MaxSimDotProduct(scoreScript, listByteVector, "vector"); + assertEquals(field.getName(), dotProductExpected, maxSimDotProduct.maxSimDotProduct(), 0.001); + switch (field.getElementType()) { + case BYTE -> { + assertEquals(field.getName(), dotProductExpected, field.get().maxSimDotProduct(byteVector), 0.001); + UnsupportedOperationException e = expectThrows( + UnsupportedOperationException.class, + () -> field.get().maxSimDotProduct(floatVector) + ); + assertThat(e.getMessage(), containsString("use [float maxSimDotProduct(byte[][] queryVector)] instead")); + } + case FLOAT -> { + assertEquals(field.getName(), dotProductExpected, field.get().maxSimDotProduct(floatVector), 0.001); + UnsupportedOperationException e = expectThrows( + UnsupportedOperationException.class, + () -> field.get().maxSimDotProduct(byteVector) + ); + assertThat(e.getMessage(), containsString("use [float maxSimDotProduct(float[][] queryVector)] instead")); + } + } + } + } + + public void testByteBoundaries() throws IOException { + String fieldName = "vector"; + int dims = 1; + float[] docVector = new float[] { 0 }; + List> greaterThanVector = List.of(List.of(128)); + List> lessThanVector = List.of(List.of(-129)); + List> decimalVector = List.of(List.of(0.5)); + + List fields = List.of( + new ByteMultiDenseVectorDocValuesField( + MultiDenseVectorScriptDocValuesTests.wrap(new float[][][] { { docVector } }, ElementType.BYTE), + MultiDenseVectorScriptDocValuesTests.wrap(new float[][] { { 1 } }), + "test", + ElementType.BYTE, + dims + ) + ); + + for (MultiDenseVectorDocValuesField field : fields) { + field.setNextDocId(0); + + ScoreScript scoreScript = mock(ScoreScript.class); + when(scoreScript.field(fieldName)).thenAnswer(mock -> field); + + IllegalArgumentException e; + + e = expectThrows(IllegalArgumentException.class, () -> new MaxSimDotProduct(scoreScript, greaterThanVector, fieldName)); + assertEquals( + "element_type [byte] vectors only support integers between [-128, 127] but found [128.0] at dim [0]; " + + "Preview of invalid vector: [128.0]", + e.getMessage() + ); + + e = expectThrows(IllegalArgumentException.class, () -> new MaxSimDotProduct(scoreScript, lessThanVector, fieldName)); + assertEquals( + e.getMessage(), + "element_type [byte] vectors only support integers between [-128, 127] but found [-129.0] at dim [0]; " + + "Preview of invalid vector: [-129.0]" + ); + e = expectThrows(IllegalArgumentException.class, () -> new MaxSimDotProduct(scoreScript, decimalVector, fieldName)); + assertEquals( + e.getMessage(), + "element_type [byte] vectors only support non-decimal values but found decimal value [0.5] at dim [0]; " + + "Preview of invalid vector: [0.5]" + ); + } + } + + public void testDimMismatch() throws IOException { + + } +} diff --git a/server/src/test/java/org/elasticsearch/script/field/vectors/MultiDenseVectorTests.java b/server/src/test/java/org/elasticsearch/script/field/vectors/MultiDenseVectorTests.java new file mode 100644 index 000000000000..12f4b931b4d0 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/script/field/vectors/MultiDenseVectorTests.java @@ -0,0 +1,83 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.script.field.vectors; + +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.index.mapper.vectors.MultiDenseVectorFieldMapper; +import org.elasticsearch.test.ESTestCase; +import org.junit.BeforeClass; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.function.IntFunction; + +public class MultiDenseVectorTests extends ESTestCase { + + @BeforeClass + public static void setup() { + assumeTrue("Requires multi-dense vector support", MultiDenseVectorFieldMapper.FEATURE_FLAG.isEnabled()); + } + + public void testByteUnsupported() { + int count = randomIntBetween(1, 16); + int dims = randomIntBetween(1, 16); + byte[][] docVector = new byte[count][dims]; + float[][] queryVector = new float[count][dims]; + for (int i = 0; i < docVector.length; i++) { + random().nextBytes(docVector[i]); + for (int j = 0; j < dims; j++) { + queryVector[i][j] = randomFloat(); + } + } + + MultiDenseVector knn = newByteVector(docVector); + UnsupportedOperationException e; + + e = expectThrows(UnsupportedOperationException.class, () -> knn.maxSimDotProduct(queryVector)); + assertEquals(e.getMessage(), "use [float maxSimDotProduct(byte[][] queryVector)] instead"); + } + + public void testFloatUnsupported() { + int count = randomIntBetween(1, 16); + int dims = randomIntBetween(1, 16); + float[][] docVector = new float[count][dims]; + byte[][] queryVector = new byte[count][dims]; + for (int i = 0; i < docVector.length; i++) { + random().nextBytes(queryVector[i]); + for (int j = 0; j < dims; j++) { + docVector[i][j] = randomFloat(); + } + } + + MultiDenseVector knn = newFloatVector(docVector); + + UnsupportedOperationException e = expectThrows(UnsupportedOperationException.class, () -> knn.maxSimDotProduct(queryVector)); + assertEquals(e.getMessage(), "use [float maxSimDotProduct(float[][] queryVector)] instead"); + } + + static MultiDenseVector newFloatVector(float[][] vector) { + BytesRef magnitudes = magnitudes(vector.length, i -> (float) Math.sqrt(VectorUtil.dotProduct(vector[i], vector[i]))); + return new FloatMultiDenseVector(VectorIterator.from(vector), magnitudes, vector.length, vector[0].length); + } + + static MultiDenseVector newByteVector(byte[][] vector) { + BytesRef magnitudes = magnitudes(vector.length, i -> (float) Math.sqrt(VectorUtil.dotProduct(vector[i], vector[i]))); + return new ByteMultiDenseVector(VectorIterator.from(vector), magnitudes, vector.length, vector[0].length); + } + + static BytesRef magnitudes(int count, IntFunction magnitude) { + ByteBuffer magnitudeBuffer = ByteBuffer.allocate(count * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (int i = 0; i < count; i++) { + magnitudeBuffer.putFloat(magnitude.apply(i)); + } + return new BytesRef(magnitudeBuffer.array()); + } +}