Adds new bit element_type for dense_vectors (#110059)

This commit adds `bit` vector support by adding `element_type: bit` for
vectors. This new element type works for indexed and non-indexed
vectors. Additionally, it works with `hnsw` and `flat` index types. No
quantization based codec works with this element type, this is
consistent with `byte` vectors.

`bit` vectors accept up to `32768` dimensions in size and expect vectors
that are being indexed to be encoded either as a hexidecimal string or a
`byte[]` array where each element of the `byte` array represents `8`
bits of the vector.

`bit` vectors support script usage and regular query usage. When
indexed, all comparisons done are `xor` and `popcount` summations (aka,
hamming distance), and the scores are transformed and normalized given
the vector dimensions. Note, indexed bit vectors require `l2_norm` to be
the similarity.

For scripts, `l1norm` is the same as `hamming` distance and `l2norm` is
`sqrt(l1norm)`. `dotProduct` and `cosineSimilarity` are not supported.

Note, the dimensions expected by this element_type are always to be
divisible by `8`, and the `byte[]` vectors provided for index must be
have size `dim/8` size, where each byte element represents `8` bits of
the vectors.

closes: https://github.com/elastic/elasticsearch/issues/48322
This commit is contained in:
Benjamin Trent 2024-06-26 14:48:41 -04:00 committed by GitHub
parent 97651dfb9f
commit 5add44d7d1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
38 changed files with 2713 additions and 187 deletions

View file

@ -0,0 +1,32 @@
pr: 110059
summary: Adds new `bit` `element_type` for `dense_vectors`
area: Vector Search
type: feature
issues: []
highlight:
title: Adds new `bit` `element_type` for `dense_vectors`
body: |-
This adds `bit` vector support by adding `element_type: bit` for
vectors. This new element type works for indexed and non-indexed
vectors. Additionally, it works with `hnsw` and `flat` index types. No
quantization based codec works with this element type, this is
consistent with `byte` vectors.
`bit` vectors accept up to `32768` dimensions in size and expect vectors
that are being indexed to be encoded either as a hexidecimal string or a
`byte[]` array where each element of the `byte` array represents `8`
bits of the vector.
`bit` vectors support script usage and regular query usage. When
indexed, all comparisons done are `xor` and `popcount` summations (aka,
hamming distance), and the scores are transformed and normalized given
the vector dimensions.
For scripts, `l1norm` is the same as `hamming` distance and `l2norm` is
`sqrt(l1norm)`. `dotProduct` and `cosineSimilarity` are not supported.
Note, the dimensions expected by this element_type are always to be
divisible by `8`, and the `byte[]` vectors provided for index must be
have size `dim/8` size, where each byte element represents `8` bits of
the vectors.
notable: true

View file

@ -183,11 +183,23 @@ The following mapping parameters are accepted:
`element_type`::
(Optional, string)
The data type used to encode vectors. The supported data types are
`float` (default) and `byte`. `float` indexes a 4-byte floating-point
value per dimension. `byte` indexes a 1-byte integer value per dimension.
Using `byte` can result in a substantially smaller index size with the
trade off of lower precision. Vectors using `byte` require dimensions with
integer values between -128 to 127, inclusive for both indexing and searching.
`float` (default), `byte`, and bit.
.Valid values for `element_type`
[%collapsible%open]
====
`float`:::
indexes a 4-byte floating-point
value per dimension. This is the default value.
`byte`:::
indexes a 1-byte integer value per dimension.
`bit`:::
indexes a single bit per dimension. Useful for very high-dimensional vectors or models that specifically support bit vectors.
NOTE: when using `bit`, the number of dimensions must be a multiple of 8 and must represent the number of bits.
====
`dims`::
(Optional, integer)
@ -205,7 +217,11 @@ API>>. Defaults to `true`.
The vector similarity metric to use in kNN search. Documents are ranked by
their vector field's similarity to the query vector. The `_score` of each
document will be derived from the similarity, in a way that ensures scores are
positive and that a larger score corresponds to a higher ranking. Defaults to `cosine`.
positive and that a larger score corresponds to a higher ranking.
Defaults to `l2_norm` when `element_type: bit` otherwise defaults to `cosine`.
NOTE: `bit` vectors only support `l2_norm` as their similarity metric.
+
^*^ This parameter can only be specified when `index` is `true`.
+
@ -217,6 +233,9 @@ Computes similarity based on the L^2^ distance (also known as Euclidean
distance) between the vectors. The document `_score` is computed as
`1 / (1 + l2_norm(query, vector)^2)`.
For `bit` vectors, instead of using `l2_norm`, the `hamming` distance between the vectors is used. The `_score`
transformation is `(numBits - hamming(a, b)) / numBits`
`dot_product`:::
Computes the dot product of two unit vectors. This option provides an optimized way
to perform cosine similarity. The constraints and computed score are defined
@ -320,3 +339,112 @@ any issues, but features in technical preview are not subject to the support SLA
of official GA features.
`dense_vector` fields support <<synthetic-source,synthetic `_source`>> .
[[dense-vector-index-bit]]
==== Indexing & Searching bit vectors
When using `element_type: bit`, this will treat all vectors as bit vectors. Bit vectors utilize only a single
bit per dimension and are internally encoded as bytes. This can be useful for very high-dimensional vectors or models.
When using `bit`, the number of dimensions must be a multiple of 8 and must represent the number of bits. Additionally,
with `bit` vectors, the typical vector similarity values are effectively all scored the same, e.g. with `hamming` distance.
Let's compare two `byte[]` arrays, each representing 40 individual bits.
`[-127, 0, 1, 42, 127]` in bits `1000000100000000000000010010101001111111`
`[127, -127, 0, 1, 42]` in bits `0111111110000001000000000000000100101010`
When comparing these two bit, vectors, we first take the {wikipedia}/Hamming_distance[`hamming` distance].
`xor` result:
```
1000000100000000000000010010101001111111
^
0111111110000001000000000000000100101010
=
1111111010000001000000010010101101010101
```
Then, we gather the count of `1` bits in the `xor` result: `18`. To scale for scoring, we subtract from the total number
of bits and divide by the total number of bits: `(40 - 18) / 40 = 0.55`. This would be the `_score` betwee these two
vectors.
Here is an example of indexing and searching bit vectors:
[source,console]
--------------------------------------------------
PUT my-bit-vectors
{
"mappings": {
"properties": {
"my_vector": {
"type": "dense_vector",
"dims": 40, <1>
"element_type": "bit"
}
}
}
}
--------------------------------------------------
<1> The number of dimensions that represents the number of bits
[source,console]
--------------------------------------------------
POST /my-bit-vectors/_bulk?refresh
{"index": {"_id" : "1"}}
{"my_vector": [127, -127, 0, 1, 42]} <1>
{"index": {"_id" : "2"}}
{"my_vector": "8100012a7f"} <2>
--------------------------------------------------
// TEST[continued]
<1> 5 bytes representing the 40 bit dimensioned vector
<2> A hexidecimal string representing the 40 bit dimensioned vector
Then, when searching, you can use the `knn` query to search for similar bit vectors:
[source,console]
--------------------------------------------------
POST /my-bit-vectors/_search?filter_path=hits.hits
{
"query": {
"knn": {
"query_vector": [127, -127, 0, 1, 42],
"field": "my_vector"
}
}
}
--------------------------------------------------
// TEST[continued]
[source,console-result]
----
{
"hits": {
"hits": [
{
"_index": "my-bit-vectors",
"_id": "1",
"_score": 1.0,
"_source": {
"my_vector": [
127,
-127,
0,
1,
42
]
}
},
{
"_index": "my-bit-vectors",
"_id": "2",
"_score": 0.55,
"_source": {
"my_vector": "8100012a7f"
}
}
]
}
}
----

View file

@ -1,4 +1,3 @@
[role="xpack"]
[[vector-functions]]
===== Functions for vector fields
@ -17,6 +16,8 @@ This is the list of available vector functions and vector access methods:
6. <<vector-functions-accessing-vectors,`doc[<field>].vectorValue`>> returns a vector's value as an array of floats
7. <<vector-functions-accessing-vectors,`doc[<field>].magnitude`>> returns a vector's magnitude
NOTE: The `cosineSimilarity` and `dotProduct` functions are not supported for `bit` vectors.
NOTE: The recommended way to access dense vectors is through the
`cosineSimilarity`, `dotProduct`, `l1norm` or `l2norm` functions. Please note
however, that you should call these functions only once per script. For example,
@ -193,7 +194,7 @@ we added `1` in the denominator.
====== Hamming distance
The `hamming` function calculates {wikipedia}/Hamming_distance[Hamming distance] between a given query vector and
document vectors. It is only available for byte vectors.
document vectors. It is only available for byte and bit vectors.
[source,console]
--------------------------------------------------
@ -278,10 +279,14 @@ You can access vector values directly through the following functions:
- `doc[<field>].vectorValue` returns a vector's value as an array of floats
NOTE: For `bit` vectors, it does return a `float[]`, where each element represents 8 bits.
- `doc[<field>].magnitude` returns a vector's magnitude as a float
(for vectors created prior to version 7.5 the magnitude is not stored.
So this function calculates it anew every time it is called).
NOTE: For `bit` vectors, this is just the square root of the sum of `1` bits.
For example, the script below implements a cosine similarity using these
two functions:
@ -319,3 +324,14 @@ GET my-index-000001/_search
}
}
--------------------------------------------------
[[vector-functions-bit-vectors]]
====== Bit vectors and vector functions
When using `bit` vectors, not all the vector functions are available. The supported functions are:
* <<vector-functions-hamming,`hamming`>> calculates Hamming distance, the sum of the bitwise XOR of the two vectors
* <<vector-functions-l1,`l1norm`>> calculates L^1^ distance, this is simply the `hamming` distance
* <<vector-functions-l2,`l2norm`>> - calculates L^2^ distance, this is the square root of the `hamming` distance
Currently, the `cosineSimilarity` and `dotProduct` functions are not supported for `bit` vectors.

View file

@ -0,0 +1,392 @@
setup:
- requires:
cluster_features: ["mapper.vectors.bit_vectors"]
reason: "support for bit vectors added in 8.15"
test_runner_features: headers
- do:
indices.create:
index: test-index
body:
mappings:
properties:
vector:
type: dense_vector
index: false
element_type: bit
dims: 40
indexed_vector:
type: dense_vector
element_type: bit
dims: 40
index: true
similarity: l2_norm
- do:
index:
index: test-index
id: "1"
body:
vector: [8, 5, -15, 1, -7]
indexed_vector: [8, 5, -15, 1, -7]
- do:
index:
index: test-index
id: "2"
body:
vector: [-1, 115, -3, 4, -128]
indexed_vector: [-1, 115, -3, 4, -128]
- do:
index:
index: test-index
id: "3"
body:
vector: [2, 18, -5, 0, -124]
indexed_vector: [2, 18, -5, 0, -124]
- do:
indices.refresh: {}
---
"Test vector magnitude equality":
- skip:
features: close_to
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "doc['vector'].magnitude"
- match: {hits.total: 3}
- match: {hits.hits.0._id: "2"}
- close_to: {hits.hits.0._score: {value: 4.690416, error: 0.01}}
- match: {hits.hits.1._id: "1"}
- close_to: {hits.hits.1._score: {value: 3.8729835, error: 0.01}}
- match: {hits.hits.2._id: "3"}
- close_to: {hits.hits.2._score: {value: 3.4641016, error: 0.01}}
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "doc['indexed_vector'].magnitude"
- match: {hits.total: 3}
- match: {hits.hits.0._id: "2"}
- close_to: {hits.hits.0._score: {value: 4.690416, error: 0.01}}
- match: {hits.hits.1._id: "1"}
- close_to: {hits.hits.1._score: {value: 3.8729835, error: 0.01}}
- match: {hits.hits.2._id: "3"}
- close_to: {hits.hits.2._score: {value: 3.4641016, error: 0.01}}
---
"Dot Product is not supported":
- do:
catch: bad_request
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "dotProduct(params.query_vector, 'vector')"
params:
query_vector: [0, 111, -13, 14, -124]
- do:
catch: bad_request
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "dotProduct(params.query_vector, 'vector')"
params:
query_vector: "006ff30e84"
---
"Cosine Similarity is not supported":
- do:
catch: bad_request
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "cosineSimilarity(params.query_vector, 'vector')"
params:
query_vector: [0, 111, -13, 14, -124]
- do:
catch: bad_request
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "cosineSimilarity(params.query_vector, 'vector')"
params:
query_vector: "006ff30e84"
- do:
catch: bad_request
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "cosineSimilarity(params.query_vector, 'indexed_vector')"
params:
query_vector: [0, 111, -13, 14, -124]
---
"L1 norm":
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "l1norm(params.query_vector, 'vector')"
params:
query_vector: [0, 111, -13, 14, -124]
- match: {hits.total: 3}
- match: {hits.hits.0._id: "2"}
- match: {hits.hits.0._score: 17.0}
- match: {hits.hits.1._id: "1"}
- match: {hits.hits.1._score: 16.0}
- match: {hits.hits.2._id: "3"}
- match: {hits.hits.2._score: 11.0}
---
"L1 norm hexidecimal":
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "l1norm(params.query_vector, 'vector')"
params:
query_vector: "006ff30e84"
- match: {hits.total: 3}
- match: {hits.hits.0._id: "2"}
- match: {hits.hits.0._score: 17.0}
- match: {hits.hits.1._id: "1"}
- match: {hits.hits.1._score: 16.0}
- match: {hits.hits.2._id: "3"}
- match: {hits.hits.2._score: 11.0}
---
"L2 norm":
- requires:
test_runner_features: close_to
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "l2norm(params.query_vector, 'vector')"
params:
query_vector: [0, 111, -13, 14, -124]
- match: {hits.total: 3}
- match: {hits.hits.0._id: "2"}
- close_to: {hits.hits.0._score: {value: 4.123, error: 0.001}}
- match: {hits.hits.1._id: "1"}
- close_to: {hits.hits.1._score: {value: 4, error: 0.001}}
- match: {hits.hits.2._id: "3"}
- close_to: {hits.hits.2._score: {value: 3.316, error: 0.001}}
---
"L2 norm hexidecimal":
- requires:
test_runner_features: close_to
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "l2norm(params.query_vector, 'vector')"
params:
query_vector: "006ff30e84"
- match: {hits.total: 3}
- match: {hits.hits.0._id: "2"}
- close_to: {hits.hits.0._score: {value: 4.123, error: 0.001}}
- match: {hits.hits.1._id: "1"}
- close_to: {hits.hits.1._score: {value: 4, error: 0.001}}
- match: {hits.hits.2._id: "3"}
- close_to: {hits.hits.2._score: {value: 3.316, error: 0.001}}
---
"Hamming distance":
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "hamming(params.query_vector, 'vector')"
params:
query_vector: [0, 111, -13, 14, -124]
- match: {hits.total: 3}
- match: {hits.hits.0._id: "2"}
- match: {hits.hits.0._score: 17.0}
- match: {hits.hits.1._id: "1"}
- match: {hits.hits.1._score: 16.0}
- match: {hits.hits.2._id: "3"}
- match: {hits.hits.2._score: 11.0}
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "hamming(params.query_vector, 'indexed_vector')"
params:
query_vector: [0, 111, -13, 14, -124]
- match: {hits.total: 3}
- match: {hits.hits.0._id: "2"}
- match: {hits.hits.0._score: 17.0}
- match: {hits.hits.1._id: "1"}
- match: {hits.hits.1._score: 16.0}
- match: {hits.hits.2._id: "3"}
- match: {hits.hits.2._score: 11.0}
---
"Hamming distance hexidecimal":
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "hamming(params.query_vector, 'vector')"
params:
query_vector: "006ff30e84"
- match: {hits.total: 3}
- match: {hits.hits.0._id: "2"}
- match: {hits.hits.0._score: 17.0}
- match: {hits.hits.1._id: "1"}
- match: {hits.hits.1._score: 16.0}
- match: {hits.hits.2._id: "3"}
- match: {hits.hits.2._score: 11.0}
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "hamming(params.query_vector, 'indexed_vector')"
params:
query_vector: "006ff30e84"
- match: {hits.total: 3}
- match: {hits.hits.0._id: "2"}
- match: {hits.hits.0._score: 17.0}
- match: {hits.hits.1._id: "1"}
- match: {hits.hits.1._score: 16.0}
- match: {hits.hits.2._id: "3"}
- match: {hits.hits.2._score: 11.0}

View file

@ -0,0 +1,301 @@
setup:
- requires:
cluster_features: "mapper.vectors.bit_vectors"
test_runner_features: close_to
reason: 'bit vectors added in 8.15'
- do:
indices.create:
index: test
body:
settings:
index:
number_of_shards: 2
mappings:
properties:
name:
type: keyword
nested:
type: nested
properties:
paragraph_id:
type: keyword
vector:
type: dense_vector
dims: 40
index: true
element_type: bit
similarity: l2_norm
- do:
index:
index: test
id: "1"
body:
name: cow.jpg
nested:
- paragraph_id: 0
vector: [100, 20, -34, 15, -100]
- paragraph_id: 1
vector: [40, 30, -3, 1, -20]
- do:
index:
index: test
id: "2"
body:
name: moose.jpg
nested:
- paragraph_id: 0
vector: [-1, 100, -13, 14, -127]
- paragraph_id: 2
vector: [0, 100, 0, 15, -127]
- paragraph_id: 3
vector: [0, 1, 0, 2, -15]
- do:
index:
index: test
id: "3"
body:
name: rabbit.jpg
nested:
- paragraph_id: 0
vector: [1, 111, -13, 14, -1]
- do:
indices.refresh: {}
---
"nested kNN search only":
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: nested.vector
query_vector: [-1, 90, -10, 14, -127]
k: 2
num_candidates: 3
- match: {hits.hits.0._id: "2"}
- match: {hits.hits.0.fields.name.0: "moose.jpg"}
- match: {hits.hits.1._id: "1"}
- match: {hits.hits.1.fields.name.0: "cow.jpg"}
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: nested.vector
query_vector: [-1, 90, -10, 14, -127]
k: 2
num_candidates: 3
inner_hits: {size: 1, "fields": ["nested.paragraph_id"], _source: false}
- match: {hits.hits.0._id: "2"}
- match: {hits.hits.0.fields.name.0: "moose.jpg"}
- match: {hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0"}
- match: {hits.hits.1._id: "1"}
- match: {hits.hits.1.fields.name.0: "cow.jpg"}
- match: {hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0"}
---
"nested kNN search filtered":
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: nested.vector
query_vector: [-1, 90, -10, 14, -127]
k: 2
num_candidates: 3
filter: {term: {name: "rabbit.jpg"}}
- match: {hits.total.value: 1}
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: nested.vector
query_vector: [-1, 90, -10, 14, -127]
k: 3
num_candidates: 3
filter: {term: {name: "rabbit.jpg"}}
inner_hits: {size: 1, fields: ["nested.paragraph_id"], _source: false}
- match: {hits.total.value: 1}
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
- match: {hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0"}
---
"nested kNN search inner_hits size > 1":
- do:
index:
index: test
id: "4"
body:
name: moose.jpg
nested:
- paragraph_id: 0
vector: [-1, 90, -10, 14, -127]
- paragraph_id: 2
vector: [ 0, 100.0, 0, 14, -127 ]
- paragraph_id: 3
vector: [ 0, 1.0, 0, 2, -15 ]
- do:
index:
index: test
id: "5"
body:
name: moose.jpg
nested:
- paragraph_id: 0
vector: [ -1, 100, -13, 14, -127 ]
- paragraph_id: 2
vector: [ 0, 100, 0, 15, -127 ]
- paragraph_id: 3
vector: [ 0, 1, 0, 2, -15 ]
- do:
index:
index: test
id: "6"
body:
name: moose.jpg
nested:
- paragraph_id: 0
vector: [ -1, 100, -13, 15, -127 ]
- paragraph_id: 2
vector: [ 0, 100, 0, 15, -127 ]
- paragraph_id: 3
vector: [ 0, 1, 0, 2, -15 ]
- do:
indices.refresh: { }
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: nested.vector
query_vector: [-1, 90, -10, 15, -127]
k: 3
num_candidates: 5
inner_hits: {size: 2, fields: ["nested.paragraph_id"], _source: false}
- match: {hits.total.value: 3}
- length: { hits.hits.0.inner_hits.nested.hits.hits: 2 }
- length: { hits.hits.1.inner_hits.nested.hits.hits: 2 }
- length: { hits.hits.2.inner_hits.nested.hits.hits: 2 }
- match: { hits.hits.0.fields.name.0: "moose.jpg" }
- match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: nested.vector
query_vector: [-1, 90, -10, 15, -127]
k: 5
num_candidates: 5
inner_hits: {size: 2, fields: ["nested.paragraph_id"], _source: false}
- match: {hits.total.value: 5}
# All these initial matches are "moose.jpg", which has 3 nested vectors, but two are closest
- match: {hits.hits.0.fields.name.0: "moose.jpg"}
- length: { hits.hits.0.inner_hits.nested.hits.hits: 2 }
- match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
- match: { hits.hits.0.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "2" }
- match: {hits.hits.1.fields.name.0: "moose.jpg"}
- length: { hits.hits.1.inner_hits.nested.hits.hits: 2 }
- match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
- match: { hits.hits.1.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "2" }
- match: {hits.hits.2.fields.name.0: "moose.jpg"}
- length: { hits.hits.2.inner_hits.nested.hits.hits: 2 }
- match: { hits.hits.2.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
- match: { hits.hits.2.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "2" }
- match: {hits.hits.3.fields.name.0: "moose.jpg"}
- length: { hits.hits.3.inner_hits.nested.hits.hits: 2 }
- match: { hits.hits.3.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
- match: { hits.hits.3.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "2" }
# Rabbit only has one passage vector
- match: {hits.hits.4.fields.name.0: "cow.jpg"}
- length: { hits.hits.4.inner_hits.nested.hits.hits: 2 }
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: nested.vector
query_vector: [ -1, 90, -10, 15, -127 ]
k: 3
num_candidates: 3
filter: {term: {name: "cow.jpg"}}
inner_hits: {size: 3, fields: ["nested.paragraph_id"], _source: false}
- match: {hits.total.value: 1}
- match: { hits.hits.0._id: "1" }
- length: { hits.hits.0.inner_hits.nested.hits.hits: 2 }
- match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
- match: { hits.hits.0.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "1" }
---
"nested kNN search inner_hits & boosting":
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: nested.vector
query_vector: [-1, 90, -10, 15, -127]
k: 3
num_candidates: 5
inner_hits: {size: 2, fields: ["nested.paragraph_id"], _source: false}
- close_to: { hits.hits.0._score: {value: 0.8, error: 0.00001} }
- close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: {value: 0.8, error: 0.00001} }
- close_to: { hits.hits.1._score: {value: 0.625, error: 0.00001} }
- close_to: { hits.hits.1.inner_hits.nested.hits.hits.0._score: {value: 0.625, error: 0.00001} }
- close_to: { hits.hits.2._score: {value: 0.5, error: 0.00001} }
- close_to: { hits.hits.2.inner_hits.nested.hits.hits.0._score: {value: 0.5, error: 0.00001} }
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: nested.vector
query_vector: [-1, 90, -10, 15, -127]
k: 3
num_candidates: 5
boost: 2
inner_hits: {size: 2, fields: ["nested.paragraph_id"], _source: false}
- close_to: { hits.hits.0._score: {value: 1.6, error: 0.00001} }
- close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: {value: 1.6, error: 0.00001} }
- close_to: { hits.hits.1._score: {value: 1.25, error: 0.00001} }
- close_to: { hits.hits.1.inner_hits.nested.hits.hits.0._score: {value: 1.25, error: 0.00001} }
- close_to: { hits.hits.2._score: {value: 1, error: 0.00001} }
- close_to: { hits.hits.2.inner_hits.nested.hits.hits.0._score: {value: 1.0, error: 0.00001} }

View file

@ -116,8 +116,9 @@ setup:
---
"Knn search with hex string for byte field - dimensions mismatch" :
# [64, 10, -30, 10] - is encoded as '400ae20a'
# the error message has been adjusted in later versions
- do:
catch: /the query vector has a different dimension \[4\] than the index vectors \[3\]/
catch: /dimension|dimensions \[4\] than the document|index vectors \[3\]/
search:
index: knn_hex_vector_index
body:

View file

@ -116,8 +116,9 @@ setup:
---
"Knn query with hex string for byte field - dimensions mismatch" :
# [64, 10, -30, 10] - is encoded as '400ae20a'
# the error message has been adjusted in later versions
- do:
catch: /the query vector has a different dimension \[4\] than the index vectors \[3\]/
catch: /dimension|dimensions \[4\] than the document|index vectors \[3\]/
search:
index: knn_hex_vector_index
body:

View file

@ -0,0 +1,356 @@
setup:
- requires:
cluster_features: "mapper.vectors.bit_vectors"
reason: 'mapper.vectors.bit_vectors'
- do:
indices.create:
index: test
body:
mappings:
properties:
name:
type: keyword
vector:
type: dense_vector
element_type: bit
dims: 40
index: true
similarity: l2_norm
- do:
index:
index: test
id: "1"
body:
name: cow.jpg
vector: [2, -1, 1, 4, -3]
- do:
index:
index: test
id: "2"
body:
name: moose.jpg
vector: [127.0, -128.0, 0.0, 1.0, -1.0]
- do:
index:
index: test
id: "3"
body:
name: rabbit.jpg
vector: [5, 4.0, 3, 2.0, 127]
- do:
indices.refresh: {}
---
"kNN search only":
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: vector
query_vector: [127, 127, -128, -128, 127]
k: 2
num_candidates: 3
- match: {hits.hits.0._id: "2"}
- match: {hits.hits.0.fields.name.0: "moose.jpg"}
- match: {hits.hits.1._id: "1"}
- match: {hits.hits.1.fields.name.0: "cow.jpg"}
---
"kNN search plus query":
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: vector
query_vector: [127.0, -128.0, 0.0, 1.0, -1.0]
k: 2
num_candidates: 3
query:
term:
name: rabbit.jpg
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
- match: {hits.hits.1._id: "2"}
- match: {hits.hits.1.fields.name.0: "moose.jpg"}
---
"kNN search with filter":
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: vector
query_vector: [5.0, 4, 3.0, 2, 127.0]
k: 2
num_candidates: 3
filter:
term:
name: "rabbit.jpg"
- match: {hits.total.value: 1}
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: vector
query_vector: [2, -1, 1, 4, -3]
k: 2
num_candidates: 3
filter:
- term:
name: "rabbit.jpg"
- term:
_id: 2
- match: {hits.total.value: 0}
---
"Vector similarity search only":
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
num_candidates: 3
k: 3
field: vector
similarity: 0.98
query_vector: [5, 4.0, 3, 2.0, 127]
- length: {hits.hits: 1}
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
---
"Vector similarity with filter only":
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
num_candidates: 3
k: 3
field: vector
similarity: 0.98
query_vector: [5, 4.0, 3, 2.0, 127]
filter: {"term": {"name": "rabbit.jpg"}}
- length: {hits.hits: 1}
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
num_candidates: 3
k: 3
field: vector
similarity: 0.98
query_vector: [5, 4.0, 3, 2.0, 127]
filter: {"term": {"name": "cow.jpg"}}
- length: {hits.hits: 0}
---
"dim mismatch":
- do:
catch: bad_request
search:
index: test
body:
fields: [ "name" ]
knn:
field: vector
query_vector: [1, 2, 3, 4, 5, 6]
k: 2
num_candidates: 3
---
"disallow quantized vector types":
- do:
catch: bad_request
indices.create:
index: test
body:
mappings:
properties:
name:
type: keyword
vector:
type: dense_vector
element_type: bit
dims: 32
index: true
similarity: l2_norm
index_options:
type: int8_flat
- do:
catch: bad_request
indices.create:
index: test
body:
mappings:
properties:
name:
type: keyword
vector:
type: dense_vector
element_type: bit
dims: 32
index: true
similarity: l2_norm
index_options:
type: int4_flat
- do:
catch: bad_request
indices.create:
index: test
body:
mappings:
properties:
name:
type: keyword
vector:
type: dense_vector
element_type: bit
dims: 32
index: true
similarity: l2_norm
index_options:
type: int8_hnsw
- do:
catch: bad_request
indices.create:
index: test
body:
mappings:
properties:
name:
type: keyword
vector:
type: dense_vector
element_type: bit
dims: 32
index: true
similarity: l2_norm
index_options:
type: int4_hnsw
---
"disallow vector index type change to quantized type":
- do:
catch: bad_request
indices.put_mapping:
index: test
body:
properties:
vector:
type: dense_vector
element_type: bit
dims: 32
index: true
similarity: l2_norm
index_options:
type: int4_hnsw
- do:
catch: bad_request
indices.put_mapping:
index: test
body:
properties:
vector:
type: dense_vector
element_type: bit
dims: 32
index: true
similarity: l2_norm
index_options:
type: int8_hnsw
---
"Defaults to l2_norm with bit vectors":
- do:
indices.create:
index: default_to_l2_norm_bit
body:
mappings:
properties:
vector:
type: dense_vector
element_type: bit
dims: 40
index: true
- do:
indices.get_mapping:
index: default_to_l2_norm_bit
- match: { default_to_l2_norm_bit.mappings.properties.vector.similarity: l2_norm }
---
"Only allow l2_norm with bit vectors":
- do:
catch: bad_request
indices.create:
index: dot_product_fails_for_bits
body:
mappings:
properties:
vector:
type: dense_vector
element_type: bit
dims: 40
index: true
similarity: dot_product
- do:
catch: bad_request
indices.create:
index: cosine_product_fails_for_bits
body:
mappings:
properties:
vector:
type: dense_vector
element_type: bit
dims: 40
index: true
similarity: cosine
- do:
catch: bad_request
indices.create:
index: cosine_product_fails_for_bits
body:
mappings:
properties:
type: dense_vector
element_type: bit
dims: 40
index: true
similarity: max_inner_product

View file

@ -0,0 +1,223 @@
setup:
- requires:
cluster_features: "mapper.vectors.bit_vectors"
reason: 'mapper.vectors.bit_vectors'
- do:
indices.create:
index: test
body:
mappings:
properties:
name:
type: keyword
vector:
type: dense_vector
element_type: bit
dims: 40
index: true
similarity: l2_norm
index_options:
type: flat
- do:
index:
index: test
id: "1"
body:
name: cow.jpg
vector: [2, -1, 1, 4, -3]
- do:
index:
index: test
id: "2"
body:
name: moose.jpg
vector: [127.0, -128.0, 0.0, 1.0, -1.0]
- do:
index:
index: test
id: "3"
body:
name: rabbit.jpg
vector: [5, 4.0, 3, 2.0, 127]
- do:
indices.refresh: {}
---
"kNN search only":
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: vector
query_vector: [127, 127, -128, -128, 127]
k: 2
num_candidates: 3
- match: {hits.hits.0._id: "2"}
- match: {hits.hits.0.fields.name.0: "moose.jpg"}
- match: {hits.hits.1._id: "1"}
- match: {hits.hits.1.fields.name.0: "cow.jpg"}
---
"kNN search plus query":
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: vector
query_vector: [127.0, -128.0, 0.0, 1.0, -1.0]
k: 2
num_candidates: 3
query:
term:
name: rabbit.jpg
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
- match: {hits.hits.1._id: "2"}
- match: {hits.hits.1.fields.name.0: "moose.jpg"}
---
"kNN search with filter":
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: vector
query_vector: [5.0, 4, 3.0, 2, 127.0]
k: 2
num_candidates: 3
filter:
term:
name: "rabbit.jpg"
- match: {hits.total.value: 1}
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
field: vector
query_vector: [2, -1, 1, 4, -3]
k: 2
num_candidates: 3
filter:
- term:
name: "rabbit.jpg"
- term:
_id: 2
- match: {hits.total.value: 0}
---
"Vector similarity search only":
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
num_candidates: 3
k: 3
field: vector
similarity: 0.98
query_vector: [5, 4.0, 3, 2.0, 127]
- length: {hits.hits: 1}
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
---
"Vector similarity with filter only":
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
num_candidates: 3
k: 3
field: vector
similarity: 0.98
query_vector: [5, 4.0, 3, 2.0, 127]
filter: {"term": {"name": "rabbit.jpg"}}
- length: {hits.hits: 1}
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
- do:
search:
index: test
body:
fields: [ "name" ]
knn:
num_candidates: 3
k: 3
field: vector
similarity: 0.98
query_vector: [5, 4.0, 3, 2.0, 127]
filter: {"term": {"name": "cow.jpg"}}
- length: {hits.hits: 0}
---
"dim mismatch":
- do:
catch: bad_request
search:
index: test
body:
fields: [ "name" ]
knn:
field: vector
query_vector: [1, 2, 3, 4, 5, 6]
k: 2
num_candidates: 3
---
"disallow vector index type change to quantized type":
- do:
catch: bad_request
indices.put_mapping:
index: test
body:
properties:
vector:
type: dense_vector
element_type: bit
dims: 32
index: true
similarity: l2_norm
index_options:
type: int4_hnsw
- do:
catch: bad_request
indices.put_mapping:
index: test
body:
properties:
vector:
type: dense_vector
element_type: bit
dims: 32
index: true
similarity: l2_norm
index_options:
type: int8_hnsw

View file

@ -449,7 +449,10 @@ module org.elasticsearch.server {
with
org.elasticsearch.index.codec.vectors.ES813FlatVectorFormat,
org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat,
org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat;
org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat,
org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat,
org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat;
provides org.apache.lucene.codecs.Codec with Elasticsearch814Codec;
provides org.apache.logging.log4j.core.util.ContextDataProvider with org.elasticsearch.common.logging.DynamicContextDataProvider;

View file

@ -54,11 +54,11 @@ public class ES813FlatVectorFormat extends KnnVectorsFormat {
return new ES813FlatVectorReader(format.fieldsReader(state));
}
public static class ES813FlatVectorWriter extends KnnVectorsWriter {
static class ES813FlatVectorWriter extends KnnVectorsWriter {
private final FlatVectorsWriter writer;
public ES813FlatVectorWriter(FlatVectorsWriter writer) {
ES813FlatVectorWriter(FlatVectorsWriter writer) {
super();
this.writer = writer;
}
@ -94,11 +94,11 @@ public class ES813FlatVectorFormat extends KnnVectorsFormat {
}
}
public static class ES813FlatVectorReader extends KnnVectorsReader {
static class ES813FlatVectorReader extends KnnVectorsReader {
private final FlatVectorsReader reader;
public ES813FlatVectorReader(FlatVectorsReader reader) {
ES813FlatVectorReader(FlatVectorsReader reader) {
super();
this.reader = reader;
}

View file

@ -0,0 +1,47 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.index.codec.vectors;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import java.io.IOException;
public class ES815BitFlatVectorFormat extends KnnVectorsFormat {
static final String NAME = "ES815BitFlatVectorFormat";
private final FlatVectorsFormat format = new ES815BitFlatVectorsFormat();
/**
* Sole constructor
*/
public ES815BitFlatVectorFormat() {
super(NAME);
}
@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new ES813FlatVectorFormat.ES813FlatVectorWriter(format.fieldsWriter(state));
}
@Override
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new ES813FlatVectorFormat.ES813FlatVectorReader(format.fieldsReader(state));
}
@Override
public String toString() {
return NAME;
}
}

View file

@ -0,0 +1,143 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.index.codec.vectors;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import java.io.IOException;
class ES815BitFlatVectorsFormat extends FlatVectorsFormat {
private final FlatVectorsFormat delegate = new Lucene99FlatVectorsFormat(FlatBitVectorScorer.INSTANCE);
@Override
public FlatVectorsWriter fieldsWriter(SegmentWriteState segmentWriteState) throws IOException {
return delegate.fieldsWriter(segmentWriteState);
}
@Override
public FlatVectorsReader fieldsReader(SegmentReadState segmentReadState) throws IOException {
return delegate.fieldsReader(segmentReadState);
}
static class FlatBitVectorScorer implements FlatVectorsScorer {
static final FlatBitVectorScorer INSTANCE = new FlatBitVectorScorer();
static void checkDimensions(int queryLen, int fieldLen) {
if (queryLen != fieldLen) {
throw new IllegalArgumentException("vector query dimension: " + queryLen + " differs from field dimension: " + fieldLen);
}
}
@Override
public String toString() {
return super.toString();
}
@Override
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction vectorSimilarityFunction,
RandomAccessVectorValues randomAccessVectorValues
) throws IOException {
assert randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes;
assert vectorSimilarityFunction == VectorSimilarityFunction.EUCLIDEAN;
if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes randomAccessVectorValuesBytes) {
assert randomAccessVectorValues instanceof RandomAccessQuantizedByteVectorValues == false;
return switch (vectorSimilarityFunction) {
case DOT_PRODUCT, MAXIMUM_INNER_PRODUCT, COSINE, EUCLIDEAN -> new HammingScorerSupplier(randomAccessVectorValuesBytes);
};
}
throw new IllegalArgumentException("Unsupported vector type or similarity function");
}
@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction vectorSimilarityFunction,
RandomAccessVectorValues randomAccessVectorValues,
byte[] bytes
) {
assert randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes;
assert vectorSimilarityFunction == VectorSimilarityFunction.EUCLIDEAN;
if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes randomAccessVectorValuesBytes) {
checkDimensions(bytes.length, randomAccessVectorValuesBytes.dimension());
return switch (vectorSimilarityFunction) {
case DOT_PRODUCT, MAXIMUM_INNER_PRODUCT, COSINE, EUCLIDEAN -> new HammingVectorScorer(
randomAccessVectorValuesBytes,
bytes
);
};
}
throw new IllegalArgumentException("Unsupported vector type or similarity function");
}
@Override
public RandomVectorScorer getRandomVectorScorer(
VectorSimilarityFunction vectorSimilarityFunction,
RandomAccessVectorValues randomAccessVectorValues,
float[] floats
) {
throw new IllegalArgumentException("Unsupported vector type");
}
}
static float hammingScore(byte[] a, byte[] b) {
return ((a.length * Byte.SIZE) - VectorUtil.xorBitCount(a, b)) / (float) (a.length * Byte.SIZE);
}
static class HammingVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer {
private final byte[] query;
private final RandomAccessVectorValues.Bytes byteValues;
HammingVectorScorer(RandomAccessVectorValues.Bytes byteValues, byte[] query) {
super(byteValues);
this.query = query;
this.byteValues = byteValues;
}
@Override
public float score(int i) throws IOException {
return hammingScore(byteValues.vectorValue(i), query);
}
}
static class HammingScorerSupplier implements RandomVectorScorerSupplier {
private final RandomAccessVectorValues.Bytes byteValues, byteValues1, byteValues2;
HammingScorerSupplier(RandomAccessVectorValues.Bytes byteValues) throws IOException {
this.byteValues = byteValues;
this.byteValues1 = byteValues.copy();
this.byteValues2 = byteValues.copy();
}
@Override
public RandomVectorScorer scorer(int i) throws IOException {
byte[] query = byteValues1.vectorValue(i);
return new HammingVectorScorer(byteValues2, query);
}
@Override
public RandomVectorScorerSupplier copy() throws IOException {
return new HammingScorerSupplier(byteValues);
}
}
}

View file

@ -0,0 +1,74 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.index.codec.vectors;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import java.io.IOException;
public class ES815HnswBitVectorsFormat extends KnnVectorsFormat {
static final String NAME = "ES815HnswBitVectorsFormat";
static final int MAXIMUM_MAX_CONN = 512;
static final int MAXIMUM_BEAM_WIDTH = 3200;
private final int maxConn;
private final int beamWidth;
private final FlatVectorsFormat flatVectorsFormat = new ES815BitFlatVectorsFormat();
public ES815HnswBitVectorsFormat() {
this(16, 100);
}
public ES815HnswBitVectorsFormat(int maxConn, int beamWidth) {
super(NAME);
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
throw new IllegalArgumentException(
"maxConn must be positive and less than or equal to " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn
);
}
if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) {
throw new IllegalArgumentException(
"beamWidth must be positive and less than or equal to " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth
);
}
this.maxConn = maxConn;
this.beamWidth = beamWidth;
}
@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), 1, null);
}
@Override
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
}
@Override
public String toString() {
return "ES815HnswBitVectorsFormat(name=ES815HnswBitVectorsFormat, maxConn="
+ maxConn
+ ", beamWidth="
+ beamWidth
+ ", flatVectorFormat="
+ flatVectorsFormat
+ ")";
}
}

View file

@ -25,7 +25,8 @@ public class MapperFeatures implements FeatureSpecification {
PassThroughObjectMapper.PASS_THROUGH_PRIORITY,
RangeFieldMapper.NULL_VALUES_OFF_BY_ONE_FIX,
SourceFieldMapper.SYNTHETIC_SOURCE_FALLBACK,
DenseVectorFieldMapper.INT4_QUANTIZATION
DenseVectorFieldMapper.INT4_QUANTIZATION,
DenseVectorFieldMapper.BIT_VECTORS
);
}
}

View file

@ -31,6 +31,7 @@ import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.common.ParsingException;
@ -41,6 +42,8 @@ import org.elasticsearch.index.IndexVersions;
import org.elasticsearch.index.codec.vectors.ES813FlatVectorFormat;
import org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat;
import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat;
import org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat;
import org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat;
import org.elasticsearch.index.fielddata.FieldDataContext;
import org.elasticsearch.index.fielddata.IndexFieldData;
import org.elasticsearch.index.mapper.ArraySourceValueFetcher;
@ -100,6 +103,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
}
public static final NodeFeature INT4_QUANTIZATION = new NodeFeature("mapper.vectors.int4_quantization");
public static final NodeFeature BIT_VECTORS = new NodeFeature("mapper.vectors.bit_vectors");
public static final IndexVersion MAGNITUDE_STORED_INDEX_VERSION = IndexVersions.V_7_5_0;
public static final IndexVersion INDEXED_BY_DEFAULT_INDEX_VERSION = IndexVersions.FIRST_DETACHED_INDEX_VERSION;
@ -109,6 +113,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
public static final String CONTENT_TYPE = "dense_vector";
public static short MAX_DIMS_COUNT = 4096; // maximum allowed number of dimensions
public static int MAX_DIMS_COUNT_BIT = 4096 * Byte.SIZE; // maximum allowed number of dimensions
public static short MIN_DIMS_FOR_DYNAMIC_FLOAT_MAPPING = 128; // minimum number of dims for floats to be dynamically mapped to vector
public static final int MAGNITUDE_BYTES = 4;
@ -134,17 +139,28 @@ public class DenseVectorFieldMapper extends FieldMapper {
throw new MapperParsingException("Property [dims] on field [" + n + "] must be an integer but got [" + o + "]");
}
int dims = XContentMapValues.nodeIntegerValue(o);
if (dims < 1 || dims > MAX_DIMS_COUNT) {
int maxDims = elementType.getValue() == ElementType.BIT ? MAX_DIMS_COUNT_BIT : MAX_DIMS_COUNT;
int minDims = elementType.getValue() == ElementType.BIT ? Byte.SIZE : 1;
if (dims < minDims || dims > maxDims) {
throw new MapperParsingException(
"The number of dimensions for field ["
+ n
+ "] should be in the range [1, "
+ MAX_DIMS_COUNT
+ "] should be in the range ["
+ minDims
+ ", "
+ maxDims
+ "] but was ["
+ dims
+ "]"
);
}
if (elementType.getValue() == ElementType.BIT) {
if (dims % Byte.SIZE != 0) {
throw new MapperParsingException(
"The number of dimensions for field [" + n + "] should be a multiple of 8 but was [" + dims + "]"
);
}
}
return dims;
}, m -> toType(m).fieldType().dims, XContentBuilder::field, Object::toString).setSerializerCheck((id, ic, v) -> v != null)
.setMergeValidator((previous, current, c) -> previous == null || Objects.equals(previous, current));
@ -171,13 +187,27 @@ public class DenseVectorFieldMapper extends FieldMapper {
"similarity",
false,
m -> toType(m).fieldType().similarity,
(Supplier<VectorSimilarity>) () -> indexedByDefault && indexed.getValue() ? VectorSimilarity.COSINE : null,
(Supplier<VectorSimilarity>) () -> {
if (indexedByDefault && indexed.getValue()) {
return elementType.getValue() == ElementType.BIT ? VectorSimilarity.L2_NORM : VectorSimilarity.COSINE;
}
return null;
},
VectorSimilarity.class
).acceptsNull().setSerializerCheck((id, ic, v) -> v != null);
).acceptsNull().setSerializerCheck((id, ic, v) -> v != null).addValidator(vectorSim -> {
if (vectorSim == null) {
return;
}
if (elementType.getValue() == ElementType.BIT && vectorSim != VectorSimilarity.L2_NORM) {
throw new IllegalArgumentException(
"The [" + VectorSimilarity.L2_NORM + "] similarity is the only supported similarity for bit vectors"
);
}
});
this.indexOptions = new Parameter<>(
"index_options",
true,
() -> defaultInt8Hnsw && elementType.getValue() != ElementType.BYTE && this.indexed.getValue()
() -> defaultInt8Hnsw && elementType.getValue() == ElementType.FLOAT && this.indexed.getValue()
? new Int8HnswIndexOptions(
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
@ -266,7 +296,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
public enum ElementType {
BYTE(1) {
BYTE {
@Override
public String toString() {
@ -371,7 +401,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
}
@Override
public double computeDotProduct(VectorData vectorData) {
public double computeSquaredMagnitude(VectorData vectorData) {
return VectorUtil.dotProduct(vectorData.asByteVector(), vectorData.asByteVector());
}
@ -428,7 +458,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
byte[] decodedVector = HexFormat.of().parseHex(context.parser().text());
fieldMapper.checkDimensionMatches(decodedVector.length, context);
VectorData vectorData = VectorData.fromBytes(decodedVector);
double squaredMagnitude = computeDotProduct(vectorData);
double squaredMagnitude = computeSquaredMagnitude(vectorData);
checkVectorMagnitude(
fieldMapper.fieldType().similarity,
errorByteElementsAppender(decodedVector),
@ -463,7 +493,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
@Override
int getNumBytes(int dimensions) {
return dimensions * elementBytes;
return dimensions;
}
@Override
@ -494,7 +524,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
}
},
FLOAT(4) {
FLOAT {
@Override
public String toString() {
@ -596,7 +626,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
}
@Override
public double computeDotProduct(VectorData vectorData) {
public double computeSquaredMagnitude(VectorData vectorData) {
return VectorUtil.dotProduct(vectorData.asFloatVector(), vectorData.asFloatVector());
}
@ -656,7 +686,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
@Override
int getNumBytes(int dimensions) {
return dimensions * elementBytes;
return dimensions * Float.BYTES;
}
@Override
@ -665,14 +695,250 @@ public class DenseVectorFieldMapper extends FieldMapper {
? ByteBuffer.wrap(new byte[numBytes]).order(ByteOrder.LITTLE_ENDIAN)
: ByteBuffer.wrap(new byte[numBytes]);
}
},
BIT {
@Override
public String toString() {
return "bit";
}
@Override
public void writeValue(ByteBuffer byteBuffer, float value) {
byteBuffer.put((byte) value);
}
@Override
public void readAndWriteValue(ByteBuffer byteBuffer, XContentBuilder b) throws IOException {
b.value(byteBuffer.get());
}
private KnnByteVectorField createKnnVectorField(String name, byte[] vector, VectorSimilarityFunction function) {
if (vector == null) {
throw new IllegalArgumentException("vector value must not be null");
}
FieldType denseVectorFieldType = new FieldType();
denseVectorFieldType.setVectorAttributes(vector.length, VectorEncoding.BYTE, function);
denseVectorFieldType.freeze();
return new KnnByteVectorField(name, vector, denseVectorFieldType);
}
@Override
IndexFieldData.Builder fielddataBuilder(DenseVectorFieldType denseVectorFieldType, FieldDataContext fieldDataContext) {
return new VectorIndexFieldData.Builder(
denseVectorFieldType.name(),
CoreValuesSourceType.KEYWORD,
denseVectorFieldType.indexVersionCreated,
this,
denseVectorFieldType.dims,
denseVectorFieldType.indexed,
r -> r
);
}
@Override
public void checkVectorBounds(float[] vector) {
checkNanAndInfinite(vector);
StringBuilder errorBuilder = null;
for (int index = 0; index < vector.length; ++index) {
float value = vector[index];
if (value % 1.0f != 0.0f) {
errorBuilder = new StringBuilder(
"element_type ["
+ this
+ "] vectors only support non-decimal values but found decimal value ["
+ value
+ "] at dim ["
+ index
+ "];"
);
break;
}
if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
errorBuilder = new StringBuilder(
"element_type ["
+ this
+ "] vectors only support integers between ["
+ Byte.MIN_VALUE
+ ", "
+ Byte.MAX_VALUE
+ "] but found ["
+ value
+ "] at dim ["
+ index
+ "];"
);
break;
}
}
if (errorBuilder != null) {
throw new IllegalArgumentException(appendErrorElements(errorBuilder, vector).toString());
}
}
@Override
void checkVectorMagnitude(
VectorSimilarity similarity,
Function<StringBuilder, StringBuilder> appender,
float squaredMagnitude
) {}
@Override
public double computeSquaredMagnitude(VectorData vectorData) {
int count = 0;
int i = 0;
byte[] byteBits = vectorData.asByteVector();
for (int upperBound = byteBits.length & -8; i < upperBound; i += 8) {
count += Long.bitCount((long) BitUtil.VH_NATIVE_LONG.get(byteBits, i));
}
while (i < byteBits.length) {
count += Integer.bitCount(byteBits[i] & 255);
++i;
}
return count;
}
private VectorData parseVectorArray(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
int index = 0;
byte[] vector = new byte[fieldMapper.fieldType().dims / Byte.SIZE];
for (XContentParser.Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser()
.nextToken()) {
fieldMapper.checkDimensionExceeded(index, context);
ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser());
final int value;
if (context.parser().numberType() != XContentParser.NumberType.INT) {
float floatValue = context.parser().floatValue(true);
if (floatValue % 1.0f != 0.0f) {
throw new IllegalArgumentException(
"element_type ["
+ this
+ "] vectors only support non-decimal values but found decimal value ["
+ floatValue
+ "] at dim ["
+ index
+ "];"
);
}
value = (int) floatValue;
} else {
value = context.parser().intValue(true);
}
if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
throw new IllegalArgumentException(
"element_type ["
+ this
+ "] vectors only support integers between ["
+ Byte.MIN_VALUE
+ ", "
+ Byte.MAX_VALUE
+ "] but found ["
+ value
+ "] at dim ["
+ index
+ "];"
);
}
if (index >= vector.length) {
throw new IllegalArgumentException(
"The number of dimensions for field ["
+ fieldMapper.fieldType().name()
+ "] should be ["
+ fieldMapper.fieldType().dims
+ "] but found ["
+ (index + 1) * Byte.SIZE
+ "]"
);
}
vector[index++] = (byte) value;
}
fieldMapper.checkDimensionMatches(index * Byte.SIZE, context);
return VectorData.fromBytes(vector);
}
private VectorData parseHexEncodedVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
byte[] decodedVector = HexFormat.of().parseHex(context.parser().text());
fieldMapper.checkDimensionMatches(decodedVector.length * Byte.SIZE, context);
return VectorData.fromBytes(decodedVector);
}
@Override
VectorData parseKnnVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
XContentParser.Token token = context.parser().currentToken();
return switch (token) {
case START_ARRAY -> parseVectorArray(context, fieldMapper);
case VALUE_STRING -> parseHexEncodedVector(context, fieldMapper);
default -> throw new ParsingException(
context.parser().getTokenLocation(),
format("Unsupported type [%s] for provided value [%s]", token, context.parser().text())
);
};
}
@Override
public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
VectorData vectorData = parseKnnVector(context, fieldMapper);
Field field = createKnnVectorField(
fieldMapper.fieldType().name(),
vectorData.asByteVector(),
fieldMapper.fieldType().similarity.vectorSimilarityFunction(fieldMapper.indexCreatedVersion, this)
);
context.doc().addWithKey(fieldMapper.fieldType().name(), field);
}
@Override
int getNumBytes(int dimensions) {
assert dimensions % Byte.SIZE == 0;
return dimensions / Byte.SIZE;
}
@Override
ByteBuffer createByteBuffer(IndexVersion indexVersion, int numBytes) {
return ByteBuffer.wrap(new byte[numBytes]);
}
@Override
int parseDimensionCount(DocumentParserContext context) throws IOException {
XContentParser.Token currentToken = context.parser().currentToken();
return switch (currentToken) {
case START_ARRAY -> {
int index = 0;
for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) {
index++;
}
yield index * Byte.SIZE;
}
case VALUE_STRING -> {
byte[] decodedVector = HexFormat.of().parseHex(context.parser().text());
yield decodedVector.length * Byte.SIZE;
}
default -> throw new ParsingException(
context.parser().getTokenLocation(),
format("Unsupported type [%s] for provided value [%s]", currentToken, context.parser().text())
);
};
}
@Override
public void checkDimensions(int dvDims, int qvDims) {
if (dvDims != qvDims * Byte.SIZE) {
throw new IllegalArgumentException(
"The query vector has a different number of dimensions ["
+ qvDims * Byte.SIZE
+ "] than the document vectors ["
+ dvDims
+ "]."
);
}
}
};
final int elementBytes;
ElementType(int elementBytes) {
this.elementBytes = elementBytes;
}
public abstract void writeValue(ByteBuffer byteBuffer, float value);
public abstract void readAndWriteValue(ByteBuffer byteBuffer, XContentBuilder b) throws IOException;
@ -695,6 +961,14 @@ public class DenseVectorFieldMapper extends FieldMapper {
float squaredMagnitude
);
public void checkDimensions(int dvDims, int qvDims) {
if (dvDims != qvDims) {
throw new IllegalArgumentException(
"The query vector has a different number of dimensions [" + qvDims + "] than the document vectors [" + dvDims + "]."
);
}
}
int parseDimensionCount(DocumentParserContext context) throws IOException {
int index = 0;
for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) {
@ -775,7 +1049,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
return sb -> appendErrorElements(sb, vector);
}
public abstract double computeDotProduct(VectorData vectorData);
public abstract double computeSquaredMagnitude(VectorData vectorData);
public static ElementType fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
@ -786,7 +1060,9 @@ public class DenseVectorFieldMapper extends FieldMapper {
ElementType.BYTE.toString(),
ElementType.BYTE,
ElementType.FLOAT.toString(),
ElementType.FLOAT
ElementType.FLOAT,
ElementType.BIT.toString(),
ElementType.BIT
);
public enum VectorSimilarity {
@ -795,6 +1071,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
float score(float similarity, ElementType elementType, int dim) {
return switch (elementType) {
case BYTE, FLOAT -> 1f / (1f + similarity * similarity);
case BIT -> (dim - similarity) / dim;
};
}
@ -806,8 +1083,10 @@ public class DenseVectorFieldMapper extends FieldMapper {
COSINE {
@Override
float score(float similarity, ElementType elementType, int dim) {
assert elementType != ElementType.BIT;
return switch (elementType) {
case BYTE, FLOAT -> (1 + similarity) / 2f;
default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]");
};
}
@ -824,6 +1103,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
return switch (elementType) {
case BYTE -> 0.5f + similarity / (float) (dim * (1 << 15));
case FLOAT -> (1 + similarity) / 2f;
default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]");
};
}
@ -837,6 +1117,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
float score(float similarity, ElementType elementType, int dim) {
return switch (elementType) {
case BYTE, FLOAT -> similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1;
default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]");
};
}
@ -863,7 +1144,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
this.type = type;
}
abstract KnnVectorsFormat getVectorsFormat();
abstract KnnVectorsFormat getVectorsFormat(ElementType elementType);
boolean supportsElementType(ElementType elementType) {
return true;
@ -1002,7 +1283,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
}
@Override
KnnVectorsFormat getVectorsFormat() {
KnnVectorsFormat getVectorsFormat(ElementType elementType) {
assert elementType == ElementType.FLOAT;
return new ES813Int8FlatVectorFormat(confidenceInterval, 7, false);
}
@ -1021,7 +1303,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
@Override
boolean supportsElementType(ElementType elementType) {
return elementType != ElementType.BYTE;
return elementType == ElementType.FLOAT;
}
@Override
@ -1047,7 +1329,10 @@ public class DenseVectorFieldMapper extends FieldMapper {
}
@Override
KnnVectorsFormat getVectorsFormat() {
KnnVectorsFormat getVectorsFormat(ElementType elementType) {
if (elementType.equals(ElementType.BIT)) {
return new ES815BitFlatVectorFormat();
}
return new ES813FlatVectorFormat();
}
@ -1083,7 +1368,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
}
@Override
public KnnVectorsFormat getVectorsFormat() {
public KnnVectorsFormat getVectorsFormat(ElementType elementType) {
assert elementType == ElementType.FLOAT;
return new ES814HnswScalarQuantizedVectorsFormat(m, efConstruction, confidenceInterval, 4, true);
}
@ -1126,7 +1412,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
@Override
boolean supportsElementType(ElementType elementType) {
return elementType != ElementType.BYTE;
return elementType == ElementType.FLOAT;
}
@Override
@ -1153,7 +1439,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
}
@Override
public KnnVectorsFormat getVectorsFormat() {
public KnnVectorsFormat getVectorsFormat(ElementType elementType) {
assert elementType == ElementType.FLOAT;
return new ES813Int8FlatVectorFormat(confidenceInterval, 4, true);
}
@ -1186,7 +1473,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
@Override
boolean supportsElementType(ElementType elementType) {
return elementType != ElementType.BYTE;
return elementType == ElementType.FLOAT;
}
@Override
@ -1216,7 +1503,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
}
@Override
public KnnVectorsFormat getVectorsFormat() {
public KnnVectorsFormat getVectorsFormat(ElementType elementType) {
assert elementType == ElementType.FLOAT;
return new ES814HnswScalarQuantizedVectorsFormat(m, efConstruction, confidenceInterval, 7, false);
}
@ -1261,7 +1549,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
@Override
boolean supportsElementType(ElementType elementType) {
return elementType != ElementType.BYTE;
return elementType == ElementType.FLOAT;
}
@Override
@ -1291,7 +1579,10 @@ public class DenseVectorFieldMapper extends FieldMapper {
}
@Override
public KnnVectorsFormat getVectorsFormat() {
public KnnVectorsFormat getVectorsFormat(ElementType elementType) {
if (elementType == ElementType.BIT) {
return new ES815HnswBitVectorsFormat(m, efConstruction);
}
return new Lucene99HnswVectorsFormat(m, efConstruction, 1, null);
}
@ -1412,48 +1703,6 @@ public class DenseVectorFieldMapper extends FieldMapper {
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] doesn't support term queries");
}
public Query createKnnQuery(
byte[] queryVector,
int numCands,
Query filter,
Float similarityThreshold,
BitSetProducer parentFilter
) {
if (isIndexed() == false) {
throw new IllegalArgumentException(
"to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]"
);
}
if (queryVector.length != dims) {
throw new IllegalArgumentException(
"the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]"
);
}
if (elementType != ElementType.BYTE) {
throw new IllegalArgumentException(
"only [" + ElementType.BYTE + "] elements are supported when querying field [" + name() + "]"
);
}
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude);
}
Query knnQuery = parentFilter != null
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, numCands, parentFilter)
: new ESKnnByteVectorQuery(name(), queryVector, numCands, filter);
if (similarityThreshold != null) {
knnQuery = new VectorSimilarityQuery(
knnQuery,
similarityThreshold,
similarity.score(similarityThreshold, elementType, dims)
);
}
return knnQuery;
}
public Query createExactKnnQuery(VectorData queryVector) {
if (isIndexed() == false) {
throw new IllegalArgumentException(
@ -1463,15 +1712,17 @@ public class DenseVectorFieldMapper extends FieldMapper {
return switch (elementType) {
case BYTE -> createExactKnnByteQuery(queryVector.asByteVector());
case FLOAT -> createExactKnnFloatQuery(queryVector.asFloatVector());
case BIT -> createExactKnnBitQuery(queryVector.asByteVector());
};
}
private Query createExactKnnBitQuery(byte[] queryVector) {
elementType.checkDimensions(dims, queryVector.length);
return new DenseVectorQuery.Bytes(queryVector, name());
}
private Query createExactKnnByteQuery(byte[] queryVector) {
if (queryVector.length != dims) {
throw new IllegalArgumentException(
"the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]"
);
}
elementType.checkDimensions(dims, queryVector.length);
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude);
@ -1480,11 +1731,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
}
private Query createExactKnnFloatQuery(float[] queryVector) {
if (queryVector.length != dims) {
throw new IllegalArgumentException(
"the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]"
);
}
elementType.checkDimensions(dims, queryVector.length);
elementType.checkVectorBounds(queryVector);
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
@ -1521,9 +1768,31 @@ public class DenseVectorFieldMapper extends FieldMapper {
return switch (getElementType()) {
case BYTE -> createKnnByteQuery(queryVector.asByteVector(), numCands, filter, similarityThreshold, parentFilter);
case FLOAT -> createKnnFloatQuery(queryVector.asFloatVector(), numCands, filter, similarityThreshold, parentFilter);
case BIT -> createKnnBitQuery(queryVector.asByteVector(), numCands, filter, similarityThreshold, parentFilter);
};
}
private Query createKnnBitQuery(
byte[] queryVector,
int numCands,
Query filter,
Float similarityThreshold,
BitSetProducer parentFilter
) {
elementType.checkDimensions(dims, queryVector.length);
Query knnQuery = parentFilter != null
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, numCands, parentFilter)
: new ESKnnByteVectorQuery(name(), queryVector, numCands, filter);
if (similarityThreshold != null) {
knnQuery = new VectorSimilarityQuery(
knnQuery,
similarityThreshold,
similarity.score(similarityThreshold, elementType, dims)
);
}
return knnQuery;
}
private Query createKnnByteQuery(
byte[] queryVector,
int numCands,
@ -1531,11 +1800,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
Float similarityThreshold,
BitSetProducer parentFilter
) {
if (queryVector.length != dims) {
throw new IllegalArgumentException(
"the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]"
);
}
elementType.checkDimensions(dims, queryVector.length);
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
@ -1561,11 +1826,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
Float similarityThreshold,
BitSetProducer parentFilter
) {
if (queryVector.length != dims) {
throw new IllegalArgumentException(
"the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]"
);
}
elementType.checkDimensions(dims, queryVector.length);
elementType.checkVectorBounds(queryVector);
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
@ -1701,7 +1962,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
vectorData.addToBuffer(byteBuffer);
if (indexCreatedVersion.onOrAfter(MAGNITUDE_STORED_INDEX_VERSION)) {
// encode vector magnitude at the end
double dotProduct = elementType.computeDotProduct(vectorData);
double dotProduct = elementType.computeSquaredMagnitude(vectorData);
float vectorMagnitude = (float) Math.sqrt(dotProduct);
byteBuffer.putFloat(vectorMagnitude);
}
@ -1780,9 +2041,9 @@ public class DenseVectorFieldMapper extends FieldMapper {
public KnnVectorsFormat getKnnVectorsFormatForField(KnnVectorsFormat defaultFormat) {
final KnnVectorsFormat format;
if (indexOptions == null) {
format = defaultFormat;
format = fieldType().elementType == ElementType.BIT ? new ES815HnswBitVectorsFormat() : defaultFormat;
} else {
format = indexOptions.getVectorsFormat();
format = indexOptions.getVectorsFormat(fieldType().elementType);
}
// It's legal to reuse the same format name as this is the same on-disk format.
return new KnnVectorsFormat(format.getName()) {

View file

@ -17,6 +17,8 @@ import org.elasticsearch.index.fielddata.SortedBinaryDocValues;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
import org.elasticsearch.script.field.DocValuesScriptFieldFactory;
import org.elasticsearch.script.field.vectors.BinaryDenseVectorDocValuesField;
import org.elasticsearch.script.field.vectors.BitBinaryDenseVectorDocValuesField;
import org.elasticsearch.script.field.vectors.BitKnnDenseVectorDocValuesField;
import org.elasticsearch.script.field.vectors.ByteBinaryDenseVectorDocValuesField;
import org.elasticsearch.script.field.vectors.ByteKnnDenseVectorDocValuesField;
import org.elasticsearch.script.field.vectors.KnnDenseVectorDocValuesField;
@ -58,12 +60,14 @@ final class VectorDVLeafFieldData implements LeafFieldData {
return switch (elementType) {
case BYTE -> new ByteKnnDenseVectorDocValuesField(reader.getByteVectorValues(field), name, dims);
case FLOAT -> new KnnDenseVectorDocValuesField(reader.getFloatVectorValues(field), name, dims);
case BIT -> new BitKnnDenseVectorDocValuesField(reader.getByteVectorValues(field), name, dims);
};
} else {
BinaryDocValues values = DocValues.getBinary(reader, field);
return switch (elementType) {
case BYTE -> new ByteBinaryDenseVectorDocValuesField(values, name, elementType, dims);
case FLOAT -> new BinaryDenseVectorDocValuesField(values, name, elementType, dims, indexVersion);
case BIT -> new BitBinaryDenseVectorDocValuesField(values, name, elementType, dims);
};
}
} catch (IOException e) {

View file

@ -56,7 +56,7 @@ public class VectorScoreScriptUtils {
*/
public ByteDenseVectorFunction(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
super(scoreScript, field);
DenseVector.checkDimensions(field.get().getDims(), queryVector.size());
field.getElementType().checkDimensions(field.get().getDims(), queryVector.size());
this.queryVector = new byte[queryVector.size()];
float[] validateValues = new float[queryVector.size()];
int queryMagnitude = 0;
@ -168,7 +168,7 @@ public class VectorScoreScriptUtils {
public L1Norm(ScoreScript scoreScript, Object queryVector, String fieldName) {
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
function = switch (field.getElementType()) {
case BYTE -> {
case BYTE, BIT -> {
if (queryVector instanceof List) {
yield new ByteL1Norm(scoreScript, field, (List<Number>) queryVector);
} else if (queryVector instanceof String s) {
@ -219,8 +219,8 @@ public class VectorScoreScriptUtils {
@SuppressWarnings("unchecked")
public Hamming(ScoreScript scoreScript, Object queryVector, String fieldName) {
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
if (field.getElementType() != DenseVectorFieldMapper.ElementType.BYTE) {
throw new IllegalArgumentException("hamming distance is only supported for byte vectors");
if (field.getElementType() == DenseVectorFieldMapper.ElementType.FLOAT) {
throw new IllegalArgumentException("hamming distance is only supported for byte or bit vectors");
}
if (queryVector instanceof List) {
function = new ByteHammingDistance(scoreScript, field, (List<Number>) queryVector);
@ -278,7 +278,7 @@ public class VectorScoreScriptUtils {
public L2Norm(ScoreScript scoreScript, Object queryVector, String fieldName) {
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
function = switch (field.getElementType()) {
case BYTE -> {
case BYTE, BIT -> {
if (queryVector instanceof List) {
yield new ByteL2Norm(scoreScript, field, (List<Number>) queryVector);
} else if (queryVector instanceof String s) {
@ -342,7 +342,7 @@ public class VectorScoreScriptUtils {
public DotProduct(ScoreScript scoreScript, Object queryVector, String fieldName) {
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
function = switch (field.getElementType()) {
case BYTE -> {
case BYTE, BIT -> {
if (queryVector instanceof List) {
yield new ByteDotProduct(scoreScript, field, (List<Number>) queryVector);
} else if (queryVector instanceof String s) {
@ -406,7 +406,7 @@ public class VectorScoreScriptUtils {
public CosineSimilarity(ScoreScript scoreScript, Object queryVector, String fieldName) {
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
function = switch (field.getElementType()) {
case BYTE -> {
case BYTE, BIT -> {
if (queryVector instanceof List) {
yield new ByteCosineSimilarity(scoreScript, field, (List<Number>) queryVector);
} else if (queryVector instanceof String s) {

View file

@ -0,0 +1,88 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.script.field.vectors;
import org.apache.lucene.util.BytesRef;
import java.util.List;
public class BitBinaryDenseVector extends ByteBinaryDenseVector {
public BitBinaryDenseVector(byte[] vectorValue, BytesRef docVector, int dims) {
super(vectorValue, docVector, dims);
}
@Override
public void checkDimensions(int qvDims) {
if (qvDims != dims) {
throw new IllegalArgumentException(
"The query vector has a different number of dimensions ["
+ qvDims * Byte.SIZE
+ "] than the document vectors ["
+ dims * Byte.SIZE
+ "]."
);
}
}
@Override
public int l1Norm(byte[] queryVector) {
return hamming(queryVector);
}
@Override
public double l1Norm(List<Number> queryVector) {
return hamming(queryVector);
}
@Override
public double l2Norm(byte[] queryVector) {
return Math.sqrt(hamming(queryVector));
}
@Override
public double l2Norm(List<Number> queryVector) {
return Math.sqrt(hamming(queryVector));
}
@Override
public int dotProduct(byte[] queryVector) {
throw new UnsupportedOperationException("dotProduct is not supported for bit vectors.");
}
@Override
public double cosineSimilarity(float[] queryVector, boolean normalizeQueryVector) {
throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors.");
}
@Override
public double dotProduct(List<Number> queryVector) {
throw new UnsupportedOperationException("dotProduct is not supported for bit vectors.");
}
@Override
public double cosineSimilarity(byte[] queryVector, float qvMagnitude) {
throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors.");
}
@Override
public double cosineSimilarity(List<Number> queryVector) {
throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors.");
}
@Override
public double dotProduct(float[] queryVector) {
throw new UnsupportedOperationException("dotProduct is not supported for bit vectors.");
}
@Override
public int getDims() {
return dims * Byte.SIZE;
}
}

View file

@ -0,0 +1,24 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.script.field.vectors;
import org.apache.lucene.index.BinaryDocValues;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
public class BitBinaryDenseVectorDocValuesField extends ByteBinaryDenseVectorDocValuesField {
public BitBinaryDenseVectorDocValuesField(BinaryDocValues input, String name, ElementType elementType, int dims) {
super(input, name, elementType, dims / 8);
}
@Override
protected DenseVector getVector() {
return new BitBinaryDenseVector(vectorValue, value, dims);
}
}

View file

@ -0,0 +1,95 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.script.field.vectors;
import java.util.List;
public class BitKnnDenseVector extends ByteKnnDenseVector {
public BitKnnDenseVector(byte[] vector) {
super(vector);
}
@Override
public void checkDimensions(int qvDims) {
if (qvDims != docVector.length) {
throw new IllegalArgumentException(
"The query vector has a different number of dimensions ["
+ qvDims * Byte.SIZE
+ "] than the document vectors ["
+ docVector.length * Byte.SIZE
+ "]."
);
}
}
@Override
public float getMagnitude() {
if (magnitudeCalculated == false) {
magnitude = DenseVector.getBitMagnitude(docVector, docVector.length);
magnitudeCalculated = true;
}
return magnitude;
}
@Override
public int l1Norm(byte[] queryVector) {
return hamming(queryVector);
}
@Override
public double l1Norm(List<Number> queryVector) {
return hamming(queryVector);
}
@Override
public double l2Norm(byte[] queryVector) {
return Math.sqrt(hamming(queryVector));
}
@Override
public double l2Norm(List<Number> queryVector) {
return Math.sqrt(hamming(queryVector));
}
@Override
public int dotProduct(byte[] queryVector) {
throw new UnsupportedOperationException("dotProduct is not supported for bit vectors.");
}
@Override
public double cosineSimilarity(float[] queryVector, boolean normalizeQueryVector) {
throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors.");
}
@Override
public double dotProduct(List<Number> queryVector) {
throw new UnsupportedOperationException("dotProduct is not supported for bit vectors.");
}
@Override
public double cosineSimilarity(byte[] queryVector, float qvMagnitude) {
throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors.");
}
@Override
public double cosineSimilarity(List<Number> queryVector) {
throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors.");
}
@Override
public double dotProduct(float[] queryVector) {
throw new UnsupportedOperationException("dotProduct is not supported for bit vectors.");
}
@Override
public int getDims() {
return docVector.length * Byte.SIZE;
}
}

View file

@ -0,0 +1,26 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.script.field.vectors;
import org.apache.lucene.index.ByteVectorValues;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
public class BitKnnDenseVectorDocValuesField extends ByteKnnDenseVectorDocValuesField {
public BitKnnDenseVectorDocValuesField(@Nullable ByteVectorValues input, String name, int dims) {
super(input, name, dims / 8, DenseVectorFieldMapper.ElementType.BIT);
}
@Override
protected DenseVector getVector() {
return new BitKnnDenseVector(vector);
}
}

View file

@ -21,7 +21,7 @@ public class ByteBinaryDenseVector implements DenseVector {
private final BytesRef docVector;
private final byte[] vectorValue;
private final int dims;
protected final int dims;
private float[] floatDocVector;
private boolean magnitudeDecoded;

View file

@ -17,11 +17,11 @@ import java.io.IOException;
public class ByteBinaryDenseVectorDocValuesField extends DenseVectorDocValuesField {
private final BinaryDocValues input;
private final int dims;
private final byte[] vectorValue;
private boolean decoded;
private BytesRef value;
protected final BinaryDocValues input;
protected final int dims;
protected final byte[] vectorValue;
protected boolean decoded;
protected BytesRef value;
public ByteBinaryDenseVectorDocValuesField(BinaryDocValues input, String name, ElementType elementType, int dims) {
super(name, elementType);
@ -50,13 +50,17 @@ public class ByteBinaryDenseVectorDocValuesField extends DenseVectorDocValuesFie
return value == null;
}
protected DenseVector getVector() {
return new ByteBinaryDenseVector(vectorValue, value, dims);
}
@Override
public DenseVector get() {
if (isEmpty()) {
return DenseVector.EMPTY;
}
decodeVectorIfNecessary();
return new ByteBinaryDenseVector(vectorValue, value, dims);
return getVector();
}
@Override
@ -65,7 +69,7 @@ public class ByteBinaryDenseVectorDocValuesField extends DenseVectorDocValuesFie
return defaultValue;
}
decodeVectorIfNecessary();
return new ByteBinaryDenseVector(vectorValue, value, dims);
return getVector();
}
@Override

View file

@ -23,7 +23,11 @@ public class ByteKnnDenseVectorDocValuesField extends DenseVectorDocValuesField
protected final int dims;
public ByteKnnDenseVectorDocValuesField(@Nullable ByteVectorValues input, String name, int dims) {
super(name, ElementType.BYTE);
this(input, name, dims, ElementType.BYTE);
}
protected ByteKnnDenseVectorDocValuesField(@Nullable ByteVectorValues input, String name, int dims, ElementType elementType) {
super(name, elementType);
this.dims = dims;
this.input = input;
}
@ -57,13 +61,17 @@ public class ByteKnnDenseVectorDocValuesField extends DenseVectorDocValuesField
return vector == null;
}
protected DenseVector getVector() {
return new ByteKnnDenseVector(vector);
}
@Override
public DenseVector get() {
if (isEmpty()) {
return DenseVector.EMPTY;
}
return new ByteKnnDenseVector(vector);
return getVector();
}
@Override
@ -72,7 +80,7 @@ public class ByteKnnDenseVectorDocValuesField extends DenseVectorDocValuesField
return defaultValue;
}
return new ByteKnnDenseVector(vector);
return getVector();
}
@Override

View file

@ -8,6 +8,7 @@
package org.elasticsearch.script.field.vectors;
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.VectorUtil;
import java.util.List;
@ -25,6 +26,10 @@ import java.util.List;
*/
public interface DenseVector {
default void checkDimensions(int qvDims) {
checkDimensions(getDims(), qvDims);
}
float[] getVector();
float getMagnitude();
@ -38,13 +43,13 @@ public interface DenseVector {
@SuppressWarnings("unchecked")
default double dotProduct(Object queryVector) {
if (queryVector instanceof float[] floats) {
checkDimensions(getDims(), floats.length);
checkDimensions(floats.length);
return dotProduct(floats);
} else if (queryVector instanceof List<?> list) {
checkDimensions(getDims(), list.size());
checkDimensions(list.size());
return dotProduct((List<Number>) list);
} else if (queryVector instanceof byte[] bytes) {
checkDimensions(getDims(), bytes.length);
checkDimensions(bytes.length);
return dotProduct(bytes);
}
@ -60,13 +65,13 @@ public interface DenseVector {
@SuppressWarnings("unchecked")
default double l1Norm(Object queryVector) {
if (queryVector instanceof float[] floats) {
checkDimensions(getDims(), floats.length);
checkDimensions(floats.length);
return l1Norm(floats);
} else if (queryVector instanceof List<?> list) {
checkDimensions(getDims(), list.size());
checkDimensions(list.size());
return l1Norm((List<Number>) list);
} else if (queryVector instanceof byte[] bytes) {
checkDimensions(getDims(), bytes.length);
checkDimensions(bytes.length);
return l1Norm(bytes);
}
@ -80,11 +85,11 @@ public interface DenseVector {
@SuppressWarnings("unchecked")
default int hamming(Object queryVector) {
if (queryVector instanceof List<?> list) {
checkDimensions(getDims(), list.size());
checkDimensions(list.size());
return hamming((List<Number>) list);
}
if (queryVector instanceof byte[] bytes) {
checkDimensions(getDims(), bytes.length);
checkDimensions(bytes.length);
return hamming(bytes);
}
@ -100,13 +105,13 @@ public interface DenseVector {
@SuppressWarnings("unchecked")
default double l2Norm(Object queryVector) {
if (queryVector instanceof float[] floats) {
checkDimensions(getDims(), floats.length);
checkDimensions(floats.length);
return l2Norm(floats);
} else if (queryVector instanceof List<?> list) {
checkDimensions(getDims(), list.size());
checkDimensions(list.size());
return l2Norm((List<Number>) list);
} else if (queryVector instanceof byte[] bytes) {
checkDimensions(getDims(), bytes.length);
checkDimensions(bytes.length);
return l2Norm(bytes);
}
@ -150,13 +155,13 @@ public interface DenseVector {
@SuppressWarnings("unchecked")
default double cosineSimilarity(Object queryVector) {
if (queryVector instanceof float[] floats) {
checkDimensions(getDims(), floats.length);
checkDimensions(floats.length);
return cosineSimilarity(floats);
} else if (queryVector instanceof List<?> list) {
checkDimensions(getDims(), list.size());
checkDimensions(list.size());
return cosineSimilarity((List<Number>) list);
} else if (queryVector instanceof byte[] bytes) {
checkDimensions(getDims(), bytes.length);
checkDimensions(bytes.length);
return cosineSimilarity(bytes);
}
@ -184,6 +189,20 @@ public interface DenseVector {
return (float) Math.sqrt(mag);
}
static float getBitMagnitude(byte[] vector, int dims) {
int count = 0;
int i = 0;
for (int upperBound = dims & -8; i < upperBound; i += 8) {
count += Long.bitCount((long) BitUtil.VH_NATIVE_LONG.get(vector, i));
}
while (i < dims) {
count += Integer.bitCount(vector[i] & 255);
++i;
}
return (float) Math.sqrt(count);
}
static float getMagnitude(float[] vector) {
return (float) Math.sqrt(VectorUtil.dotProduct(vector, vector));
}

View file

@ -1,3 +1,5 @@
org.elasticsearch.index.codec.vectors.ES813FlatVectorFormat
org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat
org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat
org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat
org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat

View file

@ -0,0 +1,149 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.index.codec.vectors;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.StoredFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.BaseIndexFileFormatTestCase;
import org.elasticsearch.common.logging.LogConfigurator;
import java.io.IOException;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
abstract class BaseKnnBitVectorsFormatTestCase extends BaseIndexFileFormatTestCase {
static {
LogConfigurator.loadLog4jPlugins();
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
}
@Override
protected void addRandomFields(Document doc) {
doc.add(new KnnByteVectorField("v2", randomVector(30), similarityFunction));
}
protected VectorSimilarityFunction similarityFunction;
protected VectorSimilarityFunction randomSimilarity() {
return VectorSimilarityFunction.values()[random().nextInt(VectorSimilarityFunction.values().length)];
}
byte[] randomVector(int dims) {
byte[] vector = new byte[dims];
random().nextBytes(vector);
return vector;
}
public void testRandom() throws Exception {
IndexWriterConfig iwc = newIndexWriterConfig();
if (random().nextBoolean()) {
iwc.setIndexSort(new Sort(new SortField("sortkey", SortField.Type.INT)));
}
String fieldName = "field";
try (Directory dir = newDirectory(); IndexWriter iw = new IndexWriter(dir, iwc)) {
int numDoc = atLeast(100);
int dimension = atLeast(10);
if (dimension % 2 != 0) {
dimension++;
}
byte[] scratch = new byte[dimension];
int numValues = 0;
byte[][] values = new byte[numDoc][];
for (int i = 0; i < numDoc; i++) {
if (random().nextInt(7) != 3) {
// usually index a vector value for a doc
values[i] = randomVector(dimension);
++numValues;
}
if (random().nextBoolean() && values[i] != null) {
// sometimes use a shared scratch array
System.arraycopy(values[i], 0, scratch, 0, scratch.length);
add(iw, fieldName, i, scratch, similarityFunction);
} else {
add(iw, fieldName, i, values[i], similarityFunction);
}
if (random().nextInt(10) == 2) {
// sometimes delete a random document
int idToDelete = random().nextInt(i + 1);
iw.deleteDocuments(new Term("id", Integer.toString(idToDelete)));
// and remember that it was deleted
if (values[idToDelete] != null) {
values[idToDelete] = null;
--numValues;
}
}
if (random().nextInt(10) == 3) {
iw.commit();
}
}
int numDeletes = 0;
try (IndexReader reader = DirectoryReader.open(iw)) {
int valueCount = 0, totalSize = 0;
for (LeafReaderContext ctx : reader.leaves()) {
ByteVectorValues vectorValues = ctx.reader().getByteVectorValues(fieldName);
if (vectorValues == null) {
continue;
}
totalSize += vectorValues.size();
StoredFields storedFields = ctx.reader().storedFields();
int docId;
while ((docId = vectorValues.nextDoc()) != NO_MORE_DOCS) {
byte[] v = vectorValues.vectorValue();
assertEquals(dimension, v.length);
String idString = storedFields.document(docId).getField("id").stringValue();
int id = Integer.parseInt(idString);
if (ctx.reader().getLiveDocs() == null || ctx.reader().getLiveDocs().get(docId)) {
assertArrayEquals(idString, values[id], v);
++valueCount;
} else {
++numDeletes;
assertNull(values[id]);
}
}
}
assertEquals(numValues, valueCount);
assertEquals(numValues, totalSize - numDeletes);
}
}
}
private void add(IndexWriter iw, String field, int id, byte[] vector, VectorSimilarityFunction similarity) throws IOException {
add(iw, field, id, random().nextInt(100), vector, similarity);
}
private void add(IndexWriter iw, String field, int id, int sortKey, byte[] vector, VectorSimilarityFunction similarityFunction)
throws IOException {
Document doc = new Document();
if (vector != null) {
doc.add(new KnnByteVectorField(field, vector, similarityFunction));
}
doc.add(new NumericDocValuesField("sortkey", sortKey));
String idString = Integer.toString(id);
doc.add(new StringField("id", idString, Field.Store.YES));
Term idTerm = new Term("id", idString);
iw.updateDocument(idTerm, doc);
}
}

View file

@ -12,8 +12,15 @@ import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
import org.elasticsearch.common.logging.LogConfigurator;
public class ES813FlatVectorFormatTests extends BaseKnnVectorsFormatTestCase {
static {
LogConfigurator.loadLog4jPlugins();
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
}
@Override
protected Codec getCodec() {
return new Lucene99Codec() {

View file

@ -12,8 +12,15 @@ import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
import org.elasticsearch.common.logging.LogConfigurator;
public class ES813Int8FlatVectorFormatTests extends BaseKnnVectorsFormatTestCase {
static {
LogConfigurator.loadLog4jPlugins();
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
}
@Override
protected Codec getCodec() {
return new Lucene99Codec() {

View file

@ -0,0 +1,34 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.index.codec.vectors;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.junit.Before;
public class ES815BitFlatVectorFormatTests extends BaseKnnBitVectorsFormatTestCase {
@Override
protected Codec getCodec() {
return new Lucene99Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new ES815BitFlatVectorFormat();
}
};
}
@Before
public void init() {
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
}
}

View file

@ -0,0 +1,33 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.index.codec.vectors;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.junit.Before;
public class ES815HnswBitVectorsFormatTests extends BaseKnnBitVectorsFormatTestCase {
@Override
protected Codec getCodec() {
return new Lucene99Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new ES815HnswBitVectorsFormat();
}
};
}
@Before
public void init() {
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
}
}

View file

@ -236,8 +236,8 @@ public class BinaryDenseVectorScriptDocValuesTests extends ESTestCase {
public static BytesRef mockEncodeDenseVector(float[] values, ElementType elementType, IndexVersion indexVersion) {
int numBytes = indexVersion.onOrAfter(DenseVectorFieldMapper.MAGNITUDE_STORED_INDEX_VERSION)
? elementType.elementBytes * values.length + DenseVectorFieldMapper.MAGNITUDE_BYTES
: elementType.elementBytes * values.length;
? elementType.getNumBytes(values.length) + DenseVectorFieldMapper.MAGNITUDE_BYTES
: elementType.getNumBytes(values.length);
double dotProduct = 0f;
ByteBuffer byteBuffer = elementType.createByteBuffer(indexVersion, numBytes);
for (float value : values) {

View file

@ -71,11 +71,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
private final ElementType elementType;
private final boolean indexed;
private final boolean indexOptionsSet;
private final int dims;
public DenseVectorFieldMapperTests() {
this.elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT);
this.elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT, ElementType.BIT);
this.indexed = randomBoolean();
this.indexOptionsSet = this.indexed && randomBoolean();
this.dims = ElementType.BIT == elementType ? 4 * Byte.SIZE : 4;
}
@Override
@ -89,7 +91,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
}
private void indexMapping(XContentBuilder b, IndexVersion indexVersion) throws IOException {
b.field("type", "dense_vector").field("dims", 4);
b.field("type", "dense_vector").field("dims", dims);
if (elementType != ElementType.FLOAT) {
b.field("element_type", elementType.toString());
}
@ -108,7 +110,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
b.endObject();
}
if (indexed) {
b.field("similarity", "dot_product");
b.field("similarity", elementType == ElementType.BIT ? "l2_norm" : "dot_product");
if (indexOptionsSet) {
b.startObject("index_options");
b.field("type", "hnsw");
@ -121,52 +123,86 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
@Override
protected Object getSampleValueForDocument() {
return elementType == ElementType.BYTE ? List.of((byte) 1, (byte) 1, (byte) 1, (byte) 1) : List.of(0.5, 0.5, 0.5, 0.5);
return elementType == ElementType.FLOAT ? List.of(0.5, 0.5, 0.5, 0.5) : List.of((byte) 1, (byte) 1, (byte) 1, (byte) 1);
}
@Override
protected void registerParameters(ParameterChecker checker) throws IOException {
checker.registerConflictCheck(
"dims",
fieldMapping(b -> b.field("type", "dense_vector").field("dims", 4)),
fieldMapping(b -> b.field("type", "dense_vector").field("dims", 5))
fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims)),
fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims + 8))
);
checker.registerConflictCheck(
"similarity",
fieldMapping(b -> b.field("type", "dense_vector").field("dims", 4).field("index", true).field("similarity", "dot_product")),
fieldMapping(b -> b.field("type", "dense_vector").field("dims", 4).field("index", true).field("similarity", "l2_norm"))
fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims).field("index", true).field("similarity", "dot_product")),
fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims).field("index", true).field("similarity", "l2_norm"))
);
checker.registerConflictCheck(
"index",
fieldMapping(b -> b.field("type", "dense_vector").field("dims", 4).field("index", true).field("similarity", "dot_product")),
fieldMapping(b -> b.field("type", "dense_vector").field("dims", 4).field("index", false))
fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims).field("index", true).field("similarity", "dot_product")),
fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims).field("index", false))
);
checker.registerConflictCheck(
"element_type",
fieldMapping(
b -> b.field("type", "dense_vector")
.field("dims", 4)
.field("dims", dims)
.field("index", true)
.field("similarity", "dot_product")
.field("element_type", "byte")
),
fieldMapping(
b -> b.field("type", "dense_vector")
.field("dims", 4)
.field("dims", dims)
.field("index", true)
.field("similarity", "dot_product")
.field("element_type", "float")
)
);
checker.registerConflictCheck(
"element_type",
fieldMapping(
b -> b.field("type", "dense_vector")
.field("dims", dims)
.field("index", true)
.field("similarity", "l2_norm")
.field("element_type", "float")
),
fieldMapping(
b -> b.field("type", "dense_vector")
.field("dims", dims)
.field("index", true)
.field("similarity", "l2_norm")
.field("element_type", "bit")
)
);
checker.registerConflictCheck(
"element_type",
fieldMapping(
b -> b.field("type", "dense_vector")
.field("dims", dims)
.field("index", true)
.field("similarity", "l2_norm")
.field("element_type", "byte")
),
fieldMapping(
b -> b.field("type", "dense_vector")
.field("dims", dims)
.field("index", true)
.field("similarity", "l2_norm")
.field("element_type", "bit")
)
);
checker.registerUpdateCheck(
b -> b.field("type", "dense_vector")
.field("dims", 4)
.field("dims", dims)
.field("index", true)
.startObject("index_options")
.field("type", "flat")
.endObject(),
b -> b.field("type", "dense_vector")
.field("dims", 4)
.field("dims", dims)
.field("index", true)
.startObject("index_options")
.field("type", "int8_flat")
@ -175,13 +211,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
);
checker.registerUpdateCheck(
b -> b.field("type", "dense_vector")
.field("dims", 4)
.field("dims", dims)
.field("index", true)
.startObject("index_options")
.field("type", "flat")
.endObject(),
b -> b.field("type", "dense_vector")
.field("dims", 4)
.field("dims", dims)
.field("index", true)
.startObject("index_options")
.field("type", "hnsw")
@ -190,13 +226,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
);
checker.registerUpdateCheck(
b -> b.field("type", "dense_vector")
.field("dims", 4)
.field("dims", dims)
.field("index", true)
.startObject("index_options")
.field("type", "flat")
.endObject(),
b -> b.field("type", "dense_vector")
.field("dims", 4)
.field("dims", dims)
.field("index", true)
.startObject("index_options")
.field("type", "int8_hnsw")
@ -205,13 +241,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
);
checker.registerUpdateCheck(
b -> b.field("type", "dense_vector")
.field("dims", 4)
.field("dims", dims)
.field("index", true)
.startObject("index_options")
.field("type", "int8_flat")
.endObject(),
b -> b.field("type", "dense_vector")
.field("dims", 4)
.field("dims", dims)
.field("index", true)
.startObject("index_options")
.field("type", "hnsw")
@ -220,13 +256,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
);
checker.registerUpdateCheck(
b -> b.field("type", "dense_vector")
.field("dims", 4)
.field("dims", dims)
.field("index", true)
.startObject("index_options")
.field("type", "int8_flat")
.endObject(),
b -> b.field("type", "dense_vector")
.field("dims", 4)
.field("dims", dims)
.field("index", true)
.startObject("index_options")
.field("type", "int8_hnsw")
@ -235,13 +271,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
);
checker.registerUpdateCheck(
b -> b.field("type", "dense_vector")
.field("dims", 4)
.field("dims", dims)
.field("index", true)
.startObject("index_options")
.field("type", "hnsw")
.endObject(),
b -> b.field("type", "dense_vector")
.field("dims", 4)
.field("dims", dims)
.field("index", true)
.startObject("index_options")
.field("type", "int8_hnsw")
@ -252,7 +288,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
"index_options",
fieldMapping(
b -> b.field("type", "dense_vector")
.field("dims", 4)
.field("dims", dims)
.field("index", true)
.startObject("index_options")
.field("type", "hnsw")
@ -260,7 +296,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
),
fieldMapping(
b -> b.field("type", "dense_vector")
.field("dims", 4)
.field("dims", dims)
.field("index", true)
.startObject("index_options")
.field("type", "flat")
@ -353,7 +389,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
mapping = mapping(b -> {
b.startObject("field");
b.field("type", "dense_vector")
.field("dims", 4)
.field("dims", dims)
.field("similarity", "cosine")
.field("index", true)
.startObject("index_options")
@ -648,7 +684,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
() -> createDocumentMapper(
fieldMapping(
b -> b.field("type", "dense_vector")
.field("dims", 4)
.field("dims", dims)
.field("element_type", "byte")
.field("similarity", "l2_norm")
.field("index", true)
@ -1020,6 +1056,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
}
yield floats;
}
case BIT -> randomByteArrayOfLength(vectorFieldType.getVectorDimensions() / 8);
};
}
@ -1196,7 +1233,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
boolean setEfConstruction = randomBoolean();
MapperService mapperService = createMapperService(fieldMapping(b -> {
b.field("type", "dense_vector");
b.field("dims", 4);
b.field("dims", dims);
b.field("index", true);
b.field("similarity", "dot_product");
b.startObject("index_options");
@ -1234,7 +1271,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
for (String quantizedFlatFormat : new String[] { "int8_flat", "int4_flat" }) {
MapperService mapperService = createMapperService(fieldMapping(b -> {
b.field("type", "dense_vector");
b.field("dims", 4);
b.field("dims", dims);
b.field("index", true);
b.field("similarity", "dot_product");
b.startObject("index_options");
@ -1275,7 +1312,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
float confidenceInterval = (float) randomDoubleBetween(0.90f, 1.0f, true);
MapperService mapperService = createMapperService(fieldMapping(b -> {
b.field("type", "dense_vector");
b.field("dims", 4);
b.field("dims", dims);
b.field("index", true);
b.field("similarity", "dot_product");
b.startObject("index_options");
@ -1316,7 +1353,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
float confidenceInterval = (float) randomDoubleBetween(0.90f, 1.0f, true);
MapperService mapperService = createMapperService(fieldMapping(b -> {
b.field("type", "dense_vector");
b.field("dims", 4);
b.field("dims", dims);
b.field("index", true);
b.field("similarity", "dot_product");
b.startObject("index_options");

View file

@ -185,10 +185,12 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
queryVector[i] = randomByte();
floatQueryVector[i] = queryVector[i];
}
Query query = field.createKnnQuery(queryVector, 10, null, null, producer);
VectorData vectorData = new VectorData(null, queryVector);
Query query = field.createKnnQuery(vectorData, 10, null, null, producer);
assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class));
query = field.createKnnQuery(floatQueryVector, 10, null, null, producer);
vectorData = new VectorData(floatQueryVector, null);
query = field.createKnnQuery(vectorData, 10, null, null, producer);
assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class));
}
}
@ -321,7 +323,8 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
for (int i = 0; i < 4096; i++) {
queryVector[i] = randomByte();
}
Query query = fieldWith4096dims.createKnnQuery(queryVector, 10, null, null, null);
VectorData vectorData = new VectorData(null, queryVector);
Query query = fieldWith4096dims.createKnnQuery(vectorData, 10, null, null, null);
assertThat(query, instanceOf(KnnByteVectorQuery.class));
}
}
@ -359,7 +362,10 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
);
assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude."));
e = expectThrows(IllegalArgumentException.class, () -> cosineField.createKnnQuery(new byte[] { 0, 0, 0 }, 10, null, null, null));
e = expectThrows(
IllegalArgumentException.class,
() -> cosineField.createKnnQuery(new VectorData(null, new byte[] { 0, 0, 0 }), 10, null, null, null)
);
assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude."));
}
}

View file

@ -114,10 +114,10 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
);
e = expectThrows(IllegalArgumentException.class, () -> new Hamming(scoreScript, queryVector, fieldName));
assertThat(e.getMessage(), containsString("hamming distance is only supported for byte vectors"));
assertThat(e.getMessage(), containsString("hamming distance is only supported for byte or bit vectors"));
e = expectThrows(IllegalArgumentException.class, () -> new Hamming(scoreScript, invalidQueryVector, fieldName));
assertThat(e.getMessage(), containsString("hamming distance is only supported for byte vectors"));
assertThat(e.getMessage(), containsString("hamming distance is only supported for byte or bit vectors"));
// Check scripting infrastructure integration
DotProduct dotProduct = new DotProduct(scoreScript, queryVector, fieldName);

View file

@ -122,7 +122,7 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
Query filterQuery = booleanQuery.clauses().isEmpty() ? null : booleanQuery;
// The field should always be resolved to the concrete field
Query knnVectorQueryBuilt = switch (elementType()) {
case BYTE -> new ESKnnByteVectorQuery(
case BYTE, BIT -> new ESKnnByteVectorQuery(
VECTOR_FIELD,
queryBuilder.queryVector().asByteVector(),
queryBuilder.numCands(),
@ -145,7 +145,10 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
SearchExecutionContext context = createSearchExecutionContext();
KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f }, 10, null);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context));
assertThat(e.getMessage(), containsString("the query vector has a different dimension [2] than the index vectors [3]"));
assertThat(
e.getMessage(),
containsString("The query vector has a different number of dimensions [2] than the document vectors [3]")
);
}
public void testNonexistentField() {

View file

@ -46,6 +46,7 @@ public class EmbeddingRequestChunker {
return switch (elementType) {
case BYTE -> EmbeddingType.BYTE;
case FLOAT -> EmbeddingType.FLOAT;
case BIT -> throw new IllegalArgumentException("Bit vectors are not supported");
};
}
};