mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 17:34:17 -04:00
Adding hamming distance function to painless for dense_vector fields (#109359)
This adds `hamming` distances, the pop-count of `xor` byte vectors as a first class citizen in painless. For byte vectors, this means that we can compute hamming distances via script_score (aka, brute-force). The implementation of `hamming` is the same that is available in Lucene, and when lucene 9.11 is merged, we should update our logic where applicable to utilize it. NOTE: this does not yet add hamming distance as a metric for indexed vectors. This will be a future PR after the Lucene 9.11 upgrade.
This commit is contained in:
parent
bbcf73028e
commit
acc99302c6
18 changed files with 438 additions and 7 deletions
|
@ -56,7 +56,7 @@ public class DistanceFunctionBenchmark {
|
||||||
@Param({ "96" })
|
@Param({ "96" })
|
||||||
private int dims;
|
private int dims;
|
||||||
|
|
||||||
@Param({ "dot", "cosine", "l1", "l2" })
|
@Param({ "dot", "cosine", "l1", "l2", "hamming" })
|
||||||
private String function;
|
private String function;
|
||||||
|
|
||||||
@Param({ "knn", "binary" })
|
@Param({ "knn", "binary" })
|
||||||
|
@ -330,6 +330,18 @@ public class DistanceFunctionBenchmark {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static class HammingKnnByteBenchmarkFunction extends KnnByteBenchmarkFunction {
|
||||||
|
|
||||||
|
private HammingKnnByteBenchmarkFunction(int dims) {
|
||||||
|
super(dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void execute(Consumer<Object> consumer) {
|
||||||
|
new ByteKnnDenseVector(docVector).hamming(queryVector);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private static class L1BinaryFloatBenchmarkFunction extends BinaryFloatBenchmarkFunction {
|
private static class L1BinaryFloatBenchmarkFunction extends BinaryFloatBenchmarkFunction {
|
||||||
|
|
||||||
private L1BinaryFloatBenchmarkFunction(int dims) {
|
private L1BinaryFloatBenchmarkFunction(int dims) {
|
||||||
|
@ -354,6 +366,18 @@ public class DistanceFunctionBenchmark {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static class HammingBinaryByteBenchmarkFunction extends BinaryByteBenchmarkFunction {
|
||||||
|
|
||||||
|
private HammingBinaryByteBenchmarkFunction(int dims) {
|
||||||
|
super(dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void execute(Consumer<Object> consumer) {
|
||||||
|
new ByteBinaryDenseVector(vectorValue, docVector, dims).hamming(queryVector);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private static class L2KnnFloatBenchmarkFunction extends KnnFloatBenchmarkFunction {
|
private static class L2KnnFloatBenchmarkFunction extends KnnFloatBenchmarkFunction {
|
||||||
|
|
||||||
private L2KnnFloatBenchmarkFunction(int dims) {
|
private L2KnnFloatBenchmarkFunction(int dims) {
|
||||||
|
@ -454,6 +478,11 @@ public class DistanceFunctionBenchmark {
|
||||||
case "binary" -> new L2BinaryByteBenchmarkFunction(dims);
|
case "binary" -> new L2BinaryByteBenchmarkFunction(dims);
|
||||||
default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
|
default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
|
||||||
};
|
};
|
||||||
|
case "hamming" -> benchmarkFunction = switch (type) {
|
||||||
|
case "knn" -> new HammingKnnByteBenchmarkFunction(dims);
|
||||||
|
case "binary" -> new HammingBinaryByteBenchmarkFunction(dims);
|
||||||
|
default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
|
||||||
|
};
|
||||||
default -> throw new UnsupportedOperationException("unexpected function [" + function + "]");
|
default -> throw new UnsupportedOperationException("unexpected function [" + function + "]");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
5
docs/changelog/109359.yaml
Normal file
5
docs/changelog/109359.yaml
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
pr: 109359
|
||||||
|
summary: Adding hamming distance function to painless for `dense_vector` fields
|
||||||
|
area: Vector Search
|
||||||
|
type: enhancement
|
||||||
|
issues: []
|
|
@ -23,6 +23,7 @@ The following methods are directly callable without a class/instance qualifier.
|
||||||
* double dotProduct(Object *, String *)
|
* double dotProduct(Object *, String *)
|
||||||
* double l1norm(Object *, String *)
|
* double l1norm(Object *, String *)
|
||||||
* double l2norm(Object *, String *)
|
* double l2norm(Object *, String *)
|
||||||
|
* double hamming(Object *, String *)
|
||||||
* double randomScore(int *)
|
* double randomScore(int *)
|
||||||
* double randomScore(int *, String *)
|
* double randomScore(int *, String *)
|
||||||
* double saturation(double, double)
|
* double saturation(double, double)
|
||||||
|
|
|
@ -12,9 +12,10 @@ This is the list of available vector functions and vector access methods:
|
||||||
1. <<vector-functions-cosine,`cosineSimilarity`>> – calculates cosine similarity
|
1. <<vector-functions-cosine,`cosineSimilarity`>> – calculates cosine similarity
|
||||||
2. <<vector-functions-dot-product,`dotProduct`>> – calculates dot product
|
2. <<vector-functions-dot-product,`dotProduct`>> – calculates dot product
|
||||||
3. <<vector-functions-l1,`l1norm`>> – calculates L^1^ distance
|
3. <<vector-functions-l1,`l1norm`>> – calculates L^1^ distance
|
||||||
4. <<vector-functions-l2,`l2norm`>> - calculates L^2^ distance
|
4. <<vector-functions-hamming,`hamming`>> – calculates Hamming distance
|
||||||
5. <<vector-functions-accessing-vectors,`doc[<field>].vectorValue`>> – returns a vector's value as an array of floats
|
5. <<vector-functions-l2,`l2norm`>> - calculates L^2^ distance
|
||||||
6. <<vector-functions-accessing-vectors,`doc[<field>].magnitude`>> – returns a vector's magnitude
|
6. <<vector-functions-accessing-vectors,`doc[<field>].vectorValue`>> – returns a vector's value as an array of floats
|
||||||
|
7. <<vector-functions-accessing-vectors,`doc[<field>].magnitude`>> – returns a vector's magnitude
|
||||||
|
|
||||||
NOTE: The recommended way to access dense vectors is through the
|
NOTE: The recommended way to access dense vectors is through the
|
||||||
`cosineSimilarity`, `dotProduct`, `l1norm` or `l2norm` functions. Please note
|
`cosineSimilarity`, `dotProduct`, `l1norm` or `l2norm` functions. Please note
|
||||||
|
@ -35,8 +36,15 @@ PUT my-index-000001
|
||||||
"properties": {
|
"properties": {
|
||||||
"my_dense_vector": {
|
"my_dense_vector": {
|
||||||
"type": "dense_vector",
|
"type": "dense_vector",
|
||||||
|
"index": false,
|
||||||
"dims": 3
|
"dims": 3
|
||||||
},
|
},
|
||||||
|
"my_byte_dense_vector": {
|
||||||
|
"type": "dense_vector",
|
||||||
|
"index": false,
|
||||||
|
"dims": 3,
|
||||||
|
"element_type": "byte"
|
||||||
|
},
|
||||||
"status" : {
|
"status" : {
|
||||||
"type" : "keyword"
|
"type" : "keyword"
|
||||||
}
|
}
|
||||||
|
@ -47,12 +55,14 @@ PUT my-index-000001
|
||||||
PUT my-index-000001/_doc/1
|
PUT my-index-000001/_doc/1
|
||||||
{
|
{
|
||||||
"my_dense_vector": [0.5, 10, 6],
|
"my_dense_vector": [0.5, 10, 6],
|
||||||
|
"my_byte_dense_vector": [0, 10, 6],
|
||||||
"status" : "published"
|
"status" : "published"
|
||||||
}
|
}
|
||||||
|
|
||||||
PUT my-index-000001/_doc/2
|
PUT my-index-000001/_doc/2
|
||||||
{
|
{
|
||||||
"my_dense_vector": [-0.5, 10, 10],
|
"my_dense_vector": [-0.5, 10, 10],
|
||||||
|
"my_byte_dense_vector": [0, 10, 10],
|
||||||
"status" : "published"
|
"status" : "published"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -179,6 +189,40 @@ we reversed the output from `l1norm` and `l2norm`. Also, to avoid
|
||||||
division by 0 when a document vector matches the query exactly,
|
division by 0 when a document vector matches the query exactly,
|
||||||
we added `1` in the denominator.
|
we added `1` in the denominator.
|
||||||
|
|
||||||
|
[[vector-functions-hamming]]
|
||||||
|
====== Hamming distance
|
||||||
|
|
||||||
|
The `hamming` function calculates {wikipedia}/Hamming_distance[Hamming distance] between a given query vector and
|
||||||
|
document vectors. It is only available for byte vectors.
|
||||||
|
|
||||||
|
[source,console]
|
||||||
|
--------------------------------------------------
|
||||||
|
GET my-index-000001/_search
|
||||||
|
{
|
||||||
|
"query": {
|
||||||
|
"script_score": {
|
||||||
|
"query" : {
|
||||||
|
"bool" : {
|
||||||
|
"filter" : {
|
||||||
|
"term" : {
|
||||||
|
"status" : "published"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"script": {
|
||||||
|
"source": "(24 - hamming(params.queryVector, 'my_byte_dense_vector')) / 24", <1>
|
||||||
|
"params": {
|
||||||
|
"queryVector": [4, 3, 0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
--------------------------------------------------
|
||||||
|
|
||||||
|
<1> Calculate the Hamming distance and normalize it by the bits to get a score between 0 and 1.
|
||||||
|
|
||||||
[[vector-functions-l2]]
|
[[vector-functions-l2]]
|
||||||
====== L^2^ distance (Euclidean distance)
|
====== L^2^ distance (Euclidean distance)
|
||||||
|
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -31,5 +31,6 @@ static_import {
|
||||||
double l2norm(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$L2Norm
|
double l2norm(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$L2Norm
|
||||||
double cosineSimilarity(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$CosineSimilarity
|
double cosineSimilarity(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$CosineSimilarity
|
||||||
double dotProduct(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$DotProduct
|
double dotProduct(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$DotProduct
|
||||||
|
double hamming(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$Hamming
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -219,3 +219,36 @@ setup:
|
||||||
|
|
||||||
- match: {hits.hits.2._id: "2"}
|
- match: {hits.hits.2._id: "2"}
|
||||||
- close_to: {hits.hits.2._score: {value: 186.34454, error: 0.01}}
|
- close_to: {hits.hits.2._score: {value: 186.34454, error: 0.01}}
|
||||||
|
---
|
||||||
|
"Test hamming distance fails on float":
|
||||||
|
- requires:
|
||||||
|
cluster_features: ["script.hamming"]
|
||||||
|
reason: "support for hamming distance added in 8.15"
|
||||||
|
- do:
|
||||||
|
headers:
|
||||||
|
Content-Type: application/json
|
||||||
|
catch: bad_request
|
||||||
|
search:
|
||||||
|
body:
|
||||||
|
query:
|
||||||
|
script_score:
|
||||||
|
query: {match_all: {} }
|
||||||
|
script:
|
||||||
|
source: "hamming(params.query_vector, 'vector')"
|
||||||
|
params:
|
||||||
|
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
|
||||||
|
|
||||||
|
- do:
|
||||||
|
headers:
|
||||||
|
Content-Type: application/json
|
||||||
|
catch: bad_request
|
||||||
|
search:
|
||||||
|
body:
|
||||||
|
query:
|
||||||
|
script_score:
|
||||||
|
query: {match_all: {} }
|
||||||
|
script:
|
||||||
|
source: "hamming(params.query_vector, 'indexed_vector')"
|
||||||
|
params:
|
||||||
|
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,156 @@
|
||||||
|
setup:
|
||||||
|
- requires:
|
||||||
|
cluster_features: ["script.hamming"]
|
||||||
|
reason: "support for hamming distance added in 8.15"
|
||||||
|
test_runner_features: headers
|
||||||
|
|
||||||
|
- do:
|
||||||
|
indices.create:
|
||||||
|
index: test-index
|
||||||
|
body:
|
||||||
|
settings:
|
||||||
|
number_of_replicas: 0
|
||||||
|
mappings:
|
||||||
|
properties:
|
||||||
|
my_dense_vector:
|
||||||
|
index: false
|
||||||
|
type: dense_vector
|
||||||
|
element_type: byte
|
||||||
|
dims: 5
|
||||||
|
my_dense_vector_indexed:
|
||||||
|
index: true
|
||||||
|
type: dense_vector
|
||||||
|
element_type: byte
|
||||||
|
dims: 5
|
||||||
|
|
||||||
|
- do:
|
||||||
|
index:
|
||||||
|
index: test-index
|
||||||
|
id: "1"
|
||||||
|
body:
|
||||||
|
my_dense_vector: [8, 5, -15, 1, -7]
|
||||||
|
my_dense_vector_indexed: [8, 5, -15, 1, -7]
|
||||||
|
|
||||||
|
- do:
|
||||||
|
index:
|
||||||
|
index: test-index
|
||||||
|
id: "2"
|
||||||
|
body:
|
||||||
|
my_dense_vector: [-1, 115, -3, 4, -128]
|
||||||
|
my_dense_vector_indexed: [-1, 115, -3, 4, -128]
|
||||||
|
|
||||||
|
- do:
|
||||||
|
index:
|
||||||
|
index: test-index
|
||||||
|
id: "3"
|
||||||
|
body:
|
||||||
|
my_dense_vector: [2, 18, -5, 0, -124]
|
||||||
|
my_dense_vector_indexed: [2, 18, -5, 0, -124]
|
||||||
|
|
||||||
|
- do:
|
||||||
|
indices.refresh: {}
|
||||||
|
|
||||||
|
---
|
||||||
|
"Hamming distance":
|
||||||
|
- do:
|
||||||
|
headers:
|
||||||
|
Content-Type: application/json
|
||||||
|
search:
|
||||||
|
rest_total_hits_as_int: true
|
||||||
|
body:
|
||||||
|
query:
|
||||||
|
script_score:
|
||||||
|
query: {match_all: {} }
|
||||||
|
script:
|
||||||
|
source: "hamming(params.query_vector, 'my_dense_vector')"
|
||||||
|
params:
|
||||||
|
query_vector: [0, 111, -13, 14, -124]
|
||||||
|
|
||||||
|
- match: {hits.total: 3}
|
||||||
|
|
||||||
|
- match: {hits.hits.0._id: "2"}
|
||||||
|
- match: {hits.hits.0._score: 17.0}
|
||||||
|
|
||||||
|
- match: {hits.hits.1._id: "1"}
|
||||||
|
- match: {hits.hits.1._score: 16.0}
|
||||||
|
|
||||||
|
- match: {hits.hits.2._id: "3"}
|
||||||
|
- match: {hits.hits.2._score: 11.0}
|
||||||
|
|
||||||
|
|
||||||
|
- do:
|
||||||
|
headers:
|
||||||
|
Content-Type: application/json
|
||||||
|
search:
|
||||||
|
rest_total_hits_as_int: true
|
||||||
|
body:
|
||||||
|
query:
|
||||||
|
script_score:
|
||||||
|
query: {match_all: {} }
|
||||||
|
script:
|
||||||
|
source: "hamming(params.query_vector, 'my_dense_vector_indexed')"
|
||||||
|
params:
|
||||||
|
query_vector: [0, 111, -13, 14, -124]
|
||||||
|
|
||||||
|
- match: {hits.total: 3}
|
||||||
|
|
||||||
|
- match: {hits.hits.0._id: "2"}
|
||||||
|
- match: {hits.hits.0._score: 17.0}
|
||||||
|
|
||||||
|
- match: {hits.hits.1._id: "1"}
|
||||||
|
- match: {hits.hits.1._score: 16.0}
|
||||||
|
|
||||||
|
- match: {hits.hits.2._id: "3"}
|
||||||
|
- match: {hits.hits.2._score: 11.0}
|
||||||
|
---
|
||||||
|
"Hamming distance hexidecimal":
|
||||||
|
- do:
|
||||||
|
headers:
|
||||||
|
Content-Type: application/json
|
||||||
|
search:
|
||||||
|
rest_total_hits_as_int: true
|
||||||
|
body:
|
||||||
|
query:
|
||||||
|
script_score:
|
||||||
|
query: {match_all: {} }
|
||||||
|
script:
|
||||||
|
source: "hamming(params.query_vector, 'my_dense_vector')"
|
||||||
|
params:
|
||||||
|
query_vector: "006ff30e84"
|
||||||
|
|
||||||
|
- match: {hits.total: 3}
|
||||||
|
|
||||||
|
- match: {hits.hits.0._id: "2"}
|
||||||
|
- match: {hits.hits.0._score: 17.0}
|
||||||
|
|
||||||
|
- match: {hits.hits.1._id: "1"}
|
||||||
|
- match: {hits.hits.1._score: 16.0}
|
||||||
|
|
||||||
|
- match: {hits.hits.2._id: "3"}
|
||||||
|
- match: {hits.hits.2._score: 11.0}
|
||||||
|
|
||||||
|
|
||||||
|
- do:
|
||||||
|
headers:
|
||||||
|
Content-Type: application/json
|
||||||
|
search:
|
||||||
|
rest_total_hits_as_int: true
|
||||||
|
body:
|
||||||
|
query:
|
||||||
|
script_score:
|
||||||
|
query: {match_all: {} }
|
||||||
|
script:
|
||||||
|
source: "hamming(params.query_vector, 'my_dense_vector_indexed')"
|
||||||
|
params:
|
||||||
|
query_vector: "006ff30e84"
|
||||||
|
|
||||||
|
- match: {hits.total: 3}
|
||||||
|
|
||||||
|
- match: {hits.hits.0._id: "2"}
|
||||||
|
- match: {hits.hits.0._score: 17.0}
|
||||||
|
|
||||||
|
- match: {hits.hits.1._id: "1"}
|
||||||
|
- match: {hits.hits.1._score: 16.0}
|
||||||
|
|
||||||
|
- match: {hits.hits.2._id: "3"}
|
||||||
|
- match: {hits.hits.2._score: 11.0}
|
|
@ -431,6 +431,7 @@ module org.elasticsearch.server {
|
||||||
org.elasticsearch.indices.IndicesFeatures,
|
org.elasticsearch.indices.IndicesFeatures,
|
||||||
org.elasticsearch.action.admin.cluster.allocation.AllocationStatsFeatures,
|
org.elasticsearch.action.admin.cluster.allocation.AllocationStatsFeatures,
|
||||||
org.elasticsearch.index.mapper.MapperFeatures,
|
org.elasticsearch.index.mapper.MapperFeatures,
|
||||||
|
org.elasticsearch.script.ScriptFeatures,
|
||||||
org.elasticsearch.search.retriever.RetrieversFeatures,
|
org.elasticsearch.search.retriever.RetrieversFeatures,
|
||||||
org.elasticsearch.reservedstate.service.FileSettingsFeatures;
|
org.elasticsearch.reservedstate.service.FileSettingsFeatures;
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0 and the Server Side Public License, v 1; you may not use this file except
|
||||||
|
* in compliance with, at your election, the Elastic License 2.0 or the Server
|
||||||
|
* Side Public License, v 1.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.script;
|
||||||
|
|
||||||
|
import org.elasticsearch.features.FeatureSpecification;
|
||||||
|
import org.elasticsearch.features.NodeFeature;
|
||||||
|
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
public final class ScriptFeatures implements FeatureSpecification {
|
||||||
|
@Override
|
||||||
|
public Set<NodeFeature> getFeatures() {
|
||||||
|
return Set.of(VectorScoreScriptUtils.HAMMING_DISTANCE_FUNCTION);
|
||||||
|
}
|
||||||
|
}
|
|
@ -9,6 +9,8 @@
|
||||||
package org.elasticsearch.script;
|
package org.elasticsearch.script;
|
||||||
|
|
||||||
import org.elasticsearch.ExceptionsHelper;
|
import org.elasticsearch.ExceptionsHelper;
|
||||||
|
import org.elasticsearch.features.NodeFeature;
|
||||||
|
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
|
||||||
import org.elasticsearch.script.field.vectors.DenseVector;
|
import org.elasticsearch.script.field.vectors.DenseVector;
|
||||||
import org.elasticsearch.script.field.vectors.DenseVectorDocValuesField;
|
import org.elasticsearch.script.field.vectors.DenseVectorDocValuesField;
|
||||||
|
|
||||||
|
@ -18,6 +20,8 @@ import java.util.List;
|
||||||
|
|
||||||
public class VectorScoreScriptUtils {
|
public class VectorScoreScriptUtils {
|
||||||
|
|
||||||
|
public static final NodeFeature HAMMING_DISTANCE_FUNCTION = new NodeFeature("script.hamming");
|
||||||
|
|
||||||
public static class DenseVectorFunction {
|
public static class DenseVectorFunction {
|
||||||
protected final ScoreScript scoreScript;
|
protected final ScoreScript scoreScript;
|
||||||
protected final DenseVectorDocValuesField field;
|
protected final DenseVectorDocValuesField field;
|
||||||
|
@ -187,6 +191,52 @@ public class VectorScoreScriptUtils {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Calculate Hamming distances between a query's dense vector and documents' dense vectors
|
||||||
|
public interface HammingDistanceInterface {
|
||||||
|
int hamming();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class ByteHammingDistance extends ByteDenseVectorFunction implements HammingDistanceInterface {
|
||||||
|
|
||||||
|
public ByteHammingDistance(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
|
||||||
|
super(scoreScript, field, queryVector);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ByteHammingDistance(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
|
||||||
|
super(scoreScript, field, queryVector);
|
||||||
|
}
|
||||||
|
|
||||||
|
public int hamming() {
|
||||||
|
setNextVector();
|
||||||
|
return field.get().hamming(queryVector);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static final class Hamming {
|
||||||
|
|
||||||
|
private final HammingDistanceInterface function;
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public Hamming(ScoreScript scoreScript, Object queryVector, String fieldName) {
|
||||||
|
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
|
||||||
|
if (field.getElementType() != DenseVectorFieldMapper.ElementType.BYTE) {
|
||||||
|
throw new IllegalArgumentException("hamming distance is only supported for byte vectors");
|
||||||
|
}
|
||||||
|
if (queryVector instanceof List) {
|
||||||
|
function = new ByteHammingDistance(scoreScript, field, (List<Number>) queryVector);
|
||||||
|
} else if (queryVector instanceof String s) {
|
||||||
|
byte[] parsedQueryVector = HexFormat.of().parseHex(s);
|
||||||
|
function = new ByteHammingDistance(scoreScript, field, parsedQueryVector);
|
||||||
|
} else {
|
||||||
|
throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public double hamming() {
|
||||||
|
return function.hamming();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Calculate l2 norm (Manhattan distance) between a query's dense vector and documents' dense vectors
|
// Calculate l2 norm (Manhattan distance) between a query's dense vector and documents' dense vectors
|
||||||
public interface L2NormInterface {
|
public interface L2NormInterface {
|
||||||
double l2norm();
|
double l2norm();
|
||||||
|
|
|
@ -83,6 +83,16 @@ public class BinaryDenseVector implements DenseVector {
|
||||||
return l1norm;
|
return l1norm;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hamming(byte[] queryVector) {
|
||||||
|
throw new UnsupportedOperationException("hamming distance is not supported for float vectors");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hamming(List<Number> queryVector) {
|
||||||
|
throw new UnsupportedOperationException("hamming distance is not supported for float vectors");
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public double l2Norm(byte[] queryVector) {
|
public double l2Norm(byte[] queryVector) {
|
||||||
throw new UnsupportedOperationException("use [double l2Norm(float[] queryVector)] instead");
|
throw new UnsupportedOperationException("use [double l2Norm(float[] queryVector)] instead");
|
||||||
|
|
|
@ -100,6 +100,20 @@ public class ByteBinaryDenseVector implements DenseVector {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hamming(byte[] queryVector) {
|
||||||
|
return VectorUtil.xorBitCount(queryVector, vectorValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hamming(List<Number> queryVector) {
|
||||||
|
int distance = 0;
|
||||||
|
for (int i = 0; i < queryVector.size(); i++) {
|
||||||
|
distance += Integer.bitCount((queryVector.get(i).intValue() ^ vectorValue[i]) & 0xFF);
|
||||||
|
}
|
||||||
|
return distance;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public double l2Norm(byte[] queryVector) {
|
public double l2Norm(byte[] queryVector) {
|
||||||
return Math.sqrt(VectorUtil.squareDistance(queryVector, vectorValue));
|
return Math.sqrt(VectorUtil.squareDistance(queryVector, vectorValue));
|
||||||
|
|
|
@ -101,6 +101,20 @@ public class ByteKnnDenseVector implements DenseVector {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hamming(byte[] queryVector) {
|
||||||
|
return VectorUtil.xorBitCount(queryVector, docVector);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hamming(List<Number> queryVector) {
|
||||||
|
int distance = 0;
|
||||||
|
for (int i = 0; i < queryVector.size(); i++) {
|
||||||
|
distance += Integer.bitCount((queryVector.get(i).intValue() ^ docVector[i]) & 0xFF);
|
||||||
|
}
|
||||||
|
return distance;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public double l2Norm(byte[] queryVector) {
|
public double l2Norm(byte[] queryVector) {
|
||||||
return Math.sqrt(VectorUtil.squareDistance(docVector, queryVector));
|
return Math.sqrt(VectorUtil.squareDistance(docVector, queryVector));
|
||||||
|
|
|
@ -14,8 +14,7 @@ import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* DenseVector value type for the painless.
|
* DenseVector value type for the painless.
|
||||||
*/
|
* dotProduct, l1Norm, l2Norm, cosineSimilarity have three flavors depending on the type of the queryVector
|
||||||
/* dotProduct, l1Norm, l2Norm, cosineSimilarity have three flavors depending on the type of the queryVector
|
|
||||||
* 1) float[], this is for the ScoreScriptUtils class bindings which have converted a List based query vector into an array
|
* 1) float[], this is for the ScoreScriptUtils class bindings which have converted a List based query vector into an array
|
||||||
* 2) List, A painless script will typically use Lists since they are easy to pass as params and have an easy
|
* 2) List, A painless script will typically use Lists since they are easy to pass as params and have an easy
|
||||||
* literal syntax. Working with Lists directly, instead of converting to a float[], trades off runtime operations against
|
* literal syntax. Working with Lists directly, instead of converting to a float[], trades off runtime operations against
|
||||||
|
@ -74,6 +73,24 @@ public interface DenseVector {
|
||||||
throw new IllegalArgumentException(badQueryVectorType(queryVector));
|
throw new IllegalArgumentException(badQueryVectorType(queryVector));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int hamming(byte[] queryVector);
|
||||||
|
|
||||||
|
int hamming(List<Number> queryVector);
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
default int hamming(Object queryVector) {
|
||||||
|
if (queryVector instanceof List<?> list) {
|
||||||
|
checkDimensions(getDims(), list.size());
|
||||||
|
return hamming((List<Number>) list);
|
||||||
|
}
|
||||||
|
if (queryVector instanceof byte[] bytes) {
|
||||||
|
checkDimensions(getDims(), bytes.length);
|
||||||
|
return hamming(bytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new IllegalArgumentException(badQueryVectorType(queryVector));
|
||||||
|
}
|
||||||
|
|
||||||
double l2Norm(byte[] queryVector);
|
double l2Norm(byte[] queryVector);
|
||||||
|
|
||||||
double l2Norm(float[] queryVector);
|
double l2Norm(float[] queryVector);
|
||||||
|
@ -231,6 +248,16 @@ public interface DenseVector {
|
||||||
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
|
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hamming(byte[] queryVector) {
|
||||||
|
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hamming(List<Number> queryVector) {
|
||||||
|
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public double l2Norm(byte[] queryVector) {
|
public double l2Norm(byte[] queryVector) {
|
||||||
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
|
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
|
||||||
|
|
|
@ -85,6 +85,16 @@ public class KnnDenseVector implements DenseVector {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hamming(byte[] queryVector) {
|
||||||
|
throw new UnsupportedOperationException("hamming distance is not supported for float vectors");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hamming(List<Number> queryVector) {
|
||||||
|
throw new UnsupportedOperationException("hamming distance is not supported for float vectors");
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public double l2Norm(byte[] queryVector) {
|
public double l2Norm(byte[] queryVector) {
|
||||||
throw new UnsupportedOperationException("use [double l2Norm(float[] queryVector)] instead");
|
throw new UnsupportedOperationException("use [double l2Norm(float[] queryVector)] instead");
|
||||||
|
|
|
@ -15,4 +15,5 @@ org.elasticsearch.indices.IndicesFeatures
|
||||||
org.elasticsearch.action.admin.cluster.allocation.AllocationStatsFeatures
|
org.elasticsearch.action.admin.cluster.allocation.AllocationStatsFeatures
|
||||||
org.elasticsearch.index.mapper.MapperFeatures
|
org.elasticsearch.index.mapper.MapperFeatures
|
||||||
org.elasticsearch.search.retriever.RetrieversFeatures
|
org.elasticsearch.search.retriever.RetrieversFeatures
|
||||||
|
org.elasticsearch.script.ScriptFeatures
|
||||||
org.elasticsearch.reservedstate.service.FileSettingsFeatures
|
org.elasticsearch.reservedstate.service.FileSettingsFeatures
|
||||||
|
|
|
@ -15,6 +15,7 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType
|
||||||
import org.elasticsearch.index.mapper.vectors.KnnDenseVectorScriptDocValuesTests;
|
import org.elasticsearch.index.mapper.vectors.KnnDenseVectorScriptDocValuesTests;
|
||||||
import org.elasticsearch.script.VectorScoreScriptUtils.CosineSimilarity;
|
import org.elasticsearch.script.VectorScoreScriptUtils.CosineSimilarity;
|
||||||
import org.elasticsearch.script.VectorScoreScriptUtils.DotProduct;
|
import org.elasticsearch.script.VectorScoreScriptUtils.DotProduct;
|
||||||
|
import org.elasticsearch.script.VectorScoreScriptUtils.Hamming;
|
||||||
import org.elasticsearch.script.VectorScoreScriptUtils.L1Norm;
|
import org.elasticsearch.script.VectorScoreScriptUtils.L1Norm;
|
||||||
import org.elasticsearch.script.VectorScoreScriptUtils.L2Norm;
|
import org.elasticsearch.script.VectorScoreScriptUtils.L2Norm;
|
||||||
import org.elasticsearch.script.field.vectors.BinaryDenseVectorDocValuesField;
|
import org.elasticsearch.script.field.vectors.BinaryDenseVectorDocValuesField;
|
||||||
|
@ -112,6 +113,12 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
|
||||||
containsString("query vector has a different number of dimensions [2] than the document vectors [5]")
|
containsString("query vector has a different number of dimensions [2] than the document vectors [5]")
|
||||||
);
|
);
|
||||||
|
|
||||||
|
e = expectThrows(IllegalArgumentException.class, () -> new Hamming(scoreScript, queryVector, fieldName));
|
||||||
|
assertThat(e.getMessage(), containsString("hamming distance is only supported for byte vectors"));
|
||||||
|
|
||||||
|
e = expectThrows(IllegalArgumentException.class, () -> new Hamming(scoreScript, invalidQueryVector, fieldName));
|
||||||
|
assertThat(e.getMessage(), containsString("hamming distance is only supported for byte vectors"));
|
||||||
|
|
||||||
// Check scripting infrastructure integration
|
// Check scripting infrastructure integration
|
||||||
DotProduct dotProduct = new DotProduct(scoreScript, queryVector, fieldName);
|
DotProduct dotProduct = new DotProduct(scoreScript, queryVector, fieldName);
|
||||||
assertEquals(65425.6249, dotProduct.dotProduct(), 0.001);
|
assertEquals(65425.6249, dotProduct.dotProduct(), 0.001);
|
||||||
|
@ -199,6 +206,11 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
|
||||||
e.getMessage(),
|
e.getMessage(),
|
||||||
containsString("query vector has a different number of dimensions [2] than the document vectors [5]")
|
containsString("query vector has a different number of dimensions [2] than the document vectors [5]")
|
||||||
);
|
);
|
||||||
|
e = expectThrows(IllegalArgumentException.class, () -> new Hamming(scoreScript, invalidQueryVector, fieldName));
|
||||||
|
assertThat(
|
||||||
|
e.getMessage(),
|
||||||
|
containsString("query vector has a different number of dimensions [2] than the document vectors [5]")
|
||||||
|
);
|
||||||
|
|
||||||
// Check scripting infrastructure integration
|
// Check scripting infrastructure integration
|
||||||
assertEquals(17382.0, new DotProduct(scoreScript, queryVector, fieldName).dotProduct(), 0.001);
|
assertEquals(17382.0, new DotProduct(scoreScript, queryVector, fieldName).dotProduct(), 0.001);
|
||||||
|
@ -207,6 +219,8 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
|
||||||
assertEquals(135.0, new L1Norm(scoreScript, hexidecimalString, fieldName).l1norm(), 0.001);
|
assertEquals(135.0, new L1Norm(scoreScript, hexidecimalString, fieldName).l1norm(), 0.001);
|
||||||
assertEquals(116.897, new L2Norm(scoreScript, queryVector, fieldName).l2norm(), 0.001);
|
assertEquals(116.897, new L2Norm(scoreScript, queryVector, fieldName).l2norm(), 0.001);
|
||||||
assertEquals(116.897, new L2Norm(scoreScript, hexidecimalString, fieldName).l2norm(), 0.001);
|
assertEquals(116.897, new L2Norm(scoreScript, hexidecimalString, fieldName).l2norm(), 0.001);
|
||||||
|
assertEquals(13.0, new Hamming(scoreScript, queryVector, fieldName).hamming(), 0.001);
|
||||||
|
assertEquals(13.0, new Hamming(scoreScript, hexidecimalString, fieldName).hamming(), 0.001);
|
||||||
DotProduct dotProduct = new DotProduct(scoreScript, queryVector, fieldName);
|
DotProduct dotProduct = new DotProduct(scoreScript, queryVector, fieldName);
|
||||||
when(scoreScript._getDocId()).thenReturn(1);
|
when(scoreScript._getDocId()).thenReturn(1);
|
||||||
e = expectThrows(IllegalArgumentException.class, dotProduct::dotProduct);
|
e = expectThrows(IllegalArgumentException.class, dotProduct::dotProduct);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue