mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-29 01:44:36 -04:00
Adds new bit
element_type for dense_vectors (#110059)
This commit adds `bit` vector support by adding `element_type: bit` for vectors. This new element type works for indexed and non-indexed vectors. Additionally, it works with `hnsw` and `flat` index types. No quantization based codec works with this element type, this is consistent with `byte` vectors. `bit` vectors accept up to `32768` dimensions in size and expect vectors that are being indexed to be encoded either as a hexidecimal string or a `byte[]` array where each element of the `byte` array represents `8` bits of the vector. `bit` vectors support script usage and regular query usage. When indexed, all comparisons done are `xor` and `popcount` summations (aka, hamming distance), and the scores are transformed and normalized given the vector dimensions. Note, indexed bit vectors require `l2_norm` to be the similarity. For scripts, `l1norm` is the same as `hamming` distance and `l2norm` is `sqrt(l1norm)`. `dotProduct` and `cosineSimilarity` are not supported. Note, the dimensions expected by this element_type are always to be divisible by `8`, and the `byte[]` vectors provided for index must be have size `dim/8` size, where each byte element represents `8` bits of the vectors. closes: https://github.com/elastic/elasticsearch/issues/48322
This commit is contained in:
parent
97651dfb9f
commit
5add44d7d1
38 changed files with 2713 additions and 187 deletions
32
docs/changelog/110059.yaml
Normal file
32
docs/changelog/110059.yaml
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
pr: 110059
|
||||||
|
summary: Adds new `bit` `element_type` for `dense_vectors`
|
||||||
|
area: Vector Search
|
||||||
|
type: feature
|
||||||
|
issues: []
|
||||||
|
highlight:
|
||||||
|
title: Adds new `bit` `element_type` for `dense_vectors`
|
||||||
|
body: |-
|
||||||
|
This adds `bit` vector support by adding `element_type: bit` for
|
||||||
|
vectors. This new element type works for indexed and non-indexed
|
||||||
|
vectors. Additionally, it works with `hnsw` and `flat` index types. No
|
||||||
|
quantization based codec works with this element type, this is
|
||||||
|
consistent with `byte` vectors.
|
||||||
|
|
||||||
|
`bit` vectors accept up to `32768` dimensions in size and expect vectors
|
||||||
|
that are being indexed to be encoded either as a hexidecimal string or a
|
||||||
|
`byte[]` array where each element of the `byte` array represents `8`
|
||||||
|
bits of the vector.
|
||||||
|
|
||||||
|
`bit` vectors support script usage and regular query usage. When
|
||||||
|
indexed, all comparisons done are `xor` and `popcount` summations (aka,
|
||||||
|
hamming distance), and the scores are transformed and normalized given
|
||||||
|
the vector dimensions.
|
||||||
|
|
||||||
|
For scripts, `l1norm` is the same as `hamming` distance and `l2norm` is
|
||||||
|
`sqrt(l1norm)`. `dotProduct` and `cosineSimilarity` are not supported.
|
||||||
|
|
||||||
|
Note, the dimensions expected by this element_type are always to be
|
||||||
|
divisible by `8`, and the `byte[]` vectors provided for index must be
|
||||||
|
have size `dim/8` size, where each byte element represents `8` bits of
|
||||||
|
the vectors.
|
||||||
|
notable: true
|
|
@ -183,11 +183,23 @@ The following mapping parameters are accepted:
|
||||||
`element_type`::
|
`element_type`::
|
||||||
(Optional, string)
|
(Optional, string)
|
||||||
The data type used to encode vectors. The supported data types are
|
The data type used to encode vectors. The supported data types are
|
||||||
`float` (default) and `byte`. `float` indexes a 4-byte floating-point
|
`float` (default), `byte`, and bit.
|
||||||
value per dimension. `byte` indexes a 1-byte integer value per dimension.
|
|
||||||
Using `byte` can result in a substantially smaller index size with the
|
.Valid values for `element_type`
|
||||||
trade off of lower precision. Vectors using `byte` require dimensions with
|
[%collapsible%open]
|
||||||
integer values between -128 to 127, inclusive for both indexing and searching.
|
====
|
||||||
|
`float`:::
|
||||||
|
indexes a 4-byte floating-point
|
||||||
|
value per dimension. This is the default value.
|
||||||
|
|
||||||
|
`byte`:::
|
||||||
|
indexes a 1-byte integer value per dimension.
|
||||||
|
|
||||||
|
`bit`:::
|
||||||
|
indexes a single bit per dimension. Useful for very high-dimensional vectors or models that specifically support bit vectors.
|
||||||
|
NOTE: when using `bit`, the number of dimensions must be a multiple of 8 and must represent the number of bits.
|
||||||
|
|
||||||
|
====
|
||||||
|
|
||||||
`dims`::
|
`dims`::
|
||||||
(Optional, integer)
|
(Optional, integer)
|
||||||
|
@ -205,7 +217,11 @@ API>>. Defaults to `true`.
|
||||||
The vector similarity metric to use in kNN search. Documents are ranked by
|
The vector similarity metric to use in kNN search. Documents are ranked by
|
||||||
their vector field's similarity to the query vector. The `_score` of each
|
their vector field's similarity to the query vector. The `_score` of each
|
||||||
document will be derived from the similarity, in a way that ensures scores are
|
document will be derived from the similarity, in a way that ensures scores are
|
||||||
positive and that a larger score corresponds to a higher ranking. Defaults to `cosine`.
|
positive and that a larger score corresponds to a higher ranking.
|
||||||
|
Defaults to `l2_norm` when `element_type: bit` otherwise defaults to `cosine`.
|
||||||
|
|
||||||
|
NOTE: `bit` vectors only support `l2_norm` as their similarity metric.
|
||||||
|
|
||||||
+
|
+
|
||||||
^*^ This parameter can only be specified when `index` is `true`.
|
^*^ This parameter can only be specified when `index` is `true`.
|
||||||
+
|
+
|
||||||
|
@ -217,6 +233,9 @@ Computes similarity based on the L^2^ distance (also known as Euclidean
|
||||||
distance) between the vectors. The document `_score` is computed as
|
distance) between the vectors. The document `_score` is computed as
|
||||||
`1 / (1 + l2_norm(query, vector)^2)`.
|
`1 / (1 + l2_norm(query, vector)^2)`.
|
||||||
|
|
||||||
|
For `bit` vectors, instead of using `l2_norm`, the `hamming` distance between the vectors is used. The `_score`
|
||||||
|
transformation is `(numBits - hamming(a, b)) / numBits`
|
||||||
|
|
||||||
`dot_product`:::
|
`dot_product`:::
|
||||||
Computes the dot product of two unit vectors. This option provides an optimized way
|
Computes the dot product of two unit vectors. This option provides an optimized way
|
||||||
to perform cosine similarity. The constraints and computed score are defined
|
to perform cosine similarity. The constraints and computed score are defined
|
||||||
|
@ -320,3 +339,112 @@ any issues, but features in technical preview are not subject to the support SLA
|
||||||
of official GA features.
|
of official GA features.
|
||||||
|
|
||||||
`dense_vector` fields support <<synthetic-source,synthetic `_source`>> .
|
`dense_vector` fields support <<synthetic-source,synthetic `_source`>> .
|
||||||
|
|
||||||
|
[[dense-vector-index-bit]]
|
||||||
|
==== Indexing & Searching bit vectors
|
||||||
|
|
||||||
|
When using `element_type: bit`, this will treat all vectors as bit vectors. Bit vectors utilize only a single
|
||||||
|
bit per dimension and are internally encoded as bytes. This can be useful for very high-dimensional vectors or models.
|
||||||
|
|
||||||
|
When using `bit`, the number of dimensions must be a multiple of 8 and must represent the number of bits. Additionally,
|
||||||
|
with `bit` vectors, the typical vector similarity values are effectively all scored the same, e.g. with `hamming` distance.
|
||||||
|
|
||||||
|
Let's compare two `byte[]` arrays, each representing 40 individual bits.
|
||||||
|
|
||||||
|
`[-127, 0, 1, 42, 127]` in bits `1000000100000000000000010010101001111111`
|
||||||
|
`[127, -127, 0, 1, 42]` in bits `0111111110000001000000000000000100101010`
|
||||||
|
|
||||||
|
When comparing these two bit, vectors, we first take the {wikipedia}/Hamming_distance[`hamming` distance].
|
||||||
|
|
||||||
|
`xor` result:
|
||||||
|
```
|
||||||
|
1000000100000000000000010010101001111111
|
||||||
|
^
|
||||||
|
0111111110000001000000000000000100101010
|
||||||
|
=
|
||||||
|
1111111010000001000000010010101101010101
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, we gather the count of `1` bits in the `xor` result: `18`. To scale for scoring, we subtract from the total number
|
||||||
|
of bits and divide by the total number of bits: `(40 - 18) / 40 = 0.55`. This would be the `_score` betwee these two
|
||||||
|
vectors.
|
||||||
|
|
||||||
|
Here is an example of indexing and searching bit vectors:
|
||||||
|
|
||||||
|
[source,console]
|
||||||
|
--------------------------------------------------
|
||||||
|
PUT my-bit-vectors
|
||||||
|
{
|
||||||
|
"mappings": {
|
||||||
|
"properties": {
|
||||||
|
"my_vector": {
|
||||||
|
"type": "dense_vector",
|
||||||
|
"dims": 40, <1>
|
||||||
|
"element_type": "bit"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
--------------------------------------------------
|
||||||
|
<1> The number of dimensions that represents the number of bits
|
||||||
|
|
||||||
|
[source,console]
|
||||||
|
--------------------------------------------------
|
||||||
|
POST /my-bit-vectors/_bulk?refresh
|
||||||
|
{"index": {"_id" : "1"}}
|
||||||
|
{"my_vector": [127, -127, 0, 1, 42]} <1>
|
||||||
|
{"index": {"_id" : "2"}}
|
||||||
|
{"my_vector": "8100012a7f"} <2>
|
||||||
|
--------------------------------------------------
|
||||||
|
// TEST[continued]
|
||||||
|
<1> 5 bytes representing the 40 bit dimensioned vector
|
||||||
|
<2> A hexidecimal string representing the 40 bit dimensioned vector
|
||||||
|
|
||||||
|
Then, when searching, you can use the `knn` query to search for similar bit vectors:
|
||||||
|
|
||||||
|
[source,console]
|
||||||
|
--------------------------------------------------
|
||||||
|
POST /my-bit-vectors/_search?filter_path=hits.hits
|
||||||
|
{
|
||||||
|
"query": {
|
||||||
|
"knn": {
|
||||||
|
"query_vector": [127, -127, 0, 1, 42],
|
||||||
|
"field": "my_vector"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
--------------------------------------------------
|
||||||
|
// TEST[continued]
|
||||||
|
|
||||||
|
[source,console-result]
|
||||||
|
----
|
||||||
|
{
|
||||||
|
"hits": {
|
||||||
|
"hits": [
|
||||||
|
{
|
||||||
|
"_index": "my-bit-vectors",
|
||||||
|
"_id": "1",
|
||||||
|
"_score": 1.0,
|
||||||
|
"_source": {
|
||||||
|
"my_vector": [
|
||||||
|
127,
|
||||||
|
-127,
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
42
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"_index": "my-bit-vectors",
|
||||||
|
"_id": "2",
|
||||||
|
"_score": 0.55,
|
||||||
|
"_source": {
|
||||||
|
"my_vector": "8100012a7f"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
----
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
[role="xpack"]
|
|
||||||
[[vector-functions]]
|
[[vector-functions]]
|
||||||
===== Functions for vector fields
|
===== Functions for vector fields
|
||||||
|
|
||||||
|
@ -17,6 +16,8 @@ This is the list of available vector functions and vector access methods:
|
||||||
6. <<vector-functions-accessing-vectors,`doc[<field>].vectorValue`>> – returns a vector's value as an array of floats
|
6. <<vector-functions-accessing-vectors,`doc[<field>].vectorValue`>> – returns a vector's value as an array of floats
|
||||||
7. <<vector-functions-accessing-vectors,`doc[<field>].magnitude`>> – returns a vector's magnitude
|
7. <<vector-functions-accessing-vectors,`doc[<field>].magnitude`>> – returns a vector's magnitude
|
||||||
|
|
||||||
|
NOTE: The `cosineSimilarity` and `dotProduct` functions are not supported for `bit` vectors.
|
||||||
|
|
||||||
NOTE: The recommended way to access dense vectors is through the
|
NOTE: The recommended way to access dense vectors is through the
|
||||||
`cosineSimilarity`, `dotProduct`, `l1norm` or `l2norm` functions. Please note
|
`cosineSimilarity`, `dotProduct`, `l1norm` or `l2norm` functions. Please note
|
||||||
however, that you should call these functions only once per script. For example,
|
however, that you should call these functions only once per script. For example,
|
||||||
|
@ -193,7 +194,7 @@ we added `1` in the denominator.
|
||||||
====== Hamming distance
|
====== Hamming distance
|
||||||
|
|
||||||
The `hamming` function calculates {wikipedia}/Hamming_distance[Hamming distance] between a given query vector and
|
The `hamming` function calculates {wikipedia}/Hamming_distance[Hamming distance] between a given query vector and
|
||||||
document vectors. It is only available for byte vectors.
|
document vectors. It is only available for byte and bit vectors.
|
||||||
|
|
||||||
[source,console]
|
[source,console]
|
||||||
--------------------------------------------------
|
--------------------------------------------------
|
||||||
|
@ -278,10 +279,14 @@ You can access vector values directly through the following functions:
|
||||||
|
|
||||||
- `doc[<field>].vectorValue` – returns a vector's value as an array of floats
|
- `doc[<field>].vectorValue` – returns a vector's value as an array of floats
|
||||||
|
|
||||||
|
NOTE: For `bit` vectors, it does return a `float[]`, where each element represents 8 bits.
|
||||||
|
|
||||||
- `doc[<field>].magnitude` – returns a vector's magnitude as a float
|
- `doc[<field>].magnitude` – returns a vector's magnitude as a float
|
||||||
(for vectors created prior to version 7.5 the magnitude is not stored.
|
(for vectors created prior to version 7.5 the magnitude is not stored.
|
||||||
So this function calculates it anew every time it is called).
|
So this function calculates it anew every time it is called).
|
||||||
|
|
||||||
|
NOTE: For `bit` vectors, this is just the square root of the sum of `1` bits.
|
||||||
|
|
||||||
For example, the script below implements a cosine similarity using these
|
For example, the script below implements a cosine similarity using these
|
||||||
two functions:
|
two functions:
|
||||||
|
|
||||||
|
@ -319,3 +324,14 @@ GET my-index-000001/_search
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
--------------------------------------------------
|
--------------------------------------------------
|
||||||
|
[[vector-functions-bit-vectors]]
|
||||||
|
====== Bit vectors and vector functions
|
||||||
|
|
||||||
|
When using `bit` vectors, not all the vector functions are available. The supported functions are:
|
||||||
|
|
||||||
|
* <<vector-functions-hamming,`hamming`>> – calculates Hamming distance, the sum of the bitwise XOR of the two vectors
|
||||||
|
* <<vector-functions-l1,`l1norm`>> – calculates L^1^ distance, this is simply the `hamming` distance
|
||||||
|
* <<vector-functions-l2,`l2norm`>> - calculates L^2^ distance, this is the square root of the `hamming` distance
|
||||||
|
|
||||||
|
Currently, the `cosineSimilarity` and `dotProduct` functions are not supported for `bit` vectors.
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,392 @@
|
||||||
|
setup:
|
||||||
|
- requires:
|
||||||
|
cluster_features: ["mapper.vectors.bit_vectors"]
|
||||||
|
reason: "support for bit vectors added in 8.15"
|
||||||
|
test_runner_features: headers
|
||||||
|
|
||||||
|
- do:
|
||||||
|
indices.create:
|
||||||
|
index: test-index
|
||||||
|
body:
|
||||||
|
mappings:
|
||||||
|
properties:
|
||||||
|
vector:
|
||||||
|
type: dense_vector
|
||||||
|
index: false
|
||||||
|
element_type: bit
|
||||||
|
dims: 40
|
||||||
|
indexed_vector:
|
||||||
|
type: dense_vector
|
||||||
|
element_type: bit
|
||||||
|
dims: 40
|
||||||
|
index: true
|
||||||
|
similarity: l2_norm
|
||||||
|
|
||||||
|
- do:
|
||||||
|
index:
|
||||||
|
index: test-index
|
||||||
|
id: "1"
|
||||||
|
body:
|
||||||
|
vector: [8, 5, -15, 1, -7]
|
||||||
|
indexed_vector: [8, 5, -15, 1, -7]
|
||||||
|
|
||||||
|
- do:
|
||||||
|
index:
|
||||||
|
index: test-index
|
||||||
|
id: "2"
|
||||||
|
body:
|
||||||
|
vector: [-1, 115, -3, 4, -128]
|
||||||
|
indexed_vector: [-1, 115, -3, 4, -128]
|
||||||
|
|
||||||
|
- do:
|
||||||
|
index:
|
||||||
|
index: test-index
|
||||||
|
id: "3"
|
||||||
|
body:
|
||||||
|
vector: [2, 18, -5, 0, -124]
|
||||||
|
indexed_vector: [2, 18, -5, 0, -124]
|
||||||
|
|
||||||
|
- do:
|
||||||
|
indices.refresh: {}
|
||||||
|
|
||||||
|
---
|
||||||
|
"Test vector magnitude equality":
|
||||||
|
- skip:
|
||||||
|
features: close_to
|
||||||
|
|
||||||
|
- do:
|
||||||
|
headers:
|
||||||
|
Content-Type: application/json
|
||||||
|
search:
|
||||||
|
rest_total_hits_as_int: true
|
||||||
|
body:
|
||||||
|
query:
|
||||||
|
script_score:
|
||||||
|
query: {match_all: {} }
|
||||||
|
script:
|
||||||
|
source: "doc['vector'].magnitude"
|
||||||
|
|
||||||
|
- match: {hits.total: 3}
|
||||||
|
|
||||||
|
- match: {hits.hits.0._id: "2"}
|
||||||
|
- close_to: {hits.hits.0._score: {value: 4.690416, error: 0.01}}
|
||||||
|
|
||||||
|
- match: {hits.hits.1._id: "1"}
|
||||||
|
- close_to: {hits.hits.1._score: {value: 3.8729835, error: 0.01}}
|
||||||
|
|
||||||
|
- match: {hits.hits.2._id: "3"}
|
||||||
|
- close_to: {hits.hits.2._score: {value: 3.4641016, error: 0.01}}
|
||||||
|
|
||||||
|
- do:
|
||||||
|
headers:
|
||||||
|
Content-Type: application/json
|
||||||
|
search:
|
||||||
|
rest_total_hits_as_int: true
|
||||||
|
body:
|
||||||
|
query:
|
||||||
|
script_score:
|
||||||
|
query: {match_all: {} }
|
||||||
|
script:
|
||||||
|
source: "doc['indexed_vector'].magnitude"
|
||||||
|
|
||||||
|
- match: {hits.total: 3}
|
||||||
|
|
||||||
|
- match: {hits.hits.0._id: "2"}
|
||||||
|
- close_to: {hits.hits.0._score: {value: 4.690416, error: 0.01}}
|
||||||
|
|
||||||
|
- match: {hits.hits.1._id: "1"}
|
||||||
|
- close_to: {hits.hits.1._score: {value: 3.8729835, error: 0.01}}
|
||||||
|
|
||||||
|
- match: {hits.hits.2._id: "3"}
|
||||||
|
- close_to: {hits.hits.2._score: {value: 3.4641016, error: 0.01}}
|
||||||
|
|
||||||
|
---
|
||||||
|
"Dot Product is not supported":
|
||||||
|
- do:
|
||||||
|
catch: bad_request
|
||||||
|
headers:
|
||||||
|
Content-Type: application/json
|
||||||
|
search:
|
||||||
|
rest_total_hits_as_int: true
|
||||||
|
body:
|
||||||
|
query:
|
||||||
|
script_score:
|
||||||
|
query: {match_all: {} }
|
||||||
|
script:
|
||||||
|
source: "dotProduct(params.query_vector, 'vector')"
|
||||||
|
params:
|
||||||
|
query_vector: [0, 111, -13, 14, -124]
|
||||||
|
- do:
|
||||||
|
catch: bad_request
|
||||||
|
headers:
|
||||||
|
Content-Type: application/json
|
||||||
|
search:
|
||||||
|
rest_total_hits_as_int: true
|
||||||
|
body:
|
||||||
|
query:
|
||||||
|
script_score:
|
||||||
|
query: {match_all: {} }
|
||||||
|
script:
|
||||||
|
source: "dotProduct(params.query_vector, 'vector')"
|
||||||
|
params:
|
||||||
|
query_vector: "006ff30e84"
|
||||||
|
|
||||||
|
---
|
||||||
|
"Cosine Similarity is not supported":
|
||||||
|
- do:
|
||||||
|
catch: bad_request
|
||||||
|
headers:
|
||||||
|
Content-Type: application/json
|
||||||
|
search:
|
||||||
|
rest_total_hits_as_int: true
|
||||||
|
body:
|
||||||
|
query:
|
||||||
|
script_score:
|
||||||
|
query: {match_all: {} }
|
||||||
|
script:
|
||||||
|
source: "cosineSimilarity(params.query_vector, 'vector')"
|
||||||
|
params:
|
||||||
|
query_vector: [0, 111, -13, 14, -124]
|
||||||
|
- do:
|
||||||
|
catch: bad_request
|
||||||
|
headers:
|
||||||
|
Content-Type: application/json
|
||||||
|
search:
|
||||||
|
rest_total_hits_as_int: true
|
||||||
|
body:
|
||||||
|
query:
|
||||||
|
script_score:
|
||||||
|
query: {match_all: {} }
|
||||||
|
script:
|
||||||
|
source: "cosineSimilarity(params.query_vector, 'vector')"
|
||||||
|
params:
|
||||||
|
query_vector: "006ff30e84"
|
||||||
|
|
||||||
|
- do:
|
||||||
|
catch: bad_request
|
||||||
|
headers:
|
||||||
|
Content-Type: application/json
|
||||||
|
search:
|
||||||
|
rest_total_hits_as_int: true
|
||||||
|
body:
|
||||||
|
query:
|
||||||
|
script_score:
|
||||||
|
query: {match_all: {} }
|
||||||
|
script:
|
||||||
|
source: "cosineSimilarity(params.query_vector, 'indexed_vector')"
|
||||||
|
params:
|
||||||
|
query_vector: [0, 111, -13, 14, -124]
|
||||||
|
---
|
||||||
|
"L1 norm":
|
||||||
|
- do:
|
||||||
|
headers:
|
||||||
|
Content-Type: application/json
|
||||||
|
search:
|
||||||
|
rest_total_hits_as_int: true
|
||||||
|
body:
|
||||||
|
query:
|
||||||
|
script_score:
|
||||||
|
query: {match_all: {} }
|
||||||
|
script:
|
||||||
|
source: "l1norm(params.query_vector, 'vector')"
|
||||||
|
params:
|
||||||
|
query_vector: [0, 111, -13, 14, -124]
|
||||||
|
|
||||||
|
- match: {hits.total: 3}
|
||||||
|
|
||||||
|
- match: {hits.hits.0._id: "2"}
|
||||||
|
- match: {hits.hits.0._score: 17.0}
|
||||||
|
|
||||||
|
- match: {hits.hits.1._id: "1"}
|
||||||
|
- match: {hits.hits.1._score: 16.0}
|
||||||
|
|
||||||
|
- match: {hits.hits.2._id: "3"}
|
||||||
|
- match: {hits.hits.2._score: 11.0}
|
||||||
|
|
||||||
|
---
|
||||||
|
"L1 norm hexidecimal":
|
||||||
|
- do:
|
||||||
|
headers:
|
||||||
|
Content-Type: application/json
|
||||||
|
search:
|
||||||
|
rest_total_hits_as_int: true
|
||||||
|
body:
|
||||||
|
query:
|
||||||
|
script_score:
|
||||||
|
query: {match_all: {} }
|
||||||
|
script:
|
||||||
|
source: "l1norm(params.query_vector, 'vector')"
|
||||||
|
params:
|
||||||
|
query_vector: "006ff30e84"
|
||||||
|
|
||||||
|
- match: {hits.total: 3}
|
||||||
|
|
||||||
|
- match: {hits.hits.0._id: "2"}
|
||||||
|
- match: {hits.hits.0._score: 17.0}
|
||||||
|
|
||||||
|
- match: {hits.hits.1._id: "1"}
|
||||||
|
- match: {hits.hits.1._score: 16.0}
|
||||||
|
|
||||||
|
- match: {hits.hits.2._id: "3"}
|
||||||
|
- match: {hits.hits.2._score: 11.0}
|
||||||
|
---
|
||||||
|
"L2 norm":
|
||||||
|
- requires:
|
||||||
|
test_runner_features: close_to
|
||||||
|
- do:
|
||||||
|
headers:
|
||||||
|
Content-Type: application/json
|
||||||
|
search:
|
||||||
|
rest_total_hits_as_int: true
|
||||||
|
body:
|
||||||
|
query:
|
||||||
|
script_score:
|
||||||
|
query: {match_all: {} }
|
||||||
|
script:
|
||||||
|
source: "l2norm(params.query_vector, 'vector')"
|
||||||
|
params:
|
||||||
|
query_vector: [0, 111, -13, 14, -124]
|
||||||
|
|
||||||
|
- match: {hits.total: 3}
|
||||||
|
|
||||||
|
- match: {hits.hits.0._id: "2"}
|
||||||
|
- close_to: {hits.hits.0._score: {value: 4.123, error: 0.001}}
|
||||||
|
|
||||||
|
- match: {hits.hits.1._id: "1"}
|
||||||
|
- close_to: {hits.hits.1._score: {value: 4, error: 0.001}}
|
||||||
|
|
||||||
|
- match: {hits.hits.2._id: "3"}
|
||||||
|
- close_to: {hits.hits.2._score: {value: 3.316, error: 0.001}}
|
||||||
|
---
|
||||||
|
"L2 norm hexidecimal":
|
||||||
|
- requires:
|
||||||
|
test_runner_features: close_to
|
||||||
|
|
||||||
|
- do:
|
||||||
|
headers:
|
||||||
|
Content-Type: application/json
|
||||||
|
search:
|
||||||
|
rest_total_hits_as_int: true
|
||||||
|
body:
|
||||||
|
query:
|
||||||
|
script_score:
|
||||||
|
query: {match_all: {} }
|
||||||
|
script:
|
||||||
|
source: "l2norm(params.query_vector, 'vector')"
|
||||||
|
params:
|
||||||
|
query_vector: "006ff30e84"
|
||||||
|
|
||||||
|
- match: {hits.total: 3}
|
||||||
|
|
||||||
|
- match: {hits.hits.0._id: "2"}
|
||||||
|
- close_to: {hits.hits.0._score: {value: 4.123, error: 0.001}}
|
||||||
|
|
||||||
|
- match: {hits.hits.1._id: "1"}
|
||||||
|
- close_to: {hits.hits.1._score: {value: 4, error: 0.001}}
|
||||||
|
|
||||||
|
- match: {hits.hits.2._id: "3"}
|
||||||
|
- close_to: {hits.hits.2._score: {value: 3.316, error: 0.001}}
|
||||||
|
---
|
||||||
|
"Hamming distance":
|
||||||
|
- do:
|
||||||
|
headers:
|
||||||
|
Content-Type: application/json
|
||||||
|
search:
|
||||||
|
rest_total_hits_as_int: true
|
||||||
|
body:
|
||||||
|
query:
|
||||||
|
script_score:
|
||||||
|
query: {match_all: {} }
|
||||||
|
script:
|
||||||
|
source: "hamming(params.query_vector, 'vector')"
|
||||||
|
params:
|
||||||
|
query_vector: [0, 111, -13, 14, -124]
|
||||||
|
|
||||||
|
- match: {hits.total: 3}
|
||||||
|
|
||||||
|
- match: {hits.hits.0._id: "2"}
|
||||||
|
- match: {hits.hits.0._score: 17.0}
|
||||||
|
|
||||||
|
- match: {hits.hits.1._id: "1"}
|
||||||
|
- match: {hits.hits.1._score: 16.0}
|
||||||
|
|
||||||
|
- match: {hits.hits.2._id: "3"}
|
||||||
|
- match: {hits.hits.2._score: 11.0}
|
||||||
|
|
||||||
|
|
||||||
|
- do:
|
||||||
|
headers:
|
||||||
|
Content-Type: application/json
|
||||||
|
search:
|
||||||
|
rest_total_hits_as_int: true
|
||||||
|
body:
|
||||||
|
query:
|
||||||
|
script_score:
|
||||||
|
query: {match_all: {} }
|
||||||
|
script:
|
||||||
|
source: "hamming(params.query_vector, 'indexed_vector')"
|
||||||
|
params:
|
||||||
|
query_vector: [0, 111, -13, 14, -124]
|
||||||
|
|
||||||
|
- match: {hits.total: 3}
|
||||||
|
|
||||||
|
- match: {hits.hits.0._id: "2"}
|
||||||
|
- match: {hits.hits.0._score: 17.0}
|
||||||
|
|
||||||
|
- match: {hits.hits.1._id: "1"}
|
||||||
|
- match: {hits.hits.1._score: 16.0}
|
||||||
|
|
||||||
|
- match: {hits.hits.2._id: "3"}
|
||||||
|
- match: {hits.hits.2._score: 11.0}
|
||||||
|
---
|
||||||
|
"Hamming distance hexidecimal":
|
||||||
|
- do:
|
||||||
|
headers:
|
||||||
|
Content-Type: application/json
|
||||||
|
search:
|
||||||
|
rest_total_hits_as_int: true
|
||||||
|
body:
|
||||||
|
query:
|
||||||
|
script_score:
|
||||||
|
query: {match_all: {} }
|
||||||
|
script:
|
||||||
|
source: "hamming(params.query_vector, 'vector')"
|
||||||
|
params:
|
||||||
|
query_vector: "006ff30e84"
|
||||||
|
|
||||||
|
- match: {hits.total: 3}
|
||||||
|
|
||||||
|
- match: {hits.hits.0._id: "2"}
|
||||||
|
- match: {hits.hits.0._score: 17.0}
|
||||||
|
|
||||||
|
- match: {hits.hits.1._id: "1"}
|
||||||
|
- match: {hits.hits.1._score: 16.0}
|
||||||
|
|
||||||
|
- match: {hits.hits.2._id: "3"}
|
||||||
|
- match: {hits.hits.2._score: 11.0}
|
||||||
|
|
||||||
|
|
||||||
|
- do:
|
||||||
|
headers:
|
||||||
|
Content-Type: application/json
|
||||||
|
search:
|
||||||
|
rest_total_hits_as_int: true
|
||||||
|
body:
|
||||||
|
query:
|
||||||
|
script_score:
|
||||||
|
query: {match_all: {} }
|
||||||
|
script:
|
||||||
|
source: "hamming(params.query_vector, 'indexed_vector')"
|
||||||
|
params:
|
||||||
|
query_vector: "006ff30e84"
|
||||||
|
|
||||||
|
- match: {hits.total: 3}
|
||||||
|
|
||||||
|
- match: {hits.hits.0._id: "2"}
|
||||||
|
- match: {hits.hits.0._score: 17.0}
|
||||||
|
|
||||||
|
- match: {hits.hits.1._id: "1"}
|
||||||
|
- match: {hits.hits.1._score: 16.0}
|
||||||
|
|
||||||
|
- match: {hits.hits.2._id: "3"}
|
||||||
|
- match: {hits.hits.2._score: 11.0}
|
|
@ -0,0 +1,301 @@
|
||||||
|
setup:
|
||||||
|
- requires:
|
||||||
|
cluster_features: "mapper.vectors.bit_vectors"
|
||||||
|
test_runner_features: close_to
|
||||||
|
reason: 'bit vectors added in 8.15'
|
||||||
|
- do:
|
||||||
|
indices.create:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
settings:
|
||||||
|
index:
|
||||||
|
number_of_shards: 2
|
||||||
|
mappings:
|
||||||
|
properties:
|
||||||
|
name:
|
||||||
|
type: keyword
|
||||||
|
nested:
|
||||||
|
type: nested
|
||||||
|
properties:
|
||||||
|
paragraph_id:
|
||||||
|
type: keyword
|
||||||
|
vector:
|
||||||
|
type: dense_vector
|
||||||
|
dims: 40
|
||||||
|
index: true
|
||||||
|
element_type: bit
|
||||||
|
similarity: l2_norm
|
||||||
|
|
||||||
|
- do:
|
||||||
|
index:
|
||||||
|
index: test
|
||||||
|
id: "1"
|
||||||
|
body:
|
||||||
|
name: cow.jpg
|
||||||
|
nested:
|
||||||
|
- paragraph_id: 0
|
||||||
|
vector: [100, 20, -34, 15, -100]
|
||||||
|
- paragraph_id: 1
|
||||||
|
vector: [40, 30, -3, 1, -20]
|
||||||
|
|
||||||
|
- do:
|
||||||
|
index:
|
||||||
|
index: test
|
||||||
|
id: "2"
|
||||||
|
body:
|
||||||
|
name: moose.jpg
|
||||||
|
nested:
|
||||||
|
- paragraph_id: 0
|
||||||
|
vector: [-1, 100, -13, 14, -127]
|
||||||
|
- paragraph_id: 2
|
||||||
|
vector: [0, 100, 0, 15, -127]
|
||||||
|
- paragraph_id: 3
|
||||||
|
vector: [0, 1, 0, 2, -15]
|
||||||
|
|
||||||
|
- do:
|
||||||
|
index:
|
||||||
|
index: test
|
||||||
|
id: "3"
|
||||||
|
body:
|
||||||
|
name: rabbit.jpg
|
||||||
|
nested:
|
||||||
|
- paragraph_id: 0
|
||||||
|
vector: [1, 111, -13, 14, -1]
|
||||||
|
|
||||||
|
- do:
|
||||||
|
indices.refresh: {}
|
||||||
|
|
||||||
|
---
|
||||||
|
"nested kNN search only":
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
fields: [ "name" ]
|
||||||
|
knn:
|
||||||
|
field: nested.vector
|
||||||
|
query_vector: [-1, 90, -10, 14, -127]
|
||||||
|
k: 2
|
||||||
|
num_candidates: 3
|
||||||
|
|
||||||
|
- match: {hits.hits.0._id: "2"}
|
||||||
|
- match: {hits.hits.0.fields.name.0: "moose.jpg"}
|
||||||
|
|
||||||
|
- match: {hits.hits.1._id: "1"}
|
||||||
|
- match: {hits.hits.1.fields.name.0: "cow.jpg"}
|
||||||
|
|
||||||
|
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
fields: [ "name" ]
|
||||||
|
knn:
|
||||||
|
field: nested.vector
|
||||||
|
query_vector: [-1, 90, -10, 14, -127]
|
||||||
|
k: 2
|
||||||
|
num_candidates: 3
|
||||||
|
inner_hits: {size: 1, "fields": ["nested.paragraph_id"], _source: false}
|
||||||
|
|
||||||
|
|
||||||
|
- match: {hits.hits.0._id: "2"}
|
||||||
|
- match: {hits.hits.0.fields.name.0: "moose.jpg"}
|
||||||
|
- match: {hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0"}
|
||||||
|
|
||||||
|
- match: {hits.hits.1._id: "1"}
|
||||||
|
- match: {hits.hits.1.fields.name.0: "cow.jpg"}
|
||||||
|
- match: {hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0"}
|
||||||
|
|
||||||
|
---
|
||||||
|
"nested kNN search filtered":
|
||||||
|
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
fields: [ "name" ]
|
||||||
|
knn:
|
||||||
|
field: nested.vector
|
||||||
|
query_vector: [-1, 90, -10, 14, -127]
|
||||||
|
k: 2
|
||||||
|
num_candidates: 3
|
||||||
|
filter: {term: {name: "rabbit.jpg"}}
|
||||||
|
|
||||||
|
- match: {hits.total.value: 1}
|
||||||
|
- match: {hits.hits.0._id: "3"}
|
||||||
|
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
|
||||||
|
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
fields: [ "name" ]
|
||||||
|
knn:
|
||||||
|
field: nested.vector
|
||||||
|
query_vector: [-1, 90, -10, 14, -127]
|
||||||
|
k: 3
|
||||||
|
num_candidates: 3
|
||||||
|
filter: {term: {name: "rabbit.jpg"}}
|
||||||
|
inner_hits: {size: 1, fields: ["nested.paragraph_id"], _source: false}
|
||||||
|
|
||||||
|
- match: {hits.total.value: 1}
|
||||||
|
- match: {hits.hits.0._id: "3"}
|
||||||
|
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
|
||||||
|
- match: {hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0"}
|
||||||
|
---
|
||||||
|
"nested kNN search inner_hits size > 1":
|
||||||
|
- do:
|
||||||
|
index:
|
||||||
|
index: test
|
||||||
|
id: "4"
|
||||||
|
body:
|
||||||
|
name: moose.jpg
|
||||||
|
nested:
|
||||||
|
- paragraph_id: 0
|
||||||
|
vector: [-1, 90, -10, 14, -127]
|
||||||
|
- paragraph_id: 2
|
||||||
|
vector: [ 0, 100.0, 0, 14, -127 ]
|
||||||
|
- paragraph_id: 3
|
||||||
|
vector: [ 0, 1.0, 0, 2, -15 ]
|
||||||
|
|
||||||
|
- do:
|
||||||
|
index:
|
||||||
|
index: test
|
||||||
|
id: "5"
|
||||||
|
body:
|
||||||
|
name: moose.jpg
|
||||||
|
nested:
|
||||||
|
- paragraph_id: 0
|
||||||
|
vector: [ -1, 100, -13, 14, -127 ]
|
||||||
|
- paragraph_id: 2
|
||||||
|
vector: [ 0, 100, 0, 15, -127 ]
|
||||||
|
- paragraph_id: 3
|
||||||
|
vector: [ 0, 1, 0, 2, -15 ]
|
||||||
|
|
||||||
|
- do:
|
||||||
|
index:
|
||||||
|
index: test
|
||||||
|
id: "6"
|
||||||
|
body:
|
||||||
|
name: moose.jpg
|
||||||
|
nested:
|
||||||
|
- paragraph_id: 0
|
||||||
|
vector: [ -1, 100, -13, 15, -127 ]
|
||||||
|
- paragraph_id: 2
|
||||||
|
vector: [ 0, 100, 0, 15, -127 ]
|
||||||
|
- paragraph_id: 3
|
||||||
|
vector: [ 0, 1, 0, 2, -15 ]
|
||||||
|
- do:
|
||||||
|
indices.refresh: { }
|
||||||
|
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
fields: [ "name" ]
|
||||||
|
knn:
|
||||||
|
field: nested.vector
|
||||||
|
query_vector: [-1, 90, -10, 15, -127]
|
||||||
|
k: 3
|
||||||
|
num_candidates: 5
|
||||||
|
inner_hits: {size: 2, fields: ["nested.paragraph_id"], _source: false}
|
||||||
|
|
||||||
|
- match: {hits.total.value: 3}
|
||||||
|
- length: { hits.hits.0.inner_hits.nested.hits.hits: 2 }
|
||||||
|
- length: { hits.hits.1.inner_hits.nested.hits.hits: 2 }
|
||||||
|
- length: { hits.hits.2.inner_hits.nested.hits.hits: 2 }
|
||||||
|
|
||||||
|
- match: { hits.hits.0.fields.name.0: "moose.jpg" }
|
||||||
|
- match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
|
||||||
|
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
fields: [ "name" ]
|
||||||
|
knn:
|
||||||
|
field: nested.vector
|
||||||
|
query_vector: [-1, 90, -10, 15, -127]
|
||||||
|
k: 5
|
||||||
|
num_candidates: 5
|
||||||
|
inner_hits: {size: 2, fields: ["nested.paragraph_id"], _source: false}
|
||||||
|
|
||||||
|
- match: {hits.total.value: 5}
|
||||||
|
# All these initial matches are "moose.jpg", which has 3 nested vectors, but two are closest
|
||||||
|
- match: {hits.hits.0.fields.name.0: "moose.jpg"}
|
||||||
|
- length: { hits.hits.0.inner_hits.nested.hits.hits: 2 }
|
||||||
|
- match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
|
||||||
|
- match: { hits.hits.0.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "2" }
|
||||||
|
- match: {hits.hits.1.fields.name.0: "moose.jpg"}
|
||||||
|
- length: { hits.hits.1.inner_hits.nested.hits.hits: 2 }
|
||||||
|
- match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
|
||||||
|
- match: { hits.hits.1.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "2" }
|
||||||
|
- match: {hits.hits.2.fields.name.0: "moose.jpg"}
|
||||||
|
- length: { hits.hits.2.inner_hits.nested.hits.hits: 2 }
|
||||||
|
- match: { hits.hits.2.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
|
||||||
|
- match: { hits.hits.2.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "2" }
|
||||||
|
- match: {hits.hits.3.fields.name.0: "moose.jpg"}
|
||||||
|
- length: { hits.hits.3.inner_hits.nested.hits.hits: 2 }
|
||||||
|
- match: { hits.hits.3.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
|
||||||
|
- match: { hits.hits.3.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "2" }
|
||||||
|
# Rabbit only has one passage vector
|
||||||
|
- match: {hits.hits.4.fields.name.0: "cow.jpg"}
|
||||||
|
- length: { hits.hits.4.inner_hits.nested.hits.hits: 2 }
|
||||||
|
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
fields: [ "name" ]
|
||||||
|
knn:
|
||||||
|
field: nested.vector
|
||||||
|
query_vector: [ -1, 90, -10, 15, -127 ]
|
||||||
|
k: 3
|
||||||
|
num_candidates: 3
|
||||||
|
filter: {term: {name: "cow.jpg"}}
|
||||||
|
inner_hits: {size: 3, fields: ["nested.paragraph_id"], _source: false}
|
||||||
|
|
||||||
|
- match: {hits.total.value: 1}
|
||||||
|
- match: { hits.hits.0._id: "1" }
|
||||||
|
- length: { hits.hits.0.inner_hits.nested.hits.hits: 2 }
|
||||||
|
- match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
|
||||||
|
- match: { hits.hits.0.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "1" }
|
||||||
|
---
|
||||||
|
"nested kNN search inner_hits & boosting":
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
fields: [ "name" ]
|
||||||
|
knn:
|
||||||
|
field: nested.vector
|
||||||
|
query_vector: [-1, 90, -10, 15, -127]
|
||||||
|
k: 3
|
||||||
|
num_candidates: 5
|
||||||
|
inner_hits: {size: 2, fields: ["nested.paragraph_id"], _source: false}
|
||||||
|
|
||||||
|
- close_to: { hits.hits.0._score: {value: 0.8, error: 0.00001} }
|
||||||
|
- close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: {value: 0.8, error: 0.00001} }
|
||||||
|
- close_to: { hits.hits.1._score: {value: 0.625, error: 0.00001} }
|
||||||
|
- close_to: { hits.hits.1.inner_hits.nested.hits.hits.0._score: {value: 0.625, error: 0.00001} }
|
||||||
|
- close_to: { hits.hits.2._score: {value: 0.5, error: 0.00001} }
|
||||||
|
- close_to: { hits.hits.2.inner_hits.nested.hits.hits.0._score: {value: 0.5, error: 0.00001} }
|
||||||
|
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
fields: [ "name" ]
|
||||||
|
knn:
|
||||||
|
field: nested.vector
|
||||||
|
query_vector: [-1, 90, -10, 15, -127]
|
||||||
|
k: 3
|
||||||
|
num_candidates: 5
|
||||||
|
boost: 2
|
||||||
|
inner_hits: {size: 2, fields: ["nested.paragraph_id"], _source: false}
|
||||||
|
- close_to: { hits.hits.0._score: {value: 1.6, error: 0.00001} }
|
||||||
|
- close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: {value: 1.6, error: 0.00001} }
|
||||||
|
- close_to: { hits.hits.1._score: {value: 1.25, error: 0.00001} }
|
||||||
|
- close_to: { hits.hits.1.inner_hits.nested.hits.hits.0._score: {value: 1.25, error: 0.00001} }
|
||||||
|
- close_to: { hits.hits.2._score: {value: 1, error: 0.00001} }
|
||||||
|
- close_to: { hits.hits.2.inner_hits.nested.hits.hits.0._score: {value: 1.0, error: 0.00001} }
|
|
@ -116,8 +116,9 @@ setup:
|
||||||
---
|
---
|
||||||
"Knn search with hex string for byte field - dimensions mismatch" :
|
"Knn search with hex string for byte field - dimensions mismatch" :
|
||||||
# [64, 10, -30, 10] - is encoded as '400ae20a'
|
# [64, 10, -30, 10] - is encoded as '400ae20a'
|
||||||
|
# the error message has been adjusted in later versions
|
||||||
- do:
|
- do:
|
||||||
catch: /the query vector has a different dimension \[4\] than the index vectors \[3\]/
|
catch: /dimension|dimensions \[4\] than the document|index vectors \[3\]/
|
||||||
search:
|
search:
|
||||||
index: knn_hex_vector_index
|
index: knn_hex_vector_index
|
||||||
body:
|
body:
|
||||||
|
|
|
@ -116,8 +116,9 @@ setup:
|
||||||
---
|
---
|
||||||
"Knn query with hex string for byte field - dimensions mismatch" :
|
"Knn query with hex string for byte field - dimensions mismatch" :
|
||||||
# [64, 10, -30, 10] - is encoded as '400ae20a'
|
# [64, 10, -30, 10] - is encoded as '400ae20a'
|
||||||
|
# the error message has been adjusted in later versions
|
||||||
- do:
|
- do:
|
||||||
catch: /the query vector has a different dimension \[4\] than the index vectors \[3\]/
|
catch: /dimension|dimensions \[4\] than the document|index vectors \[3\]/
|
||||||
search:
|
search:
|
||||||
index: knn_hex_vector_index
|
index: knn_hex_vector_index
|
||||||
body:
|
body:
|
||||||
|
|
|
@ -0,0 +1,356 @@
|
||||||
|
setup:
|
||||||
|
- requires:
|
||||||
|
cluster_features: "mapper.vectors.bit_vectors"
|
||||||
|
reason: 'mapper.vectors.bit_vectors'
|
||||||
|
|
||||||
|
- do:
|
||||||
|
indices.create:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
mappings:
|
||||||
|
properties:
|
||||||
|
name:
|
||||||
|
type: keyword
|
||||||
|
vector:
|
||||||
|
type: dense_vector
|
||||||
|
element_type: bit
|
||||||
|
dims: 40
|
||||||
|
index: true
|
||||||
|
similarity: l2_norm
|
||||||
|
|
||||||
|
- do:
|
||||||
|
index:
|
||||||
|
index: test
|
||||||
|
id: "1"
|
||||||
|
body:
|
||||||
|
name: cow.jpg
|
||||||
|
vector: [2, -1, 1, 4, -3]
|
||||||
|
|
||||||
|
- do:
|
||||||
|
index:
|
||||||
|
index: test
|
||||||
|
id: "2"
|
||||||
|
body:
|
||||||
|
name: moose.jpg
|
||||||
|
vector: [127.0, -128.0, 0.0, 1.0, -1.0]
|
||||||
|
|
||||||
|
- do:
|
||||||
|
index:
|
||||||
|
index: test
|
||||||
|
id: "3"
|
||||||
|
body:
|
||||||
|
name: rabbit.jpg
|
||||||
|
vector: [5, 4.0, 3, 2.0, 127]
|
||||||
|
|
||||||
|
- do:
|
||||||
|
indices.refresh: {}
|
||||||
|
|
||||||
|
---
|
||||||
|
"kNN search only":
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
fields: [ "name" ]
|
||||||
|
knn:
|
||||||
|
field: vector
|
||||||
|
query_vector: [127, 127, -128, -128, 127]
|
||||||
|
k: 2
|
||||||
|
num_candidates: 3
|
||||||
|
|
||||||
|
- match: {hits.hits.0._id: "2"}
|
||||||
|
- match: {hits.hits.0.fields.name.0: "moose.jpg"}
|
||||||
|
|
||||||
|
- match: {hits.hits.1._id: "1"}
|
||||||
|
- match: {hits.hits.1.fields.name.0: "cow.jpg"}
|
||||||
|
|
||||||
|
---
|
||||||
|
"kNN search plus query":
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
fields: [ "name" ]
|
||||||
|
knn:
|
||||||
|
field: vector
|
||||||
|
query_vector: [127.0, -128.0, 0.0, 1.0, -1.0]
|
||||||
|
k: 2
|
||||||
|
num_candidates: 3
|
||||||
|
query:
|
||||||
|
term:
|
||||||
|
name: rabbit.jpg
|
||||||
|
|
||||||
|
- match: {hits.hits.0._id: "3"}
|
||||||
|
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
|
||||||
|
|
||||||
|
- match: {hits.hits.1._id: "2"}
|
||||||
|
- match: {hits.hits.1.fields.name.0: "moose.jpg"}
|
||||||
|
|
||||||
|
---
|
||||||
|
"kNN search with filter":
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
fields: [ "name" ]
|
||||||
|
knn:
|
||||||
|
field: vector
|
||||||
|
query_vector: [5.0, 4, 3.0, 2, 127.0]
|
||||||
|
k: 2
|
||||||
|
num_candidates: 3
|
||||||
|
filter:
|
||||||
|
term:
|
||||||
|
name: "rabbit.jpg"
|
||||||
|
|
||||||
|
- match: {hits.total.value: 1}
|
||||||
|
- match: {hits.hits.0._id: "3"}
|
||||||
|
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
|
||||||
|
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
fields: [ "name" ]
|
||||||
|
knn:
|
||||||
|
field: vector
|
||||||
|
query_vector: [2, -1, 1, 4, -3]
|
||||||
|
k: 2
|
||||||
|
num_candidates: 3
|
||||||
|
filter:
|
||||||
|
- term:
|
||||||
|
name: "rabbit.jpg"
|
||||||
|
- term:
|
||||||
|
_id: 2
|
||||||
|
|
||||||
|
- match: {hits.total.value: 0}
|
||||||
|
|
||||||
|
---
|
||||||
|
"Vector similarity search only":
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
fields: [ "name" ]
|
||||||
|
knn:
|
||||||
|
num_candidates: 3
|
||||||
|
k: 3
|
||||||
|
field: vector
|
||||||
|
similarity: 0.98
|
||||||
|
query_vector: [5, 4.0, 3, 2.0, 127]
|
||||||
|
|
||||||
|
- length: {hits.hits: 1}
|
||||||
|
|
||||||
|
- match: {hits.hits.0._id: "3"}
|
||||||
|
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
|
||||||
|
---
|
||||||
|
"Vector similarity with filter only":
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
fields: [ "name" ]
|
||||||
|
knn:
|
||||||
|
num_candidates: 3
|
||||||
|
k: 3
|
||||||
|
field: vector
|
||||||
|
similarity: 0.98
|
||||||
|
query_vector: [5, 4.0, 3, 2.0, 127]
|
||||||
|
filter: {"term": {"name": "rabbit.jpg"}}
|
||||||
|
|
||||||
|
- length: {hits.hits: 1}
|
||||||
|
|
||||||
|
- match: {hits.hits.0._id: "3"}
|
||||||
|
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
|
||||||
|
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
fields: [ "name" ]
|
||||||
|
knn:
|
||||||
|
num_candidates: 3
|
||||||
|
k: 3
|
||||||
|
field: vector
|
||||||
|
similarity: 0.98
|
||||||
|
query_vector: [5, 4.0, 3, 2.0, 127]
|
||||||
|
filter: {"term": {"name": "cow.jpg"}}
|
||||||
|
|
||||||
|
- length: {hits.hits: 0}
|
||||||
|
---
|
||||||
|
"dim mismatch":
|
||||||
|
- do:
|
||||||
|
catch: bad_request
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
fields: [ "name" ]
|
||||||
|
knn:
|
||||||
|
field: vector
|
||||||
|
query_vector: [1, 2, 3, 4, 5, 6]
|
||||||
|
k: 2
|
||||||
|
num_candidates: 3
|
||||||
|
---
|
||||||
|
"disallow quantized vector types":
|
||||||
|
- do:
|
||||||
|
catch: bad_request
|
||||||
|
indices.create:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
mappings:
|
||||||
|
properties:
|
||||||
|
name:
|
||||||
|
type: keyword
|
||||||
|
vector:
|
||||||
|
type: dense_vector
|
||||||
|
element_type: bit
|
||||||
|
dims: 32
|
||||||
|
index: true
|
||||||
|
similarity: l2_norm
|
||||||
|
index_options:
|
||||||
|
type: int8_flat
|
||||||
|
|
||||||
|
- do:
|
||||||
|
catch: bad_request
|
||||||
|
indices.create:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
mappings:
|
||||||
|
properties:
|
||||||
|
name:
|
||||||
|
type: keyword
|
||||||
|
vector:
|
||||||
|
type: dense_vector
|
||||||
|
element_type: bit
|
||||||
|
dims: 32
|
||||||
|
index: true
|
||||||
|
similarity: l2_norm
|
||||||
|
index_options:
|
||||||
|
type: int4_flat
|
||||||
|
|
||||||
|
- do:
|
||||||
|
catch: bad_request
|
||||||
|
indices.create:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
mappings:
|
||||||
|
properties:
|
||||||
|
name:
|
||||||
|
type: keyword
|
||||||
|
vector:
|
||||||
|
type: dense_vector
|
||||||
|
element_type: bit
|
||||||
|
dims: 32
|
||||||
|
index: true
|
||||||
|
similarity: l2_norm
|
||||||
|
index_options:
|
||||||
|
type: int8_hnsw
|
||||||
|
|
||||||
|
- do:
|
||||||
|
catch: bad_request
|
||||||
|
indices.create:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
mappings:
|
||||||
|
properties:
|
||||||
|
name:
|
||||||
|
type: keyword
|
||||||
|
vector:
|
||||||
|
type: dense_vector
|
||||||
|
element_type: bit
|
||||||
|
dims: 32
|
||||||
|
index: true
|
||||||
|
similarity: l2_norm
|
||||||
|
index_options:
|
||||||
|
type: int4_hnsw
|
||||||
|
---
|
||||||
|
"disallow vector index type change to quantized type":
|
||||||
|
- do:
|
||||||
|
catch: bad_request
|
||||||
|
indices.put_mapping:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
properties:
|
||||||
|
vector:
|
||||||
|
type: dense_vector
|
||||||
|
element_type: bit
|
||||||
|
dims: 32
|
||||||
|
index: true
|
||||||
|
similarity: l2_norm
|
||||||
|
index_options:
|
||||||
|
type: int4_hnsw
|
||||||
|
- do:
|
||||||
|
catch: bad_request
|
||||||
|
indices.put_mapping:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
properties:
|
||||||
|
vector:
|
||||||
|
type: dense_vector
|
||||||
|
element_type: bit
|
||||||
|
dims: 32
|
||||||
|
index: true
|
||||||
|
similarity: l2_norm
|
||||||
|
index_options:
|
||||||
|
type: int8_hnsw
|
||||||
|
---
|
||||||
|
"Defaults to l2_norm with bit vectors":
|
||||||
|
- do:
|
||||||
|
indices.create:
|
||||||
|
index: default_to_l2_norm_bit
|
||||||
|
body:
|
||||||
|
mappings:
|
||||||
|
properties:
|
||||||
|
vector:
|
||||||
|
type: dense_vector
|
||||||
|
element_type: bit
|
||||||
|
dims: 40
|
||||||
|
index: true
|
||||||
|
|
||||||
|
- do:
|
||||||
|
indices.get_mapping:
|
||||||
|
index: default_to_l2_norm_bit
|
||||||
|
|
||||||
|
- match: { default_to_l2_norm_bit.mappings.properties.vector.similarity: l2_norm }
|
||||||
|
|
||||||
|
---
|
||||||
|
"Only allow l2_norm with bit vectors":
|
||||||
|
- do:
|
||||||
|
catch: bad_request
|
||||||
|
indices.create:
|
||||||
|
index: dot_product_fails_for_bits
|
||||||
|
body:
|
||||||
|
mappings:
|
||||||
|
properties:
|
||||||
|
vector:
|
||||||
|
type: dense_vector
|
||||||
|
element_type: bit
|
||||||
|
dims: 40
|
||||||
|
index: true
|
||||||
|
similarity: dot_product
|
||||||
|
|
||||||
|
- do:
|
||||||
|
catch: bad_request
|
||||||
|
indices.create:
|
||||||
|
index: cosine_product_fails_for_bits
|
||||||
|
body:
|
||||||
|
mappings:
|
||||||
|
properties:
|
||||||
|
vector:
|
||||||
|
type: dense_vector
|
||||||
|
element_type: bit
|
||||||
|
dims: 40
|
||||||
|
index: true
|
||||||
|
similarity: cosine
|
||||||
|
|
||||||
|
- do:
|
||||||
|
catch: bad_request
|
||||||
|
indices.create:
|
||||||
|
index: cosine_product_fails_for_bits
|
||||||
|
body:
|
||||||
|
mappings:
|
||||||
|
properties:
|
||||||
|
type: dense_vector
|
||||||
|
element_type: bit
|
||||||
|
dims: 40
|
||||||
|
index: true
|
||||||
|
similarity: max_inner_product
|
|
@ -0,0 +1,223 @@
|
||||||
|
setup:
|
||||||
|
- requires:
|
||||||
|
cluster_features: "mapper.vectors.bit_vectors"
|
||||||
|
reason: 'mapper.vectors.bit_vectors'
|
||||||
|
|
||||||
|
- do:
|
||||||
|
indices.create:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
mappings:
|
||||||
|
properties:
|
||||||
|
name:
|
||||||
|
type: keyword
|
||||||
|
vector:
|
||||||
|
type: dense_vector
|
||||||
|
element_type: bit
|
||||||
|
dims: 40
|
||||||
|
index: true
|
||||||
|
similarity: l2_norm
|
||||||
|
index_options:
|
||||||
|
type: flat
|
||||||
|
|
||||||
|
- do:
|
||||||
|
index:
|
||||||
|
index: test
|
||||||
|
id: "1"
|
||||||
|
body:
|
||||||
|
name: cow.jpg
|
||||||
|
vector: [2, -1, 1, 4, -3]
|
||||||
|
|
||||||
|
- do:
|
||||||
|
index:
|
||||||
|
index: test
|
||||||
|
id: "2"
|
||||||
|
body:
|
||||||
|
name: moose.jpg
|
||||||
|
vector: [127.0, -128.0, 0.0, 1.0, -1.0]
|
||||||
|
|
||||||
|
- do:
|
||||||
|
index:
|
||||||
|
index: test
|
||||||
|
id: "3"
|
||||||
|
body:
|
||||||
|
name: rabbit.jpg
|
||||||
|
vector: [5, 4.0, 3, 2.0, 127]
|
||||||
|
|
||||||
|
- do:
|
||||||
|
indices.refresh: {}
|
||||||
|
|
||||||
|
---
|
||||||
|
"kNN search only":
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
fields: [ "name" ]
|
||||||
|
knn:
|
||||||
|
field: vector
|
||||||
|
query_vector: [127, 127, -128, -128, 127]
|
||||||
|
k: 2
|
||||||
|
num_candidates: 3
|
||||||
|
|
||||||
|
- match: {hits.hits.0._id: "2"}
|
||||||
|
- match: {hits.hits.0.fields.name.0: "moose.jpg"}
|
||||||
|
|
||||||
|
- match: {hits.hits.1._id: "1"}
|
||||||
|
- match: {hits.hits.1.fields.name.0: "cow.jpg"}
|
||||||
|
|
||||||
|
---
|
||||||
|
"kNN search plus query":
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
fields: [ "name" ]
|
||||||
|
knn:
|
||||||
|
field: vector
|
||||||
|
query_vector: [127.0, -128.0, 0.0, 1.0, -1.0]
|
||||||
|
k: 2
|
||||||
|
num_candidates: 3
|
||||||
|
query:
|
||||||
|
term:
|
||||||
|
name: rabbit.jpg
|
||||||
|
|
||||||
|
- match: {hits.hits.0._id: "3"}
|
||||||
|
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
|
||||||
|
|
||||||
|
- match: {hits.hits.1._id: "2"}
|
||||||
|
- match: {hits.hits.1.fields.name.0: "moose.jpg"}
|
||||||
|
|
||||||
|
---
|
||||||
|
"kNN search with filter":
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
fields: [ "name" ]
|
||||||
|
knn:
|
||||||
|
field: vector
|
||||||
|
query_vector: [5.0, 4, 3.0, 2, 127.0]
|
||||||
|
k: 2
|
||||||
|
num_candidates: 3
|
||||||
|
filter:
|
||||||
|
term:
|
||||||
|
name: "rabbit.jpg"
|
||||||
|
|
||||||
|
- match: {hits.total.value: 1}
|
||||||
|
- match: {hits.hits.0._id: "3"}
|
||||||
|
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
|
||||||
|
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
fields: [ "name" ]
|
||||||
|
knn:
|
||||||
|
field: vector
|
||||||
|
query_vector: [2, -1, 1, 4, -3]
|
||||||
|
k: 2
|
||||||
|
num_candidates: 3
|
||||||
|
filter:
|
||||||
|
- term:
|
||||||
|
name: "rabbit.jpg"
|
||||||
|
- term:
|
||||||
|
_id: 2
|
||||||
|
|
||||||
|
- match: {hits.total.value: 0}
|
||||||
|
|
||||||
|
---
|
||||||
|
"Vector similarity search only":
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
fields: [ "name" ]
|
||||||
|
knn:
|
||||||
|
num_candidates: 3
|
||||||
|
k: 3
|
||||||
|
field: vector
|
||||||
|
similarity: 0.98
|
||||||
|
query_vector: [5, 4.0, 3, 2.0, 127]
|
||||||
|
|
||||||
|
- length: {hits.hits: 1}
|
||||||
|
|
||||||
|
- match: {hits.hits.0._id: "3"}
|
||||||
|
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
|
||||||
|
---
|
||||||
|
"Vector similarity with filter only":
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
fields: [ "name" ]
|
||||||
|
knn:
|
||||||
|
num_candidates: 3
|
||||||
|
k: 3
|
||||||
|
field: vector
|
||||||
|
similarity: 0.98
|
||||||
|
query_vector: [5, 4.0, 3, 2.0, 127]
|
||||||
|
filter: {"term": {"name": "rabbit.jpg"}}
|
||||||
|
|
||||||
|
- length: {hits.hits: 1}
|
||||||
|
|
||||||
|
- match: {hits.hits.0._id: "3"}
|
||||||
|
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
|
||||||
|
|
||||||
|
- do:
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
fields: [ "name" ]
|
||||||
|
knn:
|
||||||
|
num_candidates: 3
|
||||||
|
k: 3
|
||||||
|
field: vector
|
||||||
|
similarity: 0.98
|
||||||
|
query_vector: [5, 4.0, 3, 2.0, 127]
|
||||||
|
filter: {"term": {"name": "cow.jpg"}}
|
||||||
|
|
||||||
|
- length: {hits.hits: 0}
|
||||||
|
---
|
||||||
|
"dim mismatch":
|
||||||
|
- do:
|
||||||
|
catch: bad_request
|
||||||
|
search:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
fields: [ "name" ]
|
||||||
|
knn:
|
||||||
|
field: vector
|
||||||
|
query_vector: [1, 2, 3, 4, 5, 6]
|
||||||
|
k: 2
|
||||||
|
num_candidates: 3
|
||||||
|
---
|
||||||
|
"disallow vector index type change to quantized type":
|
||||||
|
- do:
|
||||||
|
catch: bad_request
|
||||||
|
indices.put_mapping:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
properties:
|
||||||
|
vector:
|
||||||
|
type: dense_vector
|
||||||
|
element_type: bit
|
||||||
|
dims: 32
|
||||||
|
index: true
|
||||||
|
similarity: l2_norm
|
||||||
|
index_options:
|
||||||
|
type: int4_hnsw
|
||||||
|
- do:
|
||||||
|
catch: bad_request
|
||||||
|
indices.put_mapping:
|
||||||
|
index: test
|
||||||
|
body:
|
||||||
|
properties:
|
||||||
|
vector:
|
||||||
|
type: dense_vector
|
||||||
|
element_type: bit
|
||||||
|
dims: 32
|
||||||
|
index: true
|
||||||
|
similarity: l2_norm
|
||||||
|
index_options:
|
||||||
|
type: int8_hnsw
|
|
@ -449,7 +449,10 @@ module org.elasticsearch.server {
|
||||||
with
|
with
|
||||||
org.elasticsearch.index.codec.vectors.ES813FlatVectorFormat,
|
org.elasticsearch.index.codec.vectors.ES813FlatVectorFormat,
|
||||||
org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat,
|
org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat,
|
||||||
org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat;
|
org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat,
|
||||||
|
org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat,
|
||||||
|
org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat;
|
||||||
|
|
||||||
provides org.apache.lucene.codecs.Codec with Elasticsearch814Codec;
|
provides org.apache.lucene.codecs.Codec with Elasticsearch814Codec;
|
||||||
|
|
||||||
provides org.apache.logging.log4j.core.util.ContextDataProvider with org.elasticsearch.common.logging.DynamicContextDataProvider;
|
provides org.apache.logging.log4j.core.util.ContextDataProvider with org.elasticsearch.common.logging.DynamicContextDataProvider;
|
||||||
|
|
|
@ -54,11 +54,11 @@ public class ES813FlatVectorFormat extends KnnVectorsFormat {
|
||||||
return new ES813FlatVectorReader(format.fieldsReader(state));
|
return new ES813FlatVectorReader(format.fieldsReader(state));
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class ES813FlatVectorWriter extends KnnVectorsWriter {
|
static class ES813FlatVectorWriter extends KnnVectorsWriter {
|
||||||
|
|
||||||
private final FlatVectorsWriter writer;
|
private final FlatVectorsWriter writer;
|
||||||
|
|
||||||
public ES813FlatVectorWriter(FlatVectorsWriter writer) {
|
ES813FlatVectorWriter(FlatVectorsWriter writer) {
|
||||||
super();
|
super();
|
||||||
this.writer = writer;
|
this.writer = writer;
|
||||||
}
|
}
|
||||||
|
@ -94,11 +94,11 @@ public class ES813FlatVectorFormat extends KnnVectorsFormat {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class ES813FlatVectorReader extends KnnVectorsReader {
|
static class ES813FlatVectorReader extends KnnVectorsReader {
|
||||||
|
|
||||||
private final FlatVectorsReader reader;
|
private final FlatVectorsReader reader;
|
||||||
|
|
||||||
public ES813FlatVectorReader(FlatVectorsReader reader) {
|
ES813FlatVectorReader(FlatVectorsReader reader) {
|
||||||
super();
|
super();
|
||||||
this.reader = reader;
|
this.reader = reader;
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0 and the Server Side Public License, v 1; you may not use this file except
|
||||||
|
* in compliance with, at your election, the Elastic License 2.0 or the Server
|
||||||
|
* Side Public License, v 1.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.index.codec.vectors;
|
||||||
|
|
||||||
|
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||||
|
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
|
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||||
|
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
|
||||||
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
|
import org.apache.lucene.index.SegmentWriteState;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
public class ES815BitFlatVectorFormat extends KnnVectorsFormat {
|
||||||
|
|
||||||
|
static final String NAME = "ES815BitFlatVectorFormat";
|
||||||
|
|
||||||
|
private final FlatVectorsFormat format = new ES815BitFlatVectorsFormat();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sole constructor
|
||||||
|
*/
|
||||||
|
public ES815BitFlatVectorFormat() {
|
||||||
|
super(NAME);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
|
||||||
|
return new ES813FlatVectorFormat.ES813FlatVectorWriter(format.fieldsWriter(state));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
|
||||||
|
return new ES813FlatVectorFormat.ES813FlatVectorReader(format.fieldsReader(state));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return NAME;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,143 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0 and the Server Side Public License, v 1; you may not use this file except
|
||||||
|
* in compliance with, at your election, the Elastic License 2.0 or the Server
|
||||||
|
* Side Public License, v 1.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.index.codec.vectors;
|
||||||
|
|
||||||
|
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
|
||||||
|
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
|
||||||
|
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
|
||||||
|
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
|
||||||
|
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
|
||||||
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
|
import org.apache.lucene.index.SegmentWriteState;
|
||||||
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.apache.lucene.util.VectorUtil;
|
||||||
|
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||||
|
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||||
|
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||||
|
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
class ES815BitFlatVectorsFormat extends FlatVectorsFormat {
|
||||||
|
|
||||||
|
private final FlatVectorsFormat delegate = new Lucene99FlatVectorsFormat(FlatBitVectorScorer.INSTANCE);
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FlatVectorsWriter fieldsWriter(SegmentWriteState segmentWriteState) throws IOException {
|
||||||
|
return delegate.fieldsWriter(segmentWriteState);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FlatVectorsReader fieldsReader(SegmentReadState segmentReadState) throws IOException {
|
||||||
|
return delegate.fieldsReader(segmentReadState);
|
||||||
|
}
|
||||||
|
|
||||||
|
static class FlatBitVectorScorer implements FlatVectorsScorer {
|
||||||
|
|
||||||
|
static final FlatBitVectorScorer INSTANCE = new FlatBitVectorScorer();
|
||||||
|
|
||||||
|
static void checkDimensions(int queryLen, int fieldLen) {
|
||||||
|
if (queryLen != fieldLen) {
|
||||||
|
throw new IllegalArgumentException("vector query dimension: " + queryLen + " differs from field dimension: " + fieldLen);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return super.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
|
||||||
|
VectorSimilarityFunction vectorSimilarityFunction,
|
||||||
|
RandomAccessVectorValues randomAccessVectorValues
|
||||||
|
) throws IOException {
|
||||||
|
assert randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes;
|
||||||
|
assert vectorSimilarityFunction == VectorSimilarityFunction.EUCLIDEAN;
|
||||||
|
if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes randomAccessVectorValuesBytes) {
|
||||||
|
assert randomAccessVectorValues instanceof RandomAccessQuantizedByteVectorValues == false;
|
||||||
|
return switch (vectorSimilarityFunction) {
|
||||||
|
case DOT_PRODUCT, MAXIMUM_INNER_PRODUCT, COSINE, EUCLIDEAN -> new HammingScorerSupplier(randomAccessVectorValuesBytes);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
throw new IllegalArgumentException("Unsupported vector type or similarity function");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RandomVectorScorer getRandomVectorScorer(
|
||||||
|
VectorSimilarityFunction vectorSimilarityFunction,
|
||||||
|
RandomAccessVectorValues randomAccessVectorValues,
|
||||||
|
byte[] bytes
|
||||||
|
) {
|
||||||
|
assert randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes;
|
||||||
|
assert vectorSimilarityFunction == VectorSimilarityFunction.EUCLIDEAN;
|
||||||
|
if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes randomAccessVectorValuesBytes) {
|
||||||
|
checkDimensions(bytes.length, randomAccessVectorValuesBytes.dimension());
|
||||||
|
return switch (vectorSimilarityFunction) {
|
||||||
|
case DOT_PRODUCT, MAXIMUM_INNER_PRODUCT, COSINE, EUCLIDEAN -> new HammingVectorScorer(
|
||||||
|
randomAccessVectorValuesBytes,
|
||||||
|
bytes
|
||||||
|
);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
throw new IllegalArgumentException("Unsupported vector type or similarity function");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RandomVectorScorer getRandomVectorScorer(
|
||||||
|
VectorSimilarityFunction vectorSimilarityFunction,
|
||||||
|
RandomAccessVectorValues randomAccessVectorValues,
|
||||||
|
float[] floats
|
||||||
|
) {
|
||||||
|
throw new IllegalArgumentException("Unsupported vector type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static float hammingScore(byte[] a, byte[] b) {
|
||||||
|
return ((a.length * Byte.SIZE) - VectorUtil.xorBitCount(a, b)) / (float) (a.length * Byte.SIZE);
|
||||||
|
}
|
||||||
|
|
||||||
|
static class HammingVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer {
|
||||||
|
private final byte[] query;
|
||||||
|
private final RandomAccessVectorValues.Bytes byteValues;
|
||||||
|
|
||||||
|
HammingVectorScorer(RandomAccessVectorValues.Bytes byteValues, byte[] query) {
|
||||||
|
super(byteValues);
|
||||||
|
this.query = query;
|
||||||
|
this.byteValues = byteValues;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float score(int i) throws IOException {
|
||||||
|
return hammingScore(byteValues.vectorValue(i), query);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static class HammingScorerSupplier implements RandomVectorScorerSupplier {
|
||||||
|
private final RandomAccessVectorValues.Bytes byteValues, byteValues1, byteValues2;
|
||||||
|
|
||||||
|
HammingScorerSupplier(RandomAccessVectorValues.Bytes byteValues) throws IOException {
|
||||||
|
this.byteValues = byteValues;
|
||||||
|
this.byteValues1 = byteValues.copy();
|
||||||
|
this.byteValues2 = byteValues.copy();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RandomVectorScorer scorer(int i) throws IOException {
|
||||||
|
byte[] query = byteValues1.vectorValue(i);
|
||||||
|
return new HammingVectorScorer(byteValues2, query);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RandomVectorScorerSupplier copy() throws IOException {
|
||||||
|
return new HammingScorerSupplier(byteValues);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,74 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0 and the Server Side Public License, v 1; you may not use this file except
|
||||||
|
* in compliance with, at your election, the Elastic License 2.0 or the Server
|
||||||
|
* Side Public License, v 1.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.index.codec.vectors;
|
||||||
|
|
||||||
|
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||||
|
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
|
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||||
|
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
|
||||||
|
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
|
||||||
|
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter;
|
||||||
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
|
import org.apache.lucene.index.SegmentWriteState;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
public class ES815HnswBitVectorsFormat extends KnnVectorsFormat {
|
||||||
|
|
||||||
|
static final String NAME = "ES815HnswBitVectorsFormat";
|
||||||
|
|
||||||
|
static final int MAXIMUM_MAX_CONN = 512;
|
||||||
|
static final int MAXIMUM_BEAM_WIDTH = 3200;
|
||||||
|
|
||||||
|
private final int maxConn;
|
||||||
|
private final int beamWidth;
|
||||||
|
|
||||||
|
private final FlatVectorsFormat flatVectorsFormat = new ES815BitFlatVectorsFormat();
|
||||||
|
|
||||||
|
public ES815HnswBitVectorsFormat() {
|
||||||
|
this(16, 100);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ES815HnswBitVectorsFormat(int maxConn, int beamWidth) {
|
||||||
|
super(NAME);
|
||||||
|
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"maxConn must be positive and less than or equal to " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"beamWidth must be positive and less than or equal to " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth
|
||||||
|
);
|
||||||
|
}
|
||||||
|
this.maxConn = maxConn;
|
||||||
|
this.beamWidth = beamWidth;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
|
||||||
|
return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), 1, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
|
||||||
|
return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "ES815HnswBitVectorsFormat(name=ES815HnswBitVectorsFormat, maxConn="
|
||||||
|
+ maxConn
|
||||||
|
+ ", beamWidth="
|
||||||
|
+ beamWidth
|
||||||
|
+ ", flatVectorFormat="
|
||||||
|
+ flatVectorsFormat
|
||||||
|
+ ")";
|
||||||
|
}
|
||||||
|
}
|
|
@ -25,7 +25,8 @@ public class MapperFeatures implements FeatureSpecification {
|
||||||
PassThroughObjectMapper.PASS_THROUGH_PRIORITY,
|
PassThroughObjectMapper.PASS_THROUGH_PRIORITY,
|
||||||
RangeFieldMapper.NULL_VALUES_OFF_BY_ONE_FIX,
|
RangeFieldMapper.NULL_VALUES_OFF_BY_ONE_FIX,
|
||||||
SourceFieldMapper.SYNTHETIC_SOURCE_FALLBACK,
|
SourceFieldMapper.SYNTHETIC_SOURCE_FALLBACK,
|
||||||
DenseVectorFieldMapper.INT4_QUANTIZATION
|
DenseVectorFieldMapper.INT4_QUANTIZATION,
|
||||||
|
DenseVectorFieldMapper.BIT_VECTORS
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,6 +31,7 @@ import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.search.FieldExistsQuery;
|
import org.apache.lucene.search.FieldExistsQuery;
|
||||||
import org.apache.lucene.search.Query;
|
import org.apache.lucene.search.Query;
|
||||||
import org.apache.lucene.search.join.BitSetProducer;
|
import org.apache.lucene.search.join.BitSetProducer;
|
||||||
|
import org.apache.lucene.util.BitUtil;
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.apache.lucene.util.VectorUtil;
|
import org.apache.lucene.util.VectorUtil;
|
||||||
import org.elasticsearch.common.ParsingException;
|
import org.elasticsearch.common.ParsingException;
|
||||||
|
@ -41,6 +42,8 @@ import org.elasticsearch.index.IndexVersions;
|
||||||
import org.elasticsearch.index.codec.vectors.ES813FlatVectorFormat;
|
import org.elasticsearch.index.codec.vectors.ES813FlatVectorFormat;
|
||||||
import org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat;
|
import org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat;
|
||||||
import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat;
|
import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat;
|
||||||
|
import org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat;
|
||||||
|
import org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat;
|
||||||
import org.elasticsearch.index.fielddata.FieldDataContext;
|
import org.elasticsearch.index.fielddata.FieldDataContext;
|
||||||
import org.elasticsearch.index.fielddata.IndexFieldData;
|
import org.elasticsearch.index.fielddata.IndexFieldData;
|
||||||
import org.elasticsearch.index.mapper.ArraySourceValueFetcher;
|
import org.elasticsearch.index.mapper.ArraySourceValueFetcher;
|
||||||
|
@ -100,6 +103,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
}
|
}
|
||||||
|
|
||||||
public static final NodeFeature INT4_QUANTIZATION = new NodeFeature("mapper.vectors.int4_quantization");
|
public static final NodeFeature INT4_QUANTIZATION = new NodeFeature("mapper.vectors.int4_quantization");
|
||||||
|
public static final NodeFeature BIT_VECTORS = new NodeFeature("mapper.vectors.bit_vectors");
|
||||||
|
|
||||||
public static final IndexVersion MAGNITUDE_STORED_INDEX_VERSION = IndexVersions.V_7_5_0;
|
public static final IndexVersion MAGNITUDE_STORED_INDEX_VERSION = IndexVersions.V_7_5_0;
|
||||||
public static final IndexVersion INDEXED_BY_DEFAULT_INDEX_VERSION = IndexVersions.FIRST_DETACHED_INDEX_VERSION;
|
public static final IndexVersion INDEXED_BY_DEFAULT_INDEX_VERSION = IndexVersions.FIRST_DETACHED_INDEX_VERSION;
|
||||||
|
@ -109,6 +113,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
|
|
||||||
public static final String CONTENT_TYPE = "dense_vector";
|
public static final String CONTENT_TYPE = "dense_vector";
|
||||||
public static short MAX_DIMS_COUNT = 4096; // maximum allowed number of dimensions
|
public static short MAX_DIMS_COUNT = 4096; // maximum allowed number of dimensions
|
||||||
|
public static int MAX_DIMS_COUNT_BIT = 4096 * Byte.SIZE; // maximum allowed number of dimensions
|
||||||
|
|
||||||
public static short MIN_DIMS_FOR_DYNAMIC_FLOAT_MAPPING = 128; // minimum number of dims for floats to be dynamically mapped to vector
|
public static short MIN_DIMS_FOR_DYNAMIC_FLOAT_MAPPING = 128; // minimum number of dims for floats to be dynamically mapped to vector
|
||||||
public static final int MAGNITUDE_BYTES = 4;
|
public static final int MAGNITUDE_BYTES = 4;
|
||||||
|
@ -134,17 +139,28 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
throw new MapperParsingException("Property [dims] on field [" + n + "] must be an integer but got [" + o + "]");
|
throw new MapperParsingException("Property [dims] on field [" + n + "] must be an integer but got [" + o + "]");
|
||||||
}
|
}
|
||||||
int dims = XContentMapValues.nodeIntegerValue(o);
|
int dims = XContentMapValues.nodeIntegerValue(o);
|
||||||
if (dims < 1 || dims > MAX_DIMS_COUNT) {
|
int maxDims = elementType.getValue() == ElementType.BIT ? MAX_DIMS_COUNT_BIT : MAX_DIMS_COUNT;
|
||||||
|
int minDims = elementType.getValue() == ElementType.BIT ? Byte.SIZE : 1;
|
||||||
|
if (dims < minDims || dims > maxDims) {
|
||||||
throw new MapperParsingException(
|
throw new MapperParsingException(
|
||||||
"The number of dimensions for field ["
|
"The number of dimensions for field ["
|
||||||
+ n
|
+ n
|
||||||
+ "] should be in the range [1, "
|
+ "] should be in the range ["
|
||||||
+ MAX_DIMS_COUNT
|
+ minDims
|
||||||
|
+ ", "
|
||||||
|
+ maxDims
|
||||||
+ "] but was ["
|
+ "] but was ["
|
||||||
+ dims
|
+ dims
|
||||||
+ "]"
|
+ "]"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
if (elementType.getValue() == ElementType.BIT) {
|
||||||
|
if (dims % Byte.SIZE != 0) {
|
||||||
|
throw new MapperParsingException(
|
||||||
|
"The number of dimensions for field [" + n + "] should be a multiple of 8 but was [" + dims + "]"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
return dims;
|
return dims;
|
||||||
}, m -> toType(m).fieldType().dims, XContentBuilder::field, Object::toString).setSerializerCheck((id, ic, v) -> v != null)
|
}, m -> toType(m).fieldType().dims, XContentBuilder::field, Object::toString).setSerializerCheck((id, ic, v) -> v != null)
|
||||||
.setMergeValidator((previous, current, c) -> previous == null || Objects.equals(previous, current));
|
.setMergeValidator((previous, current, c) -> previous == null || Objects.equals(previous, current));
|
||||||
|
@ -171,13 +187,27 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
"similarity",
|
"similarity",
|
||||||
false,
|
false,
|
||||||
m -> toType(m).fieldType().similarity,
|
m -> toType(m).fieldType().similarity,
|
||||||
(Supplier<VectorSimilarity>) () -> indexedByDefault && indexed.getValue() ? VectorSimilarity.COSINE : null,
|
(Supplier<VectorSimilarity>) () -> {
|
||||||
|
if (indexedByDefault && indexed.getValue()) {
|
||||||
|
return elementType.getValue() == ElementType.BIT ? VectorSimilarity.L2_NORM : VectorSimilarity.COSINE;
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
},
|
||||||
VectorSimilarity.class
|
VectorSimilarity.class
|
||||||
).acceptsNull().setSerializerCheck((id, ic, v) -> v != null);
|
).acceptsNull().setSerializerCheck((id, ic, v) -> v != null).addValidator(vectorSim -> {
|
||||||
|
if (vectorSim == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (elementType.getValue() == ElementType.BIT && vectorSim != VectorSimilarity.L2_NORM) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"The [" + VectorSimilarity.L2_NORM + "] similarity is the only supported similarity for bit vectors"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
this.indexOptions = new Parameter<>(
|
this.indexOptions = new Parameter<>(
|
||||||
"index_options",
|
"index_options",
|
||||||
true,
|
true,
|
||||||
() -> defaultInt8Hnsw && elementType.getValue() != ElementType.BYTE && this.indexed.getValue()
|
() -> defaultInt8Hnsw && elementType.getValue() == ElementType.FLOAT && this.indexed.getValue()
|
||||||
? new Int8HnswIndexOptions(
|
? new Int8HnswIndexOptions(
|
||||||
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
|
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
|
||||||
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
|
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
|
||||||
|
@ -266,7 +296,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
|
|
||||||
public enum ElementType {
|
public enum ElementType {
|
||||||
|
|
||||||
BYTE(1) {
|
BYTE {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString() {
|
public String toString() {
|
||||||
|
@ -371,7 +401,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public double computeDotProduct(VectorData vectorData) {
|
public double computeSquaredMagnitude(VectorData vectorData) {
|
||||||
return VectorUtil.dotProduct(vectorData.asByteVector(), vectorData.asByteVector());
|
return VectorUtil.dotProduct(vectorData.asByteVector(), vectorData.asByteVector());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -428,7 +458,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
byte[] decodedVector = HexFormat.of().parseHex(context.parser().text());
|
byte[] decodedVector = HexFormat.of().parseHex(context.parser().text());
|
||||||
fieldMapper.checkDimensionMatches(decodedVector.length, context);
|
fieldMapper.checkDimensionMatches(decodedVector.length, context);
|
||||||
VectorData vectorData = VectorData.fromBytes(decodedVector);
|
VectorData vectorData = VectorData.fromBytes(decodedVector);
|
||||||
double squaredMagnitude = computeDotProduct(vectorData);
|
double squaredMagnitude = computeSquaredMagnitude(vectorData);
|
||||||
checkVectorMagnitude(
|
checkVectorMagnitude(
|
||||||
fieldMapper.fieldType().similarity,
|
fieldMapper.fieldType().similarity,
|
||||||
errorByteElementsAppender(decodedVector),
|
errorByteElementsAppender(decodedVector),
|
||||||
|
@ -463,7 +493,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
int getNumBytes(int dimensions) {
|
int getNumBytes(int dimensions) {
|
||||||
return dimensions * elementBytes;
|
return dimensions;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -494,7 +524,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
FLOAT(4) {
|
FLOAT {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString() {
|
public String toString() {
|
||||||
|
@ -596,7 +626,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public double computeDotProduct(VectorData vectorData) {
|
public double computeSquaredMagnitude(VectorData vectorData) {
|
||||||
return VectorUtil.dotProduct(vectorData.asFloatVector(), vectorData.asFloatVector());
|
return VectorUtil.dotProduct(vectorData.asFloatVector(), vectorData.asFloatVector());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -656,7 +686,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
int getNumBytes(int dimensions) {
|
int getNumBytes(int dimensions) {
|
||||||
return dimensions * elementBytes;
|
return dimensions * Float.BYTES;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -665,14 +695,250 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
? ByteBuffer.wrap(new byte[numBytes]).order(ByteOrder.LITTLE_ENDIAN)
|
? ByteBuffer.wrap(new byte[numBytes]).order(ByteOrder.LITTLE_ENDIAN)
|
||||||
: ByteBuffer.wrap(new byte[numBytes]);
|
: ByteBuffer.wrap(new byte[numBytes]);
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
BIT {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "bit";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void writeValue(ByteBuffer byteBuffer, float value) {
|
||||||
|
byteBuffer.put((byte) value);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void readAndWriteValue(ByteBuffer byteBuffer, XContentBuilder b) throws IOException {
|
||||||
|
b.value(byteBuffer.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
private KnnByteVectorField createKnnVectorField(String name, byte[] vector, VectorSimilarityFunction function) {
|
||||||
|
if (vector == null) {
|
||||||
|
throw new IllegalArgumentException("vector value must not be null");
|
||||||
|
}
|
||||||
|
FieldType denseVectorFieldType = new FieldType();
|
||||||
|
denseVectorFieldType.setVectorAttributes(vector.length, VectorEncoding.BYTE, function);
|
||||||
|
denseVectorFieldType.freeze();
|
||||||
|
return new KnnByteVectorField(name, vector, denseVectorFieldType);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
IndexFieldData.Builder fielddataBuilder(DenseVectorFieldType denseVectorFieldType, FieldDataContext fieldDataContext) {
|
||||||
|
return new VectorIndexFieldData.Builder(
|
||||||
|
denseVectorFieldType.name(),
|
||||||
|
CoreValuesSourceType.KEYWORD,
|
||||||
|
denseVectorFieldType.indexVersionCreated,
|
||||||
|
this,
|
||||||
|
denseVectorFieldType.dims,
|
||||||
|
denseVectorFieldType.indexed,
|
||||||
|
r -> r
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void checkVectorBounds(float[] vector) {
|
||||||
|
checkNanAndInfinite(vector);
|
||||||
|
|
||||||
|
StringBuilder errorBuilder = null;
|
||||||
|
|
||||||
|
for (int index = 0; index < vector.length; ++index) {
|
||||||
|
float value = vector[index];
|
||||||
|
|
||||||
|
if (value % 1.0f != 0.0f) {
|
||||||
|
errorBuilder = new StringBuilder(
|
||||||
|
"element_type ["
|
||||||
|
+ this
|
||||||
|
+ "] vectors only support non-decimal values but found decimal value ["
|
||||||
|
+ value
|
||||||
|
+ "] at dim ["
|
||||||
|
+ index
|
||||||
|
+ "];"
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
|
||||||
|
errorBuilder = new StringBuilder(
|
||||||
|
"element_type ["
|
||||||
|
+ this
|
||||||
|
+ "] vectors only support integers between ["
|
||||||
|
+ Byte.MIN_VALUE
|
||||||
|
+ ", "
|
||||||
|
+ Byte.MAX_VALUE
|
||||||
|
+ "] but found ["
|
||||||
|
+ value
|
||||||
|
+ "] at dim ["
|
||||||
|
+ index
|
||||||
|
+ "];"
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (errorBuilder != null) {
|
||||||
|
throw new IllegalArgumentException(appendErrorElements(errorBuilder, vector).toString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
void checkVectorMagnitude(
|
||||||
|
VectorSimilarity similarity,
|
||||||
|
Function<StringBuilder, StringBuilder> appender,
|
||||||
|
float squaredMagnitude
|
||||||
|
) {}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double computeSquaredMagnitude(VectorData vectorData) {
|
||||||
|
int count = 0;
|
||||||
|
int i = 0;
|
||||||
|
byte[] byteBits = vectorData.asByteVector();
|
||||||
|
for (int upperBound = byteBits.length & -8; i < upperBound; i += 8) {
|
||||||
|
count += Long.bitCount((long) BitUtil.VH_NATIVE_LONG.get(byteBits, i));
|
||||||
|
}
|
||||||
|
|
||||||
|
while (i < byteBits.length) {
|
||||||
|
count += Integer.bitCount(byteBits[i] & 255);
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
private VectorData parseVectorArray(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
|
||||||
|
int index = 0;
|
||||||
|
byte[] vector = new byte[fieldMapper.fieldType().dims / Byte.SIZE];
|
||||||
|
for (XContentParser.Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser()
|
||||||
|
.nextToken()) {
|
||||||
|
fieldMapper.checkDimensionExceeded(index, context);
|
||||||
|
ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser());
|
||||||
|
final int value;
|
||||||
|
if (context.parser().numberType() != XContentParser.NumberType.INT) {
|
||||||
|
float floatValue = context.parser().floatValue(true);
|
||||||
|
if (floatValue % 1.0f != 0.0f) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"element_type ["
|
||||||
|
+ this
|
||||||
|
+ "] vectors only support non-decimal values but found decimal value ["
|
||||||
|
+ floatValue
|
||||||
|
+ "] at dim ["
|
||||||
|
+ index
|
||||||
|
+ "];"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
value = (int) floatValue;
|
||||||
|
} else {
|
||||||
|
value = context.parser().intValue(true);
|
||||||
|
}
|
||||||
|
if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"element_type ["
|
||||||
|
+ this
|
||||||
|
+ "] vectors only support integers between ["
|
||||||
|
+ Byte.MIN_VALUE
|
||||||
|
+ ", "
|
||||||
|
+ Byte.MAX_VALUE
|
||||||
|
+ "] but found ["
|
||||||
|
+ value
|
||||||
|
+ "] at dim ["
|
||||||
|
+ index
|
||||||
|
+ "];"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (index >= vector.length) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"The number of dimensions for field ["
|
||||||
|
+ fieldMapper.fieldType().name()
|
||||||
|
+ "] should be ["
|
||||||
|
+ fieldMapper.fieldType().dims
|
||||||
|
+ "] but found ["
|
||||||
|
+ (index + 1) * Byte.SIZE
|
||||||
|
+ "]"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
vector[index++] = (byte) value;
|
||||||
|
}
|
||||||
|
fieldMapper.checkDimensionMatches(index * Byte.SIZE, context);
|
||||||
|
return VectorData.fromBytes(vector);
|
||||||
|
}
|
||||||
|
|
||||||
|
private VectorData parseHexEncodedVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
|
||||||
|
byte[] decodedVector = HexFormat.of().parseHex(context.parser().text());
|
||||||
|
fieldMapper.checkDimensionMatches(decodedVector.length * Byte.SIZE, context);
|
||||||
|
return VectorData.fromBytes(decodedVector);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
VectorData parseKnnVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
|
||||||
|
XContentParser.Token token = context.parser().currentToken();
|
||||||
|
return switch (token) {
|
||||||
|
case START_ARRAY -> parseVectorArray(context, fieldMapper);
|
||||||
|
case VALUE_STRING -> parseHexEncodedVector(context, fieldMapper);
|
||||||
|
default -> throw new ParsingException(
|
||||||
|
context.parser().getTokenLocation(),
|
||||||
|
format("Unsupported type [%s] for provided value [%s]", token, context.parser().text())
|
||||||
|
);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
|
||||||
|
VectorData vectorData = parseKnnVector(context, fieldMapper);
|
||||||
|
Field field = createKnnVectorField(
|
||||||
|
fieldMapper.fieldType().name(),
|
||||||
|
vectorData.asByteVector(),
|
||||||
|
fieldMapper.fieldType().similarity.vectorSimilarityFunction(fieldMapper.indexCreatedVersion, this)
|
||||||
|
);
|
||||||
|
context.doc().addWithKey(fieldMapper.fieldType().name(), field);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
int getNumBytes(int dimensions) {
|
||||||
|
assert dimensions % Byte.SIZE == 0;
|
||||||
|
return dimensions / Byte.SIZE;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
ByteBuffer createByteBuffer(IndexVersion indexVersion, int numBytes) {
|
||||||
|
return ByteBuffer.wrap(new byte[numBytes]);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
int parseDimensionCount(DocumentParserContext context) throws IOException {
|
||||||
|
XContentParser.Token currentToken = context.parser().currentToken();
|
||||||
|
return switch (currentToken) {
|
||||||
|
case START_ARRAY -> {
|
||||||
|
int index = 0;
|
||||||
|
for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) {
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
yield index * Byte.SIZE;
|
||||||
|
}
|
||||||
|
case VALUE_STRING -> {
|
||||||
|
byte[] decodedVector = HexFormat.of().parseHex(context.parser().text());
|
||||||
|
yield decodedVector.length * Byte.SIZE;
|
||||||
|
}
|
||||||
|
default -> throw new ParsingException(
|
||||||
|
context.parser().getTokenLocation(),
|
||||||
|
format("Unsupported type [%s] for provided value [%s]", currentToken, context.parser().text())
|
||||||
|
);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void checkDimensions(int dvDims, int qvDims) {
|
||||||
|
if (dvDims != qvDims * Byte.SIZE) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"The query vector has a different number of dimensions ["
|
||||||
|
+ qvDims * Byte.SIZE
|
||||||
|
+ "] than the document vectors ["
|
||||||
|
+ dvDims
|
||||||
|
+ "]."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
final int elementBytes;
|
|
||||||
|
|
||||||
ElementType(int elementBytes) {
|
|
||||||
this.elementBytes = elementBytes;
|
|
||||||
}
|
|
||||||
|
|
||||||
public abstract void writeValue(ByteBuffer byteBuffer, float value);
|
public abstract void writeValue(ByteBuffer byteBuffer, float value);
|
||||||
|
|
||||||
public abstract void readAndWriteValue(ByteBuffer byteBuffer, XContentBuilder b) throws IOException;
|
public abstract void readAndWriteValue(ByteBuffer byteBuffer, XContentBuilder b) throws IOException;
|
||||||
|
@ -695,6 +961,14 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
float squaredMagnitude
|
float squaredMagnitude
|
||||||
);
|
);
|
||||||
|
|
||||||
|
public void checkDimensions(int dvDims, int qvDims) {
|
||||||
|
if (dvDims != qvDims) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"The query vector has a different number of dimensions [" + qvDims + "] than the document vectors [" + dvDims + "]."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int parseDimensionCount(DocumentParserContext context) throws IOException {
|
int parseDimensionCount(DocumentParserContext context) throws IOException {
|
||||||
int index = 0;
|
int index = 0;
|
||||||
for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) {
|
for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) {
|
||||||
|
@ -775,7 +1049,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
return sb -> appendErrorElements(sb, vector);
|
return sb -> appendErrorElements(sb, vector);
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract double computeDotProduct(VectorData vectorData);
|
public abstract double computeSquaredMagnitude(VectorData vectorData);
|
||||||
|
|
||||||
public static ElementType fromString(String name) {
|
public static ElementType fromString(String name) {
|
||||||
return valueOf(name.trim().toUpperCase(Locale.ROOT));
|
return valueOf(name.trim().toUpperCase(Locale.ROOT));
|
||||||
|
@ -786,7 +1060,9 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
ElementType.BYTE.toString(),
|
ElementType.BYTE.toString(),
|
||||||
ElementType.BYTE,
|
ElementType.BYTE,
|
||||||
ElementType.FLOAT.toString(),
|
ElementType.FLOAT.toString(),
|
||||||
ElementType.FLOAT
|
ElementType.FLOAT,
|
||||||
|
ElementType.BIT.toString(),
|
||||||
|
ElementType.BIT
|
||||||
);
|
);
|
||||||
|
|
||||||
public enum VectorSimilarity {
|
public enum VectorSimilarity {
|
||||||
|
@ -795,6 +1071,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
float score(float similarity, ElementType elementType, int dim) {
|
float score(float similarity, ElementType elementType, int dim) {
|
||||||
return switch (elementType) {
|
return switch (elementType) {
|
||||||
case BYTE, FLOAT -> 1f / (1f + similarity * similarity);
|
case BYTE, FLOAT -> 1f / (1f + similarity * similarity);
|
||||||
|
case BIT -> (dim - similarity) / dim;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -806,8 +1083,10 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
COSINE {
|
COSINE {
|
||||||
@Override
|
@Override
|
||||||
float score(float similarity, ElementType elementType, int dim) {
|
float score(float similarity, ElementType elementType, int dim) {
|
||||||
|
assert elementType != ElementType.BIT;
|
||||||
return switch (elementType) {
|
return switch (elementType) {
|
||||||
case BYTE, FLOAT -> (1 + similarity) / 2f;
|
case BYTE, FLOAT -> (1 + similarity) / 2f;
|
||||||
|
default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]");
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -824,6 +1103,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
return switch (elementType) {
|
return switch (elementType) {
|
||||||
case BYTE -> 0.5f + similarity / (float) (dim * (1 << 15));
|
case BYTE -> 0.5f + similarity / (float) (dim * (1 << 15));
|
||||||
case FLOAT -> (1 + similarity) / 2f;
|
case FLOAT -> (1 + similarity) / 2f;
|
||||||
|
default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]");
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -837,6 +1117,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
float score(float similarity, ElementType elementType, int dim) {
|
float score(float similarity, ElementType elementType, int dim) {
|
||||||
return switch (elementType) {
|
return switch (elementType) {
|
||||||
case BYTE, FLOAT -> similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1;
|
case BYTE, FLOAT -> similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1;
|
||||||
|
default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]");
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -863,7 +1144,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
this.type = type;
|
this.type = type;
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract KnnVectorsFormat getVectorsFormat();
|
abstract KnnVectorsFormat getVectorsFormat(ElementType elementType);
|
||||||
|
|
||||||
boolean supportsElementType(ElementType elementType) {
|
boolean supportsElementType(ElementType elementType) {
|
||||||
return true;
|
return true;
|
||||||
|
@ -1002,7 +1283,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
KnnVectorsFormat getVectorsFormat() {
|
KnnVectorsFormat getVectorsFormat(ElementType elementType) {
|
||||||
|
assert elementType == ElementType.FLOAT;
|
||||||
return new ES813Int8FlatVectorFormat(confidenceInterval, 7, false);
|
return new ES813Int8FlatVectorFormat(confidenceInterval, 7, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1021,7 +1303,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
boolean supportsElementType(ElementType elementType) {
|
boolean supportsElementType(ElementType elementType) {
|
||||||
return elementType != ElementType.BYTE;
|
return elementType == ElementType.FLOAT;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -1047,7 +1329,10 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
KnnVectorsFormat getVectorsFormat() {
|
KnnVectorsFormat getVectorsFormat(ElementType elementType) {
|
||||||
|
if (elementType.equals(ElementType.BIT)) {
|
||||||
|
return new ES815BitFlatVectorFormat();
|
||||||
|
}
|
||||||
return new ES813FlatVectorFormat();
|
return new ES813FlatVectorFormat();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1083,7 +1368,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public KnnVectorsFormat getVectorsFormat() {
|
public KnnVectorsFormat getVectorsFormat(ElementType elementType) {
|
||||||
|
assert elementType == ElementType.FLOAT;
|
||||||
return new ES814HnswScalarQuantizedVectorsFormat(m, efConstruction, confidenceInterval, 4, true);
|
return new ES814HnswScalarQuantizedVectorsFormat(m, efConstruction, confidenceInterval, 4, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1126,7 +1412,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
boolean supportsElementType(ElementType elementType) {
|
boolean supportsElementType(ElementType elementType) {
|
||||||
return elementType != ElementType.BYTE;
|
return elementType == ElementType.FLOAT;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -1153,7 +1439,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public KnnVectorsFormat getVectorsFormat() {
|
public KnnVectorsFormat getVectorsFormat(ElementType elementType) {
|
||||||
|
assert elementType == ElementType.FLOAT;
|
||||||
return new ES813Int8FlatVectorFormat(confidenceInterval, 4, true);
|
return new ES813Int8FlatVectorFormat(confidenceInterval, 4, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1186,7 +1473,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
boolean supportsElementType(ElementType elementType) {
|
boolean supportsElementType(ElementType elementType) {
|
||||||
return elementType != ElementType.BYTE;
|
return elementType == ElementType.FLOAT;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -1216,7 +1503,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public KnnVectorsFormat getVectorsFormat() {
|
public KnnVectorsFormat getVectorsFormat(ElementType elementType) {
|
||||||
|
assert elementType == ElementType.FLOAT;
|
||||||
return new ES814HnswScalarQuantizedVectorsFormat(m, efConstruction, confidenceInterval, 7, false);
|
return new ES814HnswScalarQuantizedVectorsFormat(m, efConstruction, confidenceInterval, 7, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1261,7 +1549,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
boolean supportsElementType(ElementType elementType) {
|
boolean supportsElementType(ElementType elementType) {
|
||||||
return elementType != ElementType.BYTE;
|
return elementType == ElementType.FLOAT;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -1291,7 +1579,10 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public KnnVectorsFormat getVectorsFormat() {
|
public KnnVectorsFormat getVectorsFormat(ElementType elementType) {
|
||||||
|
if (elementType == ElementType.BIT) {
|
||||||
|
return new ES815HnswBitVectorsFormat(m, efConstruction);
|
||||||
|
}
|
||||||
return new Lucene99HnswVectorsFormat(m, efConstruction, 1, null);
|
return new Lucene99HnswVectorsFormat(m, efConstruction, 1, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1412,48 +1703,6 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] doesn't support term queries");
|
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] doesn't support term queries");
|
||||||
}
|
}
|
||||||
|
|
||||||
public Query createKnnQuery(
|
|
||||||
byte[] queryVector,
|
|
||||||
int numCands,
|
|
||||||
Query filter,
|
|
||||||
Float similarityThreshold,
|
|
||||||
BitSetProducer parentFilter
|
|
||||||
) {
|
|
||||||
if (isIndexed() == false) {
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (queryVector.length != dims) {
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (elementType != ElementType.BYTE) {
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"only [" + ElementType.BYTE + "] elements are supported when querying field [" + name() + "]"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
|
|
||||||
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
|
|
||||||
elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude);
|
|
||||||
}
|
|
||||||
Query knnQuery = parentFilter != null
|
|
||||||
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, numCands, parentFilter)
|
|
||||||
: new ESKnnByteVectorQuery(name(), queryVector, numCands, filter);
|
|
||||||
if (similarityThreshold != null) {
|
|
||||||
knnQuery = new VectorSimilarityQuery(
|
|
||||||
knnQuery,
|
|
||||||
similarityThreshold,
|
|
||||||
similarity.score(similarityThreshold, elementType, dims)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return knnQuery;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Query createExactKnnQuery(VectorData queryVector) {
|
public Query createExactKnnQuery(VectorData queryVector) {
|
||||||
if (isIndexed() == false) {
|
if (isIndexed() == false) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
|
@ -1463,15 +1712,17 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
return switch (elementType) {
|
return switch (elementType) {
|
||||||
case BYTE -> createExactKnnByteQuery(queryVector.asByteVector());
|
case BYTE -> createExactKnnByteQuery(queryVector.asByteVector());
|
||||||
case FLOAT -> createExactKnnFloatQuery(queryVector.asFloatVector());
|
case FLOAT -> createExactKnnFloatQuery(queryVector.asFloatVector());
|
||||||
|
case BIT -> createExactKnnBitQuery(queryVector.asByteVector());
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Query createExactKnnBitQuery(byte[] queryVector) {
|
||||||
|
elementType.checkDimensions(dims, queryVector.length);
|
||||||
|
return new DenseVectorQuery.Bytes(queryVector, name());
|
||||||
|
}
|
||||||
|
|
||||||
private Query createExactKnnByteQuery(byte[] queryVector) {
|
private Query createExactKnnByteQuery(byte[] queryVector) {
|
||||||
if (queryVector.length != dims) {
|
elementType.checkDimensions(dims, queryVector.length);
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
|
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
|
||||||
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
|
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
|
||||||
elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude);
|
elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude);
|
||||||
|
@ -1480,11 +1731,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
}
|
}
|
||||||
|
|
||||||
private Query createExactKnnFloatQuery(float[] queryVector) {
|
private Query createExactKnnFloatQuery(float[] queryVector) {
|
||||||
if (queryVector.length != dims) {
|
elementType.checkDimensions(dims, queryVector.length);
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
elementType.checkVectorBounds(queryVector);
|
elementType.checkVectorBounds(queryVector);
|
||||||
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
|
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
|
||||||
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
|
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
|
||||||
|
@ -1521,9 +1768,31 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
return switch (getElementType()) {
|
return switch (getElementType()) {
|
||||||
case BYTE -> createKnnByteQuery(queryVector.asByteVector(), numCands, filter, similarityThreshold, parentFilter);
|
case BYTE -> createKnnByteQuery(queryVector.asByteVector(), numCands, filter, similarityThreshold, parentFilter);
|
||||||
case FLOAT -> createKnnFloatQuery(queryVector.asFloatVector(), numCands, filter, similarityThreshold, parentFilter);
|
case FLOAT -> createKnnFloatQuery(queryVector.asFloatVector(), numCands, filter, similarityThreshold, parentFilter);
|
||||||
|
case BIT -> createKnnBitQuery(queryVector.asByteVector(), numCands, filter, similarityThreshold, parentFilter);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Query createKnnBitQuery(
|
||||||
|
byte[] queryVector,
|
||||||
|
int numCands,
|
||||||
|
Query filter,
|
||||||
|
Float similarityThreshold,
|
||||||
|
BitSetProducer parentFilter
|
||||||
|
) {
|
||||||
|
elementType.checkDimensions(dims, queryVector.length);
|
||||||
|
Query knnQuery = parentFilter != null
|
||||||
|
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, numCands, parentFilter)
|
||||||
|
: new ESKnnByteVectorQuery(name(), queryVector, numCands, filter);
|
||||||
|
if (similarityThreshold != null) {
|
||||||
|
knnQuery = new VectorSimilarityQuery(
|
||||||
|
knnQuery,
|
||||||
|
similarityThreshold,
|
||||||
|
similarity.score(similarityThreshold, elementType, dims)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return knnQuery;
|
||||||
|
}
|
||||||
|
|
||||||
private Query createKnnByteQuery(
|
private Query createKnnByteQuery(
|
||||||
byte[] queryVector,
|
byte[] queryVector,
|
||||||
int numCands,
|
int numCands,
|
||||||
|
@ -1531,11 +1800,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
Float similarityThreshold,
|
Float similarityThreshold,
|
||||||
BitSetProducer parentFilter
|
BitSetProducer parentFilter
|
||||||
) {
|
) {
|
||||||
if (queryVector.length != dims) {
|
elementType.checkDimensions(dims, queryVector.length);
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
|
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
|
||||||
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
|
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
|
||||||
|
@ -1561,11 +1826,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
Float similarityThreshold,
|
Float similarityThreshold,
|
||||||
BitSetProducer parentFilter
|
BitSetProducer parentFilter
|
||||||
) {
|
) {
|
||||||
if (queryVector.length != dims) {
|
elementType.checkDimensions(dims, queryVector.length);
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
elementType.checkVectorBounds(queryVector);
|
elementType.checkVectorBounds(queryVector);
|
||||||
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
|
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
|
||||||
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
|
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
|
||||||
|
@ -1701,7 +1962,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
vectorData.addToBuffer(byteBuffer);
|
vectorData.addToBuffer(byteBuffer);
|
||||||
if (indexCreatedVersion.onOrAfter(MAGNITUDE_STORED_INDEX_VERSION)) {
|
if (indexCreatedVersion.onOrAfter(MAGNITUDE_STORED_INDEX_VERSION)) {
|
||||||
// encode vector magnitude at the end
|
// encode vector magnitude at the end
|
||||||
double dotProduct = elementType.computeDotProduct(vectorData);
|
double dotProduct = elementType.computeSquaredMagnitude(vectorData);
|
||||||
float vectorMagnitude = (float) Math.sqrt(dotProduct);
|
float vectorMagnitude = (float) Math.sqrt(dotProduct);
|
||||||
byteBuffer.putFloat(vectorMagnitude);
|
byteBuffer.putFloat(vectorMagnitude);
|
||||||
}
|
}
|
||||||
|
@ -1780,9 +2041,9 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
||||||
public KnnVectorsFormat getKnnVectorsFormatForField(KnnVectorsFormat defaultFormat) {
|
public KnnVectorsFormat getKnnVectorsFormatForField(KnnVectorsFormat defaultFormat) {
|
||||||
final KnnVectorsFormat format;
|
final KnnVectorsFormat format;
|
||||||
if (indexOptions == null) {
|
if (indexOptions == null) {
|
||||||
format = defaultFormat;
|
format = fieldType().elementType == ElementType.BIT ? new ES815HnswBitVectorsFormat() : defaultFormat;
|
||||||
} else {
|
} else {
|
||||||
format = indexOptions.getVectorsFormat();
|
format = indexOptions.getVectorsFormat(fieldType().elementType);
|
||||||
}
|
}
|
||||||
// It's legal to reuse the same format name as this is the same on-disk format.
|
// It's legal to reuse the same format name as this is the same on-disk format.
|
||||||
return new KnnVectorsFormat(format.getName()) {
|
return new KnnVectorsFormat(format.getName()) {
|
||||||
|
|
|
@ -17,6 +17,8 @@ import org.elasticsearch.index.fielddata.SortedBinaryDocValues;
|
||||||
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
|
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
|
||||||
import org.elasticsearch.script.field.DocValuesScriptFieldFactory;
|
import org.elasticsearch.script.field.DocValuesScriptFieldFactory;
|
||||||
import org.elasticsearch.script.field.vectors.BinaryDenseVectorDocValuesField;
|
import org.elasticsearch.script.field.vectors.BinaryDenseVectorDocValuesField;
|
||||||
|
import org.elasticsearch.script.field.vectors.BitBinaryDenseVectorDocValuesField;
|
||||||
|
import org.elasticsearch.script.field.vectors.BitKnnDenseVectorDocValuesField;
|
||||||
import org.elasticsearch.script.field.vectors.ByteBinaryDenseVectorDocValuesField;
|
import org.elasticsearch.script.field.vectors.ByteBinaryDenseVectorDocValuesField;
|
||||||
import org.elasticsearch.script.field.vectors.ByteKnnDenseVectorDocValuesField;
|
import org.elasticsearch.script.field.vectors.ByteKnnDenseVectorDocValuesField;
|
||||||
import org.elasticsearch.script.field.vectors.KnnDenseVectorDocValuesField;
|
import org.elasticsearch.script.field.vectors.KnnDenseVectorDocValuesField;
|
||||||
|
@ -58,12 +60,14 @@ final class VectorDVLeafFieldData implements LeafFieldData {
|
||||||
return switch (elementType) {
|
return switch (elementType) {
|
||||||
case BYTE -> new ByteKnnDenseVectorDocValuesField(reader.getByteVectorValues(field), name, dims);
|
case BYTE -> new ByteKnnDenseVectorDocValuesField(reader.getByteVectorValues(field), name, dims);
|
||||||
case FLOAT -> new KnnDenseVectorDocValuesField(reader.getFloatVectorValues(field), name, dims);
|
case FLOAT -> new KnnDenseVectorDocValuesField(reader.getFloatVectorValues(field), name, dims);
|
||||||
|
case BIT -> new BitKnnDenseVectorDocValuesField(reader.getByteVectorValues(field), name, dims);
|
||||||
};
|
};
|
||||||
} else {
|
} else {
|
||||||
BinaryDocValues values = DocValues.getBinary(reader, field);
|
BinaryDocValues values = DocValues.getBinary(reader, field);
|
||||||
return switch (elementType) {
|
return switch (elementType) {
|
||||||
case BYTE -> new ByteBinaryDenseVectorDocValuesField(values, name, elementType, dims);
|
case BYTE -> new ByteBinaryDenseVectorDocValuesField(values, name, elementType, dims);
|
||||||
case FLOAT -> new BinaryDenseVectorDocValuesField(values, name, elementType, dims, indexVersion);
|
case FLOAT -> new BinaryDenseVectorDocValuesField(values, name, elementType, dims, indexVersion);
|
||||||
|
case BIT -> new BitBinaryDenseVectorDocValuesField(values, name, elementType, dims);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
|
|
|
@ -56,7 +56,7 @@ public class VectorScoreScriptUtils {
|
||||||
*/
|
*/
|
||||||
public ByteDenseVectorFunction(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
|
public ByteDenseVectorFunction(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
|
||||||
super(scoreScript, field);
|
super(scoreScript, field);
|
||||||
DenseVector.checkDimensions(field.get().getDims(), queryVector.size());
|
field.getElementType().checkDimensions(field.get().getDims(), queryVector.size());
|
||||||
this.queryVector = new byte[queryVector.size()];
|
this.queryVector = new byte[queryVector.size()];
|
||||||
float[] validateValues = new float[queryVector.size()];
|
float[] validateValues = new float[queryVector.size()];
|
||||||
int queryMagnitude = 0;
|
int queryMagnitude = 0;
|
||||||
|
@ -168,7 +168,7 @@ public class VectorScoreScriptUtils {
|
||||||
public L1Norm(ScoreScript scoreScript, Object queryVector, String fieldName) {
|
public L1Norm(ScoreScript scoreScript, Object queryVector, String fieldName) {
|
||||||
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
|
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
|
||||||
function = switch (field.getElementType()) {
|
function = switch (field.getElementType()) {
|
||||||
case BYTE -> {
|
case BYTE, BIT -> {
|
||||||
if (queryVector instanceof List) {
|
if (queryVector instanceof List) {
|
||||||
yield new ByteL1Norm(scoreScript, field, (List<Number>) queryVector);
|
yield new ByteL1Norm(scoreScript, field, (List<Number>) queryVector);
|
||||||
} else if (queryVector instanceof String s) {
|
} else if (queryVector instanceof String s) {
|
||||||
|
@ -219,8 +219,8 @@ public class VectorScoreScriptUtils {
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
public Hamming(ScoreScript scoreScript, Object queryVector, String fieldName) {
|
public Hamming(ScoreScript scoreScript, Object queryVector, String fieldName) {
|
||||||
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
|
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
|
||||||
if (field.getElementType() != DenseVectorFieldMapper.ElementType.BYTE) {
|
if (field.getElementType() == DenseVectorFieldMapper.ElementType.FLOAT) {
|
||||||
throw new IllegalArgumentException("hamming distance is only supported for byte vectors");
|
throw new IllegalArgumentException("hamming distance is only supported for byte or bit vectors");
|
||||||
}
|
}
|
||||||
if (queryVector instanceof List) {
|
if (queryVector instanceof List) {
|
||||||
function = new ByteHammingDistance(scoreScript, field, (List<Number>) queryVector);
|
function = new ByteHammingDistance(scoreScript, field, (List<Number>) queryVector);
|
||||||
|
@ -278,7 +278,7 @@ public class VectorScoreScriptUtils {
|
||||||
public L2Norm(ScoreScript scoreScript, Object queryVector, String fieldName) {
|
public L2Norm(ScoreScript scoreScript, Object queryVector, String fieldName) {
|
||||||
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
|
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
|
||||||
function = switch (field.getElementType()) {
|
function = switch (field.getElementType()) {
|
||||||
case BYTE -> {
|
case BYTE, BIT -> {
|
||||||
if (queryVector instanceof List) {
|
if (queryVector instanceof List) {
|
||||||
yield new ByteL2Norm(scoreScript, field, (List<Number>) queryVector);
|
yield new ByteL2Norm(scoreScript, field, (List<Number>) queryVector);
|
||||||
} else if (queryVector instanceof String s) {
|
} else if (queryVector instanceof String s) {
|
||||||
|
@ -342,7 +342,7 @@ public class VectorScoreScriptUtils {
|
||||||
public DotProduct(ScoreScript scoreScript, Object queryVector, String fieldName) {
|
public DotProduct(ScoreScript scoreScript, Object queryVector, String fieldName) {
|
||||||
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
|
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
|
||||||
function = switch (field.getElementType()) {
|
function = switch (field.getElementType()) {
|
||||||
case BYTE -> {
|
case BYTE, BIT -> {
|
||||||
if (queryVector instanceof List) {
|
if (queryVector instanceof List) {
|
||||||
yield new ByteDotProduct(scoreScript, field, (List<Number>) queryVector);
|
yield new ByteDotProduct(scoreScript, field, (List<Number>) queryVector);
|
||||||
} else if (queryVector instanceof String s) {
|
} else if (queryVector instanceof String s) {
|
||||||
|
@ -406,7 +406,7 @@ public class VectorScoreScriptUtils {
|
||||||
public CosineSimilarity(ScoreScript scoreScript, Object queryVector, String fieldName) {
|
public CosineSimilarity(ScoreScript scoreScript, Object queryVector, String fieldName) {
|
||||||
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
|
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
|
||||||
function = switch (field.getElementType()) {
|
function = switch (field.getElementType()) {
|
||||||
case BYTE -> {
|
case BYTE, BIT -> {
|
||||||
if (queryVector instanceof List) {
|
if (queryVector instanceof List) {
|
||||||
yield new ByteCosineSimilarity(scoreScript, field, (List<Number>) queryVector);
|
yield new ByteCosineSimilarity(scoreScript, field, (List<Number>) queryVector);
|
||||||
} else if (queryVector instanceof String s) {
|
} else if (queryVector instanceof String s) {
|
||||||
|
|
|
@ -0,0 +1,88 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0 and the Server Side Public License, v 1; you may not use this file except
|
||||||
|
* in compliance with, at your election, the Elastic License 2.0 or the Server
|
||||||
|
* Side Public License, v 1.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.script.field.vectors;
|
||||||
|
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class BitBinaryDenseVector extends ByteBinaryDenseVector {
|
||||||
|
|
||||||
|
public BitBinaryDenseVector(byte[] vectorValue, BytesRef docVector, int dims) {
|
||||||
|
super(vectorValue, docVector, dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void checkDimensions(int qvDims) {
|
||||||
|
if (qvDims != dims) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"The query vector has a different number of dimensions ["
|
||||||
|
+ qvDims * Byte.SIZE
|
||||||
|
+ "] than the document vectors ["
|
||||||
|
+ dims * Byte.SIZE
|
||||||
|
+ "]."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int l1Norm(byte[] queryVector) {
|
||||||
|
return hamming(queryVector);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double l1Norm(List<Number> queryVector) {
|
||||||
|
return hamming(queryVector);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double l2Norm(byte[] queryVector) {
|
||||||
|
return Math.sqrt(hamming(queryVector));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double l2Norm(List<Number> queryVector) {
|
||||||
|
return Math.sqrt(hamming(queryVector));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int dotProduct(byte[] queryVector) {
|
||||||
|
throw new UnsupportedOperationException("dotProduct is not supported for bit vectors.");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double cosineSimilarity(float[] queryVector, boolean normalizeQueryVector) {
|
||||||
|
throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors.");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double dotProduct(List<Number> queryVector) {
|
||||||
|
throw new UnsupportedOperationException("dotProduct is not supported for bit vectors.");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double cosineSimilarity(byte[] queryVector, float qvMagnitude) {
|
||||||
|
throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors.");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double cosineSimilarity(List<Number> queryVector) {
|
||||||
|
throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors.");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double dotProduct(float[] queryVector) {
|
||||||
|
throw new UnsupportedOperationException("dotProduct is not supported for bit vectors.");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getDims() {
|
||||||
|
return dims * Byte.SIZE;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,24 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0 and the Server Side Public License, v 1; you may not use this file except
|
||||||
|
* in compliance with, at your election, the Elastic License 2.0 or the Server
|
||||||
|
* Side Public License, v 1.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.script.field.vectors;
|
||||||
|
|
||||||
|
import org.apache.lucene.index.BinaryDocValues;
|
||||||
|
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
|
||||||
|
|
||||||
|
public class BitBinaryDenseVectorDocValuesField extends ByteBinaryDenseVectorDocValuesField {
|
||||||
|
|
||||||
|
public BitBinaryDenseVectorDocValuesField(BinaryDocValues input, String name, ElementType elementType, int dims) {
|
||||||
|
super(input, name, elementType, dims / 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected DenseVector getVector() {
|
||||||
|
return new BitBinaryDenseVector(vectorValue, value, dims);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,95 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0 and the Server Side Public License, v 1; you may not use this file except
|
||||||
|
* in compliance with, at your election, the Elastic License 2.0 or the Server
|
||||||
|
* Side Public License, v 1.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.script.field.vectors;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class BitKnnDenseVector extends ByteKnnDenseVector {
|
||||||
|
|
||||||
|
public BitKnnDenseVector(byte[] vector) {
|
||||||
|
super(vector);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void checkDimensions(int qvDims) {
|
||||||
|
if (qvDims != docVector.length) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"The query vector has a different number of dimensions ["
|
||||||
|
+ qvDims * Byte.SIZE
|
||||||
|
+ "] than the document vectors ["
|
||||||
|
+ docVector.length * Byte.SIZE
|
||||||
|
+ "]."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float getMagnitude() {
|
||||||
|
if (magnitudeCalculated == false) {
|
||||||
|
magnitude = DenseVector.getBitMagnitude(docVector, docVector.length);
|
||||||
|
magnitudeCalculated = true;
|
||||||
|
}
|
||||||
|
return magnitude;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int l1Norm(byte[] queryVector) {
|
||||||
|
return hamming(queryVector);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double l1Norm(List<Number> queryVector) {
|
||||||
|
return hamming(queryVector);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double l2Norm(byte[] queryVector) {
|
||||||
|
return Math.sqrt(hamming(queryVector));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double l2Norm(List<Number> queryVector) {
|
||||||
|
return Math.sqrt(hamming(queryVector));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int dotProduct(byte[] queryVector) {
|
||||||
|
throw new UnsupportedOperationException("dotProduct is not supported for bit vectors.");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double cosineSimilarity(float[] queryVector, boolean normalizeQueryVector) {
|
||||||
|
throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors.");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double dotProduct(List<Number> queryVector) {
|
||||||
|
throw new UnsupportedOperationException("dotProduct is not supported for bit vectors.");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double cosineSimilarity(byte[] queryVector, float qvMagnitude) {
|
||||||
|
throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors.");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double cosineSimilarity(List<Number> queryVector) {
|
||||||
|
throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors.");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double dotProduct(float[] queryVector) {
|
||||||
|
throw new UnsupportedOperationException("dotProduct is not supported for bit vectors.");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getDims() {
|
||||||
|
return docVector.length * Byte.SIZE;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,26 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0 and the Server Side Public License, v 1; you may not use this file except
|
||||||
|
* in compliance with, at your election, the Elastic License 2.0 or the Server
|
||||||
|
* Side Public License, v 1.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.script.field.vectors;
|
||||||
|
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
|
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
|
||||||
|
|
||||||
|
public class BitKnnDenseVectorDocValuesField extends ByteKnnDenseVectorDocValuesField {
|
||||||
|
|
||||||
|
public BitKnnDenseVectorDocValuesField(@Nullable ByteVectorValues input, String name, int dims) {
|
||||||
|
super(input, name, dims / 8, DenseVectorFieldMapper.ElementType.BIT);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected DenseVector getVector() {
|
||||||
|
return new BitKnnDenseVector(vector);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -21,7 +21,7 @@ public class ByteBinaryDenseVector implements DenseVector {
|
||||||
|
|
||||||
private final BytesRef docVector;
|
private final BytesRef docVector;
|
||||||
private final byte[] vectorValue;
|
private final byte[] vectorValue;
|
||||||
private final int dims;
|
protected final int dims;
|
||||||
|
|
||||||
private float[] floatDocVector;
|
private float[] floatDocVector;
|
||||||
private boolean magnitudeDecoded;
|
private boolean magnitudeDecoded;
|
||||||
|
|
|
@ -17,11 +17,11 @@ import java.io.IOException;
|
||||||
|
|
||||||
public class ByteBinaryDenseVectorDocValuesField extends DenseVectorDocValuesField {
|
public class ByteBinaryDenseVectorDocValuesField extends DenseVectorDocValuesField {
|
||||||
|
|
||||||
private final BinaryDocValues input;
|
protected final BinaryDocValues input;
|
||||||
private final int dims;
|
protected final int dims;
|
||||||
private final byte[] vectorValue;
|
protected final byte[] vectorValue;
|
||||||
private boolean decoded;
|
protected boolean decoded;
|
||||||
private BytesRef value;
|
protected BytesRef value;
|
||||||
|
|
||||||
public ByteBinaryDenseVectorDocValuesField(BinaryDocValues input, String name, ElementType elementType, int dims) {
|
public ByteBinaryDenseVectorDocValuesField(BinaryDocValues input, String name, ElementType elementType, int dims) {
|
||||||
super(name, elementType);
|
super(name, elementType);
|
||||||
|
@ -50,13 +50,17 @@ public class ByteBinaryDenseVectorDocValuesField extends DenseVectorDocValuesFie
|
||||||
return value == null;
|
return value == null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected DenseVector getVector() {
|
||||||
|
return new ByteBinaryDenseVector(vectorValue, value, dims);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DenseVector get() {
|
public DenseVector get() {
|
||||||
if (isEmpty()) {
|
if (isEmpty()) {
|
||||||
return DenseVector.EMPTY;
|
return DenseVector.EMPTY;
|
||||||
}
|
}
|
||||||
decodeVectorIfNecessary();
|
decodeVectorIfNecessary();
|
||||||
return new ByteBinaryDenseVector(vectorValue, value, dims);
|
return getVector();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -65,7 +69,7 @@ public class ByteBinaryDenseVectorDocValuesField extends DenseVectorDocValuesFie
|
||||||
return defaultValue;
|
return defaultValue;
|
||||||
}
|
}
|
||||||
decodeVectorIfNecessary();
|
decodeVectorIfNecessary();
|
||||||
return new ByteBinaryDenseVector(vectorValue, value, dims);
|
return getVector();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -23,7 +23,11 @@ public class ByteKnnDenseVectorDocValuesField extends DenseVectorDocValuesField
|
||||||
protected final int dims;
|
protected final int dims;
|
||||||
|
|
||||||
public ByteKnnDenseVectorDocValuesField(@Nullable ByteVectorValues input, String name, int dims) {
|
public ByteKnnDenseVectorDocValuesField(@Nullable ByteVectorValues input, String name, int dims) {
|
||||||
super(name, ElementType.BYTE);
|
this(input, name, dims, ElementType.BYTE);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected ByteKnnDenseVectorDocValuesField(@Nullable ByteVectorValues input, String name, int dims, ElementType elementType) {
|
||||||
|
super(name, elementType);
|
||||||
this.dims = dims;
|
this.dims = dims;
|
||||||
this.input = input;
|
this.input = input;
|
||||||
}
|
}
|
||||||
|
@ -57,13 +61,17 @@ public class ByteKnnDenseVectorDocValuesField extends DenseVectorDocValuesField
|
||||||
return vector == null;
|
return vector == null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected DenseVector getVector() {
|
||||||
|
return new ByteKnnDenseVector(vector);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DenseVector get() {
|
public DenseVector get() {
|
||||||
if (isEmpty()) {
|
if (isEmpty()) {
|
||||||
return DenseVector.EMPTY;
|
return DenseVector.EMPTY;
|
||||||
}
|
}
|
||||||
|
|
||||||
return new ByteKnnDenseVector(vector);
|
return getVector();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -72,7 +80,7 @@ public class ByteKnnDenseVectorDocValuesField extends DenseVectorDocValuesField
|
||||||
return defaultValue;
|
return defaultValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
return new ByteKnnDenseVector(vector);
|
return getVector();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
|
|
||||||
package org.elasticsearch.script.field.vectors;
|
package org.elasticsearch.script.field.vectors;
|
||||||
|
|
||||||
|
import org.apache.lucene.util.BitUtil;
|
||||||
import org.apache.lucene.util.VectorUtil;
|
import org.apache.lucene.util.VectorUtil;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -25,6 +26,10 @@ import java.util.List;
|
||||||
*/
|
*/
|
||||||
public interface DenseVector {
|
public interface DenseVector {
|
||||||
|
|
||||||
|
default void checkDimensions(int qvDims) {
|
||||||
|
checkDimensions(getDims(), qvDims);
|
||||||
|
}
|
||||||
|
|
||||||
float[] getVector();
|
float[] getVector();
|
||||||
|
|
||||||
float getMagnitude();
|
float getMagnitude();
|
||||||
|
@ -38,13 +43,13 @@ public interface DenseVector {
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
default double dotProduct(Object queryVector) {
|
default double dotProduct(Object queryVector) {
|
||||||
if (queryVector instanceof float[] floats) {
|
if (queryVector instanceof float[] floats) {
|
||||||
checkDimensions(getDims(), floats.length);
|
checkDimensions(floats.length);
|
||||||
return dotProduct(floats);
|
return dotProduct(floats);
|
||||||
} else if (queryVector instanceof List<?> list) {
|
} else if (queryVector instanceof List<?> list) {
|
||||||
checkDimensions(getDims(), list.size());
|
checkDimensions(list.size());
|
||||||
return dotProduct((List<Number>) list);
|
return dotProduct((List<Number>) list);
|
||||||
} else if (queryVector instanceof byte[] bytes) {
|
} else if (queryVector instanceof byte[] bytes) {
|
||||||
checkDimensions(getDims(), bytes.length);
|
checkDimensions(bytes.length);
|
||||||
return dotProduct(bytes);
|
return dotProduct(bytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -60,13 +65,13 @@ public interface DenseVector {
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
default double l1Norm(Object queryVector) {
|
default double l1Norm(Object queryVector) {
|
||||||
if (queryVector instanceof float[] floats) {
|
if (queryVector instanceof float[] floats) {
|
||||||
checkDimensions(getDims(), floats.length);
|
checkDimensions(floats.length);
|
||||||
return l1Norm(floats);
|
return l1Norm(floats);
|
||||||
} else if (queryVector instanceof List<?> list) {
|
} else if (queryVector instanceof List<?> list) {
|
||||||
checkDimensions(getDims(), list.size());
|
checkDimensions(list.size());
|
||||||
return l1Norm((List<Number>) list);
|
return l1Norm((List<Number>) list);
|
||||||
} else if (queryVector instanceof byte[] bytes) {
|
} else if (queryVector instanceof byte[] bytes) {
|
||||||
checkDimensions(getDims(), bytes.length);
|
checkDimensions(bytes.length);
|
||||||
return l1Norm(bytes);
|
return l1Norm(bytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,11 +85,11 @@ public interface DenseVector {
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
default int hamming(Object queryVector) {
|
default int hamming(Object queryVector) {
|
||||||
if (queryVector instanceof List<?> list) {
|
if (queryVector instanceof List<?> list) {
|
||||||
checkDimensions(getDims(), list.size());
|
checkDimensions(list.size());
|
||||||
return hamming((List<Number>) list);
|
return hamming((List<Number>) list);
|
||||||
}
|
}
|
||||||
if (queryVector instanceof byte[] bytes) {
|
if (queryVector instanceof byte[] bytes) {
|
||||||
checkDimensions(getDims(), bytes.length);
|
checkDimensions(bytes.length);
|
||||||
return hamming(bytes);
|
return hamming(bytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -100,13 +105,13 @@ public interface DenseVector {
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
default double l2Norm(Object queryVector) {
|
default double l2Norm(Object queryVector) {
|
||||||
if (queryVector instanceof float[] floats) {
|
if (queryVector instanceof float[] floats) {
|
||||||
checkDimensions(getDims(), floats.length);
|
checkDimensions(floats.length);
|
||||||
return l2Norm(floats);
|
return l2Norm(floats);
|
||||||
} else if (queryVector instanceof List<?> list) {
|
} else if (queryVector instanceof List<?> list) {
|
||||||
checkDimensions(getDims(), list.size());
|
checkDimensions(list.size());
|
||||||
return l2Norm((List<Number>) list);
|
return l2Norm((List<Number>) list);
|
||||||
} else if (queryVector instanceof byte[] bytes) {
|
} else if (queryVector instanceof byte[] bytes) {
|
||||||
checkDimensions(getDims(), bytes.length);
|
checkDimensions(bytes.length);
|
||||||
return l2Norm(bytes);
|
return l2Norm(bytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -150,13 +155,13 @@ public interface DenseVector {
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
default double cosineSimilarity(Object queryVector) {
|
default double cosineSimilarity(Object queryVector) {
|
||||||
if (queryVector instanceof float[] floats) {
|
if (queryVector instanceof float[] floats) {
|
||||||
checkDimensions(getDims(), floats.length);
|
checkDimensions(floats.length);
|
||||||
return cosineSimilarity(floats);
|
return cosineSimilarity(floats);
|
||||||
} else if (queryVector instanceof List<?> list) {
|
} else if (queryVector instanceof List<?> list) {
|
||||||
checkDimensions(getDims(), list.size());
|
checkDimensions(list.size());
|
||||||
return cosineSimilarity((List<Number>) list);
|
return cosineSimilarity((List<Number>) list);
|
||||||
} else if (queryVector instanceof byte[] bytes) {
|
} else if (queryVector instanceof byte[] bytes) {
|
||||||
checkDimensions(getDims(), bytes.length);
|
checkDimensions(bytes.length);
|
||||||
return cosineSimilarity(bytes);
|
return cosineSimilarity(bytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -184,6 +189,20 @@ public interface DenseVector {
|
||||||
return (float) Math.sqrt(mag);
|
return (float) Math.sqrt(mag);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static float getBitMagnitude(byte[] vector, int dims) {
|
||||||
|
int count = 0;
|
||||||
|
int i = 0;
|
||||||
|
for (int upperBound = dims & -8; i < upperBound; i += 8) {
|
||||||
|
count += Long.bitCount((long) BitUtil.VH_NATIVE_LONG.get(vector, i));
|
||||||
|
}
|
||||||
|
|
||||||
|
while (i < dims) {
|
||||||
|
count += Integer.bitCount(vector[i] & 255);
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
return (float) Math.sqrt(count);
|
||||||
|
}
|
||||||
|
|
||||||
static float getMagnitude(float[] vector) {
|
static float getMagnitude(float[] vector) {
|
||||||
return (float) Math.sqrt(VectorUtil.dotProduct(vector, vector));
|
return (float) Math.sqrt(VectorUtil.dotProduct(vector, vector));
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
org.elasticsearch.index.codec.vectors.ES813FlatVectorFormat
|
org.elasticsearch.index.codec.vectors.ES813FlatVectorFormat
|
||||||
org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat
|
org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat
|
||||||
org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat
|
org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat
|
||||||
|
org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat
|
||||||
|
org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat
|
||||||
|
|
|
@ -0,0 +1,149 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0 and the Server Side Public License, v 1; you may not use this file except
|
||||||
|
* in compliance with, at your election, the Elastic License 2.0 or the Server
|
||||||
|
* Side Public License, v 1.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.index.codec.vectors;
|
||||||
|
|
||||||
|
import org.apache.lucene.document.Document;
|
||||||
|
import org.apache.lucene.document.Field;
|
||||||
|
import org.apache.lucene.document.KnnByteVectorField;
|
||||||
|
import org.apache.lucene.document.NumericDocValuesField;
|
||||||
|
import org.apache.lucene.document.StringField;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
|
import org.apache.lucene.index.DirectoryReader;
|
||||||
|
import org.apache.lucene.index.IndexReader;
|
||||||
|
import org.apache.lucene.index.IndexWriter;
|
||||||
|
import org.apache.lucene.index.IndexWriterConfig;
|
||||||
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
|
import org.apache.lucene.index.StoredFields;
|
||||||
|
import org.apache.lucene.index.Term;
|
||||||
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.apache.lucene.search.Sort;
|
||||||
|
import org.apache.lucene.search.SortField;
|
||||||
|
import org.apache.lucene.store.Directory;
|
||||||
|
import org.apache.lucene.tests.index.BaseIndexFileFormatTestCase;
|
||||||
|
import org.elasticsearch.common.logging.LogConfigurator;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
|
||||||
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
|
||||||
|
abstract class BaseKnnBitVectorsFormatTestCase extends BaseIndexFileFormatTestCase {
|
||||||
|
|
||||||
|
static {
|
||||||
|
LogConfigurator.loadLog4jPlugins();
|
||||||
|
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void addRandomFields(Document doc) {
|
||||||
|
doc.add(new KnnByteVectorField("v2", randomVector(30), similarityFunction));
|
||||||
|
}
|
||||||
|
|
||||||
|
protected VectorSimilarityFunction similarityFunction;
|
||||||
|
|
||||||
|
protected VectorSimilarityFunction randomSimilarity() {
|
||||||
|
return VectorSimilarityFunction.values()[random().nextInt(VectorSimilarityFunction.values().length)];
|
||||||
|
}
|
||||||
|
|
||||||
|
byte[] randomVector(int dims) {
|
||||||
|
byte[] vector = new byte[dims];
|
||||||
|
random().nextBytes(vector);
|
||||||
|
return vector;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testRandom() throws Exception {
|
||||||
|
IndexWriterConfig iwc = newIndexWriterConfig();
|
||||||
|
if (random().nextBoolean()) {
|
||||||
|
iwc.setIndexSort(new Sort(new SortField("sortkey", SortField.Type.INT)));
|
||||||
|
}
|
||||||
|
String fieldName = "field";
|
||||||
|
try (Directory dir = newDirectory(); IndexWriter iw = new IndexWriter(dir, iwc)) {
|
||||||
|
int numDoc = atLeast(100);
|
||||||
|
int dimension = atLeast(10);
|
||||||
|
if (dimension % 2 != 0) {
|
||||||
|
dimension++;
|
||||||
|
}
|
||||||
|
byte[] scratch = new byte[dimension];
|
||||||
|
int numValues = 0;
|
||||||
|
byte[][] values = new byte[numDoc][];
|
||||||
|
for (int i = 0; i < numDoc; i++) {
|
||||||
|
if (random().nextInt(7) != 3) {
|
||||||
|
// usually index a vector value for a doc
|
||||||
|
values[i] = randomVector(dimension);
|
||||||
|
++numValues;
|
||||||
|
}
|
||||||
|
if (random().nextBoolean() && values[i] != null) {
|
||||||
|
// sometimes use a shared scratch array
|
||||||
|
System.arraycopy(values[i], 0, scratch, 0, scratch.length);
|
||||||
|
add(iw, fieldName, i, scratch, similarityFunction);
|
||||||
|
} else {
|
||||||
|
add(iw, fieldName, i, values[i], similarityFunction);
|
||||||
|
}
|
||||||
|
if (random().nextInt(10) == 2) {
|
||||||
|
// sometimes delete a random document
|
||||||
|
int idToDelete = random().nextInt(i + 1);
|
||||||
|
iw.deleteDocuments(new Term("id", Integer.toString(idToDelete)));
|
||||||
|
// and remember that it was deleted
|
||||||
|
if (values[idToDelete] != null) {
|
||||||
|
values[idToDelete] = null;
|
||||||
|
--numValues;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (random().nextInt(10) == 3) {
|
||||||
|
iw.commit();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int numDeletes = 0;
|
||||||
|
try (IndexReader reader = DirectoryReader.open(iw)) {
|
||||||
|
int valueCount = 0, totalSize = 0;
|
||||||
|
for (LeafReaderContext ctx : reader.leaves()) {
|
||||||
|
ByteVectorValues vectorValues = ctx.reader().getByteVectorValues(fieldName);
|
||||||
|
if (vectorValues == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
totalSize += vectorValues.size();
|
||||||
|
StoredFields storedFields = ctx.reader().storedFields();
|
||||||
|
int docId;
|
||||||
|
while ((docId = vectorValues.nextDoc()) != NO_MORE_DOCS) {
|
||||||
|
byte[] v = vectorValues.vectorValue();
|
||||||
|
assertEquals(dimension, v.length);
|
||||||
|
String idString = storedFields.document(docId).getField("id").stringValue();
|
||||||
|
int id = Integer.parseInt(idString);
|
||||||
|
if (ctx.reader().getLiveDocs() == null || ctx.reader().getLiveDocs().get(docId)) {
|
||||||
|
assertArrayEquals(idString, values[id], v);
|
||||||
|
++valueCount;
|
||||||
|
} else {
|
||||||
|
++numDeletes;
|
||||||
|
assertNull(values[id]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assertEquals(numValues, valueCount);
|
||||||
|
assertEquals(numValues, totalSize - numDeletes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void add(IndexWriter iw, String field, int id, byte[] vector, VectorSimilarityFunction similarity) throws IOException {
|
||||||
|
add(iw, field, id, random().nextInt(100), vector, similarity);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void add(IndexWriter iw, String field, int id, int sortKey, byte[] vector, VectorSimilarityFunction similarityFunction)
|
||||||
|
throws IOException {
|
||||||
|
Document doc = new Document();
|
||||||
|
if (vector != null) {
|
||||||
|
doc.add(new KnnByteVectorField(field, vector, similarityFunction));
|
||||||
|
}
|
||||||
|
doc.add(new NumericDocValuesField("sortkey", sortKey));
|
||||||
|
String idString = Integer.toString(id);
|
||||||
|
doc.add(new StringField("id", idString, Field.Store.YES));
|
||||||
|
Term idTerm = new Term("id", idString);
|
||||||
|
iw.updateDocument(idTerm, doc);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -12,8 +12,15 @@ import org.apache.lucene.codecs.Codec;
|
||||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||||
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
|
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
|
||||||
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
|
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
|
||||||
|
import org.elasticsearch.common.logging.LogConfigurator;
|
||||||
|
|
||||||
public class ES813FlatVectorFormatTests extends BaseKnnVectorsFormatTestCase {
|
public class ES813FlatVectorFormatTests extends BaseKnnVectorsFormatTestCase {
|
||||||
|
|
||||||
|
static {
|
||||||
|
LogConfigurator.loadLog4jPlugins();
|
||||||
|
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Codec getCodec() {
|
protected Codec getCodec() {
|
||||||
return new Lucene99Codec() {
|
return new Lucene99Codec() {
|
||||||
|
|
|
@ -12,8 +12,15 @@ import org.apache.lucene.codecs.Codec;
|
||||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||||
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
|
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
|
||||||
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
|
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
|
||||||
|
import org.elasticsearch.common.logging.LogConfigurator;
|
||||||
|
|
||||||
public class ES813Int8FlatVectorFormatTests extends BaseKnnVectorsFormatTestCase {
|
public class ES813Int8FlatVectorFormatTests extends BaseKnnVectorsFormatTestCase {
|
||||||
|
|
||||||
|
static {
|
||||||
|
LogConfigurator.loadLog4jPlugins();
|
||||||
|
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Codec getCodec() {
|
protected Codec getCodec() {
|
||||||
return new Lucene99Codec() {
|
return new Lucene99Codec() {
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0 and the Server Side Public License, v 1; you may not use this file except
|
||||||
|
* in compliance with, at your election, the Elastic License 2.0 or the Server
|
||||||
|
* Side Public License, v 1.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.index.codec.vectors;
|
||||||
|
|
||||||
|
import org.apache.lucene.codecs.Codec;
|
||||||
|
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||||
|
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
|
||||||
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.junit.Before;
|
||||||
|
|
||||||
|
public class ES815BitFlatVectorFormatTests extends BaseKnnBitVectorsFormatTestCase {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Codec getCodec() {
|
||||||
|
return new Lucene99Codec() {
|
||||||
|
@Override
|
||||||
|
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||||
|
return new ES815BitFlatVectorFormat();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void init() {
|
||||||
|
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,33 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0 and the Server Side Public License, v 1; you may not use this file except
|
||||||
|
* in compliance with, at your election, the Elastic License 2.0 or the Server
|
||||||
|
* Side Public License, v 1.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.index.codec.vectors;
|
||||||
|
|
||||||
|
import org.apache.lucene.codecs.Codec;
|
||||||
|
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||||
|
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
|
||||||
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.junit.Before;
|
||||||
|
|
||||||
|
public class ES815HnswBitVectorsFormatTests extends BaseKnnBitVectorsFormatTestCase {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Codec getCodec() {
|
||||||
|
return new Lucene99Codec() {
|
||||||
|
@Override
|
||||||
|
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||||
|
return new ES815HnswBitVectorsFormat();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void init() {
|
||||||
|
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||||
|
}
|
||||||
|
}
|
|
@ -236,8 +236,8 @@ public class BinaryDenseVectorScriptDocValuesTests extends ESTestCase {
|
||||||
|
|
||||||
public static BytesRef mockEncodeDenseVector(float[] values, ElementType elementType, IndexVersion indexVersion) {
|
public static BytesRef mockEncodeDenseVector(float[] values, ElementType elementType, IndexVersion indexVersion) {
|
||||||
int numBytes = indexVersion.onOrAfter(DenseVectorFieldMapper.MAGNITUDE_STORED_INDEX_VERSION)
|
int numBytes = indexVersion.onOrAfter(DenseVectorFieldMapper.MAGNITUDE_STORED_INDEX_VERSION)
|
||||||
? elementType.elementBytes * values.length + DenseVectorFieldMapper.MAGNITUDE_BYTES
|
? elementType.getNumBytes(values.length) + DenseVectorFieldMapper.MAGNITUDE_BYTES
|
||||||
: elementType.elementBytes * values.length;
|
: elementType.getNumBytes(values.length);
|
||||||
double dotProduct = 0f;
|
double dotProduct = 0f;
|
||||||
ByteBuffer byteBuffer = elementType.createByteBuffer(indexVersion, numBytes);
|
ByteBuffer byteBuffer = elementType.createByteBuffer(indexVersion, numBytes);
|
||||||
for (float value : values) {
|
for (float value : values) {
|
||||||
|
|
|
@ -71,11 +71,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
||||||
private final ElementType elementType;
|
private final ElementType elementType;
|
||||||
private final boolean indexed;
|
private final boolean indexed;
|
||||||
private final boolean indexOptionsSet;
|
private final boolean indexOptionsSet;
|
||||||
|
private final int dims;
|
||||||
|
|
||||||
public DenseVectorFieldMapperTests() {
|
public DenseVectorFieldMapperTests() {
|
||||||
this.elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT);
|
this.elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT, ElementType.BIT);
|
||||||
this.indexed = randomBoolean();
|
this.indexed = randomBoolean();
|
||||||
this.indexOptionsSet = this.indexed && randomBoolean();
|
this.indexOptionsSet = this.indexed && randomBoolean();
|
||||||
|
this.dims = ElementType.BIT == elementType ? 4 * Byte.SIZE : 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -89,7 +91,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
private void indexMapping(XContentBuilder b, IndexVersion indexVersion) throws IOException {
|
private void indexMapping(XContentBuilder b, IndexVersion indexVersion) throws IOException {
|
||||||
b.field("type", "dense_vector").field("dims", 4);
|
b.field("type", "dense_vector").field("dims", dims);
|
||||||
if (elementType != ElementType.FLOAT) {
|
if (elementType != ElementType.FLOAT) {
|
||||||
b.field("element_type", elementType.toString());
|
b.field("element_type", elementType.toString());
|
||||||
}
|
}
|
||||||
|
@ -108,7 +110,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
||||||
b.endObject();
|
b.endObject();
|
||||||
}
|
}
|
||||||
if (indexed) {
|
if (indexed) {
|
||||||
b.field("similarity", "dot_product");
|
b.field("similarity", elementType == ElementType.BIT ? "l2_norm" : "dot_product");
|
||||||
if (indexOptionsSet) {
|
if (indexOptionsSet) {
|
||||||
b.startObject("index_options");
|
b.startObject("index_options");
|
||||||
b.field("type", "hnsw");
|
b.field("type", "hnsw");
|
||||||
|
@ -121,52 +123,86 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Object getSampleValueForDocument() {
|
protected Object getSampleValueForDocument() {
|
||||||
return elementType == ElementType.BYTE ? List.of((byte) 1, (byte) 1, (byte) 1, (byte) 1) : List.of(0.5, 0.5, 0.5, 0.5);
|
return elementType == ElementType.FLOAT ? List.of(0.5, 0.5, 0.5, 0.5) : List.of((byte) 1, (byte) 1, (byte) 1, (byte) 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void registerParameters(ParameterChecker checker) throws IOException {
|
protected void registerParameters(ParameterChecker checker) throws IOException {
|
||||||
checker.registerConflictCheck(
|
checker.registerConflictCheck(
|
||||||
"dims",
|
"dims",
|
||||||
fieldMapping(b -> b.field("type", "dense_vector").field("dims", 4)),
|
fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims)),
|
||||||
fieldMapping(b -> b.field("type", "dense_vector").field("dims", 5))
|
fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims + 8))
|
||||||
);
|
);
|
||||||
checker.registerConflictCheck(
|
checker.registerConflictCheck(
|
||||||
"similarity",
|
"similarity",
|
||||||
fieldMapping(b -> b.field("type", "dense_vector").field("dims", 4).field("index", true).field("similarity", "dot_product")),
|
fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims).field("index", true).field("similarity", "dot_product")),
|
||||||
fieldMapping(b -> b.field("type", "dense_vector").field("dims", 4).field("index", true).field("similarity", "l2_norm"))
|
fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims).field("index", true).field("similarity", "l2_norm"))
|
||||||
);
|
);
|
||||||
checker.registerConflictCheck(
|
checker.registerConflictCheck(
|
||||||
"index",
|
"index",
|
||||||
fieldMapping(b -> b.field("type", "dense_vector").field("dims", 4).field("index", true).field("similarity", "dot_product")),
|
fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims).field("index", true).field("similarity", "dot_product")),
|
||||||
fieldMapping(b -> b.field("type", "dense_vector").field("dims", 4).field("index", false))
|
fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims).field("index", false))
|
||||||
);
|
);
|
||||||
checker.registerConflictCheck(
|
checker.registerConflictCheck(
|
||||||
"element_type",
|
"element_type",
|
||||||
fieldMapping(
|
fieldMapping(
|
||||||
b -> b.field("type", "dense_vector")
|
b -> b.field("type", "dense_vector")
|
||||||
.field("dims", 4)
|
.field("dims", dims)
|
||||||
.field("index", true)
|
.field("index", true)
|
||||||
.field("similarity", "dot_product")
|
.field("similarity", "dot_product")
|
||||||
.field("element_type", "byte")
|
.field("element_type", "byte")
|
||||||
),
|
),
|
||||||
fieldMapping(
|
fieldMapping(
|
||||||
b -> b.field("type", "dense_vector")
|
b -> b.field("type", "dense_vector")
|
||||||
.field("dims", 4)
|
.field("dims", dims)
|
||||||
.field("index", true)
|
.field("index", true)
|
||||||
.field("similarity", "dot_product")
|
.field("similarity", "dot_product")
|
||||||
.field("element_type", "float")
|
.field("element_type", "float")
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
checker.registerConflictCheck(
|
||||||
|
"element_type",
|
||||||
|
fieldMapping(
|
||||||
|
b -> b.field("type", "dense_vector")
|
||||||
|
.field("dims", dims)
|
||||||
|
.field("index", true)
|
||||||
|
.field("similarity", "l2_norm")
|
||||||
|
.field("element_type", "float")
|
||||||
|
),
|
||||||
|
fieldMapping(
|
||||||
|
b -> b.field("type", "dense_vector")
|
||||||
|
.field("dims", dims)
|
||||||
|
.field("index", true)
|
||||||
|
.field("similarity", "l2_norm")
|
||||||
|
.field("element_type", "bit")
|
||||||
|
)
|
||||||
|
);
|
||||||
|
checker.registerConflictCheck(
|
||||||
|
"element_type",
|
||||||
|
fieldMapping(
|
||||||
|
b -> b.field("type", "dense_vector")
|
||||||
|
.field("dims", dims)
|
||||||
|
.field("index", true)
|
||||||
|
.field("similarity", "l2_norm")
|
||||||
|
.field("element_type", "byte")
|
||||||
|
),
|
||||||
|
fieldMapping(
|
||||||
|
b -> b.field("type", "dense_vector")
|
||||||
|
.field("dims", dims)
|
||||||
|
.field("index", true)
|
||||||
|
.field("similarity", "l2_norm")
|
||||||
|
.field("element_type", "bit")
|
||||||
|
)
|
||||||
|
);
|
||||||
checker.registerUpdateCheck(
|
checker.registerUpdateCheck(
|
||||||
b -> b.field("type", "dense_vector")
|
b -> b.field("type", "dense_vector")
|
||||||
.field("dims", 4)
|
.field("dims", dims)
|
||||||
.field("index", true)
|
.field("index", true)
|
||||||
.startObject("index_options")
|
.startObject("index_options")
|
||||||
.field("type", "flat")
|
.field("type", "flat")
|
||||||
.endObject(),
|
.endObject(),
|
||||||
b -> b.field("type", "dense_vector")
|
b -> b.field("type", "dense_vector")
|
||||||
.field("dims", 4)
|
.field("dims", dims)
|
||||||
.field("index", true)
|
.field("index", true)
|
||||||
.startObject("index_options")
|
.startObject("index_options")
|
||||||
.field("type", "int8_flat")
|
.field("type", "int8_flat")
|
||||||
|
@ -175,13 +211,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
||||||
);
|
);
|
||||||
checker.registerUpdateCheck(
|
checker.registerUpdateCheck(
|
||||||
b -> b.field("type", "dense_vector")
|
b -> b.field("type", "dense_vector")
|
||||||
.field("dims", 4)
|
.field("dims", dims)
|
||||||
.field("index", true)
|
.field("index", true)
|
||||||
.startObject("index_options")
|
.startObject("index_options")
|
||||||
.field("type", "flat")
|
.field("type", "flat")
|
||||||
.endObject(),
|
.endObject(),
|
||||||
b -> b.field("type", "dense_vector")
|
b -> b.field("type", "dense_vector")
|
||||||
.field("dims", 4)
|
.field("dims", dims)
|
||||||
.field("index", true)
|
.field("index", true)
|
||||||
.startObject("index_options")
|
.startObject("index_options")
|
||||||
.field("type", "hnsw")
|
.field("type", "hnsw")
|
||||||
|
@ -190,13 +226,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
||||||
);
|
);
|
||||||
checker.registerUpdateCheck(
|
checker.registerUpdateCheck(
|
||||||
b -> b.field("type", "dense_vector")
|
b -> b.field("type", "dense_vector")
|
||||||
.field("dims", 4)
|
.field("dims", dims)
|
||||||
.field("index", true)
|
.field("index", true)
|
||||||
.startObject("index_options")
|
.startObject("index_options")
|
||||||
.field("type", "flat")
|
.field("type", "flat")
|
||||||
.endObject(),
|
.endObject(),
|
||||||
b -> b.field("type", "dense_vector")
|
b -> b.field("type", "dense_vector")
|
||||||
.field("dims", 4)
|
.field("dims", dims)
|
||||||
.field("index", true)
|
.field("index", true)
|
||||||
.startObject("index_options")
|
.startObject("index_options")
|
||||||
.field("type", "int8_hnsw")
|
.field("type", "int8_hnsw")
|
||||||
|
@ -205,13 +241,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
||||||
);
|
);
|
||||||
checker.registerUpdateCheck(
|
checker.registerUpdateCheck(
|
||||||
b -> b.field("type", "dense_vector")
|
b -> b.field("type", "dense_vector")
|
||||||
.field("dims", 4)
|
.field("dims", dims)
|
||||||
.field("index", true)
|
.field("index", true)
|
||||||
.startObject("index_options")
|
.startObject("index_options")
|
||||||
.field("type", "int8_flat")
|
.field("type", "int8_flat")
|
||||||
.endObject(),
|
.endObject(),
|
||||||
b -> b.field("type", "dense_vector")
|
b -> b.field("type", "dense_vector")
|
||||||
.field("dims", 4)
|
.field("dims", dims)
|
||||||
.field("index", true)
|
.field("index", true)
|
||||||
.startObject("index_options")
|
.startObject("index_options")
|
||||||
.field("type", "hnsw")
|
.field("type", "hnsw")
|
||||||
|
@ -220,13 +256,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
||||||
);
|
);
|
||||||
checker.registerUpdateCheck(
|
checker.registerUpdateCheck(
|
||||||
b -> b.field("type", "dense_vector")
|
b -> b.field("type", "dense_vector")
|
||||||
.field("dims", 4)
|
.field("dims", dims)
|
||||||
.field("index", true)
|
.field("index", true)
|
||||||
.startObject("index_options")
|
.startObject("index_options")
|
||||||
.field("type", "int8_flat")
|
.field("type", "int8_flat")
|
||||||
.endObject(),
|
.endObject(),
|
||||||
b -> b.field("type", "dense_vector")
|
b -> b.field("type", "dense_vector")
|
||||||
.field("dims", 4)
|
.field("dims", dims)
|
||||||
.field("index", true)
|
.field("index", true)
|
||||||
.startObject("index_options")
|
.startObject("index_options")
|
||||||
.field("type", "int8_hnsw")
|
.field("type", "int8_hnsw")
|
||||||
|
@ -235,13 +271,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
||||||
);
|
);
|
||||||
checker.registerUpdateCheck(
|
checker.registerUpdateCheck(
|
||||||
b -> b.field("type", "dense_vector")
|
b -> b.field("type", "dense_vector")
|
||||||
.field("dims", 4)
|
.field("dims", dims)
|
||||||
.field("index", true)
|
.field("index", true)
|
||||||
.startObject("index_options")
|
.startObject("index_options")
|
||||||
.field("type", "hnsw")
|
.field("type", "hnsw")
|
||||||
.endObject(),
|
.endObject(),
|
||||||
b -> b.field("type", "dense_vector")
|
b -> b.field("type", "dense_vector")
|
||||||
.field("dims", 4)
|
.field("dims", dims)
|
||||||
.field("index", true)
|
.field("index", true)
|
||||||
.startObject("index_options")
|
.startObject("index_options")
|
||||||
.field("type", "int8_hnsw")
|
.field("type", "int8_hnsw")
|
||||||
|
@ -252,7 +288,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
||||||
"index_options",
|
"index_options",
|
||||||
fieldMapping(
|
fieldMapping(
|
||||||
b -> b.field("type", "dense_vector")
|
b -> b.field("type", "dense_vector")
|
||||||
.field("dims", 4)
|
.field("dims", dims)
|
||||||
.field("index", true)
|
.field("index", true)
|
||||||
.startObject("index_options")
|
.startObject("index_options")
|
||||||
.field("type", "hnsw")
|
.field("type", "hnsw")
|
||||||
|
@ -260,7 +296,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
||||||
),
|
),
|
||||||
fieldMapping(
|
fieldMapping(
|
||||||
b -> b.field("type", "dense_vector")
|
b -> b.field("type", "dense_vector")
|
||||||
.field("dims", 4)
|
.field("dims", dims)
|
||||||
.field("index", true)
|
.field("index", true)
|
||||||
.startObject("index_options")
|
.startObject("index_options")
|
||||||
.field("type", "flat")
|
.field("type", "flat")
|
||||||
|
@ -353,7 +389,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
||||||
mapping = mapping(b -> {
|
mapping = mapping(b -> {
|
||||||
b.startObject("field");
|
b.startObject("field");
|
||||||
b.field("type", "dense_vector")
|
b.field("type", "dense_vector")
|
||||||
.field("dims", 4)
|
.field("dims", dims)
|
||||||
.field("similarity", "cosine")
|
.field("similarity", "cosine")
|
||||||
.field("index", true)
|
.field("index", true)
|
||||||
.startObject("index_options")
|
.startObject("index_options")
|
||||||
|
@ -648,7 +684,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
||||||
() -> createDocumentMapper(
|
() -> createDocumentMapper(
|
||||||
fieldMapping(
|
fieldMapping(
|
||||||
b -> b.field("type", "dense_vector")
|
b -> b.field("type", "dense_vector")
|
||||||
.field("dims", 4)
|
.field("dims", dims)
|
||||||
.field("element_type", "byte")
|
.field("element_type", "byte")
|
||||||
.field("similarity", "l2_norm")
|
.field("similarity", "l2_norm")
|
||||||
.field("index", true)
|
.field("index", true)
|
||||||
|
@ -1020,6 +1056,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
||||||
}
|
}
|
||||||
yield floats;
|
yield floats;
|
||||||
}
|
}
|
||||||
|
case BIT -> randomByteArrayOfLength(vectorFieldType.getVectorDimensions() / 8);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1196,7 +1233,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
||||||
boolean setEfConstruction = randomBoolean();
|
boolean setEfConstruction = randomBoolean();
|
||||||
MapperService mapperService = createMapperService(fieldMapping(b -> {
|
MapperService mapperService = createMapperService(fieldMapping(b -> {
|
||||||
b.field("type", "dense_vector");
|
b.field("type", "dense_vector");
|
||||||
b.field("dims", 4);
|
b.field("dims", dims);
|
||||||
b.field("index", true);
|
b.field("index", true);
|
||||||
b.field("similarity", "dot_product");
|
b.field("similarity", "dot_product");
|
||||||
b.startObject("index_options");
|
b.startObject("index_options");
|
||||||
|
@ -1234,7 +1271,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
||||||
for (String quantizedFlatFormat : new String[] { "int8_flat", "int4_flat" }) {
|
for (String quantizedFlatFormat : new String[] { "int8_flat", "int4_flat" }) {
|
||||||
MapperService mapperService = createMapperService(fieldMapping(b -> {
|
MapperService mapperService = createMapperService(fieldMapping(b -> {
|
||||||
b.field("type", "dense_vector");
|
b.field("type", "dense_vector");
|
||||||
b.field("dims", 4);
|
b.field("dims", dims);
|
||||||
b.field("index", true);
|
b.field("index", true);
|
||||||
b.field("similarity", "dot_product");
|
b.field("similarity", "dot_product");
|
||||||
b.startObject("index_options");
|
b.startObject("index_options");
|
||||||
|
@ -1275,7 +1312,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
||||||
float confidenceInterval = (float) randomDoubleBetween(0.90f, 1.0f, true);
|
float confidenceInterval = (float) randomDoubleBetween(0.90f, 1.0f, true);
|
||||||
MapperService mapperService = createMapperService(fieldMapping(b -> {
|
MapperService mapperService = createMapperService(fieldMapping(b -> {
|
||||||
b.field("type", "dense_vector");
|
b.field("type", "dense_vector");
|
||||||
b.field("dims", 4);
|
b.field("dims", dims);
|
||||||
b.field("index", true);
|
b.field("index", true);
|
||||||
b.field("similarity", "dot_product");
|
b.field("similarity", "dot_product");
|
||||||
b.startObject("index_options");
|
b.startObject("index_options");
|
||||||
|
@ -1316,7 +1353,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
||||||
float confidenceInterval = (float) randomDoubleBetween(0.90f, 1.0f, true);
|
float confidenceInterval = (float) randomDoubleBetween(0.90f, 1.0f, true);
|
||||||
MapperService mapperService = createMapperService(fieldMapping(b -> {
|
MapperService mapperService = createMapperService(fieldMapping(b -> {
|
||||||
b.field("type", "dense_vector");
|
b.field("type", "dense_vector");
|
||||||
b.field("dims", 4);
|
b.field("dims", dims);
|
||||||
b.field("index", true);
|
b.field("index", true);
|
||||||
b.field("similarity", "dot_product");
|
b.field("similarity", "dot_product");
|
||||||
b.startObject("index_options");
|
b.startObject("index_options");
|
||||||
|
|
|
@ -185,10 +185,12 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
|
||||||
queryVector[i] = randomByte();
|
queryVector[i] = randomByte();
|
||||||
floatQueryVector[i] = queryVector[i];
|
floatQueryVector[i] = queryVector[i];
|
||||||
}
|
}
|
||||||
Query query = field.createKnnQuery(queryVector, 10, null, null, producer);
|
VectorData vectorData = new VectorData(null, queryVector);
|
||||||
|
Query query = field.createKnnQuery(vectorData, 10, null, null, producer);
|
||||||
assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class));
|
assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class));
|
||||||
|
|
||||||
query = field.createKnnQuery(floatQueryVector, 10, null, null, producer);
|
vectorData = new VectorData(floatQueryVector, null);
|
||||||
|
query = field.createKnnQuery(vectorData, 10, null, null, producer);
|
||||||
assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class));
|
assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -321,7 +323,8 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
|
||||||
for (int i = 0; i < 4096; i++) {
|
for (int i = 0; i < 4096; i++) {
|
||||||
queryVector[i] = randomByte();
|
queryVector[i] = randomByte();
|
||||||
}
|
}
|
||||||
Query query = fieldWith4096dims.createKnnQuery(queryVector, 10, null, null, null);
|
VectorData vectorData = new VectorData(null, queryVector);
|
||||||
|
Query query = fieldWith4096dims.createKnnQuery(vectorData, 10, null, null, null);
|
||||||
assertThat(query, instanceOf(KnnByteVectorQuery.class));
|
assertThat(query, instanceOf(KnnByteVectorQuery.class));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -359,7 +362,10 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
|
||||||
);
|
);
|
||||||
assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude."));
|
assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude."));
|
||||||
|
|
||||||
e = expectThrows(IllegalArgumentException.class, () -> cosineField.createKnnQuery(new byte[] { 0, 0, 0 }, 10, null, null, null));
|
e = expectThrows(
|
||||||
|
IllegalArgumentException.class,
|
||||||
|
() -> cosineField.createKnnQuery(new VectorData(null, new byte[] { 0, 0, 0 }), 10, null, null, null)
|
||||||
|
);
|
||||||
assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude."));
|
assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude."));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -114,10 +114,10 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
|
||||||
);
|
);
|
||||||
|
|
||||||
e = expectThrows(IllegalArgumentException.class, () -> new Hamming(scoreScript, queryVector, fieldName));
|
e = expectThrows(IllegalArgumentException.class, () -> new Hamming(scoreScript, queryVector, fieldName));
|
||||||
assertThat(e.getMessage(), containsString("hamming distance is only supported for byte vectors"));
|
assertThat(e.getMessage(), containsString("hamming distance is only supported for byte or bit vectors"));
|
||||||
|
|
||||||
e = expectThrows(IllegalArgumentException.class, () -> new Hamming(scoreScript, invalidQueryVector, fieldName));
|
e = expectThrows(IllegalArgumentException.class, () -> new Hamming(scoreScript, invalidQueryVector, fieldName));
|
||||||
assertThat(e.getMessage(), containsString("hamming distance is only supported for byte vectors"));
|
assertThat(e.getMessage(), containsString("hamming distance is only supported for byte or bit vectors"));
|
||||||
|
|
||||||
// Check scripting infrastructure integration
|
// Check scripting infrastructure integration
|
||||||
DotProduct dotProduct = new DotProduct(scoreScript, queryVector, fieldName);
|
DotProduct dotProduct = new DotProduct(scoreScript, queryVector, fieldName);
|
||||||
|
|
|
@ -122,7 +122,7 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
|
||||||
Query filterQuery = booleanQuery.clauses().isEmpty() ? null : booleanQuery;
|
Query filterQuery = booleanQuery.clauses().isEmpty() ? null : booleanQuery;
|
||||||
// The field should always be resolved to the concrete field
|
// The field should always be resolved to the concrete field
|
||||||
Query knnVectorQueryBuilt = switch (elementType()) {
|
Query knnVectorQueryBuilt = switch (elementType()) {
|
||||||
case BYTE -> new ESKnnByteVectorQuery(
|
case BYTE, BIT -> new ESKnnByteVectorQuery(
|
||||||
VECTOR_FIELD,
|
VECTOR_FIELD,
|
||||||
queryBuilder.queryVector().asByteVector(),
|
queryBuilder.queryVector().asByteVector(),
|
||||||
queryBuilder.numCands(),
|
queryBuilder.numCands(),
|
||||||
|
@ -145,7 +145,10 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
|
||||||
SearchExecutionContext context = createSearchExecutionContext();
|
SearchExecutionContext context = createSearchExecutionContext();
|
||||||
KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f }, 10, null);
|
KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f }, 10, null);
|
||||||
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context));
|
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context));
|
||||||
assertThat(e.getMessage(), containsString("the query vector has a different dimension [2] than the index vectors [3]"));
|
assertThat(
|
||||||
|
e.getMessage(),
|
||||||
|
containsString("The query vector has a different number of dimensions [2] than the document vectors [3]")
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testNonexistentField() {
|
public void testNonexistentField() {
|
||||||
|
|
|
@ -46,6 +46,7 @@ public class EmbeddingRequestChunker {
|
||||||
return switch (elementType) {
|
return switch (elementType) {
|
||||||
case BYTE -> EmbeddingType.BYTE;
|
case BYTE -> EmbeddingType.BYTE;
|
||||||
case FLOAT -> EmbeddingType.FLOAT;
|
case FLOAT -> EmbeddingType.FLOAT;
|
||||||
|
case BIT -> throw new IllegalArgumentException("Bit vectors are not supported");
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue