Adds new bit element_type for dense_vectors (#110059)

This commit adds `bit` vector support by adding `element_type: bit` for
vectors. This new element type works for indexed and non-indexed
vectors. Additionally, it works with `hnsw` and `flat` index types. No
quantization based codec works with this element type, this is
consistent with `byte` vectors.

`bit` vectors accept up to `32768` dimensions in size and expect vectors
that are being indexed to be encoded either as a hexidecimal string or a
`byte[]` array where each element of the `byte` array represents `8`
bits of the vector.

`bit` vectors support script usage and regular query usage. When
indexed, all comparisons done are `xor` and `popcount` summations (aka,
hamming distance), and the scores are transformed and normalized given
the vector dimensions. Note, indexed bit vectors require `l2_norm` to be
the similarity.

For scripts, `l1norm` is the same as `hamming` distance and `l2norm` is
`sqrt(l1norm)`. `dotProduct` and `cosineSimilarity` are not supported.

Note, the dimensions expected by this element_type are always to be
divisible by `8`, and the `byte[]` vectors provided for index must be
have size `dim/8` size, where each byte element represents `8` bits of
the vectors.

closes: https://github.com/elastic/elasticsearch/issues/48322
This commit is contained in:
Benjamin Trent 2024-06-26 14:48:41 -04:00 committed by GitHub
parent 97651dfb9f
commit 5add44d7d1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
38 changed files with 2713 additions and 187 deletions

View file

@ -0,0 +1,32 @@
pr: 110059
summary: Adds new `bit` `element_type` for `dense_vectors`
area: Vector Search
type: feature
issues: []
highlight:
title: Adds new `bit` `element_type` for `dense_vectors`
body: |-
This adds `bit` vector support by adding `element_type: bit` for
vectors. This new element type works for indexed and non-indexed
vectors. Additionally, it works with `hnsw` and `flat` index types. No
quantization based codec works with this element type, this is
consistent with `byte` vectors.
`bit` vectors accept up to `32768` dimensions in size and expect vectors
that are being indexed to be encoded either as a hexidecimal string or a
`byte[]` array where each element of the `byte` array represents `8`
bits of the vector.
`bit` vectors support script usage and regular query usage. When
indexed, all comparisons done are `xor` and `popcount` summations (aka,
hamming distance), and the scores are transformed and normalized given
the vector dimensions.
For scripts, `l1norm` is the same as `hamming` distance and `l2norm` is
`sqrt(l1norm)`. `dotProduct` and `cosineSimilarity` are not supported.
Note, the dimensions expected by this element_type are always to be
divisible by `8`, and the `byte[]` vectors provided for index must be
have size `dim/8` size, where each byte element represents `8` bits of
the vectors.
notable: true

View file

@ -183,11 +183,23 @@ The following mapping parameters are accepted:
`element_type`:: `element_type`::
(Optional, string) (Optional, string)
The data type used to encode vectors. The supported data types are The data type used to encode vectors. The supported data types are
`float` (default) and `byte`. `float` indexes a 4-byte floating-point `float` (default), `byte`, and bit.
value per dimension. `byte` indexes a 1-byte integer value per dimension.
Using `byte` can result in a substantially smaller index size with the .Valid values for `element_type`
trade off of lower precision. Vectors using `byte` require dimensions with [%collapsible%open]
integer values between -128 to 127, inclusive for both indexing and searching. ====
`float`:::
indexes a 4-byte floating-point
value per dimension. This is the default value.
`byte`:::
indexes a 1-byte integer value per dimension.
`bit`:::
indexes a single bit per dimension. Useful for very high-dimensional vectors or models that specifically support bit vectors.
NOTE: when using `bit`, the number of dimensions must be a multiple of 8 and must represent the number of bits.
====
`dims`:: `dims`::
(Optional, integer) (Optional, integer)
@ -205,7 +217,11 @@ API>>. Defaults to `true`.
The vector similarity metric to use in kNN search. Documents are ranked by The vector similarity metric to use in kNN search. Documents are ranked by
their vector field's similarity to the query vector. The `_score` of each their vector field's similarity to the query vector. The `_score` of each
document will be derived from the similarity, in a way that ensures scores are document will be derived from the similarity, in a way that ensures scores are
positive and that a larger score corresponds to a higher ranking. Defaults to `cosine`. positive and that a larger score corresponds to a higher ranking.
Defaults to `l2_norm` when `element_type: bit` otherwise defaults to `cosine`.
NOTE: `bit` vectors only support `l2_norm` as their similarity metric.
+ +
^*^ This parameter can only be specified when `index` is `true`. ^*^ This parameter can only be specified when `index` is `true`.
+ +
@ -217,6 +233,9 @@ Computes similarity based on the L^2^ distance (also known as Euclidean
distance) between the vectors. The document `_score` is computed as distance) between the vectors. The document `_score` is computed as
`1 / (1 + l2_norm(query, vector)^2)`. `1 / (1 + l2_norm(query, vector)^2)`.
For `bit` vectors, instead of using `l2_norm`, the `hamming` distance between the vectors is used. The `_score`
transformation is `(numBits - hamming(a, b)) / numBits`
`dot_product`::: `dot_product`:::
Computes the dot product of two unit vectors. This option provides an optimized way Computes the dot product of two unit vectors. This option provides an optimized way
to perform cosine similarity. The constraints and computed score are defined to perform cosine similarity. The constraints and computed score are defined
@ -320,3 +339,112 @@ any issues, but features in technical preview are not subject to the support SLA
of official GA features. of official GA features.
`dense_vector` fields support <<synthetic-source,synthetic `_source`>> . `dense_vector` fields support <<synthetic-source,synthetic `_source`>> .
[[dense-vector-index-bit]]
==== Indexing & Searching bit vectors
When using `element_type: bit`, this will treat all vectors as bit vectors. Bit vectors utilize only a single
bit per dimension and are internally encoded as bytes. This can be useful for very high-dimensional vectors or models.
When using `bit`, the number of dimensions must be a multiple of 8 and must represent the number of bits. Additionally,
with `bit` vectors, the typical vector similarity values are effectively all scored the same, e.g. with `hamming` distance.
Let's compare two `byte[]` arrays, each representing 40 individual bits.
`[-127, 0, 1, 42, 127]` in bits `1000000100000000000000010010101001111111`
`[127, -127, 0, 1, 42]` in bits `0111111110000001000000000000000100101010`
When comparing these two bit, vectors, we first take the {wikipedia}/Hamming_distance[`hamming` distance].
`xor` result:
```
1000000100000000000000010010101001111111
^
0111111110000001000000000000000100101010
=
1111111010000001000000010010101101010101
```
Then, we gather the count of `1` bits in the `xor` result: `18`. To scale for scoring, we subtract from the total number
of bits and divide by the total number of bits: `(40 - 18) / 40 = 0.55`. This would be the `_score` betwee these two
vectors.
Here is an example of indexing and searching bit vectors:
[source,console]
--------------------------------------------------
PUT my-bit-vectors
{
"mappings": {
"properties": {
"my_vector": {
"type": "dense_vector",
"dims": 40, <1>
"element_type": "bit"
}
}
}
}
--------------------------------------------------
<1> The number of dimensions that represents the number of bits
[source,console]
--------------------------------------------------
POST /my-bit-vectors/_bulk?refresh
{"index": {"_id" : "1"}}
{"my_vector": [127, -127, 0, 1, 42]} <1>
{"index": {"_id" : "2"}}
{"my_vector": "8100012a7f"} <2>
--------------------------------------------------
// TEST[continued]
<1> 5 bytes representing the 40 bit dimensioned vector
<2> A hexidecimal string representing the 40 bit dimensioned vector
Then, when searching, you can use the `knn` query to search for similar bit vectors:
[source,console]
--------------------------------------------------
POST /my-bit-vectors/_search?filter_path=hits.hits
{
"query": {
"knn": {
"query_vector": [127, -127, 0, 1, 42],
"field": "my_vector"
}
}
}
--------------------------------------------------
// TEST[continued]
[source,console-result]
----
{
"hits": {
"hits": [
{
"_index": "my-bit-vectors",
"_id": "1",
"_score": 1.0,
"_source": {
"my_vector": [
127,
-127,
0,
1,
42
]
}
},
{
"_index": "my-bit-vectors",
"_id": "2",
"_score": 0.55,
"_source": {
"my_vector": "8100012a7f"
}
}
]
}
}
----

View file

@ -1,4 +1,3 @@
[role="xpack"]
[[vector-functions]] [[vector-functions]]
===== Functions for vector fields ===== Functions for vector fields
@ -17,6 +16,8 @@ This is the list of available vector functions and vector access methods:
6. <<vector-functions-accessing-vectors,`doc[<field>].vectorValue`>> returns a vector's value as an array of floats 6. <<vector-functions-accessing-vectors,`doc[<field>].vectorValue`>> returns a vector's value as an array of floats
7. <<vector-functions-accessing-vectors,`doc[<field>].magnitude`>> returns a vector's magnitude 7. <<vector-functions-accessing-vectors,`doc[<field>].magnitude`>> returns a vector's magnitude
NOTE: The `cosineSimilarity` and `dotProduct` functions are not supported for `bit` vectors.
NOTE: The recommended way to access dense vectors is through the NOTE: The recommended way to access dense vectors is through the
`cosineSimilarity`, `dotProduct`, `l1norm` or `l2norm` functions. Please note `cosineSimilarity`, `dotProduct`, `l1norm` or `l2norm` functions. Please note
however, that you should call these functions only once per script. For example, however, that you should call these functions only once per script. For example,
@ -193,7 +194,7 @@ we added `1` in the denominator.
====== Hamming distance ====== Hamming distance
The `hamming` function calculates {wikipedia}/Hamming_distance[Hamming distance] between a given query vector and The `hamming` function calculates {wikipedia}/Hamming_distance[Hamming distance] between a given query vector and
document vectors. It is only available for byte vectors. document vectors. It is only available for byte and bit vectors.
[source,console] [source,console]
-------------------------------------------------- --------------------------------------------------
@ -278,10 +279,14 @@ You can access vector values directly through the following functions:
- `doc[<field>].vectorValue` returns a vector's value as an array of floats - `doc[<field>].vectorValue` returns a vector's value as an array of floats
NOTE: For `bit` vectors, it does return a `float[]`, where each element represents 8 bits.
- `doc[<field>].magnitude` returns a vector's magnitude as a float - `doc[<field>].magnitude` returns a vector's magnitude as a float
(for vectors created prior to version 7.5 the magnitude is not stored. (for vectors created prior to version 7.5 the magnitude is not stored.
So this function calculates it anew every time it is called). So this function calculates it anew every time it is called).
NOTE: For `bit` vectors, this is just the square root of the sum of `1` bits.
For example, the script below implements a cosine similarity using these For example, the script below implements a cosine similarity using these
two functions: two functions:
@ -319,3 +324,14 @@ GET my-index-000001/_search
} }
} }
-------------------------------------------------- --------------------------------------------------
[[vector-functions-bit-vectors]]
====== Bit vectors and vector functions
When using `bit` vectors, not all the vector functions are available. The supported functions are:
* <<vector-functions-hamming,`hamming`>> calculates Hamming distance, the sum of the bitwise XOR of the two vectors
* <<vector-functions-l1,`l1norm`>> calculates L^1^ distance, this is simply the `hamming` distance
* <<vector-functions-l2,`l2norm`>> - calculates L^2^ distance, this is the square root of the `hamming` distance
Currently, the `cosineSimilarity` and `dotProduct` functions are not supported for `bit` vectors.

View file

@ -0,0 +1,392 @@
setup:
- requires:
cluster_features: ["mapper.vectors.bit_vectors"]
reason: "support for bit vectors added in 8.15"
test_runner_features: headers
- do:
indices.create:
index: test-index
body:
mappings:
properties:
vector:
type: dense_vector
index: false
element_type: bit
dims: 40
indexed_vector:
type: dense_vector
element_type: bit
dims: 40
index: true
similarity: l2_norm
- do:
index:
index: test-index
id: "1"
body:
vector: [8, 5, -15, 1, -7]
indexed_vector: [8, 5, -15, 1, -7]
- do:
index:
index: test-index
id: "2"
body:
vector: [-1, 115, -3, 4, -128]
indexed_vector: [-1, 115, -3, 4, -128]
- do:
index:
index: test-index
id: "3"
body:
vector: [2, 18, -5, 0, -124]
indexed_vector: [2, 18, -5, 0, -124]
- do:
indices.refresh: {}
---
"Test vector magnitude equality":
- skip:
features: close_to
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "doc['vector'].magnitude"
- match: {hits.total: 3}
- match: {hits.hits.0._id: "2"}
- close_to: {hits.hits.0._score: {value: 4.690416, error: 0.01}}
- match: {hits.hits.1._id: "1"}
- close_to: {hits.hits.1._score: {value: 3.8729835, error: 0.01}}
- match: {hits.hits.2._id: "3"}
- close_to: {hits.hits.2._score: {value: 3.4641016, error: 0.01}}
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "doc['indexed_vector'].magnitude"
- match: {hits.total: 3}
- match: {hits.hits.0._id: "2"}
- close_to: {hits.hits.0._score: {value: 4.690416, error: 0.01}}
- match: {hits.hits.1._id: "1"}
- close_to: {hits.hits.1._score: {value: 3.8729835, error: 0.01}}
- match: {hits.hits.2._id: "3"}
- close_to: {hits.hits.2._score: {value: 3.4641016, error: 0.01}}
---
"Dot Product is not supported":
- do:
catch: bad_request
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "dotProduct(params.query_vector, 'vector')"
params:
query_vector: [0, 111, -13, 14, -124]
- do:
catch: bad_request
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "dotProduct(params.query_vector, 'vector')"
params:
query_vector: "006ff30e84"
---
"Cosine Similarity is not supported":
- do:
catch: bad_request
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "cosineSimilarity(params.query_vector, 'vector')"
params:
query_vector: [0, 111, -13, 14, -124]
- do:
catch: bad_request
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "cosineSimilarity(params.query_vector, 'vector')"
params:
query_vector: "006ff30e84"
- do:
catch: bad_request
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "cosineSimilarity(params.query_vector, 'indexed_vector')"
params:
query_vector: [0, 111, -13, 14, -124]
---
"L1 norm":
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "l1norm(params.query_vector, 'vector')"
params:
query_vector: [0, 111, -13, 14, -124]
- match: {hits.total: 3}
- match: {hits.hits.0._id: "2"}
- match: {hits.hits.0._score: 17.0}
- match: {hits.hits.1._id: "1"}
- match: {hits.hits.1._score: 16.0}
- match: {hits.hits.2._id: "3"}
- match: {hits.hits.2._score: 11.0}
---
"L1 norm hexidecimal":
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "l1norm(params.query_vector, 'vector')"
params:
query_vector: "006ff30e84"
- match: {hits.total: 3}
- match: {hits.hits.0._id: "2"}
- match: {hits.hits.0._score: 17.0}
- match: {hits.hits.1._id: "1"}
- match: {hits.hits.1._score: 16.0}
- match: {hits.hits.2._id: "3"}
- match: {hits.hits.2._score: 11.0}
---
"L2 norm":
- requires:
test_runner_features: close_to
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "l2norm(params.query_vector, 'vector')"
params:
query_vector: [0, 111, -13, 14, -124]
- match: {hits.total: 3}
- match: {hits.hits.0._id: "2"}
- close_to: {hits.hits.0._score: {value: 4.123, error: 0.001}}
- match: {hits.hits.1._id: "1"}
- close_to: {hits.hits.1._score: {value: 4, error: 0.001}}
- match: {hits.hits.2._id: "3"}
- close_to: {hits.hits.2._score: {value: 3.316, error: 0.001}}
---
"L2 norm hexidecimal":
- requires:
test_runner_features: close_to
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "l2norm(params.query_vector, 'vector')"
params:
query_vector: "006ff30e84"
- match: {hits.total: 3}
- match: {hits.hits.0._id: "2"}
- close_to: {hits.hits.0._score: {value: 4.123, error: 0.001}}
- match: {hits.hits.1._id: "1"}
- close_to: {hits.hits.1._score: {value: 4, error: 0.001}}
- match: {hits.hits.2._id: "3"}
- close_to: {hits.hits.2._score: {value: 3.316, error: 0.001}}
---
"Hamming distance":
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "hamming(params.query_vector, 'vector')"
params:
query_vector: [0, 111, -13, 14, -124]
- match: {hits.total: 3}
- match: {hits.hits.0._id: "2"}
- match: {hits.hits.0._score: 17.0}
- match: {hits.hits.1._id: "1"}
- match: {hits.hits.1._score: 16.0}
- match: {hits.hits.2._id: "3"}
- match: {hits.hits.2._score: 11.0}
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "hamming(params.query_vector, 'indexed_vector')"
params:
query_vector: [0, 111, -13, 14, -124]
- match: {hits.total: 3}
- match: {hits.hits.0._id: "2"}
- match: {hits.hits.0._score: 17.0}
- match: {hits.hits.1._id: "1"}
- match: {hits.hits.1._score: 16.0}
- match: {hits.hits.2._id: "3"}
- match: {hits.hits.2._score: 11.0}
---
"Hamming distance hexidecimal":
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "hamming(params.query_vector, 'vector')"
params:
query_vector: "006ff30e84"
- match: {hits.total: 3}
- match: {hits.hits.0._id: "2"}
- match: {hits.hits.0._score: 17.0}
- match: {hits.hits.1._id: "1"}
- match: {hits.hits.1._score: 16.0}
- match: {hits.hits.2._id: "3"}
- match: {hits.hits.2._score: 11.0}
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "hamming(params.query_vector, 'indexed_vector')"
params:
query_vector: "006ff30e84"
- match: {hits.total: 3}
- match: {hits.hits.0._id: "2"}
- match: {hits.hits.0._score: 17.0}
- match: {hits.hits.1._id: "1"}
- match: {hits.hits.1._score: 16.0}
- match: {hits.hits.2._id: "3"}
- match: {hits.hits.2._score: 11.0}

