mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 09:28:55 -04:00
Adding hamming distance function to painless for dense_vector fields (#109359)
This adds `hamming` distances, the pop-count of `xor` byte vectors as a first class citizen in painless. For byte vectors, this means that we can compute hamming distances via script_score (aka, brute-force). The implementation of `hamming` is the same that is available in Lucene, and when lucene 9.11 is merged, we should update our logic where applicable to utilize it. NOTE: this does not yet add hamming distance as a metric for indexed vectors. This will be a future PR after the Lucene 9.11 upgrade.
This commit is contained in:
parent
bbcf73028e
commit
acc99302c6
18 changed files with 438 additions and 7 deletions
|
@ -56,7 +56,7 @@ public class DistanceFunctionBenchmark {
|
|||
@Param({ "96" })
|
||||
private int dims;
|
||||
|
||||
@Param({ "dot", "cosine", "l1", "l2" })
|
||||
@Param({ "dot", "cosine", "l1", "l2", "hamming" })
|
||||
private String function;
|
||||
|
||||
@Param({ "knn", "binary" })
|
||||
|
@ -330,6 +330,18 @@ public class DistanceFunctionBenchmark {
|
|||
}
|
||||
}
|
||||
|
||||
private static class HammingKnnByteBenchmarkFunction extends KnnByteBenchmarkFunction {
|
||||
|
||||
private HammingKnnByteBenchmarkFunction(int dims) {
|
||||
super(dims);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void execute(Consumer<Object> consumer) {
|
||||
new ByteKnnDenseVector(docVector).hamming(queryVector);
|
||||
}
|
||||
}
|
||||
|
||||
private static class L1BinaryFloatBenchmarkFunction extends BinaryFloatBenchmarkFunction {
|
||||
|
||||
private L1BinaryFloatBenchmarkFunction(int dims) {
|
||||
|
@ -354,6 +366,18 @@ public class DistanceFunctionBenchmark {
|
|||
}
|
||||
}
|
||||
|
||||
private static class HammingBinaryByteBenchmarkFunction extends BinaryByteBenchmarkFunction {
|
||||
|
||||
private HammingBinaryByteBenchmarkFunction(int dims) {
|
||||
super(dims);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void execute(Consumer<Object> consumer) {
|
||||
new ByteBinaryDenseVector(vectorValue, docVector, dims).hamming(queryVector);
|
||||
}
|
||||
}
|
||||
|
||||
private static class L2KnnFloatBenchmarkFunction extends KnnFloatBenchmarkFunction {
|
||||
|
||||
private L2KnnFloatBenchmarkFunction(int dims) {
|
||||
|
@ -454,6 +478,11 @@ public class DistanceFunctionBenchmark {
|
|||
case "binary" -> new L2BinaryByteBenchmarkFunction(dims);
|
||||
default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
|
||||
};
|
||||
case "hamming" -> benchmarkFunction = switch (type) {
|
||||
case "knn" -> new HammingKnnByteBenchmarkFunction(dims);
|
||||
case "binary" -> new HammingBinaryByteBenchmarkFunction(dims);
|
||||
default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
|
||||
};
|
||||
default -> throw new UnsupportedOperationException("unexpected function [" + function + "]");
|
||||
}
|
||||
}
|
||||
|
|
5
docs/changelog/109359.yaml
Normal file
5
docs/changelog/109359.yaml
Normal file
|
@ -0,0 +1,5 @@
|
|||
pr: 109359
|
||||
summary: Adding hamming distance function to painless for `dense_vector` fields
|
||||
area: Vector Search
|
||||
type: enhancement
|
||||
issues: []
|
|
@ -23,6 +23,7 @@ The following methods are directly callable without a class/instance qualifier.
|
|||
* double dotProduct(Object *, String *)
|
||||
* double l1norm(Object *, String *)
|
||||
* double l2norm(Object *, String *)
|
||||
* double hamming(Object *, String *)
|
||||
* double randomScore(int *)
|
||||
* double randomScore(int *, String *)
|
||||
* double saturation(double, double)
|
||||
|
|
|
@ -12,9 +12,10 @@ This is the list of available vector functions and vector access methods:
|
|||
1. <<vector-functions-cosine,`cosineSimilarity`>> – calculates cosine similarity
|
||||
2. <<vector-functions-dot-product,`dotProduct`>> – calculates dot product
|
||||
3. <<vector-functions-l1,`l1norm`>> – calculates L^1^ distance
|
||||
4. <<vector-functions-l2,`l2norm`>> - calculates L^2^ distance
|
||||
5. <<vector-functions-accessing-vectors,`doc[<field>].vectorValue`>> – returns a vector's value as an array of floats
|
||||
6. <<vector-functions-accessing-vectors,`doc[<field>].magnitude`>> – returns a vector's magnitude
|
||||
4. <<vector-functions-hamming,`hamming`>> – calculates Hamming distance
|
||||
5. <<vector-functions-l2,`l2norm`>> - calculates L^2^ distance
|
||||
6. <<vector-functions-accessing-vectors,`doc[<field>].vectorValue`>> – returns a vector's value as an array of floats
|
||||
7. <<vector-functions-accessing-vectors,`doc[<field>].magnitude`>> – returns a vector's magnitude
|
||||
|
||||
NOTE: The recommended way to access dense vectors is through the
|
||||
`cosineSimilarity`, `dotProduct`, `l1norm` or `l2norm` functions. Please note
|
||||
|
@ -35,8 +36,15 @@ PUT my-index-000001
|
|||
"properties": {
|
||||
"my_dense_vector": {
|
||||
"type": "dense_vector",
|
||||
"index": false,
|
||||
"dims": 3
|
||||
},
|
||||
"my_byte_dense_vector": {
|
||||
"type": "dense_vector",
|
||||
"index": false,
|
||||
"dims": 3,
|
||||
"element_type": "byte"
|
||||
},
|
||||
"status" : {
|
||||
"type" : "keyword"
|
||||
}
|
||||
|
@ -47,12 +55,14 @@ PUT my-index-000001
|
|||
PUT my-index-000001/_doc/1
|
||||
{
|
||||
"my_dense_vector": [0.5, 10, 6],
|
||||
"my_byte_dense_vector": [0, 10, 6],
|
||||
"status" : "published"
|
||||
}
|
||||
|
||||
PUT my-index-000001/_doc/2
|
||||
{
|
||||
"my_dense_vector": [-0.5, 10, 10],
|
||||
"my_byte_dense_vector": [0, 10, 10],
|
||||
"status" : "published"
|
||||
}
|
||||
|
||||
|
@ -179,6 +189,40 @@ we reversed the output from `l1norm` and `l2norm`. Also, to avoid
|
|||
division by 0 when a document vector matches the query exactly,
|
||||
we added `1` in the denominator.
|
||||
|
||||
[[vector-functions-hamming]]
|
||||
====== Hamming distance
|
||||
|
||||
The `hamming` function calculates {wikipedia}/Hamming_distance[Hamming distance] between a given query vector and
|
||||
document vectors. It is only available for byte vectors.
|
||||
|
||||
[source,console]
|
||||
--------------------------------------------------
|
||||
GET my-index-000001/_search
|
||||
{
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query" : {
|
||||
"bool" : {
|
||||
"filter" : {
|
||||
"term" : {
|
||||
"status" : "published"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"script": {
|
||||
"source": "(24 - hamming(params.queryVector, 'my_byte_dense_vector')) / 24", <1>
|
||||
"params": {
|
||||
"queryVector": [4, 3, 0]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
--------------------------------------------------
|
||||
|
||||
<1> Calculate the Hamming distance and normalize it by the bits to get a score between 0 and 1.
|
||||
|
||||
[[vector-functions-l2]]
|
||||
====== L^2^ distance (Euclidean distance)
|
||||
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -31,5 +31,6 @@ static_import {
|
|||
double l2norm(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$L2Norm
|
||||
double cosineSimilarity(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$CosineSimilarity
|
||||
double dotProduct(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$DotProduct
|
||||
double hamming(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$Hamming
|
||||
}
|
||||
|
||||
|
|
|
@ -219,3 +219,36 @@ setup:
|
|||
|
||||
- match: {hits.hits.2._id: "2"}
|
||||
- close_to: {hits.hits.2._score: {value: 186.34454, error: 0.01}}
|
||||
---
|
||||
"Test hamming distance fails on float":
|
||||
- requires:
|
||||
cluster_features: ["script.hamming"]
|
||||
reason: "support for hamming distance added in 8.15"
|
||||
- do:
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
catch: bad_request
|
||||
search:
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "hamming(params.query_vector, 'vector')"
|
||||
params:
|
||||
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
|
||||
|
||||
- do:
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
catch: bad_request
|
||||
search:
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "hamming(params.query_vector, 'indexed_vector')"
|
||||
params:
|
||||
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
|
||||
|
||||
|
|
|
@ -0,0 +1,156 @@
|
|||
setup:
|
||||
- requires:
|
||||
cluster_features: ["script.hamming"]
|
||||
reason: "support for hamming distance added in 8.15"
|
||||
test_runner_features: headers
|
||||
|
||||
- do:
|
||||
indices.create:
|
||||
index: test-index
|
||||
body:
|
||||
settings:
|
||||
number_of_replicas: 0
|
||||
mappings:
|
||||
properties:
|
||||
my_dense_vector:
|
||||
index: false
|
||||
type: dense_vector
|
||||
element_type: byte
|
||||
dims: 5
|
||||
my_dense_vector_indexed:
|
||||
index: true
|
||||
type: dense_vector
|
||||
element_type: byte
|
||||
dims: 5
|
||||
|
||||
- do:
|
||||
index:
|
||||
index: test-index
|
||||
id: "1"
|
||||
body:
|
||||
my_dense_vector: [8, 5, -15, 1, -7]
|
||||
my_dense_vector_indexed: [8, 5, -15, 1, -7]
|
||||
|
||||
- do:
|
||||
index:
|
||||
index: test-index
|
||||
id: "2"
|
||||
body:
|
||||
my_dense_vector: [-1, 115, -3, 4, -128]
|
||||
my_dense_vector_indexed: [-1, 115, -3, 4, -128]
|
||||
|
||||
- do:
|
||||
index:
|
||||
index: test-index
|
||||
id: "3"
|
||||
body:
|
||||
my_dense_vector: [2, 18, -5, 0, -124]
|
||||
my_dense_vector_indexed: [2, 18, -5, 0, -124]
|
||||
|
||||
- do:
|
||||
indices.refresh: {}
|
||||
|
||||
---
|
||||
"Hamming distance":
|
||||
- do:
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "hamming(params.query_vector, 'my_dense_vector')"
|
||||
params:
|
||||
query_vector: [0, 111, -13, 14, -124]
|
||||
|
||||
- match: {hits.total: 3}
|
||||
|
||||
- match: {hits.hits.0._id: "2"}
|
||||
- match: {hits.hits.0._score: 17.0}
|
||||
|
||||
- match: {hits.hits.1._id: "1"}
|
||||
- match: {hits.hits.1._score: 16.0}
|
||||
|
||||
- match: {hits.hits.2._id: "3"}
|
||||
- match: {hits.hits.2._score: 11.0}
|
||||
|
||||
|
||||
- do:
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "hamming(params.query_vector, 'my_dense_vector_indexed')"
|
||||
params:
|
||||
query_vector: [0, 111, -13, 14, -124]
|
||||
|
||||
- match: {hits.total: 3}
|
||||
|
||||
- match: {hits.hits.0._id: "2"}
|
||||
- match: {hits.hits.0._score: 17.0}
|
||||
|
||||
- match: {hits.hits.1._id: "1"}
|
||||
- match: {hits.hits.1._score: 16.0}
|
||||
|
||||
- match: {hits.hits.2._id: "3"}
|
||||
- match: {hits.hits.2._score: 11.0}
|
||||
---
|
||||
"Hamming distance hexidecimal":
|
||||
- do:
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "hamming(params.query_vector, 'my_dense_vector')"
|
||||
params:
|
||||
query_vector: "006ff30e84"
|
||||
|
||||
- match: {hits.total: 3}
|
||||
|
||||
- match: {hits.hits.0._id: "2"}
|
||||
- match: {hits.hits.0._score: 17.0}
|
||||
|
||||
- match: {hits.hits.1._id: "1"}
|
||||
- match: {hits.hits.1._score: 16.0}
|
||||
|
||||
- match: {hits.hits.2._id: "3"}
|
||||
- match: {hits.hits.2._score: 11.0}
|
||||
|
||||
|
||||
- do:
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "hamming(params.query_vector, 'my_dense_vector_indexed')"
|
||||
params:
|
||||
query_vector: "006ff30e84"
|
||||
|
||||
- match: {hits.total: 3}
|
||||
|
||||
- match: {hits.hits.0._id: "2"}
|
||||
- match: {hits.hits.0._score: 17.0}
|
||||
|
||||
- match: {hits.hits.1._id: "1"}
|
||||
- match: {hits.hits.1._score: 16.0}
|
||||
|
||||
- match: {hits.hits.2._id: "3"}
|
||||
- match: {hits.hits.2._score: 11.0}
|
|
@ -431,6 +431,7 @@ module org.elasticsearch.server {
|
|||
org.elasticsearch.indices.IndicesFeatures,
|
||||
org.elasticsearch.action.admin.cluster.allocation.AllocationStatsFeatures,
|
||||
org.elasticsearch.index.mapper.MapperFeatures,
|
||||
org.elasticsearch.script.ScriptFeatures,
|
||||
org.elasticsearch.search.retriever.RetrieversFeatures,
|
||||
org.elasticsearch.reservedstate.service.FileSettingsFeatures;
|
||||
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0 and the Server Side Public License, v 1; you may not use this file except
|
||||
* in compliance with, at your election, the Elastic License 2.0 or the Server
|
||||
* Side Public License, v 1.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.script;
|
||||
|
||||
import org.elasticsearch.features.FeatureSpecification;
|
||||
import org.elasticsearch.features.NodeFeature;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
public final class ScriptFeatures implements FeatureSpecification {
|
||||
@Override
|
||||
public Set<NodeFeature> getFeatures() {
|
||||
return Set.of(VectorScoreScriptUtils.HAMMING_DISTANCE_FUNCTION);
|
||||
}
|
||||
}
|
|
@ -9,6 +9,8 @@
|
|||
package org.elasticsearch.script;
|
||||
|
||||
import org.elasticsearch.ExceptionsHelper;
|
||||
import org.elasticsearch.features.NodeFeature;
|
||||
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
|
||||
import org.elasticsearch.script.field.vectors.DenseVector;
|
||||
import org.elasticsearch.script.field.vectors.DenseVectorDocValuesField;
|
||||
|
||||
|
@ -18,6 +20,8 @@ import java.util.List;
|
|||
|
||||
public class VectorScoreScriptUtils {
|
||||
|
||||
public static final NodeFeature HAMMING_DISTANCE_FUNCTION = new NodeFeature("script.hamming");
|
||||
|
||||
public static class DenseVectorFunction {
|
||||
protected final ScoreScript scoreScript;
|
||||
protected final DenseVectorDocValuesField field;
|
||||
|
@ -187,6 +191,52 @@ public class VectorScoreScriptUtils {
|
|||
}
|
||||
}
|
||||
|
||||
// Calculate Hamming distances between a query's dense vector and documents' dense vectors
|
||||
public interface HammingDistanceInterface {
|
||||
int hamming();
|
||||
}
|
||||
|
||||
public static class ByteHammingDistance extends ByteDenseVectorFunction implements HammingDistanceInterface {
|
||||
|
||||
public ByteHammingDistance(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
|
||||
super(scoreScript, field, queryVector);
|
||||
}
|
||||
|
||||
public ByteHammingDistance(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
|
||||
super(scoreScript, field, queryVector);
|
||||
}
|
||||
|
||||
public int hamming() {
|
||||
setNextVector();
|
||||
return field.get().hamming(queryVector);
|
||||
}
|
||||
}
|
||||
|
||||
public static final class Hamming {
|
||||
|
||||
private final HammingDistanceInterface function;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public Hamming(ScoreScript scoreScript, Object queryVector, String fieldName) {
|
||||
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
|
||||
if (field.getElementType() != DenseVectorFieldMapper.ElementType.BYTE) {
|
||||
throw new IllegalArgumentException("hamming distance is only supported for byte vectors");
|
||||
}
|
||||
if (queryVector instanceof List) {
|
||||
function = new ByteHammingDistance(scoreScript, field, (List<Number>) queryVector);
|
||||
} else if (queryVector instanceof String s) {
|
||||
byte[] parsedQueryVector = HexFormat.of().parseHex(s);
|
||||
function = new ByteHammingDistance(scoreScript, field, parsedQueryVector);
|
||||
} else {
|
||||
throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName());
|
||||
}
|
||||
}
|
||||
|
||||
public double hamming() {
|
||||
return function.hamming();
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate l2 norm (Manhattan distance) between a query's dense vector and documents' dense vectors
|
||||
public interface L2NormInterface {
|
||||
double l2norm();
|
||||
|
|
|
@ -83,6 +83,16 @@ public class BinaryDenseVector implements DenseVector {
|
|||
return l1norm;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hamming(byte[] queryVector) {
|
||||
throw new UnsupportedOperationException("hamming distance is not supported for float vectors");
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hamming(List<Number> queryVector) {
|
||||
throw new UnsupportedOperationException("hamming distance is not supported for float vectors");
|
||||
}
|
||||
|
||||
@Override
|
||||
public double l2Norm(byte[] queryVector) {
|
||||
throw new UnsupportedOperationException("use [double l2Norm(float[] queryVector)] instead");
|
||||
|
|
|
@ -100,6 +100,20 @@ public class ByteBinaryDenseVector implements DenseVector {
|
|||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hamming(byte[] queryVector) {
|
||||
return VectorUtil.xorBitCount(queryVector, vectorValue);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hamming(List<Number> queryVector) {
|
||||
int distance = 0;
|
||||
for (int i = 0; i < queryVector.size(); i++) {
|
||||
distance += Integer.bitCount((queryVector.get(i).intValue() ^ vectorValue[i]) & 0xFF);
|
||||
}
|
||||
return distance;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double l2Norm(byte[] queryVector) {
|
||||
return Math.sqrt(VectorUtil.squareDistance(queryVector, vectorValue));
|
||||
|
|
|
@ -101,6 +101,20 @@ public class ByteKnnDenseVector implements DenseVector {
|
|||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hamming(byte[] queryVector) {
|
||||
return VectorUtil.xorBitCount(queryVector, docVector);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hamming(List<Number> queryVector) {
|
||||
int distance = 0;
|
||||
for (int i = 0; i < queryVector.size(); i++) {
|
||||
distance += Integer.bitCount((queryVector.get(i).intValue() ^ docVector[i]) & 0xFF);
|
||||
}
|
||||
return distance;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double l2Norm(byte[] queryVector) {
|
||||
return Math.sqrt(VectorUtil.squareDistance(docVector, queryVector));
|
||||
|
|
|
@ -14,8 +14,7 @@ import java.util.List;
|
|||
|
||||
/**
|
||||
* DenseVector value type for the painless.
|
||||
*/
|
||||
/* dotProduct, l1Norm, l2Norm, cosineSimilarity have three flavors depending on the type of the queryVector
|
||||
* dotProduct, l1Norm, l2Norm, cosineSimilarity have three flavors depending on the type of the queryVector
|
||||
* 1) float[], this is for the ScoreScriptUtils class bindings which have converted a List based query vector into an array
|
||||
* 2) List, A painless script will typically use Lists since they are easy to pass as params and have an easy
|
||||
* literal syntax. Working with Lists directly, instead of converting to a float[], trades off runtime operations against
|
||||
|
@ -74,6 +73,24 @@ public interface DenseVector {
|
|||
throw new IllegalArgumentException(badQueryVectorType(queryVector));
|
||||
}
|
||||
|
||||
int hamming(byte[] queryVector);
|
||||
|
||||
int hamming(List<Number> queryVector);
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
default int hamming(Object queryVector) {
|
||||
if (queryVector instanceof List<?> list) {
|
||||
checkDimensions(getDims(), list.size());
|
||||
return hamming((List<Number>) list);
|
||||
}
|
||||
if (queryVector instanceof byte[] bytes) {
|
||||
checkDimensions(getDims(), bytes.length);
|
||||
return hamming(bytes);
|
||||
}
|
||||
|
||||
throw new IllegalArgumentException(badQueryVectorType(queryVector));
|
||||
}
|
||||
|
||||
double l2Norm(byte[] queryVector);
|
||||
|
||||
double l2Norm(float[] queryVector);
|
||||
|
@ -231,6 +248,16 @@ public interface DenseVector {
|
|||
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hamming(byte[] queryVector) {
|
||||
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hamming(List<Number> queryVector) {
|
||||
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double l2Norm(byte[] queryVector) {
|
||||
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
|
||||
|
|
|
@ -85,6 +85,16 @@ public class KnnDenseVector implements DenseVector {
|
|||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hamming(byte[] queryVector) {
|
||||
throw new UnsupportedOperationException("hamming distance is not supported for float vectors");
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hamming(List<Number> queryVector) {
|
||||
throw new UnsupportedOperationException("hamming distance is not supported for float vectors");
|
||||
}
|
||||
|
||||
@Override
|
||||
public double l2Norm(byte[] queryVector) {
|
||||
throw new UnsupportedOperationException("use [double l2Norm(float[] queryVector)] instead");
|
||||
|
|
|
@ -15,4 +15,5 @@ org.elasticsearch.indices.IndicesFeatures
|
|||
org.elasticsearch.action.admin.cluster.allocation.AllocationStatsFeatures
|
||||
org.elasticsearch.index.mapper.MapperFeatures
|
||||
org.elasticsearch.search.retriever.RetrieversFeatures
|
||||
org.elasticsearch.script.ScriptFeatures
|
||||
org.elasticsearch.reservedstate.service.FileSettingsFeatures
|
||||
|
|
|
@ -15,6 +15,7 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType
|
|||
import org.elasticsearch.index.mapper.vectors.KnnDenseVectorScriptDocValuesTests;
|
||||
import org.elasticsearch.script.VectorScoreScriptUtils.CosineSimilarity;
|
||||
import org.elasticsearch.script.VectorScoreScriptUtils.DotProduct;
|
||||
import org.elasticsearch.script.VectorScoreScriptUtils.Hamming;
|
||||
import org.elasticsearch.script.VectorScoreScriptUtils.L1Norm;
|
||||
import org.elasticsearch.script.VectorScoreScriptUtils.L2Norm;
|
||||
import org.elasticsearch.script.field.vectors.BinaryDenseVectorDocValuesField;
|
||||
|
@ -112,6 +113,12 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
|
|||
containsString("query vector has a different number of dimensions [2] than the document vectors [5]")
|
||||
);
|
||||
|
||||
e = expectThrows(IllegalArgumentException.class, () -> new Hamming(scoreScript, queryVector, fieldName));
|
||||
assertThat(e.getMessage(), containsString("hamming distance is only supported for byte vectors"));
|
||||
|
||||
e = expectThrows(IllegalArgumentException.class, () -> new Hamming(scoreScript, invalidQueryVector, fieldName));
|
||||
assertThat(e.getMessage(), containsString("hamming distance is only supported for byte vectors"));
|
||||
|
||||
// Check scripting infrastructure integration
|
||||
DotProduct dotProduct = new DotProduct(scoreScript, queryVector, fieldName);
|
||||
assertEquals(65425.6249, dotProduct.dotProduct(), 0.001);
|
||||
|
@ -199,6 +206,11 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
|
|||
e.getMessage(),
|
||||
containsString("query vector has a different number of dimensions [2] than the document vectors [5]")
|
||||
);
|
||||
e = expectThrows(IllegalArgumentException.class, () -> new Hamming(scoreScript, invalidQueryVector, fieldName));
|
||||
assertThat(
|
||||
e.getMessage(),
|
||||
containsString("query vector has a different number of dimensions [2] than the document vectors [5]")
|
||||
);
|
||||
|
||||
// Check scripting infrastructure integration
|
||||
assertEquals(17382.0, new DotProduct(scoreScript, queryVector, fieldName).dotProduct(), 0.001);
|
||||
|
@ -207,6 +219,8 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
|
|||
assertEquals(135.0, new L1Norm(scoreScript, hexidecimalString, fieldName).l1norm(), 0.001);
|
||||
assertEquals(116.897, new L2Norm(scoreScript, queryVector, fieldName).l2norm(), 0.001);
|
||||
assertEquals(116.897, new L2Norm(scoreScript, hexidecimalString, fieldName).l2norm(), 0.001);
|
||||
assertEquals(13.0, new Hamming(scoreScript, queryVector, fieldName).hamming(), 0.001);
|
||||
assertEquals(13.0, new Hamming(scoreScript, hexidecimalString, fieldName).hamming(), 0.001);
|
||||
DotProduct dotProduct = new DotProduct(scoreScript, queryVector, fieldName);
|
||||
when(scoreScript._getDocId()).thenReturn(1);
|
||||
e = expectThrows(IllegalArgumentException.class, dotProduct::dotProduct);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue