Adds maxSim functions for multi_dense_vector fields (#116993)

This adds `maxSim` functions, specifically dotProduct and InvHamming.
Why these two you might ask? Well, they are the best approximations of
whats possible with Col* late interaction type models. Effectively, you
want a similarity metric where "greater == better". Regular `hamming`
isn't exactly that, but inverting that (just like our `element_type:
bit` index for dense_vectors), is a nice approximation with bit vectors
and multi-vector scoring.

Then, of course, dotProduct is another usage. We will allow dot-product
between like elements (bytes -> bytes, floats -> floats) and of course,
allow `floats -> bit`, where the stored `bit` elements are applied as a
"mask" over the float queries. This allows for some nice asymmetric
interactions.

This is all behind a feature flag, and I need to write a mountain of
docs in a separate PR.
This commit is contained in:
Benjamin Trent 2024-11-20 16:41:55 -05:00 committed by GitHub
parent abcdbf27b2
commit e68f31754c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 1274 additions and 16 deletions

View file

@ -50,5 +50,7 @@ static_import {
double cosineSimilarity(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$CosineSimilarity
double dotProduct(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$DotProduct
double hamming(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$Hamming
double maxSimDotProduct(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.MultiVectorScoreScriptUtils$MaxSimDotProduct
double maxSimInvHamming(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.MultiVectorScoreScriptUtils$MaxSimInvHamming
}

View file

@ -0,0 +1,206 @@
setup:
- requires:
capabilities:
- method: POST
path: /_search
capabilities: [ multi_dense_vector_script_max_sim ]
test_runner_features: capabilities
reason: "Support for multi dense vector max-sim functions capability required"
- skip:
features: headers
- do:
indices.create:
index: test-index
body:
settings:
number_of_shards: 1
mappings:
properties:
vector:
type: multi_dense_vector
dims: 5
byte_vector:
type: multi_dense_vector
dims: 5
element_type: byte
bit_vector:
type: multi_dense_vector
dims: 40
element_type: bit
- do:
index:
index: test-index
id: "1"
body:
vector: [[230.0, 300.33, -34.8988, 15.555, -200.0], [-0.5, 100.0, -13, 14.8, -156.0]]
byte_vector: [[8, 5, -15, 1, -7], [-1, 115, -3, 4, -128]]
bit_vector: [[8, 5, -15, 1, -7], [-1, 115, -3, 4, -128]]
- do:
index:
index: test-index
id: "3"
body:
vector: [[0.5, 111.3, -13.0, 14.8, -156.0]]
byte_vector: [[2, 18, -5, 0, -124]]
bit_vector: [[2, 18, -5, 0, -124]]
- do:
indices.refresh: {}
---
"Test max-sim dot product scoring":
- skip:
features: close_to
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "maxSimDotProduct(params.query_vector, 'vector')"
params:
query_vector: [[1, 2, 1, 1, 1]]
- match: {hits.total: 2}
- match: {hits.hits.0._id: "1"}
- close_to: {hits.hits.0._score: {value: 611.316, error: 0.01}}
- match: {hits.hits.1._id: "3"}
- close_to: {hits.hits.1._score: {value: 68.90001, error: 0.01}}
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "maxSimDotProduct(params.query_vector, 'byte_vector')"
params:
query_vector: [[1, 2, 1, 1, 0]]
- match: {hits.total: 2}
- match: {hits.hits.0._id: "1"}
- close_to: {hits.hits.0._score: {value: 230, error: 0.01}}
- match: {hits.hits.1._id: "3"}
- close_to: {hits.hits.1._score: {value: 33, error: 0.01}}
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "maxSimDotProduct(params.query_vector, 'bit_vector')"
params:
query_vector: [[1, 2, 1, 1, 0]]
- match: {hits.total: 2}
- match: {hits.hits.0._id: "1"}
- close_to: {hits.hits.0._score: {value: 3, error: 0.01}}
- match: {hits.hits.1._id: "3"}
- close_to: {hits.hits.1._score: {value: 2, error: 0.01}}
# doing max-sim dot product with a vector where the stored bit vectors are used as masks
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "maxSimDotProduct(params.query_vector, 'bit_vector')"
params:
query_vector: [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]]
- match: {hits.total: 2}
- match: {hits.hits.0._id: "1"}
- close_to: {hits.hits.0._score: {value: 190, error: 0.01}}
- match: {hits.hits.1._id: "3"}
- close_to: {hits.hits.1._score: {value: 125, error: 0.01}}
---
"Test max-sim inv hamming scoring":
- skip:
features: close_to
# inv hamming doesn't apply to float vectors
- do:
catch: bad_request
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "maxSimInvHamming(params.query_vector, 'vector')"
params:
query_vector: [[1, 2, 1, 1, 1]]
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "maxSimInvHamming(params.query_vector, 'byte_vector')"
params:
query_vector: [[1, 2, 1, 1, 1]]
- match: {hits.total: 2}
- match: {hits.hits.0._id: "3"}
- close_to: {hits.hits.0._score: {value: 0.675, error: 0.01}}
- match: {hits.hits.1._id: "1"}
- close_to: {hits.hits.1._score: {value: 0.65, error: 0.01}}
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "maxSimInvHamming(params.query_vector, 'bit_vector')"
params:
query_vector: [[1, 2, 1, 1, 1]]
- match: {hits.total: 2}
- match: {hits.hits.0._id: "3"}
- close_to: {hits.hits.0._score: {value: 0.675, error: 0.01}}
- match: {hits.hits.1._id: "1"}
- close_to: {hits.hits.1._score: {value: 0.65, error: 0.01}}

View file

@ -40,6 +40,8 @@ public final class SearchCapabilities {
private static final String NESTED_RETRIEVER_INNER_HITS_SUPPORT = "nested_retriever_inner_hits_support";
/** Support multi-dense-vector script field access. */
private static final String MULTI_DENSE_VECTOR_SCRIPT_ACCESS = "multi_dense_vector_script_access";
/** Initial support for multi-dense-vector maxSim functions access. */
private static final String MULTI_DENSE_VECTOR_SCRIPT_MAX_SIM = "multi_dense_vector_script_max_sim";
private static final String RANDOM_SAMPLER_WITH_SCORED_SUBAGGS = "random_sampler_with_scored_subaggs";
@ -56,6 +58,7 @@ public final class SearchCapabilities {
if (MultiDenseVectorFieldMapper.FEATURE_FLAG.isEnabled()) {
capabilities.add(MULTI_DENSE_VECTOR_FIELD_MAPPER);
capabilities.add(MULTI_DENSE_VECTOR_SCRIPT_ACCESS);
capabilities.add(MULTI_DENSE_VECTOR_SCRIPT_MAX_SIM);
}
if (Build.current().isSnapshot()) {
capabilities.add(KQL_QUERY_SUPPORTED);

View file

@ -0,0 +1,372 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.script;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.script.field.vectors.DenseVector;
import org.elasticsearch.script.field.vectors.MultiDenseVectorDocValuesField;
import java.io.IOException;
import java.util.HexFormat;
import java.util.List;
public class MultiVectorScoreScriptUtils {
public static class MultiDenseVectorFunction {
protected final ScoreScript scoreScript;
protected final MultiDenseVectorDocValuesField field;
public MultiDenseVectorFunction(ScoreScript scoreScript, MultiDenseVectorDocValuesField field) {
this.scoreScript = scoreScript;
this.field = field;
}
void setNextVector() {
try {
field.setNextDocId(scoreScript._getDocId());
} catch (IOException e) {
throw ExceptionsHelper.convertToElastic(e);
}
if (field.isEmpty()) {
throw new IllegalArgumentException("A document doesn't have a value for a multi-vector field!");
}
}
}
public static class ByteMultiDenseVectorFunction extends MultiDenseVectorFunction {
protected final byte[][] queryVector;
/**
* Constructs a dense vector function used for byte-sized vectors.
*
* @param scoreScript The script in which this function was referenced.
* @param field The vector field.
* @param queryVector The query vector.
*/
public ByteMultiDenseVectorFunction(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, List<List<Number>> queryVector) {
super(scoreScript, field);
if (queryVector.isEmpty()) {
throw new IllegalArgumentException("The query vector is empty.");
}
field.getElementType().checkDimensions(field.get().getDims(), queryVector.get(0).size());
this.queryVector = new byte[queryVector.size()][queryVector.get(0).size()];
float[] validateValues = new float[queryVector.size()];
int lastSize = -1;
for (int i = 0; i < queryVector.size(); i++) {
if (lastSize != -1 && lastSize != queryVector.get(i).size()) {
throw new IllegalArgumentException(
"The query vector contains inner vectors which have inconsistent number of dimensions."
);
}
lastSize = queryVector.get(i).size();
for (int j = 0; j < queryVector.get(i).size(); j++) {
final Number number = queryVector.get(i).get(j);
byte value = number.byteValue();
this.queryVector[i][j] = value;
validateValues[i] = number.floatValue();
}
field.getElementType().checkVectorBounds(validateValues);
}
}
/**
* Constructs a dense vector function used for byte-sized vectors.
*
* @param scoreScript The script in which this function was referenced.
* @param field The vector field.
* @param queryVector The query vector.
*/
public ByteMultiDenseVectorFunction(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, byte[][] queryVector) {
super(scoreScript, field);
this.queryVector = queryVector;
}
}
public static class FloatMultiDenseVectorFunction extends MultiDenseVectorFunction {
protected final float[][] queryVector;
/**
* Constructs a dense vector function used for float vectors.
*
* @param scoreScript The script in which this function was referenced.
* @param field The vector field.
* @param queryVector The query vector.
*/
public FloatMultiDenseVectorFunction(
ScoreScript scoreScript,
MultiDenseVectorDocValuesField field,
List<List<Number>> queryVector
) {
super(scoreScript, field);
if (queryVector.isEmpty()) {
throw new IllegalArgumentException("The query vector is empty.");
}
DenseVector.checkDimensions(field.get().getDims(), queryVector.get(0).size());
this.queryVector = new float[queryVector.size()][queryVector.get(0).size()];
int lastSize = -1;
for (int i = 0; i < queryVector.size(); i++) {
if (lastSize != -1 && lastSize != queryVector.get(i).size()) {
throw new IllegalArgumentException(
"The query vector contains inner vectors which have inconsistent number of dimensions."
);
}
lastSize = queryVector.get(i).size();
for (int j = 0; j < queryVector.get(i).size(); j++) {
this.queryVector[i][j] = queryVector.get(i).get(j).floatValue();
}
field.getElementType().checkVectorBounds(this.queryVector[i]);
}
}
}
// Calculate Hamming distances between a query's dense vector and documents' dense vectors
public interface MaxSimInvHammingDistanceInterface {
float maxSimInvHamming();
}
public static class ByteMaxSimInvHammingDistance extends ByteMultiDenseVectorFunction implements MaxSimInvHammingDistanceInterface {
public ByteMaxSimInvHammingDistance(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, List<List<Number>> queryVector) {
super(scoreScript, field, queryVector);
}
public ByteMaxSimInvHammingDistance(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, byte[][] queryVector) {
super(scoreScript, field, queryVector);
}
public float maxSimInvHamming() {
setNextVector();
return field.get().maxSimInvHamming(queryVector);
}
}
private record BytesOrList(byte[][] bytes, List<List<Number>> list) {}
@SuppressWarnings("unchecked")
private static BytesOrList parseBytes(Object queryVector) {
if (queryVector instanceof List) {
// check if its a list of strings or list of lists
if (((List<?>) queryVector).get(0) instanceof List) {
return new BytesOrList(null, ((List<List<Number>>) queryVector));
} else if (((List<?>) queryVector).get(0) instanceof String) {
byte[][] parsedQueryVector = new byte[((List<?>) queryVector).size()][];
int lastSize = -1;
for (int i = 0; i < ((List<?>) queryVector).size(); i++) {
parsedQueryVector[i] = HexFormat.of().parseHex((String) ((List<?>) queryVector).get(i));
if (lastSize != -1 && lastSize != parsedQueryVector[i].length) {
throw new IllegalArgumentException(
"The query vector contains inner vectors which have inconsistent number of dimensions."
);
}
lastSize = parsedQueryVector[i].length;
}
return new BytesOrList(parsedQueryVector, null);
} else {
throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName());
}
} else {
throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName());
}
}
public static final class MaxSimInvHamming {
private final MaxSimInvHammingDistanceInterface function;
public MaxSimInvHamming(ScoreScript scoreScript, Object queryVector, String fieldName) {
MultiDenseVectorDocValuesField field = (MultiDenseVectorDocValuesField) scoreScript.field(fieldName);
if (field.getElementType() == DenseVectorFieldMapper.ElementType.FLOAT) {
throw new IllegalArgumentException("hamming distance is only supported for byte or bit vectors");
}
BytesOrList bytesOrList = parseBytes(queryVector);
if (bytesOrList.bytes != null) {
this.function = new ByteMaxSimInvHammingDistance(scoreScript, field, bytesOrList.bytes);
} else {
this.function = new ByteMaxSimInvHammingDistance(scoreScript, field, bytesOrList.list);
}
}
public double maxSimInvHamming() {
return function.maxSimInvHamming();
}
}
// Calculate a dot product between a query's dense vector and documents' dense vectors
public interface MaxSimDotProductInterface {
double maxSimDotProduct();
}
public static class MaxSimBitDotProduct extends MultiDenseVectorFunction implements MaxSimDotProductInterface {
private final byte[][] byteQueryVector;
private final float[][] floatQueryVector;
public MaxSimBitDotProduct(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, byte[][] queryVector) {
super(scoreScript, field);
if (field.getElementType() != DenseVectorFieldMapper.ElementType.BIT) {
throw new IllegalArgumentException("Cannot calculate bit dot product for non-bit vectors");
}
int fieldDims = field.get().getDims();
if (fieldDims != queryVector.length * Byte.SIZE && fieldDims != queryVector.length) {
throw new IllegalArgumentException(
"The query vector has an incorrect number of dimensions. Must be ["
+ fieldDims / 8
+ "] for bitwise operations, or ["
+ fieldDims
+ "] for byte wise operations: provided ["
+ queryVector.length
+ "]."
);
}
this.byteQueryVector = queryVector;
this.floatQueryVector = null;
}
public MaxSimBitDotProduct(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, List<List<Number>> queryVector) {
super(scoreScript, field);
if (queryVector.isEmpty()) {
throw new IllegalArgumentException("The query vector is empty.");
}
if (field.getElementType() != DenseVectorFieldMapper.ElementType.BIT) {
throw new IllegalArgumentException("cannot calculate bit dot product for non-bit vectors");
}
float[][] floatQueryVector = new float[queryVector.size()][];
byte[][] byteQueryVector = new byte[queryVector.size()][];
boolean isFloat = false;
int lastSize = -1;
for (int i = 0; i < queryVector.size(); i++) {
if (lastSize != -1 && lastSize != queryVector.get(i).size()) {
throw new IllegalArgumentException(
"The query vector contains inner vectors which have inconsistent number of dimensions."
);
}
lastSize = queryVector.get(i).size();
floatQueryVector[i] = new float[queryVector.get(i).size()];
if (isFloat == false) {
byteQueryVector[i] = new byte[queryVector.get(i).size()];
}
for (int j = 0; j < queryVector.get(i).size(); j++) {
Number number = queryVector.get(i).get(j);
floatQueryVector[i][j] = number.floatValue();
if (isFloat == false) {
byteQueryVector[i][j] = number.byteValue();
}
if (isFloat
|| floatQueryVector[i][j] % 1.0f != 0.0f
|| floatQueryVector[i][j] < Byte.MIN_VALUE
|| floatQueryVector[i][j] > Byte.MAX_VALUE) {
isFloat = true;
}
}
}
int fieldDims = field.get().getDims();
if (isFloat) {
this.floatQueryVector = floatQueryVector;
this.byteQueryVector = null;
if (fieldDims != floatQueryVector[0].length) {
throw new IllegalArgumentException(
"The query vector contains inner vectors which have incorrect number of dimensions. Must be ["
+ fieldDims
+ "] for float wise operations: provided ["
+ floatQueryVector[0].length
+ "]."
);
}
} else {
this.floatQueryVector = null;
this.byteQueryVector = byteQueryVector;
if (fieldDims != byteQueryVector[0].length * Byte.SIZE && fieldDims != byteQueryVector[0].length) {
throw new IllegalArgumentException(
"The query vector contains inner vectors which have incorrect number of dimensions. Must be ["
+ fieldDims / 8
+ "] for bitwise operations, or ["
+ fieldDims
+ "] for byte wise operations: provided ["
+ byteQueryVector[0].length
+ "]."
);
}
}
}
@Override
public double maxSimDotProduct() {
setNextVector();
return byteQueryVector != null ? field.get().maxSimDotProduct(byteQueryVector) : field.get().maxSimDotProduct(floatQueryVector);
}
}
public static class MaxSimByteDotProduct extends ByteMultiDenseVectorFunction implements MaxSimDotProductInterface {
public MaxSimByteDotProduct(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, List<List<Number>> queryVector) {
super(scoreScript, field, queryVector);
}
public MaxSimByteDotProduct(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, byte[][] queryVector) {
super(scoreScript, field, queryVector);
}
public double maxSimDotProduct() {
setNextVector();
return field.get().maxSimDotProduct(queryVector);
}
}
public static class MaxSimFloatDotProduct extends FloatMultiDenseVectorFunction implements MaxSimDotProductInterface {
public MaxSimFloatDotProduct(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, List<List<Number>> queryVector) {
super(scoreScript, field, queryVector);
}
public double maxSimDotProduct() {
setNextVector();
return field.get().maxSimDotProduct(queryVector);
}
}
public static final class MaxSimDotProduct {
private final MaxSimDotProductInterface function;
@SuppressWarnings("unchecked")
public MaxSimDotProduct(ScoreScript scoreScript, Object queryVector, String fieldName) {
MultiDenseVectorDocValuesField field = (MultiDenseVectorDocValuesField) scoreScript.field(fieldName);
function = switch (field.getElementType()) {
case BIT -> {
BytesOrList bytesOrList = parseBytes(queryVector);
if (bytesOrList.bytes != null) {
yield new MaxSimBitDotProduct(scoreScript, field, bytesOrList.bytes);
} else {
yield new MaxSimBitDotProduct(scoreScript, field, bytesOrList.list);
}
}
case BYTE -> {
BytesOrList bytesOrList = parseBytes(queryVector);
if (bytesOrList.bytes != null) {
yield new MaxSimByteDotProduct(scoreScript, field, bytesOrList.bytes);
} else {
yield new MaxSimByteDotProduct(scoreScript, field, bytesOrList.list);
}
}
case FLOAT -> {
if (queryVector instanceof List) {
yield new MaxSimFloatDotProduct(scoreScript, field, (List<List<Number>>) queryVector);
}
throw new IllegalArgumentException("Unsupported input object for float vectors: " + queryVector.getClass().getName());
}
};
}
public double maxSimDotProduct() {
return function.maxSimDotProduct();
}
}
}

View file

@ -10,11 +10,13 @@
package org.elasticsearch.script.field.vectors;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.simdvec.ESVectorUtil;
import java.util.Iterator;
import java.util.Arrays;
public class BitMultiDenseVector extends ByteMultiDenseVector {
public BitMultiDenseVector(Iterator<byte[]> vectorValues, BytesRef magnitudesBytes, int numVecs, int dims) {
public BitMultiDenseVector(VectorIterator<byte[]> vectorValues, BytesRef magnitudesBytes, int numVecs, int dims) {
super(vectorValues, magnitudesBytes, numVecs, dims);
}
@ -31,6 +33,70 @@ public class BitMultiDenseVector extends ByteMultiDenseVector {
}
}
@Override
public float maxSimDotProduct(float[][] query) {
vectorValues.reset();
float[] maxes = new float[query.length];
Arrays.fill(maxes, Float.NEGATIVE_INFINITY);
while (vectorValues.hasNext()) {
byte[] vv = vectorValues.next();
for (int i = 0; i < query.length; i++) {
maxes[i] = Math.max(maxes[i], ESVectorUtil.ipFloatBit(query[i], vv));
}
}
float sums = 0;
for (float m : maxes) {
sums += m;
}
return sums;
}
@Override
public float maxSimDotProduct(byte[][] query) {
vectorValues.reset();
float[] maxes = new float[query.length];
Arrays.fill(maxes, Float.NEGATIVE_INFINITY);
if (query[0].length == dims) {
while (vectorValues.hasNext()) {
byte[] vv = vectorValues.next();
for (int i = 0; i < query.length; i++) {
maxes[i] = Math.max(maxes[i], ESVectorUtil.andBitCount(query[i], vv));
}
}
} else {
while (vectorValues.hasNext()) {
byte[] vv = vectorValues.next();
for (int i = 0; i < query.length; i++) {
maxes[i] = Math.max(maxes[i], ESVectorUtil.ipByteBit(query[i], vv));
}
}
}
float sum = 0;
for (float m : maxes) {
sum += m;
}
return sum;
}
@Override
public float maxSimInvHamming(byte[][] query) {
vectorValues.reset();
int bitCount = this.getDims();
float[] maxes = new float[query.length];
Arrays.fill(maxes, Float.NEGATIVE_INFINITY);
while (vectorValues.hasNext()) {
byte[] vv = vectorValues.next();
for (int i = 0; i < query.length; i++) {
maxes[i] = Math.max(maxes[i], ((bitCount - VectorUtil.xorBitCount(vv, query[i])) / (float) bitCount));
}
}
float sum = 0;
for (float m : maxes) {
sum += m;
}
return sum;
}
@Override
public int getDims() {
return dims * Byte.SIZE;

View file

@ -10,21 +10,22 @@
package org.elasticsearch.script.field.vectors;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder;
import java.util.Arrays;
import java.util.Iterator;
public class ByteMultiDenseVector implements MultiDenseVector {
protected final Iterator<byte[]> vectorValues;
protected final VectorIterator<byte[]> vectorValues;
protected final int numVecs;
protected final int dims;
private Iterator<float[]> floatDocVectors;
private float[] magnitudes;
private final BytesRef magnitudesBytes;
public ByteMultiDenseVector(Iterator<byte[]> vectorValues, BytesRef magnitudesBytes, int numVecs, int dims) {
public ByteMultiDenseVector(VectorIterator<byte[]> vectorValues, BytesRef magnitudesBytes, int numVecs, int dims) {
assert magnitudesBytes.length == numVecs * Float.BYTES;
this.vectorValues = vectorValues;
this.numVecs = numVecs;
@ -33,11 +34,50 @@ public class ByteMultiDenseVector implements MultiDenseVector {
}
@Override
public Iterator<float[]> getVectors() {
if (floatDocVectors == null) {
floatDocVectors = new ByteToFloatIteratorWrapper(vectorValues, dims);
public float maxSimDotProduct(float[][] query) {
throw new UnsupportedOperationException("use [float maxSimDotProduct(byte[][] queryVector)] instead");
}
return floatDocVectors;
@Override
public float maxSimDotProduct(byte[][] query) {
vectorValues.reset();
float[] maxes = new float[query.length];
Arrays.fill(maxes, Float.NEGATIVE_INFINITY);
while (vectorValues.hasNext()) {
byte[] vv = vectorValues.next();
for (int i = 0; i < query.length; i++) {
maxes[i] = Math.max(maxes[i], VectorUtil.dotProduct(query[i], vv));
}
}
float sum = 0;
for (float m : maxes) {
sum += m;
}
return sum;
}
@Override
public float maxSimInvHamming(byte[][] query) {
vectorValues.reset();
int bitCount = dims * Byte.SIZE;
float[] maxes = new float[query.length];
Arrays.fill(maxes, Float.NEGATIVE_INFINITY);
while (vectorValues.hasNext()) {
byte[] vv = vectorValues.next();
for (int i = 0; i < query.length; i++) {
maxes[i] = Math.max(maxes[i], ((bitCount - VectorUtil.xorBitCount(vv, query[i])) / (float) bitCount));
}
}
float sum = 0;
for (float m : maxes) {
sum += m;
}
return sum;
}
@Override
public Iterator<float[]> getVectors() {
return new ByteToFloatIteratorWrapper(vectorValues.copy(), dims);
}
@Override

View file

@ -23,7 +23,7 @@ public class ByteMultiDenseVectorDocValuesField extends MultiDenseVectorDocValue
private final BinaryDocValues magnitudes;
protected final int dims;
protected int numVecs;
protected Iterator<byte[]> vectorValue;
protected VectorIterator<byte[]> vectorValue;
protected boolean decoded;
protected BytesRef value;
protected BytesRef magnitudesValue;
@ -111,7 +111,7 @@ public class ByteMultiDenseVectorDocValuesField extends MultiDenseVectorDocValue
return value == null;
}
static class ByteVectorIterator implements Iterator<byte[]> {
static class ByteVectorIterator implements VectorIterator<byte[]> {
private final byte[] buffer;
private final BytesRef vectorValues;
private final int size;
@ -138,5 +138,15 @@ public class ByteMultiDenseVectorDocValuesField extends MultiDenseVectorDocValue
idx++;
return buffer;
}
@Override
public Iterator<byte[]> copy() {
return new ByteVectorIterator(vectorValues, new byte[buffer.length], size);
}
@Override
public void reset() {
idx = 0;
}
}
}

View file

@ -10,7 +10,9 @@
package org.elasticsearch.script.field.vectors;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.VectorUtil;
import java.util.Arrays;
import java.util.Iterator;
import static org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder.getMultiMagnitudes;
@ -21,19 +23,47 @@ public class FloatMultiDenseVector implements MultiDenseVector {
private float[] magnitudesArray = null;
private final int dims;
private final int numVectors;
private final Iterator<float[]> decodedDocVector;
private final VectorIterator<float[]> vectorValues;
public FloatMultiDenseVector(Iterator<float[]> decodedDocVector, BytesRef magnitudes, int numVectors, int dims) {
public FloatMultiDenseVector(VectorIterator<float[]> decodedDocVector, BytesRef magnitudes, int numVectors, int dims) {
assert magnitudes.length == numVectors * Float.BYTES;
this.decodedDocVector = decodedDocVector;
this.vectorValues = decodedDocVector;
this.magnitudes = magnitudes;
this.numVectors = numVectors;
this.dims = dims;
}
@Override
public float maxSimDotProduct(float[][] query) {
vectorValues.reset();
float[] maxes = new float[query.length];
Arrays.fill(maxes, Float.NEGATIVE_INFINITY);
while (vectorValues.hasNext()) {
float[] vv = vectorValues.next();
for (int i = 0; i < query.length; i++) {
maxes[i] = Math.max(maxes[i], VectorUtil.dotProduct(query[i], vv));
}
}
float sum = 0;
for (float m : maxes) {
sum += m;
}
return sum;
}
@Override
public float maxSimDotProduct(byte[][] query) {
throw new UnsupportedOperationException("use [float maxSimDotProduct(float[][] queryVector)] instead");
}
@Override
public float maxSimInvHamming(byte[][] query) {
throw new UnsupportedOperationException("hamming distance is not supported for float vectors");
}
@Override
public Iterator<float[]> getVectors() {
return decodedDocVector;
return vectorValues.copy();
}
@Override

View file

@ -110,14 +110,16 @@ public class FloatMultiDenseVectorDocValuesField extends MultiDenseVectorDocValu
}
}
static class FloatVectorIterator implements Iterator<float[]> {
static class FloatVectorIterator implements VectorIterator<float[]> {
private final float[] buffer;
private final FloatBuffer vectorValues;
private final BytesRef vectorValueBytesRef;
private final int size;
private int idx = 0;
FloatVectorIterator(BytesRef vectorValues, float[] buffer, int size) {
assert vectorValues.length == (buffer.length * Float.BYTES * size);
this.vectorValueBytesRef = vectorValues;
this.vectorValues = ByteBuffer.wrap(vectorValues.bytes, vectorValues.offset, vectorValues.length)
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();
@ -139,5 +141,16 @@ public class FloatMultiDenseVectorDocValuesField extends MultiDenseVectorDocValu
idx++;
return buffer;
}
@Override
public Iterator<float[]> copy() {
return new FloatVectorIterator(vectorValueBytesRef, new float[buffer.length], size);
}
@Override
public void reset() {
idx = 0;
vectorValues.rewind();
}
}
}

View file

@ -17,6 +17,12 @@ public interface MultiDenseVector {
checkDimensions(getDims(), qvDims);
}
float maxSimDotProduct(float[][] query);
float maxSimDotProduct(byte[][] query);
float maxSimInvHamming(byte[][] query);
Iterator<float[]> getVectors();
float[] getMagnitudes();
@ -63,6 +69,21 @@ public interface MultiDenseVector {
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
}
@Override
public float maxSimDotProduct(float[][] query) {
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
}
@Override
public float maxSimDotProduct(byte[][] query) {
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
}
@Override
public float maxSimInvHamming(byte[][] query) {
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
}
@Override
public int size() {
return 0;

View file

@ -0,0 +1,70 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.script.field.vectors;
import java.util.Iterator;
public interface VectorIterator<E> extends Iterator<E> {
Iterator<E> copy();
void reset();
static VectorIterator<float[]> from(float[][] vectors) {
return new VectorIterator<>() {
private int i = 0;
@Override
public boolean hasNext() {
return i < vectors.length;
}
@Override
public float[] next() {
return vectors[i++];
}
@Override
public Iterator<float[]> copy() {
return from(vectors);
}
@Override
public void reset() {
i = 0;
}
};
}
static VectorIterator<byte[]> from(byte[][] vectors) {
return new VectorIterator<>() {
private int i = 0;
@Override
public boolean hasNext() {
return i < vectors.length;
}
@Override
public byte[] next() {
return vectors[i++];
}
@Override
public Iterator<byte[]> copy() {
return from(vectors);
}
@Override
public void reset() {
i = 0;
}
};
}
}

View file

@ -0,0 +1,342 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.script;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
import org.elasticsearch.index.mapper.vectors.MultiDenseVectorFieldMapper;
import org.elasticsearch.index.mapper.vectors.MultiDenseVectorScriptDocValuesTests;
import org.elasticsearch.script.MultiVectorScoreScriptUtils.MaxSimDotProduct;
import org.elasticsearch.script.MultiVectorScoreScriptUtils.MaxSimInvHamming;
import org.elasticsearch.script.field.vectors.BitMultiDenseVectorDocValuesField;
import org.elasticsearch.script.field.vectors.ByteMultiDenseVectorDocValuesField;
import org.elasticsearch.script.field.vectors.FloatMultiDenseVectorDocValuesField;
import org.elasticsearch.script.field.vectors.MultiDenseVectorDocValuesField;
import org.elasticsearch.test.ESTestCase;
import org.junit.BeforeClass;
import java.io.IOException;
import java.util.Arrays;
import java.util.HexFormat;
import java.util.List;
import static org.hamcrest.Matchers.containsString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class MultiVectorScoreScriptUtilsTests extends ESTestCase {
@BeforeClass
public static void setup() {
assumeTrue("Requires multi-dense vector support", MultiDenseVectorFieldMapper.FEATURE_FLAG.isEnabled());
}
public void testFloatMultiVectorClassBindings() throws IOException {
String fieldName = "vector";
int dims = 5;
float[][][] docVectors = new float[][][] {
{ { 230.0f, 300.33f, -34.8988f, 15.555f, -200.0f }, { 100.0f, 200.0f, -50.0f, 10.0f, -150.0f } } };
float[][] docMagnitudes = new float[][] { { 0.0f, 0.0f } };
for (int i = 0; i < docVectors.length; i++) {
for (int j = 0; j < docVectors[i].length; j++) {
docMagnitudes[i][j] = (float) Math.sqrt(VectorUtil.dotProduct(docVectors[i][j], docVectors[i][j]));
}
}
List<List<Number>> queryVector = List.of(Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f));
List<List<Number>> invalidQueryVector = List.of(Arrays.asList(0.5, 111.3));
List<MultiDenseVectorDocValuesField> fields = List.of(
new FloatMultiDenseVectorDocValuesField(
MultiDenseVectorScriptDocValuesTests.wrap(docVectors, ElementType.FLOAT),
MultiDenseVectorScriptDocValuesTests.wrap(docMagnitudes),
"test",
ElementType.FLOAT,
dims
),
new FloatMultiDenseVectorDocValuesField(
MultiDenseVectorScriptDocValuesTests.wrap(docVectors, ElementType.FLOAT),
MultiDenseVectorScriptDocValuesTests.wrap(docMagnitudes),
"test",
ElementType.FLOAT,
dims
)
);
for (MultiDenseVectorDocValuesField field : fields) {
field.setNextDocId(0);
ScoreScript scoreScript = mock(ScoreScript.class);
when(scoreScript.field("vector")).thenAnswer(mock -> field);
// Test max similarity dot product
MaxSimDotProduct maxSimDotProduct = new MaxSimDotProduct(scoreScript, queryVector, fieldName);
float maxSimDotProductExpected = 65425.625f; // Adjust this value based on expected max similarity
assertEquals(
"maxSimDotProduct result is not equal to the expected value!",
maxSimDotProductExpected,
maxSimDotProduct.maxSimDotProduct(),
0.001
);
// Check each function rejects query vectors with the wrong dimension
IllegalArgumentException e = expectThrows(
IllegalArgumentException.class,
() -> new MultiVectorScoreScriptUtils.MaxSimDotProduct(scoreScript, invalidQueryVector, fieldName)
);
assertThat(
e.getMessage(),
containsString("query vector has a different number of dimensions [2] than the document vectors [5]")
);
e = expectThrows(IllegalArgumentException.class, () -> new MaxSimInvHamming(scoreScript, invalidQueryVector, fieldName));
assertThat(e.getMessage(), containsString("hamming distance is only supported for byte or bit vectors"));
// Check scripting infrastructure integration
assertEquals(65425.6249, new MaxSimDotProduct(scoreScript, queryVector, fieldName).maxSimDotProduct(), 0.001);
when(scoreScript._getDocId()).thenReturn(1);
e = expectThrows(
IllegalArgumentException.class,
() -> new MaxSimDotProduct(scoreScript, queryVector, fieldName).maxSimDotProduct()
);
assertEquals("A document doesn't have a value for a multi-vector field!", e.getMessage());
}
}
public void testByteMultiVectorClassBindings() throws IOException {
String fieldName = "vector";
int dims = 5;
float[][] docVector = new float[][] { { 1, 127, -128, 5, -10 } };
float[][] magnitudes = new float[][] { { 0 } };
for (int i = 0; i < docVector.length; i++) {
magnitudes[i][0] = (float) Math.sqrt(VectorUtil.dotProduct(docVector[i], docVector[i]));
}
List<List<Number>> queryVector = List.of(Arrays.asList((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4));
List<List<Number>> invalidQueryVector = List.of(Arrays.asList((byte) 1, (byte) 1));
List<String> hexidecimalString = List.of(HexFormat.of().formatHex(new byte[] { 1, 125, -12, 2, 4 }));
List<MultiDenseVectorDocValuesField> fields = List.of(
new ByteMultiDenseVectorDocValuesField(
MultiDenseVectorScriptDocValuesTests.wrap(new float[][][] { docVector }, ElementType.BYTE),
MultiDenseVectorScriptDocValuesTests.wrap(magnitudes),
"test",
ElementType.BYTE,
dims
)
);
for (MultiDenseVectorDocValuesField field : fields) {
field.setNextDocId(0);
ScoreScript scoreScript = mock(ScoreScript.class);
when(scoreScript.field(fieldName)).thenAnswer(mock -> field);
// Check each function rejects query vectors with the wrong dimension
IllegalArgumentException e = expectThrows(
IllegalArgumentException.class,
() -> new MaxSimDotProduct(scoreScript, invalidQueryVector, fieldName)
);
assertThat(
e.getMessage(),
containsString("query vector has a different number of dimensions [2] than the document vectors [5]")
);
e = expectThrows(IllegalArgumentException.class, () -> new MaxSimInvHamming(scoreScript, invalidQueryVector, fieldName));
assertThat(
e.getMessage(),
containsString("query vector has a different number of dimensions [2] than the document vectors [5]")
);
// Check scripting infrastructure integration
assertEquals(17382.0, new MaxSimDotProduct(scoreScript, queryVector, fieldName).maxSimDotProduct(), 0.001);
assertEquals(17382.0, new MaxSimDotProduct(scoreScript, hexidecimalString, fieldName).maxSimDotProduct(), 0.001);
assertEquals(0.675, new MaxSimInvHamming(scoreScript, queryVector, fieldName).maxSimInvHamming(), 0.001);
assertEquals(0.675, new MaxSimInvHamming(scoreScript, hexidecimalString, fieldName).maxSimInvHamming(), 0.001);
MaxSimDotProduct maxSimDotProduct = new MaxSimDotProduct(scoreScript, queryVector, fieldName);
when(scoreScript._getDocId()).thenReturn(1);
e = expectThrows(IllegalArgumentException.class, maxSimDotProduct::maxSimDotProduct);
assertEquals("A document doesn't have a value for a multi-vector field!", e.getMessage());
}
}
public void testBitMultiVectorClassBindingsDotProduct() throws IOException {
String fieldName = "vector";
int dims = 8;
float[][] docVector = new float[][] { { 124 } };
// 124 in binary is b01111100
List<List<Number>> queryVector = List.of(
Arrays.asList((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4, (byte) 1, (byte) 125, (byte) -12)
);
List<List<Number>> floatQueryVector = List.of(Arrays.asList(1.4f, -1.4f, 0.42f, 0.0f, 1f, -1f, -0.42f, 1.2f));
List<List<Number>> invalidQueryVector = List.of(Arrays.asList((byte) 1, (byte) 1));
List<String> hexidecimalString = List.of(HexFormat.of().formatHex(new byte[] { 124 }));
List<MultiDenseVectorDocValuesField> fields = List.of(
new BitMultiDenseVectorDocValuesField(
MultiDenseVectorScriptDocValuesTests.wrap(new float[][][] { docVector }, ElementType.BIT),
MultiDenseVectorScriptDocValuesTests.wrap(new float[][] { { 5 } }),
"test",
ElementType.BIT,
dims
)
);
for (MultiDenseVectorDocValuesField field : fields) {
field.setNextDocId(0);
ScoreScript scoreScript = mock(ScoreScript.class);
when(scoreScript.field(fieldName)).thenAnswer(mock -> field);
MaxSimDotProduct function = new MaxSimDotProduct(scoreScript, queryVector, fieldName);
assertEquals(
"maxSimDotProduct result is not equal to the expected value!",
-12 + 2 + 4 + 1 + 125,
function.maxSimDotProduct(),
0.001
);
function = new MaxSimDotProduct(scoreScript, floatQueryVector, fieldName);
assertEquals(
"maxSimDotProduct result is not equal to the expected value!",
0.42f + 0f + 1f - 1f - 0.42f,
function.maxSimDotProduct(),
0.001
);
function = new MaxSimDotProduct(scoreScript, hexidecimalString, fieldName);
assertEquals(
"maxSimDotProduct result is not equal to the expected value!",
Integer.bitCount(124),
function.maxSimDotProduct(),
0.0
);
// Check each function rejects query vectors with the wrong dimension
IllegalArgumentException e = expectThrows(
IllegalArgumentException.class,
() -> new MaxSimDotProduct(scoreScript, invalidQueryVector, fieldName)
);
assertThat(
e.getMessage(),
containsString(
"query vector contains inner vectors which have incorrect number of dimensions. "
+ "Must be [1] for bitwise operations, or [8] for byte wise operations: provided [2]."
)
);
}
}
public void testByteVsFloatSimilarity() throws IOException {
int dims = 5;
float[][] docVector = new float[][] { { 1f, 127f, -128f, 5f, -10f } };
float[][] magnitudes = new float[][] { { 0 } };
for (int i = 0; i < docVector.length; i++) {
magnitudes[i][0] = (float) Math.sqrt(VectorUtil.dotProduct(docVector[i], docVector[i]));
}
List<List<Number>> listFloatVector = List.of(Arrays.asList(1f, 125f, -12f, 2f, 4f));
List<List<Number>> listByteVector = List.of(Arrays.asList((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4));
float[][] floatVector = new float[][] { { 1f, 125f, -12f, 2f, 4f } };
byte[][] byteVector = new byte[][] { { (byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4 } };
List<MultiDenseVectorDocValuesField> fields = List.of(
new FloatMultiDenseVectorDocValuesField(
MultiDenseVectorScriptDocValuesTests.wrap(new float[][][] { docVector }, ElementType.FLOAT),
MultiDenseVectorScriptDocValuesTests.wrap(magnitudes),
"field1",
ElementType.FLOAT,
dims
),
new ByteMultiDenseVectorDocValuesField(
MultiDenseVectorScriptDocValuesTests.wrap(new float[][][] { docVector }, ElementType.BYTE),
MultiDenseVectorScriptDocValuesTests.wrap(magnitudes),
"field3",
ElementType.BYTE,
dims
)
);
for (MultiDenseVectorDocValuesField field : fields) {
field.setNextDocId(0);
ScoreScript scoreScript = mock(ScoreScript.class);
when(scoreScript.field("vector")).thenAnswer(mock -> field);
int dotProductExpected = 17382;
MaxSimDotProduct maxSimDotProduct = new MaxSimDotProduct(scoreScript, listFloatVector, "vector");
assertEquals(field.getName(), dotProductExpected, maxSimDotProduct.maxSimDotProduct(), 0.001);
maxSimDotProduct = new MaxSimDotProduct(scoreScript, listByteVector, "vector");
assertEquals(field.getName(), dotProductExpected, maxSimDotProduct.maxSimDotProduct(), 0.001);
switch (field.getElementType()) {
case BYTE -> {
assertEquals(field.getName(), dotProductExpected, field.get().maxSimDotProduct(byteVector), 0.001);
UnsupportedOperationException e = expectThrows(
UnsupportedOperationException.class,
() -> field.get().maxSimDotProduct(floatVector)
);
assertThat(e.getMessage(), containsString("use [float maxSimDotProduct(byte[][] queryVector)] instead"));
}
case FLOAT -> {
assertEquals(field.getName(), dotProductExpected, field.get().maxSimDotProduct(floatVector), 0.001);
UnsupportedOperationException e = expectThrows(
UnsupportedOperationException.class,
() -> field.get().maxSimDotProduct(byteVector)
);
assertThat(e.getMessage(), containsString("use [float maxSimDotProduct(float[][] queryVector)] instead"));
}
}
}
}
public void testByteBoundaries() throws IOException {
String fieldName = "vector";
int dims = 1;
float[] docVector = new float[] { 0 };
List<List<Number>> greaterThanVector = List.of(List.of(128));
List<List<Number>> lessThanVector = List.of(List.of(-129));
List<List<Number>> decimalVector = List.of(List.of(0.5));
List<MultiDenseVectorDocValuesField> fields = List.of(
new ByteMultiDenseVectorDocValuesField(
MultiDenseVectorScriptDocValuesTests.wrap(new float[][][] { { docVector } }, ElementType.BYTE),
MultiDenseVectorScriptDocValuesTests.wrap(new float[][] { { 1 } }),
"test",
ElementType.BYTE,
dims
)
);
for (MultiDenseVectorDocValuesField field : fields) {
field.setNextDocId(0);
ScoreScript scoreScript = mock(ScoreScript.class);
when(scoreScript.field(fieldName)).thenAnswer(mock -> field);
IllegalArgumentException e;
e = expectThrows(IllegalArgumentException.class, () -> new MaxSimDotProduct(scoreScript, greaterThanVector, fieldName));
assertEquals(
"element_type [byte] vectors only support integers between [-128, 127] but found [128.0] at dim [0]; "
+ "Preview of invalid vector: [128.0]",
e.getMessage()
);
e = expectThrows(IllegalArgumentException.class, () -> new MaxSimDotProduct(scoreScript, lessThanVector, fieldName));
assertEquals(
e.getMessage(),
"element_type [byte] vectors only support integers between [-128, 127] but found [-129.0] at dim [0]; "
+ "Preview of invalid vector: [-129.0]"
);
e = expectThrows(IllegalArgumentException.class, () -> new MaxSimDotProduct(scoreScript, decimalVector, fieldName));
assertEquals(
e.getMessage(),
"element_type [byte] vectors only support non-decimal values but found decimal value [0.5] at dim [0]; "
+ "Preview of invalid vector: [0.5]"
);
}
}
public void testDimMismatch() throws IOException {
}
}

View file

@ -0,0 +1,83 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.script.field.vectors;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.index.mapper.vectors.MultiDenseVectorFieldMapper;
import org.elasticsearch.test.ESTestCase;
import org.junit.BeforeClass;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.function.IntFunction;
public class MultiDenseVectorTests extends ESTestCase {
@BeforeClass
public static void setup() {
assumeTrue("Requires multi-dense vector support", MultiDenseVectorFieldMapper.FEATURE_FLAG.isEnabled());
}
public void testByteUnsupported() {
int count = randomIntBetween(1, 16);
int dims = randomIntBetween(1, 16);
byte[][] docVector = new byte[count][dims];
float[][] queryVector = new float[count][dims];
for (int i = 0; i < docVector.length; i++) {
random().nextBytes(docVector[i]);
for (int j = 0; j < dims; j++) {
queryVector[i][j] = randomFloat();
}
}
MultiDenseVector knn = newByteVector(docVector);
UnsupportedOperationException e;
e = expectThrows(UnsupportedOperationException.class, () -> knn.maxSimDotProduct(queryVector));
assertEquals(e.getMessage(), "use [float maxSimDotProduct(byte[][] queryVector)] instead");
}
public void testFloatUnsupported() {
int count = randomIntBetween(1, 16);
int dims = randomIntBetween(1, 16);
float[][] docVector = new float[count][dims];
byte[][] queryVector = new byte[count][dims];
for (int i = 0; i < docVector.length; i++) {
random().nextBytes(queryVector[i]);
for (int j = 0; j < dims; j++) {
docVector[i][j] = randomFloat();
}
}
MultiDenseVector knn = newFloatVector(docVector);
UnsupportedOperationException e = expectThrows(UnsupportedOperationException.class, () -> knn.maxSimDotProduct(queryVector));
assertEquals(e.getMessage(), "use [float maxSimDotProduct(float[][] queryVector)] instead");
}
static MultiDenseVector newFloatVector(float[][] vector) {
BytesRef magnitudes = magnitudes(vector.length, i -> (float) Math.sqrt(VectorUtil.dotProduct(vector[i], vector[i])));
return new FloatMultiDenseVector(VectorIterator.from(vector), magnitudes, vector.length, vector[0].length);
}
static MultiDenseVector newByteVector(byte[][] vector) {
BytesRef magnitudes = magnitudes(vector.length, i -> (float) Math.sqrt(VectorUtil.dotProduct(vector[i], vector[i])));
return new ByteMultiDenseVector(VectorIterator.from(vector), magnitudes, vector.length, vector[0].length);
}
static BytesRef magnitudes(int count, IntFunction<Float> magnitude) {
ByteBuffer magnitudeBuffer = ByteBuffer.allocate(count * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
for (int i = 0; i < count; i++) {
magnitudeBuffer.putFloat(magnitude.apply(i));
}
return new BytesRef(magnitudeBuffer.array());
}
}