View file

@ -0,0 +1,301 @@
setup:
- requires:
cluster_features: "mapper.vectors.bit_vectors"
test_runner_features: close_to
reason: 'bit vectors added in 8.15'
- do:
indices.create:
index: test
body:
settings:
index:
number_of_shards: 2
mappings:
properties:
name:
type: keyword
nested:
type: nested
properties:
paragraph_id:
type: keyword
vector:
type: dense_vector
dims: 40
index: true
element_type: bit
similarity: l2_norm
- do:
index:
index: test
id: "1"
body:
name: cow.jpg
nested:
- paragraph_id: 0
vector: [100, 20, -34, 15, -100]
- paragraph_id: 1
vector: [40, 30, -3, 1, -20]
- do:
index:
index: test
id: "2"
body:
name: moose.jpg
nested:
- paragraph_id: 0
vector: [-1, 100, -13, 14, -127]
- paragraph_id: 2
vector: [0, 100, 0, 15, -127]
- paragraph_id: 3
vector: [0, 1, 0, 2, -15]
- do:
index:
index: test
id: "3"
body:
name: rabbit.jpg
nested:
- paragraph_id: 0
vector: [1, 111, -13, 14, -1]
- do:
indices.refresh: {}
---
"nested kNN search only":
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: nested.vector
query_vector: [-1, 90, -10, 14, -127]
k: 2
num_candidates: 3
- match: {hits.hits.0._id: "2"}
- match: {hits.hits.0.fields.name.0: "moose.jpg"}
- match: {hits.hits.1._id: "1"}
- match: {hits.hits.1.fields.name.0: "cow.jpg"}
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: nested.vector
query_vector: [-1, 90, -10, 14, -127]
k: 2
num_candidates: 3
inner_hits: {size: 1, "fields": ["nested.paragraph_id"], _source: false}
- match: {hits.hits.0._id: "2"}
- match: {hits.hits.0.fields.name.0: "moose.jpg"}
- match: {hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0"}
- match: {hits.hits.1._id: "1"}
- match: {hits.hits.1.fields.name.0: "cow.jpg"}
- match: {hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0"}
---
"nested kNN search filtered":
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: nested.vector
query_vector: [-1, 90, -10, 14, -127]
k: 2
num_candidates: 3
filter: {term: {name: "rabbit.jpg"}}
- match: {hits.total.value: 1}
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: nested.vector
query_vector: [-1, 90, -10, 14, -127]
k: 3
num_candidates: 3
filter: {term: {name: "rabbit.jpg"}}
inner_hits: {size: 1, fields: ["nested.paragraph_id"], _source: false}
- match: {hits.total.value: 1}
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
- match: {hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0"}
---
"nested kNN search inner_hits size > 1":
- do:
index:
index: test
id: "4"
body:
name: moose.jpg
nested:
- paragraph_id: 0
vector: [-1, 90, -10, 14, -127]
- paragraph_id: 2
vector: [ 0, 100.0, 0, 14, -127 ]
- paragraph_id: 3
vector: [ 0, 1.0, 0, 2, -15 ]
- do:
index:
index: test
id: "5"
body:
name: moose.jpg
nested:
- paragraph_id: 0
vector: [ -1, 100, -13, 14, -127 ]
- paragraph_id: 2
vector: [ 0, 100, 0, 15, -127 ]
- paragraph_id: 3
vector: [ 0, 1, 0, 2, -15 ]
- do:
index:
index: test
id: "6"
body:
name: moose.jpg
nested:
- paragraph_id: 0
vector: [ -1, 100, -13, 15, -127 ]
- paragraph_id: 2
vector: [ 0, 100, 0, 15, -127 ]
- paragraph_id: 3
vector: [ 0, 1, 0, 2, -15 ]
- do:
indices.refresh: { }
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: nested.vector
query_vector: [-1, 90, -10, 15, -127]
k: 3
num_candidates: 5
inner_hits: {size: 2, fields: ["nested.paragraph_id"], _source: false}
- match: {hits.total.value: 3}
- length: { hits.hits.0.inner_hits.nested.hits.hits: 2 }
- length: { hits.hits.1.inner_hits.nested.hits.hits: 2 }
- length: { hits.hits.2.inner_hits.nested.hits.hits: 2 }
- match: { hits.hits.0.fields.name.0: "moose.jpg" }
- match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: nested.vector
query_vector: [-1, 90, -10, 15, -127]
k: 5
num_candidates: 5
inner_hits: {size: 2, fields: ["nested.paragraph_id"], _source: false}
- match: {hits.total.value: 5}
# All these initial matches are "moose.jpg", which has 3 nested vectors, but two are closest
- match: {hits.hits.0.fields.name.0: "moose.jpg"}
- length: { hits.hits.0.inner_hits.nested.hits.hits: 2 }
- match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
- match: { hits.hits.0.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "2" }
- match: {hits.hits.1.fields.name.0: "moose.jpg"}
- length: { hits.hits.1.inner_hits.nested.hits.hits: 2 }
- match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
- match: { hits.hits.1.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "2" }
- match: {hits.hits.2.fields.name.0: "moose.jpg"}
- length: { hits.hits.2.inner_hits.nested.hits.hits: 2 }
- match: { hits.hits.2.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
- match: { hits.hits.2.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "2" }
- match: {hits.hits.3.fields.name.0: "moose.jpg"}
- length: { hits.hits.3.inner_hits.nested.hits.hits: 2 }
- match: { hits.hits.3.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
- match: { hits.hits.3.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "2" }
# Rabbit only has one passage vector
- match: {hits.hits.4.fields.name.0: "cow.jpg"}
- length: { hits.hits.4.inner_hits.nested.hits.hits: 2 }
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: nested.vector
query_vector: [ -1, 90, -10, 15, -127 ]
k: 3
num_candidates: 3
filter: {term: {name: "cow.jpg"}}
inner_hits: {size: 3, fields: ["nested.paragraph_id"], _source: false}
- match: {hits.total.value: 1}
- match: { hits.hits.0._id: "1" }
- length: { hits.hits.0.inner_hits.nested.hits.hits: 2 }
- match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
- match: { hits.hits.0.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "1" }
---
"nested kNN search inner_hits & boosting":
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: nested.vector
query_vector: [-1, 90, -10, 15, -127]
k: 3
num_candidates: 5
inner_hits: {size: 2, fields: ["nested.paragraph_id"], _source: false}
- close_to: { hits.hits.0._score: {value: 0.8, error: 0.00001} }
- close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: {value: 0.8, error: 0.00001} }
- close_to: { hits.hits.1._score: {value: 0.625, error: 0.00001} }
- close_to: { hits.hits.1.inner_hits.nested.hits.hits.0._score: {value: 0.625, error: 0.00001} }
- close_to: { hits.hits.2._score: {value: 0.5, error: 0.00001} }
- close_to: { hits.hits.2.inner_hits.nested.hits.hits.0._score: {value: 0.5, error: 0.00001} }
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: nested.vector
query_vector: [-1, 90, -10, 15, -127]
k: 3
num_candidates: 5
boost: 2
inner_hits: {size: 2, fields: ["nested.paragraph_id"], _source: false}
- close_to: { hits.hits.0._score: {value: 1.6, error: 0.00001} }
- close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: {value: 1.6, error: 0.00001} }
- close_to: { hits.hits.1._score: {value: 1.25, error: 0.00001} }
- close_to: { hits.hits.1.inner_hits.nested.hits.hits.0._score: {value: 1.25, error: 0.00001} }
- close_to: { hits.hits.2._score: {value: 1, error: 0.00001} }
- close_to: { hits.hits.2.inner_hits.nested.hits.hits.0._score: {value: 1.0, error: 0.00001} }

View file

@ -116,8 +116,9 @@ setup:
--- ---
"Knn search with hex string for byte field - dimensions mismatch" : "Knn search with hex string for byte field - dimensions mismatch" :
# [64, 10, -30, 10] - is encoded as '400ae20a' # [64, 10, -30, 10] - is encoded as '400ae20a'
# the error message has been adjusted in later versions
- do: - do:
catch: /the query vector has a different dimension \[4\] than the index vectors \[3\]/ catch: /dimension|dimensions \[4\] than the document|index vectors \[3\]/
search: search:
index: knn_hex_vector_index index: knn_hex_vector_index
body: body:

View file

@ -116,8 +116,9 @@ setup:
--- ---
"Knn query with hex string for byte field - dimensions mismatch" : "Knn query with hex string for byte field - dimensions mismatch" :
# [64, 10, -30, 10] - is encoded as '400ae20a' # [64, 10, -30, 10] - is encoded as '400ae20a'
# the error message has been adjusted in later versions
- do: - do:
catch: /the query vector has a different dimension \[4\] than the index vectors \[3\]/ catch: /dimension|dimensions \[4\] than the document|index vectors \[3\]/
search: search:
index: knn_hex_vector_index index: knn_hex_vector_index
body: body:

View file

@ -0,0 +1,356 @@
setup:
- requires:
cluster_features: "mapper.vectors.bit_vectors"
reason: 'mapper.vectors.bit_vectors'
- do:
indices.create:
index: test
body:
mappings:
properties:
name:
type: keyword
vector:
type: dense_vector
element_type: bit
dims: 40
index: true
similarity: l2_norm
- do:
index:
index: test
id: "1"
body:
name: cow.jpg
vector: [2, -1, 1, 4, -3]
- do:
index:
index: test
id: "2"
body:
name: moose.jpg
vector: [127.0, -128.0, 0.0, 1.0, -1.0]
- do:
index:
index: test
id: "3"
body:
name: rabbit.jpg
vector: [5, 4.0, 3, 2.0, 127]
- do:
indices.refresh: {}
---
"kNN search only":
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: vector
query_vector: [127, 127, -128, -128, 127]
k: 2
num_candidates: 3
- match: {hits.hits.0._id: "2"}
- match: {hits.hits.0.fields.name.0: "moose.jpg"}
- match: {hits.hits.1._id: "1"}
- match: {hits.hits.1.fields.name.0: "cow.jpg"}
---
"kNN search plus query":
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: vector
query_vector: [127.0, -128.0, 0.0, 1.0, -1.0]
k: 2
num_candidates: 3
query:
term:
name: rabbit.jpg
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
- match: {hits.hits.1._id: "2"}
- match: {hits.hits.1.fields.name.0: "moose.jpg"}
---
"kNN search with filter":
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: vector
query_vector: [5.0, 4, 3.0, 2, 127.0]
k: 2
num_candidates: 3
filter:
term:
name: "rabbit.jpg"
- match: {hits.total.value: 1}
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: vector
query_vector: [2, -1, 1, 4, -3]
k: 2
num_candidates: 3
filter:
- term:
name: "rabbit.jpg"
- term:
_id: 2
- match: {hits.total.value: 0}
---
"Vector similarity search only":
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
num_candidates: 3
k: 3
field: vector
similarity: 0.98
query_vector: [5, 4.0, 3, 2.0, 127]
- length: {hits.hits: 1}
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
---
"Vector similarity with filter only":
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
num_candidates: 3
k: 3
field: vector
similarity: 0.98
query_vector: [5, 4.0, 3, 2.0, 127]
filter: {"term": {"name": "rabbit.jpg"}}
- length: {hits.hits: 1}
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
num_candidates: 3
k: 3
field: vector
similarity: 0.98
query_vector: [5, 4.0, 3, 2.0, 127]
filter: {"term": {"name": "cow.jpg"}}
- length: {hits.hits: 0}
---
"dim mismatch":
- do:
catch: bad_request
search:
index: test
body:
fields: [ "name" ]
knn:
field: vector
query_vector: [1, 2, 3, 4, 5, 6]
k: 2
num_candidates: 3
---
"disallow quantized vector types":
- do:
catch: bad_request
indices.create:
index: test
body:
mappings:
properties:
name:
type: keyword
vector:
type: dense_vector
element_type: bit
dims: 32
index: true
similarity: l2_norm
index_options:
type: int8_flat
- do:
catch: bad_request
indices.create:
index: test
body:
mappings:
properties:
name:
type: keyword
vector:
type: dense_vector
element_type: bit
dims: 32
index: true
similarity: l2_norm
index_options:
type: int4_flat
- do:
catch: bad_request
indices.create:
index: test
body:
mappings:
properties:
name:
type: keyword
vector:
type: dense_vector
element_type: bit
dims: 32
index: true
similarity: l2_norm
index_options:
type: int8_hnsw
- do:
catch: bad_request
indices.create:
index: test
body:
mappings:
properties:
name:
type: keyword
vector:
type: dense_vector
element_type: bit
dims: 32
index: true
similarity: l2_norm
index_options:
type: int4_hnsw
---
"disallow vector index type change to quantized type":
- do:
catch: bad_request
indices.put_mapping:
index: test
body:
properties:
vector:
type: dense_vector
element_type: bit
dims: 32
index: true
similarity: l2_norm
index_options:
type: int4_hnsw
- do:
catch: bad_request
indices.put_mapping:
index: test
body:
properties:
vector:
type: dense_vector
element_type: bit
dims: 32
index: true
similarity: l2_norm
index_options:
type: int8_hnsw
---
"Defaults to l2_norm with bit vectors":
- do:
indices.create:
index: default_to_l2_norm_bit
body:
mappings:
properties:
vector:
type: dense_vector
element_type: bit
dims: 40
index: true
- do:
indices.get_mapping:
index: default_to_l2_norm_bit
- match: { default_to_l2_norm_bit.mappings.properties.vector.similarity: l2_norm }
---
"Only allow l2_norm with bit vectors":
- do:
catch: bad_request
indices.create:
index: dot_product_fails_for_bits
body:
mappings:
properties:
vector:
type: dense_vector
element_type: bit
dims: 40
index: true
similarity: dot_product
- do:
catch: bad_request
indices.create:
index: cosine_product_fails_for_bits
body:
mappings:
properties:
vector:
type: dense_vector
element_type: bit
dims: 40
index: true
similarity: cosine
- do:
catch: bad_request
indices.create:
index: cosine_product_fails_for_bits
body:
mappings:
properties:
type: dense_vector
element_type: bit
dims: 40
index: true
similarity: max_inner_product

View file

@ -0,0 +1,223 @@
setup:
- requires:
cluster_features: "mapper.vectors.bit_vectors"
reason: 'mapper.vectors.bit_vectors'
- do:
indices.create:
index: test
body:
mappings:
properties:
name:
type: keyword
vector:
type: dense_vector
element_type: bit
dims: 40
index: true
similarity: l2_norm
index_options:
type: flat
- do:
index:
index: test
id: "1"
body:
name: cow.jpg
vector: [2, -1, 1, 4, -3]
- do:
index:
index: test
id: "2"
body:
name: moose.jpg
vector: [127.0, -128.0, 0.0, 1.0, -1.0]
- do:
index:
index: test
id: "3"
body:
name: rabbit.jpg
vector: [5, 4.0, 3, 2.0, 127]
- do:
indices.refresh: {}
---
"kNN search only":
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: vector
query_vector: [127, 127, -128, -128, 127]
k: 2
num_candidates: 3
- match: {hits.hits.0._id: "2"}
- match: {hits.hits.0.fields.name.0: "moose.jpg"}
- match: {hits.hits.1._id: "1"}
- match: {hits.hits.1.fields.name.0: "cow.jpg"}
---
"kNN search plus query":
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: vector
query_vector: [127.0, -128.0, 0.0, 1.0, -1.0]
k: 2
num_candidates: 3
query:
term:
name: rabbit.jpg
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
- match: {hits.hits.1._id: "2"}
- match: {hits.hits.1.fields.name.0: "moose.jpg"}
---
"kNN search with filter":
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: vector
query_vector: [5.0, 4, 3.0, 2, 127.0]
k: 2
num_candidates: 3
filter:
term:
name: "rabbit.jpg"
- match: {hits.total.value: 1}
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: vector
query_vector: [2, -1, 1, 4, -3]
k: 2
num_candidates: 3
filter:
- term:
name: "rabbit.jpg"
- term:
_id: 2
- match: {hits.total.value: 0}
---
"Vector similarity search only":
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
num_candidates: 3
k: 3
field: vector
similarity: 0.98
query_vector: [5, 4.0, 3, 2.0, 127]
- length: {hits.hits: 1}
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
---
"Vector similarity with filter only":
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
num_candidates: 3
k: 3
field: vector
similarity: 0.98
query_vector: [5, 4.0, 3, 2.0, 127]
filter: {"term": {"name": "rabbit.jpg"}}
- length: {hits.hits: 1}
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
num_candidates: 3
k: 3
field: vector
similarity: 0.98
query_vector: [5, 4.0, 3, 2.0, 127]
filter: {"term": {"name": "cow.jpg"}}
- length: {hits.hits: 0}
---
"dim mismatch":
- do:
catch: bad_request
search:
index: test
body:
fields: [ "name" ]
knn:
field: vector
query_vector: [1, 2, 3, 4, 5, 6]
k: 2
num_candidates: 3
---
"disallow vector index type change to quantized type":
- do:
catch: bad_request
indices.put_mapping:
index: test
body:
properties:
vector:
type: dense_vector
element_type: bit
dims: 32
index: true
similarity: l2_norm
index_options:
type: int4_hnsw
- do:
catch: bad_request
indices.put_mapping:
index: test
body:
properties:
vector:
type: dense_vector
element_type: bit
dims: 32
index: true
similarity: l2_norm
index_options:
type: int8_hnsw

View file

@ -449,7 +449,10 @@ module org.elasticsearch.server {
with with
org.elasticsearch.index.codec.vectors.ES813FlatVectorFormat, org.elasticsearch.index.codec.vectors.ES813FlatVectorFormat,
org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat, org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat,
org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat; org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat,
org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat,
org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat;
provides org.apache.lucene.codecs.Codec with Elasticsearch814Codec; provides org.apache.lucene.codecs.Codec with Elasticsearch814Codec;
provides org.apache.logging.log4j.core.util.ContextDataProvider with org.elasticsearch.common.logging.DynamicContextDataProvider; provides org.apache.logging.log4j.core.util.ContextDataProvider with org.elasticsearch.common.logging.DynamicContextDataProvider;

View file

@ -54,11 +54,11 @@ public class ES813FlatVectorFormat extends KnnVectorsFormat {
return new ES813FlatVectorReader(format.fieldsReader(state)); return new ES813FlatVectorReader(format.fieldsReader(state));
} }
public static class ES813FlatVectorWriter extends KnnVectorsWriter { static class ES813FlatVectorWriter extends KnnVectorsWriter {
private final FlatVectorsWriter writer; private final FlatVectorsWriter writer;
public ES813FlatVectorWriter(FlatVectorsWriter writer) { ES813FlatVectorWriter(FlatVectorsWriter writer) {
super(); super();
this.writer = writer; this.writer = writer;
} }
@ -94,11 +94,11 @@ public class ES813FlatVectorFormat extends KnnVectorsFormat {
} }
} }
public static class ES813FlatVectorReader extends KnnVectorsReader { static class ES813FlatVectorReader extends KnnVectorsReader {
private final FlatVectorsReader reader; private final FlatVectorsReader reader;
public ES813FlatVectorReader(FlatVectorsReader reader) { ES813FlatVectorReader(FlatVectorsReader reader) {
super(); super();
this.reader = reader; this.reader = reader;
} }

View file

@ -0,0 +1,47 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.index.codec.vectors;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import java.io.IOException;
public class ES815BitFlatVectorFormat extends KnnVectorsFormat {
static final String NAME = "ES815BitFlatVectorFormat";
private final FlatVectorsFormat format = new ES815BitFlatVectorsFormat();
/**
* Sole constructor
*/
public ES815BitFlatVectorFormat() {
super(NAME);
}
@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new ES813FlatVectorFormat.ES813FlatVectorWriter(format.fieldsWriter(state));
}
@Override
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new ES813FlatVectorFormat.ES813FlatVectorReader(format.fieldsReader(state));
}
@Override
public String toString() {
return NAME;
}
}

View file

@ -0,0 +1,143 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.index.codec.vectors;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import java.io.IOException;
class ES815BitFlatVectorsFormat extends FlatVectorsFormat {
private final FlatVectorsFormat delegate = new Lucene99FlatVectorsFormat(FlatBitVectorScorer.INSTANCE);
@Override
public FlatVectorsWriter fieldsWriter(SegmentWriteState segmentWriteState) throws IOException {
return delegate.fieldsWriter(segmentWriteState);
}
@Override
public FlatVectorsReader fieldsReader(SegmentReadState segmentReadState) throws IOException {
return delegate.fieldsReader(segmentReadState);
}
static class FlatBitVectorScorer implements FlatVectorsScorer {
static final FlatBitVectorScorer INSTANCE = new FlatBitVectorScorer();
static void checkDimensions(int queryLen, int fieldLen) {
if (queryLen != fieldLen) {
throw new IllegalArgumentException("vector query dimension: " + queryLen + " differs from field dimension: " + fieldLen);
}
}
@Override
public String toString() {
return super.toString();
}
@Override
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction vectorSimilarityFunction,
RandomAccessVectorValues randomAccessVectorValues
) throws IOException {
assert randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes;
assert vectorSimilarityFunction == VectorSimilarityFunction.EUCLIDEAN;
if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes randomAccessVectorValuesBytes) {
assert randomAccessVectorValues instanceof RandomAccessQuantizedByteVectorValues == false;
return switch (vectorSimilarityFunction) {
case DOT_PRODUCT, MAXIMUM_INNER_PRODUCT, COSINE, EUCLIDEAN -> new HammingScorerSupplier(randomAccessVectorValuesBytes);
};
}
throw new IllegalArgumentException("Unsupported vector type or similarity function");
}
@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction vectorSimilarityFunction,
RandomAccessVectorValues randomAccessVectorValues,
byte[] bytes
) {
assert randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes;
assert vectorSimilarityFunction == VectorSimilarityFunction.EUCLIDEAN;
if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes randomAccessVectorValuesBytes) {
checkDimensions(bytes.length, randomAccessVectorValuesBytes.dimension());
return switch (vectorSimilarityFunction) {
case DOT_PRODUCT, MAXIMUM_INNER_PRODUCT, COSINE, EUCLIDEAN -> new HammingVectorScorer(
randomAccessVectorValuesBytes,
bytes
);
};
}
throw new IllegalArgumentException("Unsupported vector type or similarity function");
}
@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction vectorSimilarityFunction,
RandomAccessVectorValues randomAccessVectorValues,
float[] floats
) {
throw new IllegalArgumentException("Unsupported vector type");
}
}
static float hammingScore(byte[] a, byte[] b) {
return ((a.length * Byte.SIZE) - VectorUtil.xorBitCount(a, b)) / (float) (a.length * Byte.SIZE);
}
static class HammingVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer {
private final byte[] query;
private final RandomAccessVectorValues.Bytes byteValues;
HammingVectorScorer(RandomAccessVectorValues.Bytes byteValues, byte[] query) {
super(byteValues);
this.query = query;
this.byteValues = byteValues;
}
@Override
public float score(int i) throws IOException {
return hammingScore(byteValues.vectorValue(i), query);
}
}
static class HammingScorerSupplier implements RandomVectorScorerSupplier {
private final RandomAccessVectorValues.Bytes byteValues, byteValues1, byteValues2;
HammingScorerSupplier(RandomAccessVectorValues.Bytes byteValues) throws IOException {
this.byteValues = byteValues;
this.byteValues1 = byteValues.copy();
this.byteValues2 = byteValues.copy();
}
@Override
public RandomVectorScorer scorer(int i) throws IOException {
byte[] query = byteValues1.vectorValue(i);
return new HammingVectorScorer(byteValues2, query);
}
@Override
public RandomVectorScorerSupplier copy() throws IOException {
return new HammingScorerSupplier(byteValues);
}
}
}

