mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 17:34:17 -04:00
Adds maxSim
functions for multi_dense_vector fields (#116993)
This adds `maxSim` functions, specifically dotProduct and InvHamming. Why these two you might ask? Well, they are the best approximations of whats possible with Col* late interaction type models. Effectively, you want a similarity metric where "greater == better". Regular `hamming` isn't exactly that, but inverting that (just like our `element_type: bit` index for dense_vectors), is a nice approximation with bit vectors and multi-vector scoring. Then, of course, dotProduct is another usage. We will allow dot-product between like elements (bytes -> bytes, floats -> floats) and of course, allow `floats -> bit`, where the stored `bit` elements are applied as a "mask" over the float queries. This allows for some nice asymmetric interactions. This is all behind a feature flag, and I need to write a mountain of docs in a separate PR.
This commit is contained in:
parent
abcdbf27b2
commit
e68f31754c
13 changed files with 1274 additions and 16 deletions
|
@ -50,5 +50,7 @@ static_import {
|
|||
double cosineSimilarity(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$CosineSimilarity
|
||||
double dotProduct(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$DotProduct
|
||||
double hamming(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$Hamming
|
||||
double maxSimDotProduct(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.MultiVectorScoreScriptUtils$MaxSimDotProduct
|
||||
double maxSimInvHamming(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.MultiVectorScoreScriptUtils$MaxSimInvHamming
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,206 @@
|
|||
setup:
|
||||
- requires:
|
||||
capabilities:
|
||||
- method: POST
|
||||
path: /_search
|
||||
capabilities: [ multi_dense_vector_script_max_sim ]
|
||||
test_runner_features: capabilities
|
||||
reason: "Support for multi dense vector max-sim functions capability required"
|
||||
- skip:
|
||||
features: headers
|
||||
|
||||
- do:
|
||||
indices.create:
|
||||
index: test-index
|
||||
body:
|
||||
settings:
|
||||
number_of_shards: 1
|
||||
mappings:
|
||||
properties:
|
||||
vector:
|
||||
type: multi_dense_vector
|
||||
dims: 5
|
||||
byte_vector:
|
||||
type: multi_dense_vector
|
||||
dims: 5
|
||||
element_type: byte
|
||||
bit_vector:
|
||||
type: multi_dense_vector
|
||||
dims: 40
|
||||
element_type: bit
|
||||
- do:
|
||||
index:
|
||||
index: test-index
|
||||
id: "1"
|
||||
body:
|
||||
vector: [[230.0, 300.33, -34.8988, 15.555, -200.0], [-0.5, 100.0, -13, 14.8, -156.0]]
|
||||
byte_vector: [[8, 5, -15, 1, -7], [-1, 115, -3, 4, -128]]
|
||||
bit_vector: [[8, 5, -15, 1, -7], [-1, 115, -3, 4, -128]]
|
||||
|
||||
- do:
|
||||
index:
|
||||
index: test-index
|
||||
id: "3"
|
||||
body:
|
||||
vector: [[0.5, 111.3, -13.0, 14.8, -156.0]]
|
||||
byte_vector: [[2, 18, -5, 0, -124]]
|
||||
bit_vector: [[2, 18, -5, 0, -124]]
|
||||
|
||||
- do:
|
||||
indices.refresh: {}
|
||||
---
|
||||
"Test max-sim dot product scoring":
|
||||
- skip:
|
||||
features: close_to
|
||||
|
||||
- do:
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "maxSimDotProduct(params.query_vector, 'vector')"
|
||||
params:
|
||||
query_vector: [[1, 2, 1, 1, 1]]
|
||||
|
||||
- match: {hits.total: 2}
|
||||
|
||||
- match: {hits.hits.0._id: "1"}
|
||||
- close_to: {hits.hits.0._score: {value: 611.316, error: 0.01}}
|
||||
|
||||
- match: {hits.hits.1._id: "3"}
|
||||
- close_to: {hits.hits.1._score: {value: 68.90001, error: 0.01}}
|
||||
|
||||
- do:
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "maxSimDotProduct(params.query_vector, 'byte_vector')"
|
||||
params:
|
||||
query_vector: [[1, 2, 1, 1, 0]]
|
||||
|
||||
- match: {hits.total: 2}
|
||||
|
||||
- match: {hits.hits.0._id: "1"}
|
||||
- close_to: {hits.hits.0._score: {value: 230, error: 0.01}}
|
||||
|
||||
- match: {hits.hits.1._id: "3"}
|
||||
- close_to: {hits.hits.1._score: {value: 33, error: 0.01}}
|
||||
|
||||
- do:
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "maxSimDotProduct(params.query_vector, 'bit_vector')"
|
||||
params:
|
||||
query_vector: [[1, 2, 1, 1, 0]]
|
||||
|
||||
- match: {hits.total: 2}
|
||||
|
||||
- match: {hits.hits.0._id: "1"}
|
||||
- close_to: {hits.hits.0._score: {value: 3, error: 0.01}}
|
||||
|
||||
- match: {hits.hits.1._id: "3"}
|
||||
- close_to: {hits.hits.1._score: {value: 2, error: 0.01}}
|
||||
|
||||
# doing max-sim dot product with a vector where the stored bit vectors are used as masks
|
||||
- do:
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "maxSimDotProduct(params.query_vector, 'bit_vector')"
|
||||
params:
|
||||
query_vector: [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]]
|
||||
- match: {hits.total: 2}
|
||||
|
||||
- match: {hits.hits.0._id: "1"}
|
||||
- close_to: {hits.hits.0._score: {value: 190, error: 0.01}}
|
||||
|
||||
- match: {hits.hits.1._id: "3"}
|
||||
- close_to: {hits.hits.1._score: {value: 125, error: 0.01}}
|
||||
---
|
||||
"Test max-sim inv hamming scoring":
|
||||
- skip:
|
||||
features: close_to
|
||||
|
||||
# inv hamming doesn't apply to float vectors
|
||||
- do:
|
||||
catch: bad_request
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "maxSimInvHamming(params.query_vector, 'vector')"
|
||||
params:
|
||||
query_vector: [[1, 2, 1, 1, 1]]
|
||||
|
||||
- do:
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "maxSimInvHamming(params.query_vector, 'byte_vector')"
|
||||
params:
|
||||
query_vector: [[1, 2, 1, 1, 1]]
|
||||
|
||||
- match: {hits.total: 2}
|
||||
|
||||
- match: {hits.hits.0._id: "3"}
|
||||
- close_to: {hits.hits.0._score: {value: 0.675, error: 0.01}}
|
||||
|
||||
- match: {hits.hits.1._id: "1"}
|
||||
- close_to: {hits.hits.1._score: {value: 0.65, error: 0.01}}
|
||||
|
||||
- do:
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "maxSimInvHamming(params.query_vector, 'bit_vector')"
|
||||
params:
|
||||
query_vector: [[1, 2, 1, 1, 1]]
|
||||
|
||||
- match: {hits.total: 2}
|
||||
|
||||
- match: {hits.hits.0._id: "3"}
|
||||
- close_to: {hits.hits.0._score: {value: 0.675, error: 0.01}}
|
||||
|
||||
- match: {hits.hits.1._id: "1"}
|
||||
- close_to: {hits.hits.1._score: {value: 0.65, error: 0.01}}
|
|
@ -40,6 +40,8 @@ public final class SearchCapabilities {
|
|||
private static final String NESTED_RETRIEVER_INNER_HITS_SUPPORT = "nested_retriever_inner_hits_support";
|
||||
/** Support multi-dense-vector script field access. */
|
||||
private static final String MULTI_DENSE_VECTOR_SCRIPT_ACCESS = "multi_dense_vector_script_access";
|
||||
/** Initial support for multi-dense-vector maxSim functions access. */
|
||||
private static final String MULTI_DENSE_VECTOR_SCRIPT_MAX_SIM = "multi_dense_vector_script_max_sim";
|
||||
|
||||
private static final String RANDOM_SAMPLER_WITH_SCORED_SUBAGGS = "random_sampler_with_scored_subaggs";
|
||||
|
||||
|
@ -56,6 +58,7 @@ public final class SearchCapabilities {
|
|||
if (MultiDenseVectorFieldMapper.FEATURE_FLAG.isEnabled()) {
|
||||
capabilities.add(MULTI_DENSE_VECTOR_FIELD_MAPPER);
|
||||
capabilities.add(MULTI_DENSE_VECTOR_SCRIPT_ACCESS);
|
||||
capabilities.add(MULTI_DENSE_VECTOR_SCRIPT_MAX_SIM);
|
||||
}
|
||||
if (Build.current().isSnapshot()) {
|
||||
capabilities.add(KQL_QUERY_SUPPORTED);
|
||||
|
|
|
@ -0,0 +1,372 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.script;
|
||||
|
||||
import org.elasticsearch.ExceptionsHelper;
|
||||
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
|
||||
import org.elasticsearch.script.field.vectors.DenseVector;
|
||||
import org.elasticsearch.script.field.vectors.MultiDenseVectorDocValuesField;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HexFormat;
|
||||
import java.util.List;
|
||||
|
||||
public class MultiVectorScoreScriptUtils {
|
||||
|
||||
public static class MultiDenseVectorFunction {
|
||||
protected final ScoreScript scoreScript;
|
||||
protected final MultiDenseVectorDocValuesField field;
|
||||
|
||||
public MultiDenseVectorFunction(ScoreScript scoreScript, MultiDenseVectorDocValuesField field) {
|
||||
this.scoreScript = scoreScript;
|
||||
this.field = field;
|
||||
}
|
||||
|
||||
void setNextVector() {
|
||||
try {
|
||||
field.setNextDocId(scoreScript._getDocId());
|
||||
} catch (IOException e) {
|
||||
throw ExceptionsHelper.convertToElastic(e);
|
||||
}
|
||||
if (field.isEmpty()) {
|
||||
throw new IllegalArgumentException("A document doesn't have a value for a multi-vector field!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static class ByteMultiDenseVectorFunction extends MultiDenseVectorFunction {
|
||||
protected final byte[][] queryVector;
|
||||
|
||||
/**
|
||||
* Constructs a dense vector function used for byte-sized vectors.
|
||||
*
|
||||
* @param scoreScript The script in which this function was referenced.
|
||||
* @param field The vector field.
|
||||
* @param queryVector The query vector.
|
||||
*/
|
||||
public ByteMultiDenseVectorFunction(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, List<List<Number>> queryVector) {
|
||||
super(scoreScript, field);
|
||||
if (queryVector.isEmpty()) {
|
||||
throw new IllegalArgumentException("The query vector is empty.");
|
||||
}
|
||||
field.getElementType().checkDimensions(field.get().getDims(), queryVector.get(0).size());
|
||||
this.queryVector = new byte[queryVector.size()][queryVector.get(0).size()];
|
||||
float[] validateValues = new float[queryVector.size()];
|
||||
int lastSize = -1;
|
||||
for (int i = 0; i < queryVector.size(); i++) {
|
||||
if (lastSize != -1 && lastSize != queryVector.get(i).size()) {
|
||||
throw new IllegalArgumentException(
|
||||
"The query vector contains inner vectors which have inconsistent number of dimensions."
|
||||
);
|
||||
}
|
||||
lastSize = queryVector.get(i).size();
|
||||
for (int j = 0; j < queryVector.get(i).size(); j++) {
|
||||
final Number number = queryVector.get(i).get(j);
|
||||
byte value = number.byteValue();
|
||||
this.queryVector[i][j] = value;
|
||||
validateValues[i] = number.floatValue();
|
||||
}
|
||||
field.getElementType().checkVectorBounds(validateValues);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a dense vector function used for byte-sized vectors.
|
||||
*
|
||||
* @param scoreScript The script in which this function was referenced.
|
||||
* @param field The vector field.
|
||||
* @param queryVector The query vector.
|
||||
*/
|
||||
public ByteMultiDenseVectorFunction(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, byte[][] queryVector) {
|
||||
super(scoreScript, field);
|
||||
this.queryVector = queryVector;
|
||||
}
|
||||
}
|
||||
|
||||
public static class FloatMultiDenseVectorFunction extends MultiDenseVectorFunction {
|
||||
protected final float[][] queryVector;
|
||||
|
||||
/**
|
||||
* Constructs a dense vector function used for float vectors.
|
||||
*
|
||||
* @param scoreScript The script in which this function was referenced.
|
||||
* @param field The vector field.
|
||||
* @param queryVector The query vector.
|
||||
*/
|
||||
public FloatMultiDenseVectorFunction(
|
||||
ScoreScript scoreScript,
|
||||
MultiDenseVectorDocValuesField field,
|
||||
List<List<Number>> queryVector
|
||||
) {
|
||||
super(scoreScript, field);
|
||||
if (queryVector.isEmpty()) {
|
||||
throw new IllegalArgumentException("The query vector is empty.");
|
||||
}
|
||||
DenseVector.checkDimensions(field.get().getDims(), queryVector.get(0).size());
|
||||
|
||||
this.queryVector = new float[queryVector.size()][queryVector.get(0).size()];
|
||||
int lastSize = -1;
|
||||
for (int i = 0; i < queryVector.size(); i++) {
|
||||
if (lastSize != -1 && lastSize != queryVector.get(i).size()) {
|
||||
throw new IllegalArgumentException(
|
||||
"The query vector contains inner vectors which have inconsistent number of dimensions."
|
||||
);
|
||||
}
|
||||
lastSize = queryVector.get(i).size();
|
||||
for (int j = 0; j < queryVector.get(i).size(); j++) {
|
||||
this.queryVector[i][j] = queryVector.get(i).get(j).floatValue();
|
||||
}
|
||||
field.getElementType().checkVectorBounds(this.queryVector[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate Hamming distances between a query's dense vector and documents' dense vectors
|
||||
public interface MaxSimInvHammingDistanceInterface {
|
||||
float maxSimInvHamming();
|
||||
}
|
||||
|
||||
public static class ByteMaxSimInvHammingDistance extends ByteMultiDenseVectorFunction implements MaxSimInvHammingDistanceInterface {
|
||||
|
||||
public ByteMaxSimInvHammingDistance(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, List<List<Number>> queryVector) {
|
||||
super(scoreScript, field, queryVector);
|
||||
}
|
||||
|
||||
public ByteMaxSimInvHammingDistance(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, byte[][] queryVector) {
|
||||
super(scoreScript, field, queryVector);
|
||||
}
|
||||
|
||||
public float maxSimInvHamming() {
|
||||
setNextVector();
|
||||
return field.get().maxSimInvHamming(queryVector);
|
||||
}
|
||||
}
|
||||
|
||||
private record BytesOrList(byte[][] bytes, List<List<Number>> list) {}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static BytesOrList parseBytes(Object queryVector) {
|
||||
if (queryVector instanceof List) {
|
||||
// check if its a list of strings or list of lists
|
||||
if (((List<?>) queryVector).get(0) instanceof List) {
|
||||
return new BytesOrList(null, ((List<List<Number>>) queryVector));
|
||||
} else if (((List<?>) queryVector).get(0) instanceof String) {
|
||||
byte[][] parsedQueryVector = new byte[((List<?>) queryVector).size()][];
|
||||
int lastSize = -1;
|
||||
for (int i = 0; i < ((List<?>) queryVector).size(); i++) {
|
||||
parsedQueryVector[i] = HexFormat.of().parseHex((String) ((List<?>) queryVector).get(i));
|
||||
if (lastSize != -1 && lastSize != parsedQueryVector[i].length) {
|
||||
throw new IllegalArgumentException(
|
||||
"The query vector contains inner vectors which have inconsistent number of dimensions."
|
||||
);
|
||||
}
|
||||
lastSize = parsedQueryVector[i].length;
|
||||
}
|
||||
return new BytesOrList(parsedQueryVector, null);
|
||||
} else {
|
||||
throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName());
|
||||
}
|
||||
} else {
|
||||
throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName());
|
||||
}
|
||||
}
|
||||
|
||||
public static final class MaxSimInvHamming {
|
||||
|
||||
private final MaxSimInvHammingDistanceInterface function;
|
||||
|
||||
public MaxSimInvHamming(ScoreScript scoreScript, Object queryVector, String fieldName) {
|
||||
MultiDenseVectorDocValuesField field = (MultiDenseVectorDocValuesField) scoreScript.field(fieldName);
|
||||
if (field.getElementType() == DenseVectorFieldMapper.ElementType.FLOAT) {
|
||||
throw new IllegalArgumentException("hamming distance is only supported for byte or bit vectors");
|
||||
}
|
||||
BytesOrList bytesOrList = parseBytes(queryVector);
|
||||
if (bytesOrList.bytes != null) {
|
||||
this.function = new ByteMaxSimInvHammingDistance(scoreScript, field, bytesOrList.bytes);
|
||||
} else {
|
||||
this.function = new ByteMaxSimInvHammingDistance(scoreScript, field, bytesOrList.list);
|
||||
}
|
||||
}
|
||||
|
||||
public double maxSimInvHamming() {
|
||||
return function.maxSimInvHamming();
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate a dot product between a query's dense vector and documents' dense vectors
|
||||
public interface MaxSimDotProductInterface {
|
||||
double maxSimDotProduct();
|
||||
}
|
||||
|
||||
public static class MaxSimBitDotProduct extends MultiDenseVectorFunction implements MaxSimDotProductInterface {
|
||||
private final byte[][] byteQueryVector;
|
||||
private final float[][] floatQueryVector;
|
||||
|
||||
public MaxSimBitDotProduct(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, byte[][] queryVector) {
|
||||
super(scoreScript, field);
|
||||
if (field.getElementType() != DenseVectorFieldMapper.ElementType.BIT) {
|
||||
throw new IllegalArgumentException("Cannot calculate bit dot product for non-bit vectors");
|
||||
}
|
||||
int fieldDims = field.get().getDims();
|
||||
if (fieldDims != queryVector.length * Byte.SIZE && fieldDims != queryVector.length) {
|
||||
throw new IllegalArgumentException(
|
||||
"The query vector has an incorrect number of dimensions. Must be ["
|
||||
+ fieldDims / 8
|
||||
+ "] for bitwise operations, or ["
|
||||
+ fieldDims
|
||||
+ "] for byte wise operations: provided ["
|
||||
+ queryVector.length
|
||||
+ "]."
|
||||
);
|
||||
}
|
||||
this.byteQueryVector = queryVector;
|
||||
this.floatQueryVector = null;
|
||||
}
|
||||
|
||||
public MaxSimBitDotProduct(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, List<List<Number>> queryVector) {
|
||||
super(scoreScript, field);
|
||||
if (queryVector.isEmpty()) {
|
||||
throw new IllegalArgumentException("The query vector is empty.");
|
||||
}
|
||||
if (field.getElementType() != DenseVectorFieldMapper.ElementType.BIT) {
|
||||
throw new IllegalArgumentException("cannot calculate bit dot product for non-bit vectors");
|
||||
}
|
||||
float[][] floatQueryVector = new float[queryVector.size()][];
|
||||
byte[][] byteQueryVector = new byte[queryVector.size()][];
|
||||
boolean isFloat = false;
|
||||
int lastSize = -1;
|
||||
for (int i = 0; i < queryVector.size(); i++) {
|
||||
if (lastSize != -1 && lastSize != queryVector.get(i).size()) {
|
||||
throw new IllegalArgumentException(
|
||||
"The query vector contains inner vectors which have inconsistent number of dimensions."
|
||||
);
|
||||
}
|
||||
lastSize = queryVector.get(i).size();
|
||||
floatQueryVector[i] = new float[queryVector.get(i).size()];
|
||||
if (isFloat == false) {
|
||||
byteQueryVector[i] = new byte[queryVector.get(i).size()];
|
||||
}
|
||||
for (int j = 0; j < queryVector.get(i).size(); j++) {
|
||||
Number number = queryVector.get(i).get(j);
|
||||
floatQueryVector[i][j] = number.floatValue();
|
||||
if (isFloat == false) {
|
||||
byteQueryVector[i][j] = number.byteValue();
|
||||
}
|
||||
if (isFloat
|
||||
|| floatQueryVector[i][j] % 1.0f != 0.0f
|
||||
|| floatQueryVector[i][j] < Byte.MIN_VALUE
|
||||
|| floatQueryVector[i][j] > Byte.MAX_VALUE) {
|
||||
isFloat = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
int fieldDims = field.get().getDims();
|
||||
if (isFloat) {
|
||||
this.floatQueryVector = floatQueryVector;
|
||||
this.byteQueryVector = null;
|
||||
if (fieldDims != floatQueryVector[0].length) {
|
||||
throw new IllegalArgumentException(
|
||||
"The query vector contains inner vectors which have incorrect number of dimensions. Must be ["
|
||||
+ fieldDims
|
||||
+ "] for float wise operations: provided ["
|
||||
+ floatQueryVector[0].length
|
||||
+ "]."
|
||||
);
|
||||
}
|
||||
} else {
|
||||
this.floatQueryVector = null;
|
||||
this.byteQueryVector = byteQueryVector;
|
||||
if (fieldDims != byteQueryVector[0].length * Byte.SIZE && fieldDims != byteQueryVector[0].length) {
|
||||
throw new IllegalArgumentException(
|
||||
"The query vector contains inner vectors which have incorrect number of dimensions. Must be ["
|
||||
+ fieldDims / 8
|
||||
+ "] for bitwise operations, or ["
|
||||
+ fieldDims
|
||||
+ "] for byte wise operations: provided ["
|
||||
+ byteQueryVector[0].length
|
||||
+ "]."
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public double maxSimDotProduct() {
|
||||
setNextVector();
|
||||
return byteQueryVector != null ? field.get().maxSimDotProduct(byteQueryVector) : field.get().maxSimDotProduct(floatQueryVector);
|
||||
}
|
||||
}
|
||||
|
||||
public static class MaxSimByteDotProduct extends ByteMultiDenseVectorFunction implements MaxSimDotProductInterface {
|
||||
|
||||
public MaxSimByteDotProduct(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, List<List<Number>> queryVector) {
|
||||
super(scoreScript, field, queryVector);
|
||||
}
|
||||
|
||||
public MaxSimByteDotProduct(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, byte[][] queryVector) {
|
||||
super(scoreScript, field, queryVector);
|
||||
}
|
||||
|
||||
public double maxSimDotProduct() {
|
||||
setNextVector();
|
||||
return field.get().maxSimDotProduct(queryVector);
|
||||
}
|
||||
}
|
||||
|
||||
public static class MaxSimFloatDotProduct extends FloatMultiDenseVectorFunction implements MaxSimDotProductInterface {
|
||||
|
||||
public MaxSimFloatDotProduct(ScoreScript scoreScript, MultiDenseVectorDocValuesField field, List<List<Number>> queryVector) {
|
||||
super(scoreScript, field, queryVector);
|
||||
}
|
||||
|
||||
public double maxSimDotProduct() {
|
||||
setNextVector();
|
||||
return field.get().maxSimDotProduct(queryVector);
|
||||
}
|
||||
}
|
||||
|
||||
public static final class MaxSimDotProduct {
|
||||
|
||||
private final MaxSimDotProductInterface function;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public MaxSimDotProduct(ScoreScript scoreScript, Object queryVector, String fieldName) {
|
||||
MultiDenseVectorDocValuesField field = (MultiDenseVectorDocValuesField) scoreScript.field(fieldName);
|
||||
function = switch (field.getElementType()) {
|
||||
case BIT -> {
|
||||
BytesOrList bytesOrList = parseBytes(queryVector);
|
||||
if (bytesOrList.bytes != null) {
|
||||
yield new MaxSimBitDotProduct(scoreScript, field, bytesOrList.bytes);
|
||||
} else {
|
||||
yield new MaxSimBitDotProduct(scoreScript, field, bytesOrList.list);
|
||||
}
|
||||
}
|
||||
case BYTE -> {
|
||||
BytesOrList bytesOrList = parseBytes(queryVector);
|
||||
if (bytesOrList.bytes != null) {
|
||||
yield new MaxSimByteDotProduct(scoreScript, field, bytesOrList.bytes);
|
||||
} else {
|
||||
yield new MaxSimByteDotProduct(scoreScript, field, bytesOrList.list);
|
||||
}
|
||||
}
|
||||
case FLOAT -> {
|
||||
if (queryVector instanceof List) {
|
||||
yield new MaxSimFloatDotProduct(scoreScript, field, (List<List<Number>>) queryVector);
|
||||
}
|
||||
throw new IllegalArgumentException("Unsupported input object for float vectors: " + queryVector.getClass().getName());
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
public double maxSimDotProduct() {
|
||||
return function.maxSimDotProduct();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -10,11 +10,13 @@
|
|||
package org.elasticsearch.script.field.vectors;
|
||||
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.elasticsearch.simdvec.ESVectorUtil;
|
||||
|
||||
import java.util.Iterator;
|
||||
import java.util.Arrays;
|
||||
|
||||
public class BitMultiDenseVector extends ByteMultiDenseVector {
|
||||
public BitMultiDenseVector(Iterator<byte[]> vectorValues, BytesRef magnitudesBytes, int numVecs, int dims) {
|
||||
public BitMultiDenseVector(VectorIterator<byte[]> vectorValues, BytesRef magnitudesBytes, int numVecs, int dims) {
|
||||
super(vectorValues, magnitudesBytes, numVecs, dims);
|
||||
}
|
||||
|
||||
|
@ -31,6 +33,70 @@ public class BitMultiDenseVector extends ByteMultiDenseVector {
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public float maxSimDotProduct(float[][] query) {
|
||||
vectorValues.reset();
|
||||
float[] maxes = new float[query.length];
|
||||
Arrays.fill(maxes, Float.NEGATIVE_INFINITY);
|
||||
while (vectorValues.hasNext()) {
|
||||
byte[] vv = vectorValues.next();
|
||||
for (int i = 0; i < query.length; i++) {
|
||||
maxes[i] = Math.max(maxes[i], ESVectorUtil.ipFloatBit(query[i], vv));
|
||||
}
|
||||
}
|
||||
float sums = 0;
|
||||
for (float m : maxes) {
|
||||
sums += m;
|
||||
}
|
||||
return sums;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float maxSimDotProduct(byte[][] query) {
|
||||
vectorValues.reset();
|
||||
float[] maxes = new float[query.length];
|
||||
Arrays.fill(maxes, Float.NEGATIVE_INFINITY);
|
||||
if (query[0].length == dims) {
|
||||
while (vectorValues.hasNext()) {
|
||||
byte[] vv = vectorValues.next();
|
||||
for (int i = 0; i < query.length; i++) {
|
||||
maxes[i] = Math.max(maxes[i], ESVectorUtil.andBitCount(query[i], vv));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
while (vectorValues.hasNext()) {
|
||||
byte[] vv = vectorValues.next();
|
||||
for (int i = 0; i < query.length; i++) {
|
||||
maxes[i] = Math.max(maxes[i], ESVectorUtil.ipByteBit(query[i], vv));
|
||||
}
|
||||
}
|
||||
}
|
||||
float sum = 0;
|
||||
for (float m : maxes) {
|
||||
sum += m;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float maxSimInvHamming(byte[][] query) {
|
||||
vectorValues.reset();
|
||||
int bitCount = this.getDims();
|
||||
float[] maxes = new float[query.length];
|
||||
Arrays.fill(maxes, Float.NEGATIVE_INFINITY);
|
||||
while (vectorValues.hasNext()) {
|
||||
byte[] vv = vectorValues.next();
|
||||
for (int i = 0; i < query.length; i++) {
|
||||
maxes[i] = Math.max(maxes[i], ((bitCount - VectorUtil.xorBitCount(vv, query[i])) / (float) bitCount));
|
||||
}
|
||||
}
|
||||
float sum = 0;
|
||||
for (float m : maxes) {
|
||||
sum += m;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getDims() {
|
||||
return dims * Byte.SIZE;
|
||||
|
|
|
@ -10,21 +10,22 @@
|
|||
package org.elasticsearch.script.field.vectors;
|
||||
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Iterator;
|
||||
|
||||
public class ByteMultiDenseVector implements MultiDenseVector {
|
||||
|
||||
protected final Iterator<byte[]> vectorValues;
|
||||
protected final VectorIterator<byte[]> vectorValues;
|
||||
protected final int numVecs;
|
||||
protected final int dims;
|
||||
|
||||
private Iterator<float[]> floatDocVectors;
|
||||
private float[] magnitudes;
|
||||
private final BytesRef magnitudesBytes;
|
||||
|
||||
public ByteMultiDenseVector(Iterator<byte[]> vectorValues, BytesRef magnitudesBytes, int numVecs, int dims) {
|
||||
public ByteMultiDenseVector(VectorIterator<byte[]> vectorValues, BytesRef magnitudesBytes, int numVecs, int dims) {
|
||||
assert magnitudesBytes.length == numVecs * Float.BYTES;
|
||||
this.vectorValues = vectorValues;
|
||||
this.numVecs = numVecs;
|
||||
|
@ -33,11 +34,50 @@ public class ByteMultiDenseVector implements MultiDenseVector {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Iterator<float[]> getVectors() {
|
||||
if (floatDocVectors == null) {
|
||||
floatDocVectors = new ByteToFloatIteratorWrapper(vectorValues, dims);
|
||||
public float maxSimDotProduct(float[][] query) {
|
||||
throw new UnsupportedOperationException("use [float maxSimDotProduct(byte[][] queryVector)] instead");
|
||||
}
|
||||
|
||||
@Override
|
||||
public float maxSimDotProduct(byte[][] query) {
|
||||
vectorValues.reset();
|
||||
float[] maxes = new float[query.length];
|
||||
Arrays.fill(maxes, Float.NEGATIVE_INFINITY);
|
||||
while (vectorValues.hasNext()) {
|
||||
byte[] vv = vectorValues.next();
|
||||
for (int i = 0; i < query.length; i++) {
|
||||
maxes[i] = Math.max(maxes[i], VectorUtil.dotProduct(query[i], vv));
|
||||
}
|
||||
}
|
||||
return floatDocVectors;
|
||||
float sum = 0;
|
||||
for (float m : maxes) {
|
||||
sum += m;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float maxSimInvHamming(byte[][] query) {
|
||||
vectorValues.reset();
|
||||
int bitCount = dims * Byte.SIZE;
|
||||
float[] maxes = new float[query.length];
|
||||
Arrays.fill(maxes, Float.NEGATIVE_INFINITY);
|
||||
while (vectorValues.hasNext()) {
|
||||
byte[] vv = vectorValues.next();
|
||||
for (int i = 0; i < query.length; i++) {
|
||||
maxes[i] = Math.max(maxes[i], ((bitCount - VectorUtil.xorBitCount(vv, query[i])) / (float) bitCount));
|
||||
}
|
||||
}
|
||||
float sum = 0;
|
||||
for (float m : maxes) {
|
||||
sum += m;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<float[]> getVectors() {
|
||||
return new ByteToFloatIteratorWrapper(vectorValues.copy(), dims);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -23,7 +23,7 @@ public class ByteMultiDenseVectorDocValuesField extends MultiDenseVectorDocValue
|
|||
private final BinaryDocValues magnitudes;
|
||||
protected final int dims;
|
||||
protected int numVecs;
|
||||
protected Iterator<byte[]> vectorValue;
|
||||
protected VectorIterator<byte[]> vectorValue;
|
||||
protected boolean decoded;
|
||||
protected BytesRef value;
|
||||
protected BytesRef magnitudesValue;
|
||||
|
@ -111,7 +111,7 @@ public class ByteMultiDenseVectorDocValuesField extends MultiDenseVectorDocValue
|
|||
return value == null;
|
||||
}
|
||||
|
||||
static class ByteVectorIterator implements Iterator<byte[]> {
|
||||
static class ByteVectorIterator implements VectorIterator<byte[]> {
|
||||
private final byte[] buffer;
|
||||
private final BytesRef vectorValues;
|
||||
private final int size;
|
||||
|
@ -138,5 +138,15 @@ public class ByteMultiDenseVectorDocValuesField extends MultiDenseVectorDocValue
|
|||
idx++;
|
||||
return buffer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<byte[]> copy() {
|
||||
return new ByteVectorIterator(vectorValues, new byte[buffer.length], size);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
idx = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,7 +10,9 @@
|
|||
package org.elasticsearch.script.field.vectors;
|
||||
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Iterator;
|
||||
|
||||
import static org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder.getMultiMagnitudes;
|
||||
|
@ -21,19 +23,47 @@ public class FloatMultiDenseVector implements MultiDenseVector {
|
|||
private float[] magnitudesArray = null;
|
||||
private final int dims;
|
||||
private final int numVectors;
|
||||
private final Iterator<float[]> decodedDocVector;
|
||||
private final VectorIterator<float[]> vectorValues;
|
||||
|
||||
public FloatMultiDenseVector(Iterator<float[]> decodedDocVector, BytesRef magnitudes, int numVectors, int dims) {
|
||||
public FloatMultiDenseVector(VectorIterator<float[]> decodedDocVector, BytesRef magnitudes, int numVectors, int dims) {
|
||||
assert magnitudes.length == numVectors * Float.BYTES;
|
||||
this.decodedDocVector = decodedDocVector;
|
||||
this.vectorValues = decodedDocVector;
|
||||
this.magnitudes = magnitudes;
|
||||
this.numVectors = numVectors;
|
||||
this.dims = dims;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float maxSimDotProduct(float[][] query) {
|
||||
vectorValues.reset();
|
||||
float[] maxes = new float[query.length];
|
||||
Arrays.fill(maxes, Float.NEGATIVE_INFINITY);
|
||||
while (vectorValues.hasNext()) {
|
||||
float[] vv = vectorValues.next();
|
||||
for (int i = 0; i < query.length; i++) {
|
||||
maxes[i] = Math.max(maxes[i], VectorUtil.dotProduct(query[i], vv));
|
||||
}
|
||||
}
|
||||
float sum = 0;
|
||||
for (float m : maxes) {
|
||||
sum += m;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float maxSimDotProduct(byte[][] query) {
|
||||
throw new UnsupportedOperationException("use [float maxSimDotProduct(float[][] queryVector)] instead");
|
||||
}
|
||||
|
||||
@Override
|
||||
public float maxSimInvHamming(byte[][] query) {
|
||||
throw new UnsupportedOperationException("hamming distance is not supported for float vectors");
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<float[]> getVectors() {
|
||||
return decodedDocVector;
|
||||
return vectorValues.copy();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -110,14 +110,16 @@ public class FloatMultiDenseVectorDocValuesField extends MultiDenseVectorDocValu
|
|||
}
|
||||
}
|
||||
|
||||
static class FloatVectorIterator implements Iterator<float[]> {
|
||||
static class FloatVectorIterator implements VectorIterator<float[]> {
|
||||
private final float[] buffer;
|
||||
private final FloatBuffer vectorValues;
|
||||
private final BytesRef vectorValueBytesRef;
|
||||
private final int size;
|
||||
private int idx = 0;
|
||||
|
||||
FloatVectorIterator(BytesRef vectorValues, float[] buffer, int size) {
|
||||
assert vectorValues.length == (buffer.length * Float.BYTES * size);
|
||||
this.vectorValueBytesRef = vectorValues;
|
||||
this.vectorValues = ByteBuffer.wrap(vectorValues.bytes, vectorValues.offset, vectorValues.length)
|
||||
.order(ByteOrder.LITTLE_ENDIAN)
|
||||
.asFloatBuffer();
|
||||
|
@ -139,5 +141,16 @@ public class FloatMultiDenseVectorDocValuesField extends MultiDenseVectorDocValu
|
|||
idx++;
|
||||
return buffer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<float[]> copy() {
|
||||
return new FloatVectorIterator(vectorValueBytesRef, new float[buffer.length], size);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
idx = 0;
|
||||
vectorValues.rewind();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,6 +17,12 @@ public interface MultiDenseVector {
|
|||
checkDimensions(getDims(), qvDims);
|
||||
}
|
||||
|
||||
float maxSimDotProduct(float[][] query);
|
||||
|
||||
float maxSimDotProduct(byte[][] query);
|
||||
|
||||
float maxSimInvHamming(byte[][] query);
|
||||
|
||||
Iterator<float[]> getVectors();
|
||||
|
||||
float[] getMagnitudes();
|
||||
|
@ -63,6 +69,21 @@ public interface MultiDenseVector {
|
|||
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float maxSimDotProduct(float[][] query) {
|
||||
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float maxSimDotProduct(byte[][] query) {
|
||||
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float maxSimInvHamming(byte[][] query) {
|
||||
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return 0;
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.script.field.vectors;
|
||||
|
||||
import java.util.Iterator;
|
||||
|
||||
public interface VectorIterator<E> extends Iterator<E> {
|
||||
Iterator<E> copy();
|
||||
|
||||
void reset();
|
||||
|
||||
static VectorIterator<float[]> from(float[][] vectors) {
|
||||
return new VectorIterator<>() {
|
||||
private int i = 0;
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
return i < vectors.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] next() {
|
||||
return vectors[i++];
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<float[]> copy() {
|
||||
return from(vectors);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
i = 0;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
static VectorIterator<byte[]> from(byte[][] vectors) {
|
||||
return new VectorIterator<>() {
|
||||
private int i = 0;
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
return i < vectors.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] next() {
|
||||
return vectors[i++];
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<byte[]> copy() {
|
||||
return from(vectors);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
i = 0;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
|
@ -0,0 +1,342 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.script;
|
||||
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
|
||||
import org.elasticsearch.index.mapper.vectors.MultiDenseVectorFieldMapper;
|
||||
import org.elasticsearch.index.mapper.vectors.MultiDenseVectorScriptDocValuesTests;
|
||||
import org.elasticsearch.script.MultiVectorScoreScriptUtils.MaxSimDotProduct;
|
||||
import org.elasticsearch.script.MultiVectorScoreScriptUtils.MaxSimInvHamming;
|
||||
import org.elasticsearch.script.field.vectors.BitMultiDenseVectorDocValuesField;
|
||||
import org.elasticsearch.script.field.vectors.ByteMultiDenseVectorDocValuesField;
|
||||
import org.elasticsearch.script.field.vectors.FloatMultiDenseVectorDocValuesField;
|
||||
import org.elasticsearch.script.field.vectors.MultiDenseVectorDocValuesField;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.junit.BeforeClass;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.HexFormat;
|
||||
import java.util.List;
|
||||
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class MultiVectorScoreScriptUtilsTests extends ESTestCase {
|
||||
|
||||
@BeforeClass
|
||||
public static void setup() {
|
||||
assumeTrue("Requires multi-dense vector support", MultiDenseVectorFieldMapper.FEATURE_FLAG.isEnabled());
|
||||
}
|
||||
|
||||
public void testFloatMultiVectorClassBindings() throws IOException {
|
||||
String fieldName = "vector";
|
||||
int dims = 5;
|
||||
float[][][] docVectors = new float[][][] {
|
||||
{ { 230.0f, 300.33f, -34.8988f, 15.555f, -200.0f }, { 100.0f, 200.0f, -50.0f, 10.0f, -150.0f } } };
|
||||
float[][] docMagnitudes = new float[][] { { 0.0f, 0.0f } };
|
||||
for (int i = 0; i < docVectors.length; i++) {
|
||||
for (int j = 0; j < docVectors[i].length; j++) {
|
||||
docMagnitudes[i][j] = (float) Math.sqrt(VectorUtil.dotProduct(docVectors[i][j], docVectors[i][j]));
|
||||
}
|
||||
}
|
||||
|
||||
List<List<Number>> queryVector = List.of(Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f));
|
||||
List<List<Number>> invalidQueryVector = List.of(Arrays.asList(0.5, 111.3));
|
||||
|
||||
List<MultiDenseVectorDocValuesField> fields = List.of(
|
||||
new FloatMultiDenseVectorDocValuesField(
|
||||
MultiDenseVectorScriptDocValuesTests.wrap(docVectors, ElementType.FLOAT),
|
||||
MultiDenseVectorScriptDocValuesTests.wrap(docMagnitudes),
|
||||
"test",
|
||||
ElementType.FLOAT,
|
||||
dims
|
||||
),
|
||||
new FloatMultiDenseVectorDocValuesField(
|
||||
MultiDenseVectorScriptDocValuesTests.wrap(docVectors, ElementType.FLOAT),
|
||||
MultiDenseVectorScriptDocValuesTests.wrap(docMagnitudes),
|
||||
"test",
|
||||
ElementType.FLOAT,
|
||||
dims
|
||||
)
|
||||
);
|
||||
for (MultiDenseVectorDocValuesField field : fields) {
|
||||
field.setNextDocId(0);
|
||||
|
||||
ScoreScript scoreScript = mock(ScoreScript.class);
|
||||
when(scoreScript.field("vector")).thenAnswer(mock -> field);
|
||||
|
||||
// Test max similarity dot product
|
||||
MaxSimDotProduct maxSimDotProduct = new MaxSimDotProduct(scoreScript, queryVector, fieldName);
|
||||
float maxSimDotProductExpected = 65425.625f; // Adjust this value based on expected max similarity
|
||||
assertEquals(
|
||||
"maxSimDotProduct result is not equal to the expected value!",
|
||||
maxSimDotProductExpected,
|
||||
maxSimDotProduct.maxSimDotProduct(),
|
||||
0.001
|
||||
);
|
||||
|
||||
// Check each function rejects query vectors with the wrong dimension
|
||||
IllegalArgumentException e = expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> new MultiVectorScoreScriptUtils.MaxSimDotProduct(scoreScript, invalidQueryVector, fieldName)
|
||||
);
|
||||
assertThat(
|
||||
e.getMessage(),
|
||||
containsString("query vector has a different number of dimensions [2] than the document vectors [5]")
|
||||
);
|
||||
e = expectThrows(IllegalArgumentException.class, () -> new MaxSimInvHamming(scoreScript, invalidQueryVector, fieldName));
|
||||
assertThat(e.getMessage(), containsString("hamming distance is only supported for byte or bit vectors"));
|
||||
|
||||
// Check scripting infrastructure integration
|
||||
assertEquals(65425.6249, new MaxSimDotProduct(scoreScript, queryVector, fieldName).maxSimDotProduct(), 0.001);
|
||||
when(scoreScript._getDocId()).thenReturn(1);
|
||||
e = expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> new MaxSimDotProduct(scoreScript, queryVector, fieldName).maxSimDotProduct()
|
||||
);
|
||||
assertEquals("A document doesn't have a value for a multi-vector field!", e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
public void testByteMultiVectorClassBindings() throws IOException {
|
||||
String fieldName = "vector";
|
||||
int dims = 5;
|
||||
float[][] docVector = new float[][] { { 1, 127, -128, 5, -10 } };
|
||||
float[][] magnitudes = new float[][] { { 0 } };
|
||||
for (int i = 0; i < docVector.length; i++) {
|
||||
magnitudes[i][0] = (float) Math.sqrt(VectorUtil.dotProduct(docVector[i], docVector[i]));
|
||||
}
|
||||
List<List<Number>> queryVector = List.of(Arrays.asList((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4));
|
||||
List<List<Number>> invalidQueryVector = List.of(Arrays.asList((byte) 1, (byte) 1));
|
||||
List<String> hexidecimalString = List.of(HexFormat.of().formatHex(new byte[] { 1, 125, -12, 2, 4 }));
|
||||
|
||||
List<MultiDenseVectorDocValuesField> fields = List.of(
|
||||
new ByteMultiDenseVectorDocValuesField(
|
||||
MultiDenseVectorScriptDocValuesTests.wrap(new float[][][] { docVector }, ElementType.BYTE),
|
||||
MultiDenseVectorScriptDocValuesTests.wrap(magnitudes),
|
||||
"test",
|
||||
ElementType.BYTE,
|
||||
dims
|
||||
)
|
||||
);
|
||||
for (MultiDenseVectorDocValuesField field : fields) {
|
||||
field.setNextDocId(0);
|
||||
|
||||
ScoreScript scoreScript = mock(ScoreScript.class);
|
||||
when(scoreScript.field(fieldName)).thenAnswer(mock -> field);
|
||||
|
||||
// Check each function rejects query vectors with the wrong dimension
|
||||
IllegalArgumentException e = expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> new MaxSimDotProduct(scoreScript, invalidQueryVector, fieldName)
|
||||
);
|
||||
assertThat(
|
||||
e.getMessage(),
|
||||
containsString("query vector has a different number of dimensions [2] than the document vectors [5]")
|
||||
);
|
||||
e = expectThrows(IllegalArgumentException.class, () -> new MaxSimInvHamming(scoreScript, invalidQueryVector, fieldName));
|
||||
assertThat(
|
||||
e.getMessage(),
|
||||
containsString("query vector has a different number of dimensions [2] than the document vectors [5]")
|
||||
);
|
||||
|
||||
// Check scripting infrastructure integration
|
||||
assertEquals(17382.0, new MaxSimDotProduct(scoreScript, queryVector, fieldName).maxSimDotProduct(), 0.001);
|
||||
assertEquals(17382.0, new MaxSimDotProduct(scoreScript, hexidecimalString, fieldName).maxSimDotProduct(), 0.001);
|
||||
assertEquals(0.675, new MaxSimInvHamming(scoreScript, queryVector, fieldName).maxSimInvHamming(), 0.001);
|
||||
assertEquals(0.675, new MaxSimInvHamming(scoreScript, hexidecimalString, fieldName).maxSimInvHamming(), 0.001);
|
||||
MaxSimDotProduct maxSimDotProduct = new MaxSimDotProduct(scoreScript, queryVector, fieldName);
|
||||
when(scoreScript._getDocId()).thenReturn(1);
|
||||
e = expectThrows(IllegalArgumentException.class, maxSimDotProduct::maxSimDotProduct);
|
||||
assertEquals("A document doesn't have a value for a multi-vector field!", e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
public void testBitMultiVectorClassBindingsDotProduct() throws IOException {
|
||||
String fieldName = "vector";
|
||||
int dims = 8;
|
||||
float[][] docVector = new float[][] { { 124 } };
|
||||
// 124 in binary is b01111100
|
||||
List<List<Number>> queryVector = List.of(
|
||||
Arrays.asList((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4, (byte) 1, (byte) 125, (byte) -12)
|
||||
);
|
||||
List<List<Number>> floatQueryVector = List.of(Arrays.asList(1.4f, -1.4f, 0.42f, 0.0f, 1f, -1f, -0.42f, 1.2f));
|
||||
List<List<Number>> invalidQueryVector = List.of(Arrays.asList((byte) 1, (byte) 1));
|
||||
List<String> hexidecimalString = List.of(HexFormat.of().formatHex(new byte[] { 124 }));
|
||||
|
||||
List<MultiDenseVectorDocValuesField> fields = List.of(
|
||||
new BitMultiDenseVectorDocValuesField(
|
||||
MultiDenseVectorScriptDocValuesTests.wrap(new float[][][] { docVector }, ElementType.BIT),
|
||||
MultiDenseVectorScriptDocValuesTests.wrap(new float[][] { { 5 } }),
|
||||
"test",
|
||||
ElementType.BIT,
|
||||
dims
|
||||
)
|
||||
);
|
||||
for (MultiDenseVectorDocValuesField field : fields) {
|
||||
field.setNextDocId(0);
|
||||
|
||||
ScoreScript scoreScript = mock(ScoreScript.class);
|
||||
when(scoreScript.field(fieldName)).thenAnswer(mock -> field);
|
||||
|
||||
MaxSimDotProduct function = new MaxSimDotProduct(scoreScript, queryVector, fieldName);
|
||||
assertEquals(
|
||||
"maxSimDotProduct result is not equal to the expected value!",
|
||||
-12 + 2 + 4 + 1 + 125,
|
||||
function.maxSimDotProduct(),
|
||||
0.001
|
||||
);
|
||||
|
||||
function = new MaxSimDotProduct(scoreScript, floatQueryVector, fieldName);
|
||||
assertEquals(
|
||||
"maxSimDotProduct result is not equal to the expected value!",
|
||||
0.42f + 0f + 1f - 1f - 0.42f,
|
||||
function.maxSimDotProduct(),
|
||||
0.001
|
||||
);
|
||||
|
||||
function = new MaxSimDotProduct(scoreScript, hexidecimalString, fieldName);
|
||||
assertEquals(
|
||||
"maxSimDotProduct result is not equal to the expected value!",
|
||||
Integer.bitCount(124),
|
||||
function.maxSimDotProduct(),
|
||||
0.0
|
||||
);
|
||||
|
||||
// Check each function rejects query vectors with the wrong dimension
|
||||
IllegalArgumentException e = expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> new MaxSimDotProduct(scoreScript, invalidQueryVector, fieldName)
|
||||
);
|
||||
assertThat(
|
||||
e.getMessage(),
|
||||
containsString(
|
||||
"query vector contains inner vectors which have incorrect number of dimensions. "
|
||||
+ "Must be [1] for bitwise operations, or [8] for byte wise operations: provided [2]."
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
public void testByteVsFloatSimilarity() throws IOException {
|
||||
int dims = 5;
|
||||
float[][] docVector = new float[][] { { 1f, 127f, -128f, 5f, -10f } };
|
||||
float[][] magnitudes = new float[][] { { 0 } };
|
||||
for (int i = 0; i < docVector.length; i++) {
|
||||
magnitudes[i][0] = (float) Math.sqrt(VectorUtil.dotProduct(docVector[i], docVector[i]));
|
||||
}
|
||||
List<List<Number>> listFloatVector = List.of(Arrays.asList(1f, 125f, -12f, 2f, 4f));
|
||||
List<List<Number>> listByteVector = List.of(Arrays.asList((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4));
|
||||
float[][] floatVector = new float[][] { { 1f, 125f, -12f, 2f, 4f } };
|
||||
byte[][] byteVector = new byte[][] { { (byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4 } };
|
||||
|
||||
List<MultiDenseVectorDocValuesField> fields = List.of(
|
||||
new FloatMultiDenseVectorDocValuesField(
|
||||
MultiDenseVectorScriptDocValuesTests.wrap(new float[][][] { docVector }, ElementType.FLOAT),
|
||||
MultiDenseVectorScriptDocValuesTests.wrap(magnitudes),
|
||||
"field1",
|
||||
ElementType.FLOAT,
|
||||
dims
|
||||
),
|
||||
new ByteMultiDenseVectorDocValuesField(
|
||||
MultiDenseVectorScriptDocValuesTests.wrap(new float[][][] { docVector }, ElementType.BYTE),
|
||||
MultiDenseVectorScriptDocValuesTests.wrap(magnitudes),
|
||||
"field3",
|
||||
ElementType.BYTE,
|
||||
dims
|
||||
)
|
||||
);
|
||||
for (MultiDenseVectorDocValuesField field : fields) {
|
||||
field.setNextDocId(0);
|
||||
|
||||
ScoreScript scoreScript = mock(ScoreScript.class);
|
||||
when(scoreScript.field("vector")).thenAnswer(mock -> field);
|
||||
|
||||
int dotProductExpected = 17382;
|
||||
MaxSimDotProduct maxSimDotProduct = new MaxSimDotProduct(scoreScript, listFloatVector, "vector");
|
||||
assertEquals(field.getName(), dotProductExpected, maxSimDotProduct.maxSimDotProduct(), 0.001);
|
||||
maxSimDotProduct = new MaxSimDotProduct(scoreScript, listByteVector, "vector");
|
||||
assertEquals(field.getName(), dotProductExpected, maxSimDotProduct.maxSimDotProduct(), 0.001);
|
||||
switch (field.getElementType()) {
|
||||
case BYTE -> {
|
||||
assertEquals(field.getName(), dotProductExpected, field.get().maxSimDotProduct(byteVector), 0.001);
|
||||
UnsupportedOperationException e = expectThrows(
|
||||
UnsupportedOperationException.class,
|
||||
() -> field.get().maxSimDotProduct(floatVector)
|
||||
);
|
||||
assertThat(e.getMessage(), containsString("use [float maxSimDotProduct(byte[][] queryVector)] instead"));
|
||||
}
|
||||
case FLOAT -> {
|
||||
assertEquals(field.getName(), dotProductExpected, field.get().maxSimDotProduct(floatVector), 0.001);
|
||||
UnsupportedOperationException e = expectThrows(
|
||||
UnsupportedOperationException.class,
|
||||
() -> field.get().maxSimDotProduct(byteVector)
|
||||
);
|
||||
assertThat(e.getMessage(), containsString("use [float maxSimDotProduct(float[][] queryVector)] instead"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testByteBoundaries() throws IOException {
|
||||
String fieldName = "vector";
|
||||
int dims = 1;
|
||||
float[] docVector = new float[] { 0 };
|
||||
List<List<Number>> greaterThanVector = List.of(List.of(128));
|
||||
List<List<Number>> lessThanVector = List.of(List.of(-129));
|
||||
List<List<Number>> decimalVector = List.of(List.of(0.5));
|
||||
|
||||
List<MultiDenseVectorDocValuesField> fields = List.of(
|
||||
new ByteMultiDenseVectorDocValuesField(
|
||||
MultiDenseVectorScriptDocValuesTests.wrap(new float[][][] { { docVector } }, ElementType.BYTE),
|
||||
MultiDenseVectorScriptDocValuesTests.wrap(new float[][] { { 1 } }),
|
||||
"test",
|
||||
ElementType.BYTE,
|
||||
dims
|
||||
)
|
||||
);
|
||||
|
||||
for (MultiDenseVectorDocValuesField field : fields) {
|
||||
field.setNextDocId(0);
|
||||
|
||||
ScoreScript scoreScript = mock(ScoreScript.class);
|
||||
when(scoreScript.field(fieldName)).thenAnswer(mock -> field);
|
||||
|
||||
IllegalArgumentException e;
|
||||
|
||||
e = expectThrows(IllegalArgumentException.class, () -> new MaxSimDotProduct(scoreScript, greaterThanVector, fieldName));
|
||||
assertEquals(
|
||||
"element_type [byte] vectors only support integers between [-128, 127] but found [128.0] at dim [0]; "
|
||||
+ "Preview of invalid vector: [128.0]",
|
||||
e.getMessage()
|
||||
);
|
||||
|
||||
e = expectThrows(IllegalArgumentException.class, () -> new MaxSimDotProduct(scoreScript, lessThanVector, fieldName));
|
||||
assertEquals(
|
||||
e.getMessage(),
|
||||
"element_type [byte] vectors only support integers between [-128, 127] but found [-129.0] at dim [0]; "
|
||||
+ "Preview of invalid vector: [-129.0]"
|
||||
);
|
||||
e = expectThrows(IllegalArgumentException.class, () -> new MaxSimDotProduct(scoreScript, decimalVector, fieldName));
|
||||
assertEquals(
|
||||
e.getMessage(),
|
||||
"element_type [byte] vectors only support non-decimal values but found decimal value [0.5] at dim [0]; "
|
||||
+ "Preview of invalid vector: [0.5]"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
public void testDimMismatch() throws IOException {
|
||||
|
||||
}
|
||||
}
|
|
@ -0,0 +1,83 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
package org.elasticsearch.script.field.vectors;
|
||||
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.elasticsearch.index.mapper.vectors.MultiDenseVectorFieldMapper;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.junit.BeforeClass;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.function.IntFunction;
|
||||
|
||||
public class MultiDenseVectorTests extends ESTestCase {
|
||||
|
||||
@BeforeClass
|
||||
public static void setup() {
|
||||
assumeTrue("Requires multi-dense vector support", MultiDenseVectorFieldMapper.FEATURE_FLAG.isEnabled());
|
||||
}
|
||||
|
||||
public void testByteUnsupported() {
|
||||
int count = randomIntBetween(1, 16);
|
||||
int dims = randomIntBetween(1, 16);
|
||||
byte[][] docVector = new byte[count][dims];
|
||||
float[][] queryVector = new float[count][dims];
|
||||
for (int i = 0; i < docVector.length; i++) {
|
||||
random().nextBytes(docVector[i]);
|
||||
for (int j = 0; j < dims; j++) {
|
||||
queryVector[i][j] = randomFloat();
|
||||
}
|
||||
}
|
||||
|
||||
MultiDenseVector knn = newByteVector(docVector);
|
||||
UnsupportedOperationException e;
|
||||
|
||||
e = expectThrows(UnsupportedOperationException.class, () -> knn.maxSimDotProduct(queryVector));
|
||||
assertEquals(e.getMessage(), "use [float maxSimDotProduct(byte[][] queryVector)] instead");
|
||||
}
|
||||
|
||||
public void testFloatUnsupported() {
|
||||
int count = randomIntBetween(1, 16);
|
||||
int dims = randomIntBetween(1, 16);
|
||||
float[][] docVector = new float[count][dims];
|
||||
byte[][] queryVector = new byte[count][dims];
|
||||
for (int i = 0; i < docVector.length; i++) {
|
||||
random().nextBytes(queryVector[i]);
|
||||
for (int j = 0; j < dims; j++) {
|
||||
docVector[i][j] = randomFloat();
|
||||
}
|
||||
}
|
||||
|
||||
MultiDenseVector knn = newFloatVector(docVector);
|
||||
|
||||
UnsupportedOperationException e = expectThrows(UnsupportedOperationException.class, () -> knn.maxSimDotProduct(queryVector));
|
||||
assertEquals(e.getMessage(), "use [float maxSimDotProduct(float[][] queryVector)] instead");
|
||||
}
|
||||
|
||||
static MultiDenseVector newFloatVector(float[][] vector) {
|
||||
BytesRef magnitudes = magnitudes(vector.length, i -> (float) Math.sqrt(VectorUtil.dotProduct(vector[i], vector[i])));
|
||||
return new FloatMultiDenseVector(VectorIterator.from(vector), magnitudes, vector.length, vector[0].length);
|
||||
}
|
||||
|
||||
static MultiDenseVector newByteVector(byte[][] vector) {
|
||||
BytesRef magnitudes = magnitudes(vector.length, i -> (float) Math.sqrt(VectorUtil.dotProduct(vector[i], vector[i])));
|
||||
return new ByteMultiDenseVector(VectorIterator.from(vector), magnitudes, vector.length, vector[0].length);
|
||||
}
|
||||
|
||||
static BytesRef magnitudes(int count, IntFunction<Float> magnitude) {
|
||||
ByteBuffer magnitudeBuffer = ByteBuffer.allocate(count * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
for (int i = 0; i < count; i++) {
|
||||
magnitudeBuffer.putFloat(magnitude.apply(i));
|
||||
}
|
||||
return new BytesRef(magnitudeBuffer.array());
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue