Adding hamming distance function to painless for dense_vector fields (#109359)

This adds `hamming` distances, the pop-count of `xor` byte vectors as a
first class citizen in painless. 

For byte vectors, this means that we can compute hamming distances via
script_score (aka, brute-force).

The implementation of `hamming` is the same that is available in Lucene,
and when lucene 9.11 is merged, we should update our logic where
applicable to utilize it.

NOTE: this does not yet add hamming distance as a metric for indexed
vectors. This will be a future PR after the Lucene 9.11 upgrade.
This commit is contained in:
Benjamin Trent 2024-06-17 13:41:20 -04:00 committed by GitHub
parent bbcf73028e
commit acc99302c6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 438 additions and 7 deletions

View file

@ -56,7 +56,7 @@ public class DistanceFunctionBenchmark {
@Param({ "96" })
private int dims;
@Param({ "dot", "cosine", "l1", "l2" })
@Param({ "dot", "cosine", "l1", "l2", "hamming" })
private String function;
@Param({ "knn", "binary" })
@ -330,6 +330,18 @@ public class DistanceFunctionBenchmark {
}
}
private static class HammingKnnByteBenchmarkFunction extends KnnByteBenchmarkFunction {
private HammingKnnByteBenchmarkFunction(int dims) {
super(dims);
}
@Override
public void execute(Consumer<Object> consumer) {
new ByteKnnDenseVector(docVector).hamming(queryVector);
}
}
private static class L1BinaryFloatBenchmarkFunction extends BinaryFloatBenchmarkFunction {
private L1BinaryFloatBenchmarkFunction(int dims) {
@ -354,6 +366,18 @@ public class DistanceFunctionBenchmark {
}
}
private static class HammingBinaryByteBenchmarkFunction extends BinaryByteBenchmarkFunction {
private HammingBinaryByteBenchmarkFunction(int dims) {
super(dims);
}
@Override
public void execute(Consumer<Object> consumer) {
new ByteBinaryDenseVector(vectorValue, docVector, dims).hamming(queryVector);
}
}
private static class L2KnnFloatBenchmarkFunction extends KnnFloatBenchmarkFunction {
private L2KnnFloatBenchmarkFunction(int dims) {
@ -454,6 +478,11 @@ public class DistanceFunctionBenchmark {
case "binary" -> new L2BinaryByteBenchmarkFunction(dims);
default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
};
case "hamming" -> benchmarkFunction = switch (type) {
case "knn" -> new HammingKnnByteBenchmarkFunction(dims);
case "binary" -> new HammingBinaryByteBenchmarkFunction(dims);
default -> throw new UnsupportedOperationException("unexpected type [" + type + "]");
};
default -> throw new UnsupportedOperationException("unexpected function [" + function + "]");
}
}

View file

@ -0,0 +1,5 @@
pr: 109359
summary: Adding hamming distance function to painless for `dense_vector` fields
area: Vector Search
type: enhancement
issues: []

View file

@ -23,6 +23,7 @@ The following methods are directly callable without a class/instance qualifier.
* double dotProduct(Object *, String *)
* double l1norm(Object *, String *)
* double l2norm(Object *, String *)
* double hamming(Object *, String *)
* double randomScore(int *)
* double randomScore(int *, String *)
* double saturation(double, double)

View file

@ -12,9 +12,10 @@ This is the list of available vector functions and vector access methods:
1. <<vector-functions-cosine,`cosineSimilarity`>> calculates cosine similarity
2. <<vector-functions-dot-product,`dotProduct`>> calculates dot product
3. <<vector-functions-l1,`l1norm`>> calculates L^1^ distance
4. <<vector-functions-l2,`l2norm`>> - calculates L^2^ distance
5. <<vector-functions-accessing-vectors,`doc[<field>].vectorValue`>> returns a vector's value as an array of floats
6. <<vector-functions-accessing-vectors,`doc[<field>].magnitude`>> returns a vector's magnitude
4. <<vector-functions-hamming,`hamming`>> calculates Hamming distance
5. <<vector-functions-l2,`l2norm`>> - calculates L^2^ distance
6. <<vector-functions-accessing-vectors,`doc[<field>].vectorValue`>> returns a vector's value as an array of floats
7. <<vector-functions-accessing-vectors,`doc[<field>].magnitude`>> returns a vector's magnitude
NOTE: The recommended way to access dense vectors is through the
`cosineSimilarity`, `dotProduct`, `l1norm` or `l2norm` functions. Please note
@ -35,8 +36,15 @@ PUT my-index-000001
"properties": {
"my_dense_vector": {
"type": "dense_vector",
"index": false,
"dims": 3
},
"my_byte_dense_vector": {
"type": "dense_vector",
"index": false,
"dims": 3,
"element_type": "byte"
},
"status" : {
"type" : "keyword"
}
@ -47,12 +55,14 @@ PUT my-index-000001
PUT my-index-000001/_doc/1
{
"my_dense_vector": [0.5, 10, 6],
"my_byte_dense_vector": [0, 10, 6],
"status" : "published"
}
PUT my-index-000001/_doc/2
{
"my_dense_vector": [-0.5, 10, 10],
"my_byte_dense_vector": [0, 10, 10],
"status" : "published"
}
@ -179,6 +189,40 @@ we reversed the output from `l1norm` and `l2norm`. Also, to avoid
division by 0 when a document vector matches the query exactly,
we added `1` in the denominator.
[[vector-functions-hamming]]
====== Hamming distance
The `hamming` function calculates {wikipedia}/Hamming_distance[Hamming distance] between a given query vector and
document vectors. It is only available for byte vectors.
[source,console]
--------------------------------------------------
GET my-index-000001/_search
{
"query": {
"script_score": {
"query" : {
"bool" : {
"filter" : {
"term" : {
"status" : "published"
}
}
}
},
"script": {
"source": "(24 - hamming(params.queryVector, 'my_byte_dense_vector')) / 24", <1>
"params": {
"queryVector": [4, 3, 0]
}
}
}
}
}
--------------------------------------------------
<1> Calculate the Hamming distance and normalize it by the bits to get a score between 0 and 1.
[[vector-functions-l2]]
====== L^2^ distance (Euclidean distance)

File diff suppressed because one or more lines are too long

View file

@ -31,5 +31,6 @@ static_import {
double l2norm(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$L2Norm
double cosineSimilarity(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$CosineSimilarity
double dotProduct(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$DotProduct
double hamming(org.elasticsearch.script.ScoreScript, Object, String) bound_to org.elasticsearch.script.VectorScoreScriptUtils$Hamming
}

View file

@ -219,3 +219,36 @@ setup:
- match: {hits.hits.2._id: "2"}
- close_to: {hits.hits.2._score: {value: 186.34454, error: 0.01}}
---
"Test hamming distance fails on float":
- requires:
cluster_features: ["script.hamming"]
reason: "support for hamming distance added in 8.15"
- do:
headers:
Content-Type: application/json
catch: bad_request
search:
body:
query:
script_score:
query: {match_all: {} }
script:
source: "hamming(params.query_vector, 'vector')"
params:
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
- do:
headers:
Content-Type: application/json
catch: bad_request
search:
body:
query:
script_score:
query: {match_all: {} }
script:
source: "hamming(params.query_vector, 'indexed_vector')"
params:
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]

View file

@ -0,0 +1,156 @@
setup:
- requires:
cluster_features: ["script.hamming"]
reason: "support for hamming distance added in 8.15"
test_runner_features: headers
- do:
indices.create:
index: test-index
body:
settings:
number_of_replicas: 0
mappings:
properties:
my_dense_vector:
index: false
type: dense_vector
element_type: byte
dims: 5
my_dense_vector_indexed:
index: true
type: dense_vector
element_type: byte
dims: 5
- do:
index:
index: test-index
id: "1"
body:
my_dense_vector: [8, 5, -15, 1, -7]
my_dense_vector_indexed: [8, 5, -15, 1, -7]
- do:
index:
index: test-index
id: "2"
body:
my_dense_vector: [-1, 115, -3, 4, -128]
my_dense_vector_indexed: [-1, 115, -3, 4, -128]
- do:
index:
index: test-index
id: "3"
body:
my_dense_vector: [2, 18, -5, 0, -124]
my_dense_vector_indexed: [2, 18, -5, 0, -124]
- do:
indices.refresh: {}
---
"Hamming distance":
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "hamming(params.query_vector, 'my_dense_vector')"
params:
query_vector: [0, 111, -13, 14, -124]
- match: {hits.total: 3}
- match: {hits.hits.0._id: "2"}
- match: {hits.hits.0._score: 17.0}
- match: {hits.hits.1._id: "1"}
- match: {hits.hits.1._score: 16.0}
- match: {hits.hits.2._id: "3"}
- match: {hits.hits.2._score: 11.0}
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "hamming(params.query_vector, 'my_dense_vector_indexed')"
params:
query_vector: [0, 111, -13, 14, -124]
- match: {hits.total: 3}
- match: {hits.hits.0._id: "2"}
- match: {hits.hits.0._score: 17.0}
- match: {hits.hits.1._id: "1"}
- match: {hits.hits.1._score: 16.0}
- match: {hits.hits.2._id: "3"}
- match: {hits.hits.2._score: 11.0}
---
"Hamming distance hexidecimal":
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "hamming(params.query_vector, 'my_dense_vector')"
params:
query_vector: "006ff30e84"
- match: {hits.total: 3}
- match: {hits.hits.0._id: "2"}
- match: {hits.hits.0._score: 17.0}
- match: {hits.hits.1._id: "1"}
- match: {hits.hits.1._score: 16.0}
- match: {hits.hits.2._id: "3"}
- match: {hits.hits.2._score: 11.0}
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "hamming(params.query_vector, 'my_dense_vector_indexed')"
params:
query_vector: "006ff30e84"
- match: {hits.total: 3}
- match: {hits.hits.0._id: "2"}
- match: {hits.hits.0._score: 17.0}
- match: {hits.hits.1._id: "1"}
- match: {hits.hits.1._score: 16.0}
- match: {hits.hits.2._id: "3"}
- match: {hits.hits.2._score: 11.0}

View file

@ -431,6 +431,7 @@ module org.elasticsearch.server {
org.elasticsearch.indices.IndicesFeatures,
org.elasticsearch.action.admin.cluster.allocation.AllocationStatsFeatures,
org.elasticsearch.index.mapper.MapperFeatures,
org.elasticsearch.script.ScriptFeatures,
org.elasticsearch.search.retriever.RetrieversFeatures,
org.elasticsearch.reservedstate.service.FileSettingsFeatures;

View file

@ -0,0 +1,21 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.script;
import org.elasticsearch.features.FeatureSpecification;
import org.elasticsearch.features.NodeFeature;
import java.util.Set;
public final class ScriptFeatures implements FeatureSpecification {
@Override
public Set<NodeFeature> getFeatures() {
return Set.of(VectorScoreScriptUtils.HAMMING_DISTANCE_FUNCTION);
}
}

View file

@ -9,6 +9,8 @@
package org.elasticsearch.script;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.script.field.vectors.DenseVector;
import org.elasticsearch.script.field.vectors.DenseVectorDocValuesField;
@ -18,6 +20,8 @@ import java.util.List;
public class VectorScoreScriptUtils {
public static final NodeFeature HAMMING_DISTANCE_FUNCTION = new NodeFeature("script.hamming");
public static class DenseVectorFunction {
protected final ScoreScript scoreScript;
protected final DenseVectorDocValuesField field;
@ -187,6 +191,52 @@ public class VectorScoreScriptUtils {
}
}
// Calculate Hamming distances between a query's dense vector and documents' dense vectors
public interface HammingDistanceInterface {
int hamming();
}
public static class ByteHammingDistance extends ByteDenseVectorFunction implements HammingDistanceInterface {
public ByteHammingDistance(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
super(scoreScript, field, queryVector);
}
public ByteHammingDistance(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
super(scoreScript, field, queryVector);
}
public int hamming() {
setNextVector();
return field.get().hamming(queryVector);
}
}
public static final class Hamming {
private final HammingDistanceInterface function;
@SuppressWarnings("unchecked")
public Hamming(ScoreScript scoreScript, Object queryVector, String fieldName) {
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
if (field.getElementType() != DenseVectorFieldMapper.ElementType.BYTE) {
throw new IllegalArgumentException("hamming distance is only supported for byte vectors");
}
if (queryVector instanceof List) {
function = new ByteHammingDistance(scoreScript, field, (List<Number>) queryVector);
} else if (queryVector instanceof String s) {
byte[] parsedQueryVector = HexFormat.of().parseHex(s);
function = new ByteHammingDistance(scoreScript, field, parsedQueryVector);
} else {
throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName());
}
}
public double hamming() {
return function.hamming();
}
}
// Calculate l2 norm (Manhattan distance) between a query's dense vector and documents' dense vectors
public interface L2NormInterface {
double l2norm();

View file

@ -83,6 +83,16 @@ public class BinaryDenseVector implements DenseVector {
return l1norm;
}
@Override
public int hamming(byte[] queryVector) {
throw new UnsupportedOperationException("hamming distance is not supported for float vectors");
}
@Override
public int hamming(List<Number> queryVector) {
throw new UnsupportedOperationException("hamming distance is not supported for float vectors");
}
@Override
public double l2Norm(byte[] queryVector) {
throw new UnsupportedOperationException("use [double l2Norm(float[] queryVector)] instead");

View file

@ -100,6 +100,20 @@ public class ByteBinaryDenseVector implements DenseVector {
return result;
}
@Override
public int hamming(byte[] queryVector) {
return VectorUtil.xorBitCount(queryVector, vectorValue);
}
@Override
public int hamming(List<Number> queryVector) {
int distance = 0;
for (int i = 0; i < queryVector.size(); i++) {
distance += Integer.bitCount((queryVector.get(i).intValue() ^ vectorValue[i]) & 0xFF);
}
return distance;
}
@Override
public double l2Norm(byte[] queryVector) {
return Math.sqrt(VectorUtil.squareDistance(queryVector, vectorValue));

View file

@ -101,6 +101,20 @@ public class ByteKnnDenseVector implements DenseVector {
return result;
}
@Override
public int hamming(byte[] queryVector) {
return VectorUtil.xorBitCount(queryVector, docVector);
}
@Override
public int hamming(List<Number> queryVector) {
int distance = 0;
for (int i = 0; i < queryVector.size(); i++) {
distance += Integer.bitCount((queryVector.get(i).intValue() ^ docVector[i]) & 0xFF);
}
return distance;
}
@Override
public double l2Norm(byte[] queryVector) {
return Math.sqrt(VectorUtil.squareDistance(docVector, queryVector));

View file

@ -14,8 +14,7 @@ import java.util.List;
/**
* DenseVector value type for the painless.
*/
/* dotProduct, l1Norm, l2Norm, cosineSimilarity have three flavors depending on the type of the queryVector
* dotProduct, l1Norm, l2Norm, cosineSimilarity have three flavors depending on the type of the queryVector
* 1) float[], this is for the ScoreScriptUtils class bindings which have converted a List based query vector into an array
* 2) List, A painless script will typically use Lists since they are easy to pass as params and have an easy
* literal syntax. Working with Lists directly, instead of converting to a float[], trades off runtime operations against
@ -74,6 +73,24 @@ public interface DenseVector {
throw new IllegalArgumentException(badQueryVectorType(queryVector));
}
int hamming(byte[] queryVector);
int hamming(List<Number> queryVector);
@SuppressWarnings("unchecked")
default int hamming(Object queryVector) {
if (queryVector instanceof List<?> list) {
checkDimensions(getDims(), list.size());
return hamming((List<Number>) list);
}
if (queryVector instanceof byte[] bytes) {
checkDimensions(getDims(), bytes.length);
return hamming(bytes);
}
throw new IllegalArgumentException(badQueryVectorType(queryVector));
}
double l2Norm(byte[] queryVector);
double l2Norm(float[] queryVector);
@ -231,6 +248,16 @@ public interface DenseVector {
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
}
@Override
public int hamming(byte[] queryVector) {
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
}
@Override
public int hamming(List<Number> queryVector) {
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
}
@Override
public double l2Norm(byte[] queryVector) {
throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);

View file

@ -85,6 +85,16 @@ public class KnnDenseVector implements DenseVector {
return result;
}
@Override
public int hamming(byte[] queryVector) {
throw new UnsupportedOperationException("hamming distance is not supported for float vectors");
}
@Override
public int hamming(List<Number> queryVector) {
throw new UnsupportedOperationException("hamming distance is not supported for float vectors");
}
@Override
public double l2Norm(byte[] queryVector) {
throw new UnsupportedOperationException("use [double l2Norm(float[] queryVector)] instead");

View file

@ -15,4 +15,5 @@ org.elasticsearch.indices.IndicesFeatures
org.elasticsearch.action.admin.cluster.allocation.AllocationStatsFeatures
org.elasticsearch.index.mapper.MapperFeatures
org.elasticsearch.search.retriever.RetrieversFeatures
org.elasticsearch.script.ScriptFeatures
org.elasticsearch.reservedstate.service.FileSettingsFeatures

View file

@ -15,6 +15,7 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType
import org.elasticsearch.index.mapper.vectors.KnnDenseVectorScriptDocValuesTests;
import org.elasticsearch.script.VectorScoreScriptUtils.CosineSimilarity;
import org.elasticsearch.script.VectorScoreScriptUtils.DotProduct;
import org.elasticsearch.script.VectorScoreScriptUtils.Hamming;
import org.elasticsearch.script.VectorScoreScriptUtils.L1Norm;
import org.elasticsearch.script.VectorScoreScriptUtils.L2Norm;
import org.elasticsearch.script.field.vectors.BinaryDenseVectorDocValuesField;
@ -112,6 +113,12 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
containsString("query vector has a different number of dimensions [2] than the document vectors [5]")
);
e = expectThrows(IllegalArgumentException.class, () -> new Hamming(scoreScript, queryVector, fieldName));
assertThat(e.getMessage(), containsString("hamming distance is only supported for byte vectors"));
e = expectThrows(IllegalArgumentException.class, () -> new Hamming(scoreScript, invalidQueryVector, fieldName));
assertThat(e.getMessage(), containsString("hamming distance is only supported for byte vectors"));
// Check scripting infrastructure integration
DotProduct dotProduct = new DotProduct(scoreScript, queryVector, fieldName);
assertEquals(65425.6249, dotProduct.dotProduct(), 0.001);
@ -199,6 +206,11 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
e.getMessage(),
containsString("query vector has a different number of dimensions [2] than the document vectors [5]")
);
e = expectThrows(IllegalArgumentException.class, () -> new Hamming(scoreScript, invalidQueryVector, fieldName));
assertThat(
e.getMessage(),
containsString("query vector has a different number of dimensions [2] than the document vectors [5]")
);
// Check scripting infrastructure integration
assertEquals(17382.0, new DotProduct(scoreScript, queryVector, fieldName).dotProduct(), 0.001);
@ -207,6 +219,8 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
assertEquals(135.0, new L1Norm(scoreScript, hexidecimalString, fieldName).l1norm(), 0.001);
assertEquals(116.897, new L2Norm(scoreScript, queryVector, fieldName).l2norm(), 0.001);
assertEquals(116.897, new L2Norm(scoreScript, hexidecimalString, fieldName).l2norm(), 0.001);
assertEquals(13.0, new Hamming(scoreScript, queryVector, fieldName).hamming(), 0.001);
assertEquals(13.0, new Hamming(scoreScript, hexidecimalString, fieldName).hamming(), 0.001);
DotProduct dotProduct = new DotProduct(scoreScript, queryVector, fieldName);
when(scoreScript._getDocId()).thenReturn(1);
e = expectThrows(IllegalArgumentException.class, dotProduct::dotProduct);