View file

@ -0,0 +1,74 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.index.codec.vectors;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import java.io.IOException;
public class ES815HnswBitVectorsFormat extends KnnVectorsFormat {
static final String NAME = "ES815HnswBitVectorsFormat";
static final int MAXIMUM_MAX_CONN = 512;
static final int MAXIMUM_BEAM_WIDTH = 3200;
private final int maxConn;
private final int beamWidth;
private final FlatVectorsFormat flatVectorsFormat = new ES815BitFlatVectorsFormat();
public ES815HnswBitVectorsFormat() {
this(16, 100);
}
public ES815HnswBitVectorsFormat(int maxConn, int beamWidth) {
super(NAME);
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
throw new IllegalArgumentException(
"maxConn must be positive and less than or equal to " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn
);
}
if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) {
throw new IllegalArgumentException(
"beamWidth must be positive and less than or equal to " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth
);
}
this.maxConn = maxConn;
this.beamWidth = beamWidth;
}
@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), 1, null);
}
@Override
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
}
@Override
public String toString() {
return "ES815HnswBitVectorsFormat(name=ES815HnswBitVectorsFormat, maxConn="
+ maxConn
+ ", beamWidth="
+ beamWidth
+ ", flatVectorFormat="
+ flatVectorsFormat
+ ")";
}
}

View file

@ -25,7 +25,8 @@ public class MapperFeatures implements FeatureSpecification {
PassThroughObjectMapper.PASS_THROUGH_PRIORITY, PassThroughObjectMapper.PASS_THROUGH_PRIORITY,
RangeFieldMapper.NULL_VALUES_OFF_BY_ONE_FIX, RangeFieldMapper.NULL_VALUES_OFF_BY_ONE_FIX,
SourceFieldMapper.SYNTHETIC_SOURCE_FALLBACK, SourceFieldMapper.SYNTHETIC_SOURCE_FALLBACK,
DenseVectorFieldMapper.INT4_QUANTIZATION DenseVectorFieldMapper.INT4_QUANTIZATION,
DenseVectorFieldMapper.BIT_VECTORS
); );
} }
} }

View file

@ -31,6 +31,7 @@ import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.ParsingException;
@ -41,6 +42,8 @@ import org.elasticsearch.index.IndexVersions;
import org.elasticsearch.index.codec.vectors.ES813FlatVectorFormat; import org.elasticsearch.index.codec.vectors.ES813FlatVectorFormat;
import org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat; import org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat;
import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat;
import org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat;
import org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat;
import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.FieldDataContext;
import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.fielddata.IndexFieldData;
import org.elasticsearch.index.mapper.ArraySourceValueFetcher; import org.elasticsearch.index.mapper.ArraySourceValueFetcher;
@ -100,6 +103,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
} }
public static final NodeFeature INT4_QUANTIZATION = new NodeFeature("mapper.vectors.int4_quantization"); public static final NodeFeature INT4_QUANTIZATION = new NodeFeature("mapper.vectors.int4_quantization");
public static final NodeFeature BIT_VECTORS = new NodeFeature("mapper.vectors.bit_vectors");
public static final IndexVersion MAGNITUDE_STORED_INDEX_VERSION = IndexVersions.V_7_5_0; public static final IndexVersion MAGNITUDE_STORED_INDEX_VERSION = IndexVersions.V_7_5_0;
public static final IndexVersion INDEXED_BY_DEFAULT_INDEX_VERSION = IndexVersions.FIRST_DETACHED_INDEX_VERSION; public static final IndexVersion INDEXED_BY_DEFAULT_INDEX_VERSION = IndexVersions.FIRST_DETACHED_INDEX_VERSION;
@ -109,6 +113,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
public static final String CONTENT_TYPE = "dense_vector"; public static final String CONTENT_TYPE = "dense_vector";
public static short MAX_DIMS_COUNT = 4096; // maximum allowed number of dimensions public static short MAX_DIMS_COUNT = 4096; // maximum allowed number of dimensions
public static int MAX_DIMS_COUNT_BIT = 4096 * Byte.SIZE; // maximum allowed number of dimensions
public static short MIN_DIMS_FOR_DYNAMIC_FLOAT_MAPPING = 128; // minimum number of dims for floats to be dynamically mapped to vector public static short MIN_DIMS_FOR_DYNAMIC_FLOAT_MAPPING = 128; // minimum number of dims for floats to be dynamically mapped to vector
public static final int MAGNITUDE_BYTES = 4; public static final int MAGNITUDE_BYTES = 4;
@ -134,17 +139,28 @@ public class DenseVectorFieldMapper extends FieldMapper {
throw new MapperParsingException("Property [dims] on field [" + n + "] must be an integer but got [" + o + "]"); throw new MapperParsingException("Property [dims] on field [" + n + "] must be an integer but got [" + o + "]");
} }
int dims = XContentMapValues.nodeIntegerValue(o); int dims = XContentMapValues.nodeIntegerValue(o);
if (dims < 1 || dims > MAX_DIMS_COUNT) { int maxDims = elementType.getValue() == ElementType.BIT ? MAX_DIMS_COUNT_BIT : MAX_DIMS_COUNT;
int minDims = elementType.getValue() == ElementType.BIT ? Byte.SIZE : 1;
if (dims < minDims || dims > maxDims) {
throw new MapperParsingException( throw new MapperParsingException(
"The number of dimensions for field [" "The number of dimensions for field ["
+ n + n
+ "] should be in the range [1, " + "] should be in the range ["
+ MAX_DIMS_COUNT + minDims
+ ", "
+ maxDims
+ "] but was [" + "] but was ["
+ dims + dims
+ "]" + "]"
); );
} }
if (elementType.getValue() == ElementType.BIT) {
if (dims % Byte.SIZE != 0) {
throw new MapperParsingException(
"The number of dimensions for field [" + n + "] should be a multiple of 8 but was [" + dims + "]"
);
}
}
return dims; return dims;
}, m -> toType(m).fieldType().dims, XContentBuilder::field, Object::toString).setSerializerCheck((id, ic, v) -> v != null) }, m -> toType(m).fieldType().dims, XContentBuilder::field, Object::toString).setSerializerCheck((id, ic, v) -> v != null)
.setMergeValidator((previous, current, c) -> previous == null || Objects.equals(previous, current)); .setMergeValidator((previous, current, c) -> previous == null || Objects.equals(previous, current));
@ -171,13 +187,27 @@ public class DenseVectorFieldMapper extends FieldMapper {
"similarity", "similarity",
false, false,
m -> toType(m).fieldType().similarity, m -> toType(m).fieldType().similarity,
(Supplier<VectorSimilarity>) () -> indexedByDefault && indexed.getValue() ? VectorSimilarity.COSINE : null, (Supplier<VectorSimilarity>) () -> {
if (indexedByDefault && indexed.getValue()) {
return elementType.getValue() == ElementType.BIT ? VectorSimilarity.L2_NORM : VectorSimilarity.COSINE;
}
return null;
},
VectorSimilarity.class VectorSimilarity.class
).acceptsNull().setSerializerCheck((id, ic, v) -> v != null); ).acceptsNull().setSerializerCheck((id, ic, v) -> v != null).addValidator(vectorSim -> {
if (vectorSim == null) {
return;
}
if (elementType.getValue() == ElementType.BIT && vectorSim != VectorSimilarity.L2_NORM) {
throw new IllegalArgumentException(
"The [" + VectorSimilarity.L2_NORM + "] similarity is the only supported similarity for bit vectors"
);
}
});
this.indexOptions = new Parameter<>( this.indexOptions = new Parameter<>(
"index_options", "index_options",
true, true,
() -> defaultInt8Hnsw && elementType.getValue() != ElementType.BYTE && this.indexed.getValue() () -> defaultInt8Hnsw && elementType.getValue() == ElementType.FLOAT && this.indexed.getValue()
? new Int8HnswIndexOptions( ? new Int8HnswIndexOptions(
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
@ -266,7 +296,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
public enum ElementType { public enum ElementType {
BYTE(1) { BYTE {
@Override @Override
public String toString() { public String toString() {
@ -371,7 +401,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
} }
@Override @Override
public double computeDotProduct(VectorData vectorData) { public double computeSquaredMagnitude(VectorData vectorData) {
return VectorUtil.dotProduct(vectorData.asByteVector(), vectorData.asByteVector()); return VectorUtil.dotProduct(vectorData.asByteVector(), vectorData.asByteVector());
} }
@ -428,7 +458,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
byte[] decodedVector = HexFormat.of().parseHex(context.parser().text()); byte[] decodedVector = HexFormat.of().parseHex(context.parser().text());
fieldMapper.checkDimensionMatches(decodedVector.length, context); fieldMapper.checkDimensionMatches(decodedVector.length, context);
VectorData vectorData = VectorData.fromBytes(decodedVector); VectorData vectorData = VectorData.fromBytes(decodedVector);
double squaredMagnitude = computeDotProduct(vectorData); double squaredMagnitude = computeSquaredMagnitude(vectorData);
checkVectorMagnitude( checkVectorMagnitude(
fieldMapper.fieldType().similarity, fieldMapper.fieldType().similarity,
errorByteElementsAppender(decodedVector), errorByteElementsAppender(decodedVector),
@ -463,7 +493,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
@Override @Override
int getNumBytes(int dimensions) { int getNumBytes(int dimensions) {
return dimensions * elementBytes; return dimensions;
} }
@Override @Override
@ -494,7 +524,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
} }
}, },
FLOAT(4) { FLOAT {
@Override @Override
public String toString() { public String toString() {
@ -596,7 +626,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
} }
@Override @Override
public double computeDotProduct(VectorData vectorData) { public double computeSquaredMagnitude(VectorData vectorData) {
return VectorUtil.dotProduct(vectorData.asFloatVector(), vectorData.asFloatVector()); return VectorUtil.dotProduct(vectorData.asFloatVector(), vectorData.asFloatVector());
} }
@ -656,7 +686,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
@Override @Override
int getNumBytes(int dimensions) { int getNumBytes(int dimensions) {
return dimensions * elementBytes; return dimensions * Float.BYTES;
} }
@Override @Override
@ -665,14 +695,250 @@ public class DenseVectorFieldMapper extends FieldMapper {
? ByteBuffer.wrap(new byte[numBytes]).order(ByteOrder.LITTLE_ENDIAN) ? ByteBuffer.wrap(new byte[numBytes]).order(ByteOrder.LITTLE_ENDIAN)
: ByteBuffer.wrap(new byte[numBytes]); : ByteBuffer.wrap(new byte[numBytes]);
} }
}; },
final int elementBytes; BIT {
ElementType(int elementBytes) { @Override
this.elementBytes = elementBytes; public String toString() {
return "bit";
} }
@Override
public void writeValue(ByteBuffer byteBuffer, float value) {
byteBuffer.put((byte) value);
}
@Override
public void readAndWriteValue(ByteBuffer byteBuffer, XContentBuilder b) throws IOException {
b.value(byteBuffer.get());
}
private KnnByteVectorField createKnnVectorField(String name, byte[] vector, VectorSimilarityFunction function) {
if (vector == null) {
throw new IllegalArgumentException("vector value must not be null");
}
FieldType denseVectorFieldType = new FieldType();
denseVectorFieldType.setVectorAttributes(vector.length, VectorEncoding.BYTE, function);
denseVectorFieldType.freeze();
return new KnnByteVectorField(name, vector, denseVectorFieldType);
}
@Override
IndexFieldData.Builder fielddataBuilder(DenseVectorFieldType denseVectorFieldType, FieldDataContext fieldDataContext) {
return new VectorIndexFieldData.Builder(
denseVectorFieldType.name(),
CoreValuesSourceType.KEYWORD,
denseVectorFieldType.indexVersionCreated,
this,
denseVectorFieldType.dims,
denseVectorFieldType.indexed,
r -> r
);
}
@Override
public void checkVectorBounds(float[] vector) {
checkNanAndInfinite(vector);
StringBuilder errorBuilder = null;
for (int index = 0; index < vector.length; ++index) {
float value = vector[index];
if (value % 1.0f != 0.0f) {
errorBuilder = new StringBuilder(
"element_type ["
+ this
+ "] vectors only support non-decimal values but found decimal value ["
+ value
+ "] at dim ["
+ index
+ "];"
);
break;
}
if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
errorBuilder = new StringBuilder(
"element_type ["
+ this
+ "] vectors only support integers between ["
+ Byte.MIN_VALUE
+ ", "
+ Byte.MAX_VALUE
+ "] but found ["
+ value
+ "] at dim ["
+ index
+ "];"
);
break;
}
}
if (errorBuilder != null) {
throw new IllegalArgumentException(appendErrorElements(errorBuilder, vector).toString());
}
}
@Override
void checkVectorMagnitude(
VectorSimilarity similarity,
Function<StringBuilder, StringBuilder> appender,
float squaredMagnitude
) {}
@Override
public double computeSquaredMagnitude(VectorData vectorData) {
int count = 0;
int i = 0;
byte[] byteBits = vectorData.asByteVector();
for (int upperBound = byteBits.length & -8; i < upperBound; i += 8) {
count += Long.bitCount((long) BitUtil.VH_NATIVE_LONG.get(byteBits, i));
}
while (i < byteBits.length) {
count += Integer.bitCount(byteBits[i] & 255);
++i;
}
return count;
}
private VectorData parseVectorArray(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
int index = 0;
byte[] vector = new byte[fieldMapper.fieldType().dims / Byte.SIZE];
for (XContentParser.Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser()
.nextToken()) {
fieldMapper.checkDimensionExceeded(index, context);
ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser());
final int value;
if (context.parser().numberType() != XContentParser.NumberType.INT) {
float floatValue = context.parser().floatValue(true);
if (floatValue % 1.0f != 0.0f) {
throw new IllegalArgumentException(
"element_type ["
+ this
+ "] vectors only support non-decimal values but found decimal value ["
+ floatValue
+ "] at dim ["
+ index
+ "];"
);
}
value = (int) floatValue;
} else {
value = context.parser().intValue(true);
}
if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
throw new IllegalArgumentException(
"element_type ["
+ this
+ "] vectors only support integers between ["
+ Byte.MIN_VALUE
+ ", "
+ Byte.MAX_VALUE
+ "] but found ["
+ value
+ "] at dim ["
+ index
+ "];"
);
}
if (index >= vector.length) {
throw new IllegalArgumentException(
"The number of dimensions for field ["
+ fieldMapper.fieldType().name()
+ "] should be ["
+ fieldMapper.fieldType().dims
+ "] but found ["
+ (index + 1) * Byte.SIZE
+ "]"
);
}
vector[index++] = (byte) value;
}
fieldMapper.checkDimensionMatches(index * Byte.SIZE, context);
return VectorData.fromBytes(vector);
}
private VectorData parseHexEncodedVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
byte[] decodedVector = HexFormat.of().parseHex(context.parser().text());
fieldMapper.checkDimensionMatches(decodedVector.length * Byte.SIZE, context);
return VectorData.fromBytes(decodedVector);
}
@Override
VectorData parseKnnVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
XContentParser.Token token = context.parser().currentToken();
return switch (token) {
case START_ARRAY -> parseVectorArray(context, fieldMapper);
case VALUE_STRING -> parseHexEncodedVector(context, fieldMapper);
default -> throw new ParsingException(
context.parser().getTokenLocation(),
format("Unsupported type [%s] for provided value [%s]", token, context.parser().text())
);
};
}
@Override
public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
VectorData vectorData = parseKnnVector(context, fieldMapper);
Field field = createKnnVectorField(
fieldMapper.fieldType().name(),
vectorData.asByteVector(),
fieldMapper.fieldType().similarity.vectorSimilarityFunction(fieldMapper.indexCreatedVersion, this)
);
context.doc().addWithKey(fieldMapper.fieldType().name(), field);
}
@Override
int getNumBytes(int dimensions) {
assert dimensions % Byte.SIZE == 0;
return dimensions / Byte.SIZE;
}
@Override
ByteBuffer createByteBuffer(IndexVersion indexVersion, int numBytes) {
return ByteBuffer.wrap(new byte[numBytes]);
}
@Override
int parseDimensionCount(DocumentParserContext context) throws IOException {
XContentParser.Token currentToken = context.parser().currentToken();
return switch (currentToken) {
case START_ARRAY -> {
int index = 0;
for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) {
index++;
}
yield index * Byte.SIZE;
}
case VALUE_STRING -> {
byte[] decodedVector = HexFormat.of().parseHex(context.parser().text());
yield decodedVector.length * Byte.SIZE;
}
default -> throw new ParsingException(
context.parser().getTokenLocation(),
format("Unsupported type [%s] for provided value [%s]", currentToken, context.parser().text())
);
};
}
@Override
public void checkDimensions(int dvDims, int qvDims) {
if (dvDims != qvDims * Byte.SIZE) {
throw new IllegalArgumentException(
"The query vector has a different number of dimensions ["
+ qvDims * Byte.SIZE
+ "] than the document vectors ["
+ dvDims
+ "]."
);
}
}
};
public abstract void writeValue(ByteBuffer byteBuffer, float value); public abstract void writeValue(ByteBuffer byteBuffer, float value);
public abstract void readAndWriteValue(ByteBuffer byteBuffer, XContentBuilder b) throws IOException; public abstract void readAndWriteValue(ByteBuffer byteBuffer, XContentBuilder b) throws IOException;
@ -695,6 +961,14 @@ public class DenseVectorFieldMapper extends FieldMapper {
float squaredMagnitude float squaredMagnitude
); );
public void checkDimensions(int dvDims, int qvDims) {
if (dvDims != qvDims) {
throw new IllegalArgumentException(
"The query vector has a different number of dimensions [" + qvDims + "] than the document vectors [" + dvDims + "]."
);
}
}
int parseDimensionCount(DocumentParserContext context) throws IOException { int parseDimensionCount(DocumentParserContext context) throws IOException {
int index = 0; int index = 0;
for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) { for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) {
@ -775,7 +1049,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
return sb -> appendErrorElements(sb, vector); return sb -> appendErrorElements(sb, vector);
} }
public abstract double computeDotProduct(VectorData vectorData); public abstract double computeSquaredMagnitude(VectorData vectorData);
public static ElementType fromString(String name) { public static ElementType fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT)); return valueOf(name.trim().toUpperCase(Locale.ROOT));
@ -786,7 +1060,9 @@ public class DenseVectorFieldMapper extends FieldMapper {
ElementType.BYTE.toString(), ElementType.BYTE.toString(),
ElementType.BYTE, ElementType.BYTE,
ElementType.FLOAT.toString(), ElementType.FLOAT.toString(),
ElementType.FLOAT ElementType.FLOAT,
ElementType.BIT.toString(),
ElementType.BIT
); );
public enum VectorSimilarity { public enum VectorSimilarity {
@ -795,6 +1071,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
float score(float similarity, ElementType elementType, int dim) { float score(float similarity, ElementType elementType, int dim) {
return switch (elementType) { return switch (elementType) {
case BYTE, FLOAT -> 1f / (1f + similarity * similarity); case BYTE, FLOAT -> 1f / (1f + similarity * similarity);
case BIT -> (dim - similarity) / dim;
}; };
} }
@ -806,8 +1083,10 @@ public class DenseVectorFieldMapper extends FieldMapper {
COSINE { COSINE {
@Override @Override
float score(float similarity, ElementType elementType, int dim) { float score(float similarity, ElementType elementType, int dim) {
assert elementType != ElementType.BIT;
return switch (elementType) { return switch (elementType) {
case BYTE, FLOAT -> (1 + similarity) / 2f; case BYTE, FLOAT -> (1 + similarity) / 2f;
default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]");
}; };
} }
@ -824,6 +1103,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
return switch (elementType) { return switch (elementType) {
case BYTE -> 0.5f + similarity / (float) (dim * (1 << 15)); case BYTE -> 0.5f + similarity / (float) (dim * (1 << 15));
case FLOAT -> (1 + similarity) / 2f; case FLOAT -> (1 + similarity) / 2f;
default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]");
}; };
} }
@ -837,6 +1117,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
float score(float similarity, ElementType elementType, int dim) { float score(float similarity, ElementType elementType, int dim) {
return switch (elementType) { return switch (elementType) {
case BYTE, FLOAT -> similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1; case BYTE, FLOAT -> similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1;
default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]");
}; };
} }
@ -863,7 +1144,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
this.type = type; this.type = type;
} }
abstract KnnVectorsFormat getVectorsFormat(); abstract KnnVectorsFormat getVectorsFormat(ElementType elementType);
boolean supportsElementType(ElementType elementType) { boolean supportsElementType(ElementType elementType) {
return true; return true;
@ -1002,7 +1283,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
} }
@Override @Override
KnnVectorsFormat getVectorsFormat() { KnnVectorsFormat getVectorsFormat(ElementType elementType) {
assert elementType == ElementType.FLOAT;
return new ES813Int8FlatVectorFormat(confidenceInterval, 7, false); return new ES813Int8FlatVectorFormat(confidenceInterval, 7, false);
} }
@ -1021,7 +1303,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
@Override @Override
boolean supportsElementType(ElementType elementType) { boolean supportsElementType(ElementType elementType) {
return elementType != ElementType.BYTE; return elementType == ElementType.FLOAT;
} }
@Override @Override
@ -1047,7 +1329,10 @@ public class DenseVectorFieldMapper extends FieldMapper {
} }
@Override @Override
KnnVectorsFormat getVectorsFormat() { KnnVectorsFormat getVectorsFormat(ElementType elementType) {
if (elementType.equals(ElementType.BIT)) {
return new ES815BitFlatVectorFormat();
}
return new ES813FlatVectorFormat(); return new ES813FlatVectorFormat();
} }
@ -1083,7 +1368,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
} }
@Override @Override
public KnnVectorsFormat getVectorsFormat() { public KnnVectorsFormat getVectorsFormat(ElementType elementType) {
assert elementType == ElementType.FLOAT;
return new ES814HnswScalarQuantizedVectorsFormat(m, efConstruction, confidenceInterval, 4, true); return new ES814HnswScalarQuantizedVectorsFormat(m, efConstruction, confidenceInterval, 4, true);
} }
@ -1126,7 +1412,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
@Override @Override
boolean supportsElementType(ElementType elementType) { boolean supportsElementType(ElementType elementType) {
return elementType != ElementType.BYTE; return elementType == ElementType.FLOAT;
} }
@Override @Override
@ -1153,7 +1439,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
} }
@Override @Override
public KnnVectorsFormat getVectorsFormat() { public KnnVectorsFormat getVectorsFormat(ElementType elementType) {
assert elementType == ElementType.FLOAT;
return new ES813Int8FlatVectorFormat(confidenceInterval, 4, true); return new ES813Int8FlatVectorFormat(confidenceInterval, 4, true);
} }
@ -1186,7 +1473,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
@Override @Override
boolean supportsElementType(ElementType elementType) { boolean supportsElementType(ElementType elementType) {
return elementType != ElementType.BYTE; return elementType == ElementType.FLOAT;
} }
@Override @Override
@ -1216,7 +1503,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
} }
@Override @Override
public KnnVectorsFormat getVectorsFormat() { public KnnVectorsFormat getVectorsFormat(ElementType elementType) {
assert elementType == ElementType.FLOAT;
return new ES814HnswScalarQuantizedVectorsFormat(m, efConstruction, confidenceInterval, 7, false); return new ES814HnswScalarQuantizedVectorsFormat(m, efConstruction, confidenceInterval, 7, false);
} }
@ -1261,7 +1549,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
@Override @Override
boolean supportsElementType(ElementType elementType) { boolean supportsElementType(ElementType elementType) {
return elementType != ElementType.BYTE; return elementType == ElementType.FLOAT;
} }
@Override @Override
@ -1291,7 +1579,10 @@ public class DenseVectorFieldMapper extends FieldMapper {
} }
@Override @Override
public KnnVectorsFormat getVectorsFormat() { public KnnVectorsFormat getVectorsFormat(ElementType elementType) {
if (elementType == ElementType.BIT) {
return new ES815HnswBitVectorsFormat(m, efConstruction);
}
return new Lucene99HnswVectorsFormat(m, efConstruction, 1, null); return new Lucene99HnswVectorsFormat(m, efConstruction, 1, null);
} }
@ -1412,48 +1703,6 @@ public class DenseVectorFieldMapper extends FieldMapper {
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] doesn't support term queries"); throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] doesn't support term queries");
} }
public Query createKnnQuery(
byte[] queryVector,
int numCands,
Query filter,
Float similarityThreshold,
BitSetProducer parentFilter
) {
if (isIndexed() == false) {
throw new IllegalArgumentException(
"to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]"
);
}
if (queryVector.length != dims) {
throw new IllegalArgumentException(
"the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]"
);
}
if (elementType != ElementType.BYTE) {
throw new IllegalArgumentException(
"only [" + ElementType.BYTE + "] elements are supported when querying field [" + name() + "]"
);
}
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude);
}
Query knnQuery = parentFilter != null
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, numCands, parentFilter)
: new ESKnnByteVectorQuery(name(), queryVector, numCands, filter);
if (similarityThreshold != null) {
knnQuery = new VectorSimilarityQuery(
knnQuery,
similarityThreshold,
similarity.score(similarityThreshold, elementType, dims)
);
}
return knnQuery;
}
public Query createExactKnnQuery(VectorData queryVector) { public Query createExactKnnQuery(VectorData queryVector) {
if (isIndexed() == false) { if (isIndexed() == false) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
@ -1463,15 +1712,17 @@ public class DenseVectorFieldMapper extends FieldMapper {
return switch (elementType) { return switch (elementType) {
case BYTE -> createExactKnnByteQuery(queryVector.asByteVector()); case BYTE -> createExactKnnByteQuery(queryVector.asByteVector());
case FLOAT -> createExactKnnFloatQuery(queryVector.asFloatVector()); case FLOAT -> createExactKnnFloatQuery(queryVector.asFloatVector());
case BIT -> createExactKnnBitQuery(queryVector.asByteVector());
}; };
} }
private Query createExactKnnByteQuery(byte[] queryVector) { private Query createExactKnnBitQuery(byte[] queryVector) {
if (queryVector.length != dims) { elementType.checkDimensions(dims, queryVector.length);
throw new IllegalArgumentException( return new DenseVectorQuery.Bytes(queryVector, name());
"the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]"
);
} }
private Query createExactKnnByteQuery(byte[] queryVector) {
elementType.checkDimensions(dims, queryVector.length);
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) { if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude); elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude);
@ -1480,11 +1731,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
} }
private Query createExactKnnFloatQuery(float[] queryVector) { private Query createExactKnnFloatQuery(float[] queryVector) {
if (queryVector.length != dims) { elementType.checkDimensions(dims, queryVector.length);
throw new IllegalArgumentException(
"the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]"
);
}
elementType.checkVectorBounds(queryVector); elementType.checkVectorBounds(queryVector);
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) { if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
@ -1521,9 +1768,31 @@ public class DenseVectorFieldMapper extends FieldMapper {
return switch (getElementType()) { return switch (getElementType()) {
case BYTE -> createKnnByteQuery(queryVector.asByteVector(), numCands, filter, similarityThreshold, parentFilter); case BYTE -> createKnnByteQuery(queryVector.asByteVector(), numCands, filter, similarityThreshold, parentFilter);
case FLOAT -> createKnnFloatQuery(queryVector.asFloatVector(), numCands, filter, similarityThreshold, parentFilter); case FLOAT -> createKnnFloatQuery(queryVector.asFloatVector(), numCands, filter, similarityThreshold, parentFilter);
case BIT -> createKnnBitQuery(queryVector.asByteVector(), numCands, filter, similarityThreshold, parentFilter);
}; };
} }
private Query createKnnBitQuery(
byte[] queryVector,
int numCands,
Query filter,
Float similarityThreshold,
BitSetProducer parentFilter
) {
elementType.checkDimensions(dims, queryVector.length);
Query knnQuery = parentFilter != null
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, numCands, parentFilter)
: new ESKnnByteVectorQuery(name(), queryVector, numCands, filter);
if (similarityThreshold != null) {
knnQuery = new VectorSimilarityQuery(
knnQuery,
similarityThreshold,
similarity.score(similarityThreshold, elementType, dims)
);
}
return knnQuery;
}
private Query createKnnByteQuery( private Query createKnnByteQuery(
byte[] queryVector, byte[] queryVector,
int numCands, int numCands,
@ -1531,11 +1800,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
Float similarityThreshold, Float similarityThreshold,
BitSetProducer parentFilter BitSetProducer parentFilter
) { ) {
if (queryVector.length != dims) { elementType.checkDimensions(dims, queryVector.length);
throw new IllegalArgumentException(
"the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]"
);
}
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) { if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
@ -1561,11 +1826,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
Float similarityThreshold, Float similarityThreshold,
BitSetProducer parentFilter BitSetProducer parentFilter
) { ) {
if (queryVector.length != dims) { elementType.checkDimensions(dims, queryVector.length);
throw new IllegalArgumentException(
"the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]"
);
}
elementType.checkVectorBounds(queryVector); elementType.checkVectorBounds(queryVector);
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) { if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
@ -1701,7 +1962,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
vectorData.addToBuffer(byteBuffer); vectorData.addToBuffer(byteBuffer);
if (indexCreatedVersion.onOrAfter(MAGNITUDE_STORED_INDEX_VERSION)) { if (indexCreatedVersion.onOrAfter(MAGNITUDE_STORED_INDEX_VERSION)) {
// encode vector magnitude at the end // encode vector magnitude at the end
double dotProduct = elementType.computeDotProduct(vectorData); double dotProduct = elementType.computeSquaredMagnitude(vectorData);
float vectorMagnitude = (float) Math.sqrt(dotProduct); float vectorMagnitude = (float) Math.sqrt(dotProduct);
byteBuffer.putFloat(vectorMagnitude); byteBuffer.putFloat(vectorMagnitude);
} }
@ -1780,9 +2041,9 @@ public class DenseVectorFieldMapper extends FieldMapper {
public KnnVectorsFormat getKnnVectorsFormatForField(KnnVectorsFormat defaultFormat) { public KnnVectorsFormat getKnnVectorsFormatForField(KnnVectorsFormat defaultFormat) {
final KnnVectorsFormat format; final KnnVectorsFormat format;
if (indexOptions == null) { if (indexOptions == null) {
format = defaultFormat; format = fieldType().elementType == ElementType.BIT ? new ES815HnswBitVectorsFormat() : defaultFormat;
} else { } else {
format = indexOptions.getVectorsFormat(); format = indexOptions.getVectorsFormat(fieldType().elementType);
} }
// It's legal to reuse the same format name as this is the same on-disk format. // It's legal to reuse the same format name as this is the same on-disk format.
return new KnnVectorsFormat(format.getName()) { return new KnnVectorsFormat(format.getName()) {

View file

@ -17,6 +17,8 @@ import org.elasticsearch.index.fielddata.SortedBinaryDocValues;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
import org.elasticsearch.script.field.DocValuesScriptFieldFactory; import org.elasticsearch.script.field.DocValuesScriptFieldFactory;
import org.elasticsearch.script.field.vectors.BinaryDenseVectorDocValuesField; import org.elasticsearch.script.field.vectors.BinaryDenseVectorDocValuesField;
import org.elasticsearch.script.field.vectors.BitBinaryDenseVectorDocValuesField;
import org.elasticsearch.script.field.vectors.BitKnnDenseVectorDocValuesField;
import org.elasticsearch.script.field.vectors.ByteBinaryDenseVectorDocValuesField; import org.elasticsearch.script.field.vectors.ByteBinaryDenseVectorDocValuesField;
import org.elasticsearch.script.field.vectors.ByteKnnDenseVectorDocValuesField; import org.elasticsearch.script.field.vectors.ByteKnnDenseVectorDocValuesField;
import org.elasticsearch.script.field.vectors.KnnDenseVectorDocValuesField; import org.elasticsearch.script.field.vectors.KnnDenseVectorDocValuesField;
@ -58,12 +60,14 @@ final class VectorDVLeafFieldData implements LeafFieldData {
return switch (elementType) { return switch (elementType) {
case BYTE -> new ByteKnnDenseVectorDocValuesField(reader.getByteVectorValues(field), name, dims); case BYTE -> new ByteKnnDenseVectorDocValuesField(reader.getByteVectorValues(field), name, dims);
case FLOAT -> new KnnDenseVectorDocValuesField(reader.getFloatVectorValues(field), name, dims); case FLOAT -> new KnnDenseVectorDocValuesField(reader.getFloatVectorValues(field), name, dims);
case BIT -> new BitKnnDenseVectorDocValuesField(reader.getByteVectorValues(field), name, dims);
}; };
} else { } else {
BinaryDocValues values = DocValues.getBinary(reader, field); BinaryDocValues values = DocValues.getBinary(reader, field);
return switch (elementType) { return switch (elementType) {
case BYTE -> new ByteBinaryDenseVectorDocValuesField(values, name, elementType, dims); case BYTE -> new ByteBinaryDenseVectorDocValuesField(values, name, elementType, dims);
case FLOAT -> new BinaryDenseVectorDocValuesField(values, name, elementType, dims, indexVersion); case FLOAT -> new BinaryDenseVectorDocValuesField(values, name, elementType, dims, indexVersion);
case BIT -> new BitBinaryDenseVectorDocValuesField(values, name, elementType, dims);
}; };
} }
} catch (IOException e) { } catch (IOException e) {

View file

@ -56,7 +56,7 @@ public class VectorScoreScriptUtils {
*/ */
public ByteDenseVectorFunction(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) { public ByteDenseVectorFunction(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
super(scoreScript, field); super(scoreScript, field);
DenseVector.checkDimensions(field.get().getDims(), queryVector.size()); field.getElementType().checkDimensions(field.get().getDims(), queryVector.size());
this.queryVector = new byte[queryVector.size()]; this.queryVector = new byte[queryVector.size()];
float[] validateValues = new float[queryVector.size()]; float[] validateValues = new float[queryVector.size()];
int queryMagnitude = 0; int queryMagnitude = 0;
@ -168,7 +168,7 @@ public class VectorScoreScriptUtils {
public L1Norm(ScoreScript scoreScript, Object queryVector, String fieldName) { public L1Norm(ScoreScript scoreScript, Object queryVector, String fieldName) {
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName); DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
function = switch (field.getElementType()) { function = switch (field.getElementType()) {
case BYTE -> { case BYTE, BIT -> {
if (queryVector instanceof List) { if (queryVector instanceof List) {
yield new ByteL1Norm(scoreScript, field, (List<Number>) queryVector); yield new ByteL1Norm(scoreScript, field, (List<Number>) queryVector);
} else if (queryVector instanceof String s) { } else if (queryVector instanceof String s) {
@ -219,8 +219,8 @@ public class VectorScoreScriptUtils {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public Hamming(ScoreScript scoreScript, Object queryVector, String fieldName) { public Hamming(ScoreScript scoreScript, Object queryVector, String fieldName) {
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName); DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
if (field.getElementType() != DenseVectorFieldMapper.ElementType.BYTE) { if (field.getElementType() == DenseVectorFieldMapper.ElementType.FLOAT) {
throw new IllegalArgumentException("hamming distance is only supported for byte vectors"); throw new IllegalArgumentException("hamming distance is only supported for byte or bit vectors");
} }
if (queryVector instanceof List) { if (queryVector instanceof List) {
function = new ByteHammingDistance(scoreScript, field, (List<Number>) queryVector); function = new ByteHammingDistance(scoreScript, field, (List<Number>) queryVector);
@ -278,7 +278,7 @@ public class VectorScoreScriptUtils {
public L2Norm(ScoreScript scoreScript, Object queryVector, String fieldName) { public L2Norm(ScoreScript scoreScript, Object queryVector, String fieldName) {
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName); DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
function = switch (field.getElementType()) { function = switch (field.getElementType()) {
case BYTE -> { case BYTE, BIT -> {
if (queryVector instanceof List) { if (queryVector instanceof List) {
yield new ByteL2Norm(scoreScript, field, (List<Number>) queryVector); yield new ByteL2Norm(scoreScript, field, (List<Number>) queryVector);
} else if (queryVector instanceof String s) { } else if (queryVector instanceof String s) {
@ -342,7 +342,7 @@ public class VectorScoreScriptUtils {
public DotProduct(ScoreScript scoreScript, Object queryVector, String fieldName) { public DotProduct(ScoreScript scoreScript, Object queryVector, String fieldName) {
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName); DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
function = switch (field.getElementType()) { function = switch (field.getElementType()) {
case BYTE -> { case BYTE, BIT -> {
if (queryVector instanceof List) { if (queryVector instanceof List) {
yield new ByteDotProduct(scoreScript, field, (List<Number>) queryVector); yield new ByteDotProduct(scoreScript, field, (List<Number>) queryVector);
} else if (queryVector instanceof String s) { } else if (queryVector instanceof String s) {
@ -406,7 +406,7 @@ public class VectorScoreScriptUtils {
public CosineSimilarity(ScoreScript scoreScript, Object queryVector, String fieldName) { public CosineSimilarity(ScoreScript scoreScript, Object queryVector, String fieldName) {
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName); DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
function = switch (field.getElementType()) { function = switch (field.getElementType()) {
case BYTE -> { case BYTE, BIT -> {
if (queryVector instanceof List) { if (queryVector instanceof List) {
yield new ByteCosineSimilarity(scoreScript, field, (List<Number>) queryVector); yield new ByteCosineSimilarity(scoreScript, field, (List<Number>) queryVector);
} else if (queryVector instanceof String s) { } else if (queryVector instanceof String s) {

View file

@ -0,0 +1,88 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.script.field.vectors;
import org.apache.lucene.util.BytesRef;
import java.util.List;
public class BitBinaryDenseVector extends ByteBinaryDenseVector {
public BitBinaryDenseVector(byte[] vectorValue, BytesRef docVector, int dims) {
super(vectorValue, docVector, dims);
}
@Override
public void checkDimensions(int qvDims) {
if (qvDims != dims) {
throw new IllegalArgumentException(
"The query vector has a different number of dimensions ["
+ qvDims * Byte.SIZE
+ "] than the document vectors ["
+ dims * Byte.SIZE
+ "]."
);
}
}
@Override
public int l1Norm(byte[] queryVector) {
return hamming(queryVector);
}
@Override
public double l1Norm(List<Number> queryVector) {
return hamming(queryVector);
}
@Override
public double l2Norm(byte[] queryVector) {
return Math.sqrt(hamming(queryVector));
}
@Override
public double l2Norm(List<Number> queryVector) {
return Math.sqrt(hamming(queryVector));
}
@Override
public int dotProduct(byte[] queryVector) {
throw new UnsupportedOperationException("dotProduct is not supported for bit vectors.");
}
@Override
public double cosineSimilarity(float[] queryVector, boolean normalizeQueryVector) {
throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors.");
}
@Override
public double dotProduct(List<Number> queryVector) {
throw new UnsupportedOperationException("dotProduct is not supported for bit vectors.");
}
@Override
public double cosineSimilarity(byte[] queryVector, float qvMagnitude) {
throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors.");
}
@Override
public double cosineSimilarity(List<Number> queryVector) {
throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors.");
}
@Override
public double dotProduct(float[] queryVector) {
throw new UnsupportedOperationException("dotProduct is not supported for bit vectors.");
}
@Override
public int getDims() {
return dims * Byte.SIZE;
}
}

View file

@ -0,0 +1,24 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.script.field.vectors;
import org.apache.lucene.index.BinaryDocValues;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
public class BitBinaryDenseVectorDocValuesField extends ByteBinaryDenseVectorDocValuesField {
public BitBinaryDenseVectorDocValuesField(BinaryDocValues input, String name, ElementType elementType, int dims) {
super(input, name, elementType, dims / 8);
}
@Override
protected DenseVector getVector() {
return new BitBinaryDenseVector(vectorValue, value, dims);
}
}

View file

@ -0,0 +1,95 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.script.field.vectors;
import java.util.List;
public class BitKnnDenseVector extends ByteKnnDenseVector {
public BitKnnDenseVector(byte[] vector) {
super(vector);
}
@Override
public void checkDimensions(int qvDims) {
if (qvDims != docVector.length) {
throw new IllegalArgumentException(
"The query vector has a different number of dimensions ["
+ qvDims * Byte.SIZE
+ "] than the document vectors ["
+ docVector.length * Byte.SIZE
+ "]."
);
}
}
@Override
public float getMagnitude() {
if (magnitudeCalculated == false) {
magnitude = DenseVector.getBitMagnitude(docVector, docVector.length);
magnitudeCalculated = true;
}
return magnitude;
}
@Override
public int l1Norm(byte[] queryVector) {
return hamming(queryVector);
}
@Override
public double l1Norm(List<Number> queryVector) {
return hamming(queryVector);
}
@Override
public double l2Norm(byte[] queryVector) {
return Math.sqrt(hamming(queryVector));
}
@Override
public double l2Norm(List<Number> queryVector) {
return Math.sqrt(hamming(queryVector));
}
@Override
public int dotProduct(byte[] queryVector) {
throw new UnsupportedOperationException("dotProduct is not supported for bit vectors.");
}
@Override
public double cosineSimilarity(float[] queryVector, boolean normalizeQueryVector) {
throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors.");
}
@Override
public double dotProduct(List<Number> queryVector) {
throw new UnsupportedOperationException("dotProduct is not supported for bit vectors.");
}
@Override
public double cosineSimilarity(byte[] queryVector, float qvMagnitude) {
throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors.");
}
@Override
public double cosineSimilarity(List<Number> queryVector) {
throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors.");
}
@Override
public double dotProduct(float[] queryVector) {
throw new UnsupportedOperationException("dotProduct is not supported for bit vectors.");
}
@Override
public int getDims() {
return docVector.length * Byte.SIZE;
}
}

View file

@ -0,0 +1,26 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.script.field.vectors;
import org.apache.lucene.index.ByteVectorValues;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
public class BitKnnDenseVectorDocValuesField extends ByteKnnDenseVectorDocValuesField {
public BitKnnDenseVectorDocValuesField(@Nullable ByteVectorValues input, String name, int dims) {
super(input, name, dims / 8, DenseVectorFieldMapper.ElementType.BIT);
}
@Override
protected DenseVector getVector() {
return new BitKnnDenseVector(vector);
}
}

View file

@ -21,7 +21,7 @@ public class ByteBinaryDenseVector implements DenseVector {
private final BytesRef docVector; private final BytesRef docVector;
private final byte[] vectorValue; private final byte[] vectorValue;
private final int dims; protected final int dims;
private float[] floatDocVector; private float[] floatDocVector;
private boolean magnitudeDecoded; private boolean magnitudeDecoded;

View file

@ -17,11 +17,11 @@ import java.io.IOException;
public class ByteBinaryDenseVectorDocValuesField extends DenseVectorDocValuesField { public class ByteBinaryDenseVectorDocValuesField extends DenseVectorDocValuesField {
private final BinaryDocValues input; protected final BinaryDocValues input;
private final int dims; protected final int dims;
private final byte[] vectorValue; protected final byte[] vectorValue;
private boolean decoded; protected boolean decoded;
private BytesRef value; protected BytesRef value;
public ByteBinaryDenseVectorDocValuesField(BinaryDocValues input, String name, ElementType elementType, int dims) { public ByteBinaryDenseVectorDocValuesField(BinaryDocValues input, String name, ElementType elementType, int dims) {
super(name, elementType); super(name, elementType);
@ -50,13 +50,17 @@ public class ByteBinaryDenseVectorDocValuesField extends DenseVectorDocValuesFie
return value == null; return value == null;
} }
protected DenseVector getVector() {
return new ByteBinaryDenseVector(vectorValue, value, dims);
}
@Override @Override
public DenseVector get() { public DenseVector get() {
if (isEmpty()) { if (isEmpty()) {
return DenseVector.EMPTY; return DenseVector.EMPTY;
} }
decodeVectorIfNecessary(); decodeVectorIfNecessary();
return new ByteBinaryDenseVector(vectorValue, value, dims); return getVector();
} }
@Override @Override
@ -65,7 +69,7 @@ public class ByteBinaryDenseVectorDocValuesField extends DenseVectorDocValuesFie
return defaultValue; return defaultValue;
} }
decodeVectorIfNecessary(); decodeVectorIfNecessary();
return new ByteBinaryDenseVector(vectorValue, value, dims); return getVector();
} }
@Override @Override

View file

@ -23,7 +23,11 @@ public class ByteKnnDenseVectorDocValuesField extends DenseVectorDocValuesField
protected final int dims; protected final int dims;
public ByteKnnDenseVectorDocValuesField(@Nullable ByteVectorValues input, String name, int dims) { public ByteKnnDenseVectorDocValuesField(@Nullable ByteVectorValues input, String name, int dims) {
super(name, ElementType.BYTE); this(input, name, dims, ElementType.BYTE);
}
protected ByteKnnDenseVectorDocValuesField(@Nullable ByteVectorValues input, String name, int dims, ElementType elementType) {
super(name, elementType);
this.dims = dims; this.dims = dims;
this.input = input; this.input = input;
} }
@ -57,13 +61,17 @@ public class ByteKnnDenseVectorDocValuesField extends DenseVectorDocValuesField
return vector == null; return vector == null;
} }
protected DenseVector getVector() {
return new ByteKnnDenseVector(vector);
}
@Override @Override
public DenseVector get() { public DenseVector get() {
if (isEmpty()) { if (isEmpty()) {
return DenseVector.EMPTY; return DenseVector.EMPTY;
} }
return new ByteKnnDenseVector(vector); return getVector();
} }
@Override @Override
@ -72,7 +80,7 @@ public class ByteKnnDenseVectorDocValuesField extends DenseVectorDocValuesField
return defaultValue; return defaultValue;
} }
return new ByteKnnDenseVector(vector); return getVector();
} }
@Override @Override

View file

@ -8,6 +8,7 @@
package org.elasticsearch.script.field.vectors; package org.elasticsearch.script.field.vectors;
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.VectorUtil;
import java.util.List; import java.util.List;
@ -25,6 +26,10 @@ import java.util.List;
*/ */
public interface DenseVector { public interface DenseVector {
default void checkDimensions(int qvDims) {
checkDimensions(getDims(), qvDims);
}
float[] getVector(); float[] getVector();
float getMagnitude(); float getMagnitude();
@ -38,13 +43,13 @@ public interface DenseVector {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
default double dotProduct(Object queryVector) { default double dotProduct(Object queryVector) {
if (queryVector instanceof float[] floats) { if (queryVector instanceof float[] floats) {
checkDimensions(getDims(), floats.length); checkDimensions(floats.length);
return dotProduct(floats); return dotProduct(floats);
} else if (queryVector instanceof List<?> list) { } else if (queryVector instanceof List<?> list) {
checkDimensions(getDims(), list.size()); checkDimensions(list.size());
return dotProduct((List<Number>) list); return dotProduct((List<Number>) list);
} else if (queryVector instanceof byte[] bytes) { } else if (queryVector instanceof byte[] bytes) {
checkDimensions(getDims(), bytes.length); checkDimensions(bytes.length);
return dotProduct(bytes); return dotProduct(bytes);
} }
@ -60,13 +65,13 @@ public interface DenseVector {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
default double l1Norm(Object queryVector) { default double l1Norm(Object queryVector) {
if (queryVector instanceof float[] floats) { if (queryVector instanceof float[] floats) {
checkDimensions(getDims(), floats.length); checkDimensions(floats.length);
return l1Norm(floats); return l1Norm(floats);
} else if (queryVector instanceof List<?> list) { } else if (queryVector instanceof List<?> list) {
checkDimensions(getDims(), list.size()); checkDimensions(list.size());
return l1Norm((List<Number>) list); return l1Norm((List<Number>) list);
} else if (queryVector instanceof byte[] bytes) { } else if (queryVector instanceof byte[] bytes) {
checkDimensions(getDims(), bytes.length); checkDimensions(bytes.length);
return l1Norm(bytes); return l1Norm(bytes);
} }
@ -80,11 +85,11 @@ public interface DenseVector {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
default int hamming(Object queryVector) { default int hamming(Object queryVector) {
if (queryVector instanceof List<?> list) { if (queryVector instanceof List<?> list) {
checkDimensions(getDims(), list.size()); checkDimensions(list.size());
return hamming((List<Number>) list); return hamming((List<Number>) list);
} }
if (queryVector instanceof byte[] bytes) { if (queryVector instanceof byte[] bytes) {
checkDimensions(getDims(), bytes.length); checkDimensions(bytes.length);
return hamming(bytes); return hamming(bytes);
} }
@ -100,13 +105,13 @@ public interface DenseVector {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
default double l2Norm(Object queryVector) { default double l2Norm(Object queryVector) {
if (queryVector instanceof float[] floats) { if (queryVector instanceof float[] floats) {
checkDimensions(getDims(), floats.length); checkDimensions(floats.length);
return l2Norm(floats); return l2Norm(floats);
} else if (queryVector instanceof List<?> list) { } else if (queryVector instanceof List<?> list) {
checkDimensions(getDims(), list.size()); checkDimensions(list.size());
return l2Norm((List<Number>) list); return l2Norm((List<Number>) list);
} else if (queryVector instanceof byte[] bytes) { } else if (queryVector instanceof byte[] bytes) {
checkDimensions(getDims(), bytes.length); checkDimensions(bytes.length);
return l2Norm(bytes); return l2Norm(bytes);
} }
@ -150,13 +155,13 @@ public interface DenseVector {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
default double cosineSimilarity(Object queryVector) { default double cosineSimilarity(Object queryVector) {
if (queryVector instanceof float[] floats) { if (queryVector instanceof float[] floats) {
checkDimensions(getDims(), floats.length); checkDimensions(floats.length);
return cosineSimilarity(floats); return cosineSimilarity(floats);
} else if (queryVector instanceof List<?> list) { } else if (queryVector instanceof List<?> list) {
checkDimensions(getDims(), list.size()); checkDimensions(list.size());
return cosineSimilarity((List<Number>) list); return cosineSimilarity((List<Number>) list);
} else if (queryVector instanceof byte[] bytes) { } else if (queryVector instanceof byte[] bytes) {
checkDimensions(getDims(), bytes.length); checkDimensions(bytes.length);
return cosineSimilarity(bytes); return cosineSimilarity(bytes);
} }
@ -184,6 +189,20 @@ public interface DenseVector {
return (float) Math.sqrt(mag); return (float) Math.sqrt(mag);
} }
static float getBitMagnitude(byte[] vector, int dims) {
int count = 0;
int i = 0;
for (int upperBound = dims & -8; i < upperBound; i += 8) {
count += Long.bitCount((long) BitUtil.VH_NATIVE_LONG.get(vector, i));
}
while (i < dims) {
count += Integer.bitCount(vector[i] & 255);
++i;
}
return (float) Math.sqrt(count);
}
static float getMagnitude(float[] vector) { static float getMagnitude(float[] vector) {
return (float) Math.sqrt(VectorUtil.dotProduct(vector, vector)); return (float) Math.sqrt(VectorUtil.dotProduct(vector, vector));
} }

View file

@ -1,3 +1,5 @@
org.elasticsearch.index.codec.vectors.ES813FlatVectorFormat org.elasticsearch.index.codec.vectors.ES813FlatVectorFormat
org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat
org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat
org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat
org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat

View file

@ -0,0 +1,149 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.index.codec.vectors;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.StoredFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.BaseIndexFileFormatTestCase;
import org.elasticsearch.common.logging.LogConfigurator;
import java.io.IOException;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
abstract class BaseKnnBitVectorsFormatTestCase extends BaseIndexFileFormatTestCase {
static {
LogConfigurator.loadLog4jPlugins();
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
}
@Override
protected void addRandomFields(Document doc) {
doc.add(new KnnByteVectorField("v2", randomVector(30), similarityFunction));
}
protected VectorSimilarityFunction similarityFunction;
protected VectorSimilarityFunction randomSimilarity() {
return VectorSimilarityFunction.values()[random().nextInt(VectorSimilarityFunction.values().length)];
}
byte[] randomVector(int dims) {
byte[] vector = new byte[dims];
random().nextBytes(vector);
return vector;
}
public void testRandom() throws Exception {
IndexWriterConfig iwc = newIndexWriterConfig();
if (random().nextBoolean()) {
iwc.setIndexSort(new Sort(new SortField("sortkey", SortField.Type.INT)));
}
String fieldName = "field";
try (Directory dir = newDirectory(); IndexWriter iw = new IndexWriter(dir, iwc)) {
int numDoc = atLeast(100);
int dimension = atLeast(10);
if (dimension % 2 != 0) {
dimension++;
}
byte[] scratch = new byte[dimension];
int numValues = 0;
byte[][] values = new byte[numDoc][];
for (int i = 0; i < numDoc; i++) {
if (random().nextInt(7) != 3) {
// usually index a vector value for a doc
values[i] = randomVector(dimension);
++numValues;
}
if (random().nextBoolean() && values[i] != null) {
// sometimes use a shared scratch array
System.arraycopy(values[i], 0, scratch, 0, scratch.length);
add(iw, fieldName, i, scratch, similarityFunction);
} else {
add(iw, fieldName, i, values[i], similarityFunction);
}
if (random().nextInt(10) == 2) {
// sometimes delete a random document
int idToDelete = random().nextInt(i + 1);
iw.deleteDocuments(new Term("id", Integer.toString(idToDelete)));
// and remember that it was deleted
if (values[idToDelete] != null) {
values[idToDelete] = null;
--numValues;
}
}
if (random().nextInt(10) == 3) {
iw.commit();
}
}
int numDeletes = 0;
try (IndexReader reader = DirectoryReader.open(iw)) {
int valueCount = 0, totalSize = 0;
for (LeafReaderContext ctx : reader.leaves()) {
ByteVectorValues vectorValues = ctx.reader().getByteVectorValues(fieldName);
if (vectorValues == null) {
continue;
}
totalSize += vectorValues.size();
StoredFields storedFields = ctx.reader().storedFields();
int docId;
while ((docId = vectorValues.nextDoc()) != NO_MORE_DOCS) {
byte[] v = vectorValues.vectorValue();
assertEquals(dimension, v.length);
String idString = storedFields.document(docId).getField("id").stringValue();
int id = Integer.parseInt(idString);
if (ctx.reader().getLiveDocs() == null || ctx.reader().getLiveDocs().get(docId)) {
assertArrayEquals(idString, values[id], v);
++valueCount;
} else {
++numDeletes;
assertNull(values[id]);
}
}
}
assertEquals(numValues, valueCount);
assertEquals(numValues, totalSize - numDeletes);
}
}
}
private void add(IndexWriter iw, String field, int id, byte[] vector, VectorSimilarityFunction similarity) throws IOException {
add(iw, field, id, random().nextInt(100), vector, similarity);
}
private void add(IndexWriter iw, String field, int id, int sortKey, byte[] vector, VectorSimilarityFunction similarityFunction)
throws IOException {
Document doc = new Document();
if (vector != null) {
doc.add(new KnnByteVectorField(field, vector, similarityFunction));
}
doc.add(new NumericDocValuesField("sortkey", sortKey));
String idString = Integer.toString(id);
doc.add(new StringField("id", idString, Field.Store.YES));
Term idTerm = new Term("id", idString);
iw.updateDocument(idTerm, doc);
}
}

View file

@ -12,8 +12,15 @@ import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99Codec; import org.apache.lucene.codecs.lucene99.Lucene99Codec;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
import org.elasticsearch.common.logging.LogConfigurator;
public class ES813FlatVectorFormatTests extends BaseKnnVectorsFormatTestCase { public class ES813FlatVectorFormatTests extends BaseKnnVectorsFormatTestCase {
static {
LogConfigurator.loadLog4jPlugins();
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
}
@Override @Override
protected Codec getCodec() { protected Codec getCodec() {
return new Lucene99Codec() { return new Lucene99Codec() {

View file

@ -12,8 +12,15 @@ import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99Codec; import org.apache.lucene.codecs.lucene99.Lucene99Codec;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
import org.elasticsearch.common.logging.LogConfigurator;
public class ES813Int8FlatVectorFormatTests extends BaseKnnVectorsFormatTestCase { public class ES813Int8FlatVectorFormatTests extends BaseKnnVectorsFormatTestCase {
static {
LogConfigurator.loadLog4jPlugins();
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
}
@Override @Override
protected Codec getCodec() { protected Codec getCodec() {
return new Lucene99Codec() { return new Lucene99Codec() {

View file

@ -0,0 +1,34 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.index.codec.vectors;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.junit.Before;
public class ES815BitFlatVectorFormatTests extends BaseKnnBitVectorsFormatTestCase {
@Override
protected Codec getCodec() {
return new Lucene99Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new ES815BitFlatVectorFormat();
}
};
}
@Before
public void init() {
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
}
}

View file

@ -0,0 +1,33 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.index.codec.vectors;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.junit.Before;
public class ES815HnswBitVectorsFormatTests extends BaseKnnBitVectorsFormatTestCase {
@Override
protected Codec getCodec() {
return new Lucene99Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new ES815HnswBitVectorsFormat();
}
};
}
@Before
public void init() {
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
}
}

View file

@ -236,8 +236,8 @@ public class BinaryDenseVectorScriptDocValuesTests extends ESTestCase {
public static BytesRef mockEncodeDenseVector(float[] values, ElementType elementType, IndexVersion indexVersion) { public static BytesRef mockEncodeDenseVector(float[] values, ElementType elementType, IndexVersion indexVersion) {
int numBytes = indexVersion.onOrAfter(DenseVectorFieldMapper.MAGNITUDE_STORED_INDEX_VERSION) int numBytes = indexVersion.onOrAfter(DenseVectorFieldMapper.MAGNITUDE_STORED_INDEX_VERSION)
? elementType.elementBytes * values.length + DenseVectorFieldMapper.MAGNITUDE_BYTES ? elementType.getNumBytes(values.length) + DenseVectorFieldMapper.MAGNITUDE_BYTES
: elementType.elementBytes * values.length; : elementType.getNumBytes(values.length);
double dotProduct = 0f; double dotProduct = 0f;
ByteBuffer byteBuffer = elementType.createByteBuffer(indexVersion, numBytes); ByteBuffer byteBuffer = elementType.createByteBuffer(indexVersion, numBytes);
for (float value : values) { for (float value : values) {

View file

@ -71,11 +71,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
private final ElementType elementType; private final ElementType elementType;
private final boolean indexed; private final boolean indexed;
private final boolean indexOptionsSet; private final boolean indexOptionsSet;
private final int dims;
public DenseVectorFieldMapperTests() { public DenseVectorFieldMapperTests() {
this.elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT); this.elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT, ElementType.BIT);
this.indexed = randomBoolean(); this.indexed = randomBoolean();
this.indexOptionsSet = this.indexed && randomBoolean(); this.indexOptionsSet = this.indexed && randomBoolean();
this.dims = ElementType.BIT == elementType ? 4 * Byte.SIZE : 4;
} }
@Override @Override
@ -89,7 +91,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
} }
private void indexMapping(XContentBuilder b, IndexVersion indexVersion) throws IOException { private void indexMapping(XContentBuilder b, IndexVersion indexVersion) throws IOException {
b.field("type", "dense_vector").field("dims", 4); b.field("type", "dense_vector").field("dims", dims);
if (elementType != ElementType.FLOAT) { if (elementType != ElementType.FLOAT) {
b.field("element_type", elementType.toString()); b.field("element_type", elementType.toString());
} }
@ -108,7 +110,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
b.endObject(); b.endObject();
} }
if (indexed) { if (indexed) {
b.field("similarity", "dot_product"); b.field("similarity", elementType == ElementType.BIT ? "l2_norm" : "dot_product");
if (indexOptionsSet) { if (indexOptionsSet) {
b.startObject("index_options"); b.startObject("index_options");
b.field("type", "hnsw"); b.field("type", "hnsw");
@ -121,52 +123,86 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
@Override @Override
protected Object getSampleValueForDocument() { protected Object getSampleValueForDocument() {
return elementType == ElementType.BYTE ? List.of((byte) 1, (byte) 1, (byte) 1, (byte) 1) : List.of(0.5, 0.5, 0.5, 0.5); return elementType == ElementType.FLOAT ? List.of(0.5, 0.5, 0.5, 0.5) : List.of((byte) 1, (byte) 1, (byte) 1, (byte) 1);
} }
@Override @Override
protected void registerParameters(ParameterChecker checker) throws IOException { protected void registerParameters(ParameterChecker checker) throws IOException {
checker.registerConflictCheck( checker.registerConflictCheck(
"dims", "dims",
fieldMapping(b -> b.field("type", "dense_vector").field("dims", 4)), fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims)),
fieldMapping(b -> b.field("type", "dense_vector").field("dims", 5)) fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims + 8))
); );
checker.registerConflictCheck( checker.registerConflictCheck(
"similarity", "similarity",
fieldMapping(b -> b.field("type", "dense_vector").field("dims", 4).field("index", true).field("similarity", "dot_product")), fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims).field("index", true).field("similarity", "dot_product")),
fieldMapping(b -> b.field("type", "dense_vector").field("dims", 4).field("index", true).field("similarity", "l2_norm")) fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims).field("index", true).field("similarity", "l2_norm"))
); );
checker.registerConflictCheck( checker.registerConflictCheck(
"index", "index",
fieldMapping(b -> b.field("type", "dense_vector").field("dims", 4).field("index", true).field("similarity", "dot_product")), fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims).field("index", true).field("similarity", "dot_product")),
fieldMapping(b -> b.field("type", "dense_vector").field("dims", 4).field("index", false)) fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims).field("index", false))
); );
checker.registerConflictCheck( checker.registerConflictCheck(
"element_type", "element_type",
fieldMapping( fieldMapping(
b -> b.field("type", "dense_vector") b -> b.field("type", "dense_vector")
.field("dims", 4) .field("dims", dims)
.field("index", true) .field("index", true)
.field("similarity", "dot_product") .field("similarity", "dot_product")
.field("element_type", "byte") .field("element_type", "byte")
), ),
fieldMapping( fieldMapping(
b -> b.field("type", "dense_vector") b -> b.field("type", "dense_vector")
.field("dims", 4) .field("dims", dims)
.field("index", true) .field("index", true)
.field("similarity", "dot_product") .field("similarity", "dot_product")
.field("element_type", "float") .field("element_type", "float")
) )
); );
checker.registerConflictCheck(
"element_type",
fieldMapping(
b -> b.field("type", "dense_vector")
.field("dims", dims)
.field("index", true)
.field("similarity", "l2_norm")
.field("element_type", "float")
),
fieldMapping(
b -> b.field("type", "dense_vector")
.field("dims", dims)
.field("index", true)
.field("similarity", "l2_norm")
.field("element_type", "bit")
)
);
checker.registerConflictCheck(
"element_type",
fieldMapping(
b -> b.field("type", "dense_vector")
.field("dims", dims)
.field("index", true)
.field("similarity", "l2_norm")
.field("element_type", "byte")
),
fieldMapping(
b -> b.field("type", "dense_vector")
.field("dims", dims)
.field("index", true)
.field("similarity", "l2_norm")
.field("element_type", "bit")
)
);
checker.registerUpdateCheck( checker.registerUpdateCheck(
b -> b.field("type", "dense_vector") b -> b.field("type", "dense_vector")
.field("dims", 4) .field("dims", dims)
.field("index", true) .field("index", true)
.startObject("index_options") .startObject("index_options")
.field("type", "flat") .field("type", "flat")
.endObject(), .endObject(),
b -> b.field("type", "dense_vector") b -> b.field("type", "dense_vector")
.field("dims", 4) .field("dims", dims)
.field("index", true) .field("index", true)
.startObject("index_options") .startObject("index_options")
.field("type", "int8_flat") .field("type", "int8_flat")
@ -175,13 +211,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
); );
checker.registerUpdateCheck( checker.registerUpdateCheck(
b -> b.field("type", "dense_vector") b -> b.field("type", "dense_vector")
.field("dims", 4) .field("dims", dims)
.field("index", true) .field("index", true)
.startObject("index_options") .startObject("index_options")
.field("type", "flat") .field("type", "flat")
.endObject(), .endObject(),
b -> b.field("type", "dense_vector") b -> b.field("type", "dense_vector")
.field("dims", 4) .field("dims", dims)
.field("index", true) .field("index", true)
.startObject("index_options") .startObject("index_options")
.field("type", "hnsw") .field("type", "hnsw")
@ -190,13 +226,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
); );
checker.registerUpdateCheck( checker.registerUpdateCheck(
b -> b.field("type", "dense_vector") b -> b.field("type", "dense_vector")
.field("dims", 4) .field("dims", dims)
.field("index", true) .field("index", true)
.startObject("index_options") .startObject("index_options")
.field("type", "flat") .field("type", "flat")
.endObject(), .endObject(),
b -> b.field("type", "dense_vector") b -> b.field("type", "dense_vector")
.field("dims", 4) .field("dims", dims)
.field("index", true) .field("index", true)
.startObject("index_options") .startObject("index_options")
.field("type", "int8_hnsw") .field("type", "int8_hnsw")
@ -205,13 +241,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
); );
checker.registerUpdateCheck( checker.registerUpdateCheck(
b -> b.field("type", "dense_vector") b -> b.field("type", "dense_vector")
.field("dims", 4) .field("dims", dims)
.field("index", true) .field("index", true)
.startObject("index_options") .startObject("index_options")
.field("type", "int8_flat") .field("type", "int8_flat")
.endObject(), .endObject(),
b -> b.field("type", "dense_vector") b -> b.field("type", "dense_vector")
.field("dims", 4) .field("dims", dims)
.field("index", true) .field("index", true)
.startObject("index_options") .startObject("index_options")
.field("type", "hnsw") .field("type", "hnsw")
@ -220,13 +256,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
); );
checker.registerUpdateCheck( checker.registerUpdateCheck(
b -> b.field("type", "dense_vector") b -> b.field("type", "dense_vector")
.field("dims", 4) .field("dims", dims)
.field("index", true) .field("index", true)
.startObject("index_options") .startObject("index_options")
.field("type", "int8_flat") .field("type", "int8_flat")
.endObject(), .endObject(),
b -> b.field("type", "dense_vector") b -> b.field("type", "dense_vector")
.field("dims", 4) .field("dims", dims)
.field("index", true) .field("index", true)
.startObject("index_options") .startObject("index_options")
.field("type", "int8_hnsw") .field("type", "int8_hnsw")
@ -235,13 +271,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
); );
checker.registerUpdateCheck( checker.registerUpdateCheck(
b -> b.field("type", "dense_vector") b -> b.field("type", "dense_vector")
.field("dims", 4) .field("dims", dims)
.field("index", true) .field("index", true)
.startObject("index_options") .startObject("index_options")
.field("type", "hnsw") .field("type", "hnsw")
.endObject(), .endObject(),
b -> b.field("type", "dense_vector") b -> b.field("type", "dense_vector")
.field("dims", 4) .field("dims", dims)
.field("index", true) .field("index", true)
.startObject("index_options") .startObject("index_options")
.field("type", "int8_hnsw") .field("type", "int8_hnsw")
@ -252,7 +288,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
"index_options", "index_options",
fieldMapping( fieldMapping(
b -> b.field("type", "dense_vector") b -> b.field("type", "dense_vector")
.field("dims", 4) .field("dims", dims)
.field("index", true) .field("index", true)
.startObject("index_options") .startObject("index_options")
.field("type", "hnsw") .field("type", "hnsw")
@ -260,7 +296,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
), ),
fieldMapping( fieldMapping(
b -> b.field("type", "dense_vector") b -> b.field("type", "dense_vector")
.field("dims", 4) .field("dims", dims)
.field("index", true) .field("index", true)
.startObject("index_options") .startObject("index_options")
.field("type", "flat") .field("type", "flat")
@ -353,7 +389,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
mapping = mapping(b -> { mapping = mapping(b -> {
b.startObject("field"); b.startObject("field");
b.field("type", "dense_vector") b.field("type", "dense_vector")
.field("dims", 4) .field("dims", dims)
.field("similarity", "cosine") .field("similarity", "cosine")
.field("index", true) .field("index", true)
.startObject("index_options") .startObject("index_options")
@ -648,7 +684,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
() -> createDocumentMapper( () -> createDocumentMapper(
fieldMapping( fieldMapping(
b -> b.field("type", "dense_vector") b -> b.field("type", "dense_vector")
.field("dims", 4) .field("dims", dims)
.field("element_type", "byte") .field("element_type", "byte")
.field("similarity", "l2_norm") .field("similarity", "l2_norm")
.field("index", true) .field("index", true)
@ -1020,6 +1056,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
} }
yield floats; yield floats;
} }
case BIT -> randomByteArrayOfLength(vectorFieldType.getVectorDimensions() / 8);
}; };
} }
@ -1196,7 +1233,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
boolean setEfConstruction = randomBoolean(); boolean setEfConstruction = randomBoolean();
MapperService mapperService = createMapperService(fieldMapping(b -> { MapperService mapperService = createMapperService(fieldMapping(b -> {
b.field("type", "dense_vector"); b.field("type", "dense_vector");
b.field("dims", 4); b.field("dims", dims);
b.field("index", true); b.field("index", true);
b.field("similarity", "dot_product"); b.field("similarity", "dot_product");
b.startObject("index_options"); b.startObject("index_options");
@ -1234,7 +1271,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
for (String quantizedFlatFormat : new String[] { "int8_flat", "int4_flat" }) { for (String quantizedFlatFormat : new String[] { "int8_flat", "int4_flat" }) {
MapperService mapperService = createMapperService(fieldMapping(b -> { MapperService mapperService = createMapperService(fieldMapping(b -> {
b.field("type", "dense_vector"); b.field("type", "dense_vector");
b.field("dims", 4); b.field("dims", dims);
b.field("index", true); b.field("index", true);
b.field("similarity", "dot_product"); b.field("similarity", "dot_product");
b.startObject("index_options"); b.startObject("index_options");
@ -1275,7 +1312,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
float confidenceInterval = (float) randomDoubleBetween(0.90f, 1.0f, true); float confidenceInterval = (float) randomDoubleBetween(0.90f, 1.0f, true);
MapperService mapperService = createMapperService(fieldMapping(b -> { MapperService mapperService = createMapperService(fieldMapping(b -> {
b.field("type", "dense_vector"); b.field("type", "dense_vector");
b.field("dims", 4); b.field("dims", dims);
b.field("index", true); b.field("index", true);
b.field("similarity", "dot_product"); b.field("similarity", "dot_product");
b.startObject("index_options"); b.startObject("index_options");
@ -1316,7 +1353,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
float confidenceInterval = (float) randomDoubleBetween(0.90f, 1.0f, true); float confidenceInterval = (float) randomDoubleBetween(0.90f, 1.0f, true);
MapperService mapperService = createMapperService(fieldMapping(b -> { MapperService mapperService = createMapperService(fieldMapping(b -> {
b.field("type", "dense_vector"); b.field("type", "dense_vector");
b.field("dims", 4); b.field("dims", dims);
b.field("index", true); b.field("index", true);
b.field("similarity", "dot_product"); b.field("similarity", "dot_product");
b.startObject("index_options"); b.startObject("index_options");

View file

@ -185,10 +185,12 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
queryVector[i] = randomByte(); queryVector[i] = randomByte();
floatQueryVector[i] = queryVector[i]; floatQueryVector[i] = queryVector[i];
} }
Query query = field.createKnnQuery(queryVector, 10, null, null, producer); VectorData vectorData = new VectorData(null, queryVector);
Query query = field.createKnnQuery(vectorData, 10, null, null, producer);
assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class));
query = field.createKnnQuery(floatQueryVector, 10, null, null, producer); vectorData = new VectorData(floatQueryVector, null);
query = field.createKnnQuery(vectorData, 10, null, null, producer);
assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class)); assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class));
} }
} }
@ -321,7 +323,8 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
for (int i = 0; i < 4096; i++) { for (int i = 0; i < 4096; i++) {
queryVector[i] = randomByte(); queryVector[i] = randomByte();
} }
Query query = fieldWith4096dims.createKnnQuery(queryVector, 10, null, null, null); VectorData vectorData = new VectorData(null, queryVector);
Query query = fieldWith4096dims.createKnnQuery(vectorData, 10, null, null, null);
assertThat(query, instanceOf(KnnByteVectorQuery.class)); assertThat(query, instanceOf(KnnByteVectorQuery.class));
} }
} }
@ -359,7 +362,10 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
); );
assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude."));
e = expectThrows(IllegalArgumentException.class, () -> cosineField.createKnnQuery(new byte[] { 0, 0, 0 }, 10, null, null, null)); e = expectThrows(
IllegalArgumentException.class,
() -> cosineField.createKnnQuery(new VectorData(null, new byte[] { 0, 0, 0 }), 10, null, null, null)
);
assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude.")); assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude."));
} }
} }

View file

@ -114,10 +114,10 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
); );
e = expectThrows(IllegalArgumentException.class, () -> new Hamming(scoreScript, queryVector, fieldName)); e = expectThrows(IllegalArgumentException.class, () -> new Hamming(scoreScript, queryVector, fieldName));
assertThat(e.getMessage(), containsString("hamming distance is only supported for byte vectors")); assertThat(e.getMessage(), containsString("hamming distance is only supported for byte or bit vectors"));
e = expectThrows(IllegalArgumentException.class, () -> new Hamming(scoreScript, invalidQueryVector, fieldName)); e = expectThrows(IllegalArgumentException.class, () -> new Hamming(scoreScript, invalidQueryVector, fieldName));
assertThat(e.getMessage(), containsString("hamming distance is only supported for byte vectors")); assertThat(e.getMessage(), containsString("hamming distance is only supported for byte or bit vectors"));
// Check scripting infrastructure integration // Check scripting infrastructure integration
DotProduct dotProduct = new DotProduct(scoreScript, queryVector, fieldName); DotProduct dotProduct = new DotProduct(scoreScript, queryVector, fieldName);

View file

@ -122,7 +122,7 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
Query filterQuery = booleanQuery.clauses().isEmpty() ? null : booleanQuery; Query filterQuery = booleanQuery.clauses().isEmpty() ? null : booleanQuery;
// The field should always be resolved to the concrete field // The field should always be resolved to the concrete field
Query knnVectorQueryBuilt = switch (elementType()) { Query knnVectorQueryBuilt = switch (elementType()) {
case BYTE -> new ESKnnByteVectorQuery( case BYTE, BIT -> new ESKnnByteVectorQuery(
VECTOR_FIELD, VECTOR_FIELD,
queryBuilder.queryVector().asByteVector(), queryBuilder.queryVector().asByteVector(),
queryBuilder.numCands(), queryBuilder.numCands(),
@ -145,7 +145,10 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
SearchExecutionContext context = createSearchExecutionContext(); SearchExecutionContext context = createSearchExecutionContext();
KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f }, 10, null); KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f }, 10, null);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context)); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context));
assertThat(e.getMessage(), containsString("the query vector has a different dimension [2] than the index vectors [3]")); assertThat(
e.getMessage(),
containsString("The query vector has a different number of dimensions [2] than the document vectors [3]")
);
} }
public void testNonexistentField() { public void testNonexistentField() {

View file

@ -46,6 +46,7 @@ public class EmbeddingRequestChunker {
return switch (elementType) { return switch (elementType) {
case BYTE -> EmbeddingType.BYTE; case BYTE -> EmbeddingType.BYTE;
case FLOAT -> EmbeddingType.FLOAT; case FLOAT -> EmbeddingType.FLOAT;
case BIT -> throw new IllegalArgumentException("Bit vectors are not supported");
}; };
} }
}; };