mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 09:28:55 -04:00
Adds new bit
element_type for dense_vectors (#110059)
This commit adds `bit` vector support by adding `element_type: bit` for vectors. This new element type works for indexed and non-indexed vectors. Additionally, it works with `hnsw` and `flat` index types. No quantization based codec works with this element type, this is consistent with `byte` vectors. `bit` vectors accept up to `32768` dimensions in size and expect vectors that are being indexed to be encoded either as a hexidecimal string or a `byte[]` array where each element of the `byte` array represents `8` bits of the vector. `bit` vectors support script usage and regular query usage. When indexed, all comparisons done are `xor` and `popcount` summations (aka, hamming distance), and the scores are transformed and normalized given the vector dimensions. Note, indexed bit vectors require `l2_norm` to be the similarity. For scripts, `l1norm` is the same as `hamming` distance and `l2norm` is `sqrt(l1norm)`. `dotProduct` and `cosineSimilarity` are not supported. Note, the dimensions expected by this element_type are always to be divisible by `8`, and the `byte[]` vectors provided for index must be have size `dim/8` size, where each byte element represents `8` bits of the vectors. closes: https://github.com/elastic/elasticsearch/issues/48322
This commit is contained in:
parent
97651dfb9f
commit
5add44d7d1
38 changed files with 2713 additions and 187 deletions
32
docs/changelog/110059.yaml
Normal file
32
docs/changelog/110059.yaml
Normal file
|
@ -0,0 +1,32 @@
|
|||
pr: 110059
|
||||
summary: Adds new `bit` `element_type` for `dense_vectors`
|
||||
area: Vector Search
|
||||
type: feature
|
||||
issues: []
|
||||
highlight:
|
||||
title: Adds new `bit` `element_type` for `dense_vectors`
|
||||
body: |-
|
||||
This adds `bit` vector support by adding `element_type: bit` for
|
||||
vectors. This new element type works for indexed and non-indexed
|
||||
vectors. Additionally, it works with `hnsw` and `flat` index types. No
|
||||
quantization based codec works with this element type, this is
|
||||
consistent with `byte` vectors.
|
||||
|
||||
`bit` vectors accept up to `32768` dimensions in size and expect vectors
|
||||
that are being indexed to be encoded either as a hexidecimal string or a
|
||||
`byte[]` array where each element of the `byte` array represents `8`
|
||||
bits of the vector.
|
||||
|
||||
`bit` vectors support script usage and regular query usage. When
|
||||
indexed, all comparisons done are `xor` and `popcount` summations (aka,
|
||||
hamming distance), and the scores are transformed and normalized given
|
||||
the vector dimensions.
|
||||
|
||||
For scripts, `l1norm` is the same as `hamming` distance and `l2norm` is
|
||||
`sqrt(l1norm)`. `dotProduct` and `cosineSimilarity` are not supported.
|
||||
|
||||
Note, the dimensions expected by this element_type are always to be
|
||||
divisible by `8`, and the `byte[]` vectors provided for index must be
|
||||
have size `dim/8` size, where each byte element represents `8` bits of
|
||||
the vectors.
|
||||
notable: true
|
|
@ -183,11 +183,23 @@ The following mapping parameters are accepted:
|
|||
`element_type`::
|
||||
(Optional, string)
|
||||
The data type used to encode vectors. The supported data types are
|
||||
`float` (default) and `byte`. `float` indexes a 4-byte floating-point
|
||||
value per dimension. `byte` indexes a 1-byte integer value per dimension.
|
||||
Using `byte` can result in a substantially smaller index size with the
|
||||
trade off of lower precision. Vectors using `byte` require dimensions with
|
||||
integer values between -128 to 127, inclusive for both indexing and searching.
|
||||
`float` (default), `byte`, and bit.
|
||||
|
||||
.Valid values for `element_type`
|
||||
[%collapsible%open]
|
||||
====
|
||||
`float`:::
|
||||
indexes a 4-byte floating-point
|
||||
value per dimension. This is the default value.
|
||||
|
||||
`byte`:::
|
||||
indexes a 1-byte integer value per dimension.
|
||||
|
||||
`bit`:::
|
||||
indexes a single bit per dimension. Useful for very high-dimensional vectors or models that specifically support bit vectors.
|
||||
NOTE: when using `bit`, the number of dimensions must be a multiple of 8 and must represent the number of bits.
|
||||
|
||||
====
|
||||
|
||||
`dims`::
|
||||
(Optional, integer)
|
||||
|
@ -205,7 +217,11 @@ API>>. Defaults to `true`.
|
|||
The vector similarity metric to use in kNN search. Documents are ranked by
|
||||
their vector field's similarity to the query vector. The `_score` of each
|
||||
document will be derived from the similarity, in a way that ensures scores are
|
||||
positive and that a larger score corresponds to a higher ranking. Defaults to `cosine`.
|
||||
positive and that a larger score corresponds to a higher ranking.
|
||||
Defaults to `l2_norm` when `element_type: bit` otherwise defaults to `cosine`.
|
||||
|
||||
NOTE: `bit` vectors only support `l2_norm` as their similarity metric.
|
||||
|
||||
+
|
||||
^*^ This parameter can only be specified when `index` is `true`.
|
||||
+
|
||||
|
@ -217,6 +233,9 @@ Computes similarity based on the L^2^ distance (also known as Euclidean
|
|||
distance) between the vectors. The document `_score` is computed as
|
||||
`1 / (1 + l2_norm(query, vector)^2)`.
|
||||
|
||||
For `bit` vectors, instead of using `l2_norm`, the `hamming` distance between the vectors is used. The `_score`
|
||||
transformation is `(numBits - hamming(a, b)) / numBits`
|
||||
|
||||
`dot_product`:::
|
||||
Computes the dot product of two unit vectors. This option provides an optimized way
|
||||
to perform cosine similarity. The constraints and computed score are defined
|
||||
|
@ -320,3 +339,112 @@ any issues, but features in technical preview are not subject to the support SLA
|
|||
of official GA features.
|
||||
|
||||
`dense_vector` fields support <<synthetic-source,synthetic `_source`>> .
|
||||
|
||||
[[dense-vector-index-bit]]
|
||||
==== Indexing & Searching bit vectors
|
||||
|
||||
When using `element_type: bit`, this will treat all vectors as bit vectors. Bit vectors utilize only a single
|
||||
bit per dimension and are internally encoded as bytes. This can be useful for very high-dimensional vectors or models.
|
||||
|
||||
When using `bit`, the number of dimensions must be a multiple of 8 and must represent the number of bits. Additionally,
|
||||
with `bit` vectors, the typical vector similarity values are effectively all scored the same, e.g. with `hamming` distance.
|
||||
|
||||
Let's compare two `byte[]` arrays, each representing 40 individual bits.
|
||||
|
||||
`[-127, 0, 1, 42, 127]` in bits `1000000100000000000000010010101001111111`
|
||||
`[127, -127, 0, 1, 42]` in bits `0111111110000001000000000000000100101010`
|
||||
|
||||
When comparing these two bit, vectors, we first take the {wikipedia}/Hamming_distance[`hamming` distance].
|
||||
|
||||
`xor` result:
|
||||
```
|
||||
1000000100000000000000010010101001111111
|
||||
^
|
||||
0111111110000001000000000000000100101010
|
||||
=
|
||||
1111111010000001000000010010101101010101
|
||||
```
|
||||
|
||||
Then, we gather the count of `1` bits in the `xor` result: `18`. To scale for scoring, we subtract from the total number
|
||||
of bits and divide by the total number of bits: `(40 - 18) / 40 = 0.55`. This would be the `_score` betwee these two
|
||||
vectors.
|
||||
|
||||
Here is an example of indexing and searching bit vectors:
|
||||
|
||||
[source,console]
|
||||
--------------------------------------------------
|
||||
PUT my-bit-vectors
|
||||
{
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"my_vector": {
|
||||
"type": "dense_vector",
|
||||
"dims": 40, <1>
|
||||
"element_type": "bit"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
--------------------------------------------------
|
||||
<1> The number of dimensions that represents the number of bits
|
||||
|
||||
[source,console]
|
||||
--------------------------------------------------
|
||||
POST /my-bit-vectors/_bulk?refresh
|
||||
{"index": {"_id" : "1"}}
|
||||
{"my_vector": [127, -127, 0, 1, 42]} <1>
|
||||
{"index": {"_id" : "2"}}
|
||||
{"my_vector": "8100012a7f"} <2>
|
||||
--------------------------------------------------
|
||||
// TEST[continued]
|
||||
<1> 5 bytes representing the 40 bit dimensioned vector
|
||||
<2> A hexidecimal string representing the 40 bit dimensioned vector
|
||||
|
||||
Then, when searching, you can use the `knn` query to search for similar bit vectors:
|
||||
|
||||
[source,console]
|
||||
--------------------------------------------------
|
||||
POST /my-bit-vectors/_search?filter_path=hits.hits
|
||||
{
|
||||
"query": {
|
||||
"knn": {
|
||||
"query_vector": [127, -127, 0, 1, 42],
|
||||
"field": "my_vector"
|
||||
}
|
||||
}
|
||||
}
|
||||
--------------------------------------------------
|
||||
// TEST[continued]
|
||||
|
||||
[source,console-result]
|
||||
----
|
||||
{
|
||||
"hits": {
|
||||
"hits": [
|
||||
{
|
||||
"_index": "my-bit-vectors",
|
||||
"_id": "1",
|
||||
"_score": 1.0,
|
||||
"_source": {
|
||||
"my_vector": [
|
||||
127,
|
||||
-127,
|
||||
0,
|
||||
1,
|
||||
42
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"_index": "my-bit-vectors",
|
||||
"_id": "2",
|
||||
"_score": 0.55,
|
||||
"_source": {
|
||||
"my_vector": "8100012a7f"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
----
|
||||
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
[role="xpack"]
|
||||
[[vector-functions]]
|
||||
===== Functions for vector fields
|
||||
|
||||
|
@ -17,6 +16,8 @@ This is the list of available vector functions and vector access methods:
|
|||
6. <<vector-functions-accessing-vectors,`doc[<field>].vectorValue`>> – returns a vector's value as an array of floats
|
||||
7. <<vector-functions-accessing-vectors,`doc[<field>].magnitude`>> – returns a vector's magnitude
|
||||
|
||||
NOTE: The `cosineSimilarity` and `dotProduct` functions are not supported for `bit` vectors.
|
||||
|
||||
NOTE: The recommended way to access dense vectors is through the
|
||||
`cosineSimilarity`, `dotProduct`, `l1norm` or `l2norm` functions. Please note
|
||||
however, that you should call these functions only once per script. For example,
|
||||
|
@ -193,7 +194,7 @@ we added `1` in the denominator.
|
|||
====== Hamming distance
|
||||
|
||||
The `hamming` function calculates {wikipedia}/Hamming_distance[Hamming distance] between a given query vector and
|
||||
document vectors. It is only available for byte vectors.
|
||||
document vectors. It is only available for byte and bit vectors.
|
||||
|
||||
[source,console]
|
||||
--------------------------------------------------
|
||||
|
@ -278,10 +279,14 @@ You can access vector values directly through the following functions:
|
|||
|
||||
- `doc[<field>].vectorValue` – returns a vector's value as an array of floats
|
||||
|
||||
NOTE: For `bit` vectors, it does return a `float[]`, where each element represents 8 bits.
|
||||
|
||||
- `doc[<field>].magnitude` – returns a vector's magnitude as a float
|
||||
(for vectors created prior to version 7.5 the magnitude is not stored.
|
||||
So this function calculates it anew every time it is called).
|
||||
|
||||
NOTE: For `bit` vectors, this is just the square root of the sum of `1` bits.
|
||||
|
||||
For example, the script below implements a cosine similarity using these
|
||||
two functions:
|
||||
|
||||
|
@ -319,3 +324,14 @@ GET my-index-000001/_search
|
|||
}
|
||||
}
|
||||
--------------------------------------------------
|
||||
[[vector-functions-bit-vectors]]
|
||||
====== Bit vectors and vector functions
|
||||
|
||||
When using `bit` vectors, not all the vector functions are available. The supported functions are:
|
||||
|
||||
* <<vector-functions-hamming,`hamming`>> – calculates Hamming distance, the sum of the bitwise XOR of the two vectors
|
||||
* <<vector-functions-l1,`l1norm`>> – calculates L^1^ distance, this is simply the `hamming` distance
|
||||
* <<vector-functions-l2,`l2norm`>> - calculates L^2^ distance, this is the square root of the `hamming` distance
|
||||
|
||||
Currently, the `cosineSimilarity` and `dotProduct` functions are not supported for `bit` vectors.
|
||||
|
||||
|
|
|
@ -0,0 +1,392 @@
|
|||
setup:
|
||||
- requires:
|
||||
cluster_features: ["mapper.vectors.bit_vectors"]
|
||||
reason: "support for bit vectors added in 8.15"
|
||||
test_runner_features: headers
|
||||
|
||||
- do:
|
||||
indices.create:
|
||||
index: test-index
|
||||
body:
|
||||
mappings:
|
||||
properties:
|
||||
vector:
|
||||
type: dense_vector
|
||||
index: false
|
||||
element_type: bit
|
||||
dims: 40
|
||||
indexed_vector:
|
||||
type: dense_vector
|
||||
element_type: bit
|
||||
dims: 40
|
||||
index: true
|
||||
similarity: l2_norm
|
||||
|
||||
- do:
|
||||
index:
|
||||
index: test-index
|
||||
id: "1"
|
||||
body:
|
||||
vector: [8, 5, -15, 1, -7]
|
||||
indexed_vector: [8, 5, -15, 1, -7]
|
||||
|
||||
- do:
|
||||
index:
|
||||
index: test-index
|
||||
id: "2"
|
||||
body:
|
||||
vector: [-1, 115, -3, 4, -128]
|
||||
indexed_vector: [-1, 115, -3, 4, -128]
|
||||
|
||||
- do:
|
||||
index:
|
||||
index: test-index
|
||||
id: "3"
|
||||
body:
|
||||
vector: [2, 18, -5, 0, -124]
|
||||
indexed_vector: [2, 18, -5, 0, -124]
|
||||
|
||||
- do:
|
||||
indices.refresh: {}
|
||||
|
||||
---
|
||||
"Test vector magnitude equality":
|
||||
- skip:
|
||||
features: close_to
|
||||
|
||||
- do:
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "doc['vector'].magnitude"
|
||||
|
||||
- match: {hits.total: 3}
|
||||
|
||||
- match: {hits.hits.0._id: "2"}
|
||||
- close_to: {hits.hits.0._score: {value: 4.690416, error: 0.01}}
|
||||
|
||||
- match: {hits.hits.1._id: "1"}
|
||||
- close_to: {hits.hits.1._score: {value: 3.8729835, error: 0.01}}
|
||||
|
||||
- match: {hits.hits.2._id: "3"}
|
||||
- close_to: {hits.hits.2._score: {value: 3.4641016, error: 0.01}}
|
||||
|
||||
- do:
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "doc['indexed_vector'].magnitude"
|
||||
|
||||
- match: {hits.total: 3}
|
||||
|
||||
- match: {hits.hits.0._id: "2"}
|
||||
- close_to: {hits.hits.0._score: {value: 4.690416, error: 0.01}}
|
||||
|
||||
- match: {hits.hits.1._id: "1"}
|
||||
- close_to: {hits.hits.1._score: {value: 3.8729835, error: 0.01}}
|
||||
|
||||
- match: {hits.hits.2._id: "3"}
|
||||
- close_to: {hits.hits.2._score: {value: 3.4641016, error: 0.01}}
|
||||
|
||||
---
|
||||
"Dot Product is not supported":
|
||||
- do:
|
||||
catch: bad_request
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "dotProduct(params.query_vector, 'vector')"
|
||||
params:
|
||||
query_vector: [0, 111, -13, 14, -124]
|
||||
- do:
|
||||
catch: bad_request
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "dotProduct(params.query_vector, 'vector')"
|
||||
params:
|
||||
query_vector: "006ff30e84"
|
||||
|
||||
---
|
||||
"Cosine Similarity is not supported":
|
||||
- do:
|
||||
catch: bad_request
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "cosineSimilarity(params.query_vector, 'vector')"
|
||||
params:
|
||||
query_vector: [0, 111, -13, 14, -124]
|
||||
- do:
|
||||
catch: bad_request
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "cosineSimilarity(params.query_vector, 'vector')"
|
||||
params:
|
||||
query_vector: "006ff30e84"
|
||||
|
||||
- do:
|
||||
catch: bad_request
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "cosineSimilarity(params.query_vector, 'indexed_vector')"
|
||||
params:
|
||||
query_vector: [0, 111, -13, 14, -124]
|
||||
---
|
||||
"L1 norm":
|
||||
- do:
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "l1norm(params.query_vector, 'vector')"
|
||||
params:
|
||||
query_vector: [0, 111, -13, 14, -124]
|
||||
|
||||
- match: {hits.total: 3}
|
||||
|
||||
- match: {hits.hits.0._id: "2"}
|
||||
- match: {hits.hits.0._score: 17.0}
|
||||
|
||||
- match: {hits.hits.1._id: "1"}
|
||||
- match: {hits.hits.1._score: 16.0}
|
||||
|
||||
- match: {hits.hits.2._id: "3"}
|
||||
- match: {hits.hits.2._score: 11.0}
|
||||
|
||||
---
|
||||
"L1 norm hexidecimal":
|
||||
- do:
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "l1norm(params.query_vector, 'vector')"
|
||||
params:
|
||||
query_vector: "006ff30e84"
|
||||
|
||||
- match: {hits.total: 3}
|
||||
|
||||
- match: {hits.hits.0._id: "2"}
|
||||
- match: {hits.hits.0._score: 17.0}
|
||||
|
||||
- match: {hits.hits.1._id: "1"}
|
||||
- match: {hits.hits.1._score: 16.0}
|
||||
|
||||
- match: {hits.hits.2._id: "3"}
|
||||
- match: {hits.hits.2._score: 11.0}
|
||||
---
|
||||
"L2 norm":
|
||||
- requires:
|
||||
test_runner_features: close_to
|
||||
- do:
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "l2norm(params.query_vector, 'vector')"
|
||||
params:
|
||||
query_vector: [0, 111, -13, 14, -124]
|
||||
|
||||
- match: {hits.total: 3}
|
||||
|
||||
- match: {hits.hits.0._id: "2"}
|
||||
- close_to: {hits.hits.0._score: {value: 4.123, error: 0.001}}
|
||||
|
||||
- match: {hits.hits.1._id: "1"}
|
||||
- close_to: {hits.hits.1._score: {value: 4, error: 0.001}}
|
||||
|
||||
- match: {hits.hits.2._id: "3"}
|
||||
- close_to: {hits.hits.2._score: {value: 3.316, error: 0.001}}
|
||||
---
|
||||
"L2 norm hexidecimal":
|
||||
- requires:
|
||||
test_runner_features: close_to
|
||||
|
||||
- do:
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "l2norm(params.query_vector, 'vector')"
|
||||
params:
|
||||
query_vector: "006ff30e84"
|
||||
|
||||
- match: {hits.total: 3}
|
||||
|
||||
- match: {hits.hits.0._id: "2"}
|
||||
- close_to: {hits.hits.0._score: {value: 4.123, error: 0.001}}
|
||||
|
||||
- match: {hits.hits.1._id: "1"}
|
||||
- close_to: {hits.hits.1._score: {value: 4, error: 0.001}}
|
||||
|
||||
- match: {hits.hits.2._id: "3"}
|
||||
- close_to: {hits.hits.2._score: {value: 3.316, error: 0.001}}
|
||||
---
|
||||
"Hamming distance":
|
||||
- do:
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "hamming(params.query_vector, 'vector')"
|
||||
params:
|
||||
query_vector: [0, 111, -13, 14, -124]
|
||||
|
||||
- match: {hits.total: 3}
|
||||
|
||||
- match: {hits.hits.0._id: "2"}
|
||||
- match: {hits.hits.0._score: 17.0}
|
||||
|
||||
- match: {hits.hits.1._id: "1"}
|
||||
- match: {hits.hits.1._score: 16.0}
|
||||
|
||||
- match: {hits.hits.2._id: "3"}
|
||||
- match: {hits.hits.2._score: 11.0}
|
||||
|
||||
|
||||
- do:
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "hamming(params.query_vector, 'indexed_vector')"
|
||||
params:
|
||||
query_vector: [0, 111, -13, 14, -124]
|
||||
|
||||
- match: {hits.total: 3}
|
||||
|
||||
- match: {hits.hits.0._id: "2"}
|
||||
- match: {hits.hits.0._score: 17.0}
|
||||
|
||||
- match: {hits.hits.1._id: "1"}
|
||||
- match: {hits.hits.1._score: 16.0}
|
||||
|
||||
- match: {hits.hits.2._id: "3"}
|
||||
- match: {hits.hits.2._score: 11.0}
|
||||
---
|
||||
"Hamming distance hexidecimal":
|
||||
- do:
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "hamming(params.query_vector, 'vector')"
|
||||
params:
|
||||
query_vector: "006ff30e84"
|
||||
|
||||
- match: {hits.total: 3}
|
||||
|
||||
- match: {hits.hits.0._id: "2"}
|
||||
- match: {hits.hits.0._score: 17.0}
|
||||
|
||||
- match: {hits.hits.1._id: "1"}
|
||||
- match: {hits.hits.1._score: 16.0}
|
||||
|
||||
- match: {hits.hits.2._id: "3"}
|
||||
- match: {hits.hits.2._score: 11.0}
|
||||
|
||||
|
||||
- do:
|
||||
headers:
|
||||
Content-Type: application/json
|
||||
search:
|
||||
rest_total_hits_as_int: true
|
||||
body:
|
||||
query:
|
||||
script_score:
|
||||
query: {match_all: {} }
|
||||
script:
|
||||
source: "hamming(params.query_vector, 'indexed_vector')"
|
||||
params:
|
||||
query_vector: "006ff30e84"
|
||||
|
||||
- match: {hits.total: 3}
|
||||
|
||||
- match: {hits.hits.0._id: "2"}
|
||||
- match: {hits.hits.0._score: 17.0}
|
||||
|
||||
- match: {hits.hits.1._id: "1"}
|
||||
- match: {hits.hits.1._score: 16.0}
|
||||
|
||||
- match: {hits.hits.2._id: "3"}
|
||||
- match: {hits.hits.2._score: 11.0}
|
|
@ -0,0 +1,301 @@
|
|||
setup:
|
||||
- requires:
|
||||
cluster_features: "mapper.vectors.bit_vectors"
|
||||
test_runner_features: close_to
|
||||
reason: 'bit vectors added in 8.15'
|
||||
- do:
|
||||
indices.create:
|
||||
index: test
|
||||
body:
|
||||
settings:
|
||||
index:
|
||||
number_of_shards: 2
|
||||
mappings:
|
||||
properties:
|
||||
name:
|
||||
type: keyword
|
||||
nested:
|
||||
type: nested
|
||||
properties:
|
||||
paragraph_id:
|
||||
type: keyword
|
||||
vector:
|
||||
type: dense_vector
|
||||
dims: 40
|
||||
index: true
|
||||
element_type: bit
|
||||
similarity: l2_norm
|
||||
|
||||
- do:
|
||||
index:
|
||||
index: test
|
||||
id: "1"
|
||||
body:
|
||||
name: cow.jpg
|
||||
nested:
|
||||
- paragraph_id: 0
|
||||
vector: [100, 20, -34, 15, -100]
|
||||
- paragraph_id: 1
|
||||
vector: [40, 30, -3, 1, -20]
|
||||
|
||||
- do:
|
||||
index:
|
||||
index: test
|
||||
id: "2"
|
||||
body:
|
||||
name: moose.jpg
|
||||
nested:
|
||||
- paragraph_id: 0
|
||||
vector: [-1, 100, -13, 14, -127]
|
||||
- paragraph_id: 2
|
||||
vector: [0, 100, 0, 15, -127]
|
||||
- paragraph_id: 3
|
||||
vector: [0, 1, 0, 2, -15]
|
||||
|
||||
- do:
|
||||
index:
|
||||
index: test
|
||||
id: "3"
|
||||
body:
|
||||
name: rabbit.jpg
|
||||
nested:
|
||||
- paragraph_id: 0
|
||||
vector: [1, 111, -13, 14, -1]
|
||||
|
||||
- do:
|
||||
indices.refresh: {}
|
||||
|
||||
---
|
||||
"nested kNN search only":
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
fields: [ "name" ]
|
||||
knn:
|
||||
field: nested.vector
|
||||
query_vector: [-1, 90, -10, 14, -127]
|
||||
k: 2
|
||||
num_candidates: 3
|
||||
|
||||
- match: {hits.hits.0._id: "2"}
|
||||
- match: {hits.hits.0.fields.name.0: "moose.jpg"}
|
||||
|
||||
- match: {hits.hits.1._id: "1"}
|
||||
- match: {hits.hits.1.fields.name.0: "cow.jpg"}
|
||||
|
||||
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
fields: [ "name" ]
|
||||
knn:
|
||||
field: nested.vector
|
||||
query_vector: [-1, 90, -10, 14, -127]
|
||||
k: 2
|
||||
num_candidates: 3
|
||||
inner_hits: {size: 1, "fields": ["nested.paragraph_id"], _source: false}
|
||||
|
||||
|
||||
- match: {hits.hits.0._id: "2"}
|
||||
- match: {hits.hits.0.fields.name.0: "moose.jpg"}
|
||||
- match: {hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0"}
|
||||
|
||||
- match: {hits.hits.1._id: "1"}
|
||||
- match: {hits.hits.1.fields.name.0: "cow.jpg"}
|
||||
- match: {hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0"}
|
||||
|
||||
---
|
||||
"nested kNN search filtered":
|
||||
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
fields: [ "name" ]
|
||||
knn:
|
||||
field: nested.vector
|
||||
query_vector: [-1, 90, -10, 14, -127]
|
||||
k: 2
|
||||
num_candidates: 3
|
||||
filter: {term: {name: "rabbit.jpg"}}
|
||||
|
||||
- match: {hits.total.value: 1}
|
||||
- match: {hits.hits.0._id: "3"}
|
||||
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
|
||||
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
fields: [ "name" ]
|
||||
knn:
|
||||
field: nested.vector
|
||||
query_vector: [-1, 90, -10, 14, -127]
|
||||
k: 3
|
||||
num_candidates: 3
|
||||
filter: {term: {name: "rabbit.jpg"}}
|
||||
inner_hits: {size: 1, fields: ["nested.paragraph_id"], _source: false}
|
||||
|
||||
- match: {hits.total.value: 1}
|
||||
- match: {hits.hits.0._id: "3"}
|
||||
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
|
||||
- match: {hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0"}
|
||||
---
|
||||
"nested kNN search inner_hits size > 1":
|
||||
- do:
|
||||
index:
|
||||
index: test
|
||||
id: "4"
|
||||
body:
|
||||
name: moose.jpg
|
||||
nested:
|
||||
- paragraph_id: 0
|
||||
vector: [-1, 90, -10, 14, -127]
|
||||
- paragraph_id: 2
|
||||
vector: [ 0, 100.0, 0, 14, -127 ]
|
||||
- paragraph_id: 3
|
||||
vector: [ 0, 1.0, 0, 2, -15 ]
|
||||
|
||||
- do:
|
||||
index:
|
||||
index: test
|
||||
id: "5"
|
||||
body:
|
||||
name: moose.jpg
|
||||
nested:
|
||||
- paragraph_id: 0
|
||||
vector: [ -1, 100, -13, 14, -127 ]
|
||||
- paragraph_id: 2
|
||||
vector: [ 0, 100, 0, 15, -127 ]
|
||||
- paragraph_id: 3
|
||||
vector: [ 0, 1, 0, 2, -15 ]
|
||||
|
||||
- do:
|
||||
index:
|
||||
index: test
|
||||
id: "6"
|
||||
body:
|
||||
name: moose.jpg
|
||||
nested:
|
||||
- paragraph_id: 0
|
||||
vector: [ -1, 100, -13, 15, -127 ]
|
||||
- paragraph_id: 2
|
||||
vector: [ 0, 100, 0, 15, -127 ]
|
||||
- paragraph_id: 3
|
||||
vector: [ 0, 1, 0, 2, -15 ]
|
||||
- do:
|
||||
indices.refresh: { }
|
||||
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
fields: [ "name" ]
|
||||
knn:
|
||||
field: nested.vector
|
||||
query_vector: [-1, 90, -10, 15, -127]
|
||||
k: 3
|
||||
num_candidates: 5
|
||||
inner_hits: {size: 2, fields: ["nested.paragraph_id"], _source: false}
|
||||
|
||||
- match: {hits.total.value: 3}
|
||||
- length: { hits.hits.0.inner_hits.nested.hits.hits: 2 }
|
||||
- length: { hits.hits.1.inner_hits.nested.hits.hits: 2 }
|
||||
- length: { hits.hits.2.inner_hits.nested.hits.hits: 2 }
|
||||
|
||||
- match: { hits.hits.0.fields.name.0: "moose.jpg" }
|
||||
- match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
|
||||
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
fields: [ "name" ]
|
||||
knn:
|
||||
field: nested.vector
|
||||
query_vector: [-1, 90, -10, 15, -127]
|
||||
k: 5
|
||||
num_candidates: 5
|
||||
inner_hits: {size: 2, fields: ["nested.paragraph_id"], _source: false}
|
||||
|
||||
- match: {hits.total.value: 5}
|
||||
# All these initial matches are "moose.jpg", which has 3 nested vectors, but two are closest
|
||||
- match: {hits.hits.0.fields.name.0: "moose.jpg"}
|
||||
- length: { hits.hits.0.inner_hits.nested.hits.hits: 2 }
|
||||
- match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
|
||||
- match: { hits.hits.0.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "2" }
|
||||
- match: {hits.hits.1.fields.name.0: "moose.jpg"}
|
||||
- length: { hits.hits.1.inner_hits.nested.hits.hits: 2 }
|
||||
- match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
|
||||
- match: { hits.hits.1.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "2" }
|
||||
- match: {hits.hits.2.fields.name.0: "moose.jpg"}
|
||||
- length: { hits.hits.2.inner_hits.nested.hits.hits: 2 }
|
||||
- match: { hits.hits.2.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
|
||||
- match: { hits.hits.2.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "2" }
|
||||
- match: {hits.hits.3.fields.name.0: "moose.jpg"}
|
||||
- length: { hits.hits.3.inner_hits.nested.hits.hits: 2 }
|
||||
- match: { hits.hits.3.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
|
||||
- match: { hits.hits.3.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "2" }
|
||||
# Rabbit only has one passage vector
|
||||
- match: {hits.hits.4.fields.name.0: "cow.jpg"}
|
||||
- length: { hits.hits.4.inner_hits.nested.hits.hits: 2 }
|
||||
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
fields: [ "name" ]
|
||||
knn:
|
||||
field: nested.vector
|
||||
query_vector: [ -1, 90, -10, 15, -127 ]
|
||||
k: 3
|
||||
num_candidates: 3
|
||||
filter: {term: {name: "cow.jpg"}}
|
||||
inner_hits: {size: 3, fields: ["nested.paragraph_id"], _source: false}
|
||||
|
||||
- match: {hits.total.value: 1}
|
||||
- match: { hits.hits.0._id: "1" }
|
||||
- length: { hits.hits.0.inner_hits.nested.hits.hits: 2 }
|
||||
- match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
|
||||
- match: { hits.hits.0.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "1" }
|
||||
---
|
||||
"nested kNN search inner_hits & boosting":
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
fields: [ "name" ]
|
||||
knn:
|
||||
field: nested.vector
|
||||
query_vector: [-1, 90, -10, 15, -127]
|
||||
k: 3
|
||||
num_candidates: 5
|
||||
inner_hits: {size: 2, fields: ["nested.paragraph_id"], _source: false}
|
||||
|
||||
- close_to: { hits.hits.0._score: {value: 0.8, error: 0.00001} }
|
||||
- close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: {value: 0.8, error: 0.00001} }
|
||||
- close_to: { hits.hits.1._score: {value: 0.625, error: 0.00001} }
|
||||
- close_to: { hits.hits.1.inner_hits.nested.hits.hits.0._score: {value: 0.625, error: 0.00001} }
|
||||
- close_to: { hits.hits.2._score: {value: 0.5, error: 0.00001} }
|
||||
- close_to: { hits.hits.2.inner_hits.nested.hits.hits.0._score: {value: 0.5, error: 0.00001} }
|
||||
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
fields: [ "name" ]
|
||||
knn:
|
||||
field: nested.vector
|
||||
query_vector: [-1, 90, -10, 15, -127]
|
||||
k: 3
|
||||
num_candidates: 5
|
||||
boost: 2
|
||||
inner_hits: {size: 2, fields: ["nested.paragraph_id"], _source: false}
|
||||
- close_to: { hits.hits.0._score: {value: 1.6, error: 0.00001} }
|
||||
- close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: {value: 1.6, error: 0.00001} }
|
||||
- close_to: { hits.hits.1._score: {value: 1.25, error: 0.00001} }
|
||||
- close_to: { hits.hits.1.inner_hits.nested.hits.hits.0._score: {value: 1.25, error: 0.00001} }
|
||||
- close_to: { hits.hits.2._score: {value: 1, error: 0.00001} }
|
||||
- close_to: { hits.hits.2.inner_hits.nested.hits.hits.0._score: {value: 1.0, error: 0.00001} }
|
|
@ -116,8 +116,9 @@ setup:
|
|||
---
|
||||
"Knn search with hex string for byte field - dimensions mismatch" :
|
||||
# [64, 10, -30, 10] - is encoded as '400ae20a'
|
||||
# the error message has been adjusted in later versions
|
||||
- do:
|
||||
catch: /the query vector has a different dimension \[4\] than the index vectors \[3\]/
|
||||
catch: /dimension|dimensions \[4\] than the document|index vectors \[3\]/
|
||||
search:
|
||||
index: knn_hex_vector_index
|
||||
body:
|
||||
|
|
|
@ -116,8 +116,9 @@ setup:
|
|||
---
|
||||
"Knn query with hex string for byte field - dimensions mismatch" :
|
||||
# [64, 10, -30, 10] - is encoded as '400ae20a'
|
||||
# the error message has been adjusted in later versions
|
||||
- do:
|
||||
catch: /the query vector has a different dimension \[4\] than the index vectors \[3\]/
|
||||
catch: /dimension|dimensions \[4\] than the document|index vectors \[3\]/
|
||||
search:
|
||||
index: knn_hex_vector_index
|
||||
body:
|
||||
|
|
|
@ -0,0 +1,356 @@
|
|||
setup:
|
||||
- requires:
|
||||
cluster_features: "mapper.vectors.bit_vectors"
|
||||
reason: 'mapper.vectors.bit_vectors'
|
||||
|
||||
- do:
|
||||
indices.create:
|
||||
index: test
|
||||
body:
|
||||
mappings:
|
||||
properties:
|
||||
name:
|
||||
type: keyword
|
||||
vector:
|
||||
type: dense_vector
|
||||
element_type: bit
|
||||
dims: 40
|
||||
index: true
|
||||
similarity: l2_norm
|
||||
|
||||
- do:
|
||||
index:
|
||||
index: test
|
||||
id: "1"
|
||||
body:
|
||||
name: cow.jpg
|
||||
vector: [2, -1, 1, 4, -3]
|
||||
|
||||
- do:
|
||||
index:
|
||||
index: test
|
||||
id: "2"
|
||||
body:
|
||||
name: moose.jpg
|
||||
vector: [127.0, -128.0, 0.0, 1.0, -1.0]
|
||||
|
||||
- do:
|
||||
index:
|
||||
index: test
|
||||
id: "3"
|
||||
body:
|
||||
name: rabbit.jpg
|
||||
vector: [5, 4.0, 3, 2.0, 127]
|
||||
|
||||
- do:
|
||||
indices.refresh: {}
|
||||
|
||||
---
|
||||
"kNN search only":
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
fields: [ "name" ]
|
||||
knn:
|
||||
field: vector
|
||||
query_vector: [127, 127, -128, -128, 127]
|
||||
k: 2
|
||||
num_candidates: 3
|
||||
|
||||
- match: {hits.hits.0._id: "2"}
|
||||
- match: {hits.hits.0.fields.name.0: "moose.jpg"}
|
||||
|
||||
- match: {hits.hits.1._id: "1"}
|
||||
- match: {hits.hits.1.fields.name.0: "cow.jpg"}
|
||||
|
||||
---
|
||||
"kNN search plus query":
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
fields: [ "name" ]
|
||||
knn:
|
||||
field: vector
|
||||
query_vector: [127.0, -128.0, 0.0, 1.0, -1.0]
|
||||
k: 2
|
||||
num_candidates: 3
|
||||
query:
|
||||
term:
|
||||
name: rabbit.jpg
|
||||
|
||||
- match: {hits.hits.0._id: "3"}
|
||||
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
|
||||
|
||||
- match: {hits.hits.1._id: "2"}
|
||||
- match: {hits.hits.1.fields.name.0: "moose.jpg"}
|
||||
|
||||
---
|
||||
"kNN search with filter":
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
fields: [ "name" ]
|
||||
knn:
|
||||
field: vector
|
||||
query_vector: [5.0, 4, 3.0, 2, 127.0]
|
||||
k: 2
|
||||
num_candidates: 3
|
||||
filter:
|
||||
term:
|
||||
name: "rabbit.jpg"
|
||||
|
||||
- match: {hits.total.value: 1}
|
||||
- match: {hits.hits.0._id: "3"}
|
||||
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
|
||||
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
fields: [ "name" ]
|
||||
knn:
|
||||
field: vector
|
||||
query_vector: [2, -1, 1, 4, -3]
|
||||
k: 2
|
||||
num_candidates: 3
|
||||
filter:
|
||||
- term:
|
||||
name: "rabbit.jpg"
|
||||
- term:
|
||||
_id: 2
|
||||
|
||||
- match: {hits.total.value: 0}
|
||||
|
||||
---
|
||||
"Vector similarity search only":
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
fields: [ "name" ]
|
||||
knn:
|
||||
num_candidates: 3
|
||||
k: 3
|
||||
field: vector
|
||||
similarity: 0.98
|
||||
query_vector: [5, 4.0, 3, 2.0, 127]
|
||||
|
||||
- length: {hits.hits: 1}
|
||||
|
||||
- match: {hits.hits.0._id: "3"}
|
||||
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
|
||||
---
|
||||
"Vector similarity with filter only":
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
fields: [ "name" ]
|
||||
knn:
|
||||
num_candidates: 3
|
||||
k: 3
|
||||
field: vector
|
||||
similarity: 0.98
|
||||
query_vector: [5, 4.0, 3, 2.0, 127]
|
||||
filter: {"term": {"name": "rabbit.jpg"}}
|
||||
|
||||
- length: {hits.hits: 1}
|
||||
|
||||
- match: {hits.hits.0._id: "3"}
|
||||
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
|
||||
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
fields: [ "name" ]
|
||||
knn:
|
||||
num_candidates: 3
|
||||
k: 3
|
||||
field: vector
|
||||
similarity: 0.98
|
||||
query_vector: [5, 4.0, 3, 2.0, 127]
|
||||
filter: {"term": {"name": "cow.jpg"}}
|
||||
|
||||
- length: {hits.hits: 0}
|
||||
---
|
||||
"dim mismatch":
|
||||
- do:
|
||||
catch: bad_request
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
fields: [ "name" ]
|
||||
knn:
|
||||
field: vector
|
||||
query_vector: [1, 2, 3, 4, 5, 6]
|
||||
k: 2
|
||||
num_candidates: 3
|
||||
---
|
||||
"disallow quantized vector types":
|
||||
- do:
|
||||
catch: bad_request
|
||||
indices.create:
|
||||
index: test
|
||||
body:
|
||||
mappings:
|
||||
properties:
|
||||
name:
|
||||
type: keyword
|
||||
vector:
|
||||
type: dense_vector
|
||||
element_type: bit
|
||||
dims: 32
|
||||
index: true
|
||||
similarity: l2_norm
|
||||
index_options:
|
||||
type: int8_flat
|
||||
|
||||
- do:
|
||||
catch: bad_request
|
||||
indices.create:
|
||||
index: test
|
||||
body:
|
||||
mappings:
|
||||
properties:
|
||||
name:
|
||||
type: keyword
|
||||
vector:
|
||||
type: dense_vector
|
||||
element_type: bit
|
||||
dims: 32
|
||||
index: true
|
||||
similarity: l2_norm
|
||||
index_options:
|
||||
type: int4_flat
|
||||
|
||||
- do:
|
||||
catch: bad_request
|
||||
indices.create:
|
||||
index: test
|
||||
body:
|
||||
mappings:
|
||||
properties:
|
||||
name:
|
||||
type: keyword
|
||||
vector:
|
||||
type: dense_vector
|
||||
element_type: bit
|
||||
dims: 32
|
||||
index: true
|
||||
similarity: l2_norm
|
||||
index_options:
|
||||
type: int8_hnsw
|
||||
|
||||
- do:
|
||||
catch: bad_request
|
||||
indices.create:
|
||||
index: test
|
||||
body:
|
||||
mappings:
|
||||
properties:
|
||||
name:
|
||||
type: keyword
|
||||
vector:
|
||||
type: dense_vector
|
||||
element_type: bit
|
||||
dims: 32
|
||||
index: true
|
||||
similarity: l2_norm
|
||||
index_options:
|
||||
type: int4_hnsw
|
||||
---
|
||||
"disallow vector index type change to quantized type":
|
||||
- do:
|
||||
catch: bad_request
|
||||
indices.put_mapping:
|
||||
index: test
|
||||
body:
|
||||
properties:
|
||||
vector:
|
||||
type: dense_vector
|
||||
element_type: bit
|
||||
dims: 32
|
||||
index: true
|
||||
similarity: l2_norm
|
||||
index_options:
|
||||
type: int4_hnsw
|
||||
- do:
|
||||
catch: bad_request
|
||||
indices.put_mapping:
|
||||
index: test
|
||||
body:
|
||||
properties:
|
||||
vector:
|
||||
type: dense_vector
|
||||
element_type: bit
|
||||
dims: 32
|
||||
index: true
|
||||
similarity: l2_norm
|
||||
index_options:
|
||||
type: int8_hnsw
|
||||
---
|
||||
"Defaults to l2_norm with bit vectors":
|
||||
- do:
|
||||
indices.create:
|
||||
index: default_to_l2_norm_bit
|
||||
body:
|
||||
mappings:
|
||||
properties:
|
||||
vector:
|
||||
type: dense_vector
|
||||
element_type: bit
|
||||
dims: 40
|
||||
index: true
|
||||
|
||||
- do:
|
||||
indices.get_mapping:
|
||||
index: default_to_l2_norm_bit
|
||||
|
||||
- match: { default_to_l2_norm_bit.mappings.properties.vector.similarity: l2_norm }
|
||||
|
||||
---
|
||||
"Only allow l2_norm with bit vectors":
|
||||
- do:
|
||||
catch: bad_request
|
||||
indices.create:
|
||||
index: dot_product_fails_for_bits
|
||||
body:
|
||||
mappings:
|
||||
properties:
|
||||
vector:
|
||||
type: dense_vector
|
||||
element_type: bit
|
||||
dims: 40
|
||||
index: true
|
||||
similarity: dot_product
|
||||
|
||||
- do:
|
||||
catch: bad_request
|
||||
indices.create:
|
||||
index: cosine_product_fails_for_bits
|
||||
body:
|
||||
mappings:
|
||||
properties:
|
||||
vector:
|
||||
type: dense_vector
|
||||
element_type: bit
|
||||
dims: 40
|
||||
index: true
|
||||
similarity: cosine
|
||||
|
||||
- do:
|
||||
catch: bad_request
|
||||
indices.create:
|
||||
index: cosine_product_fails_for_bits
|
||||
body:
|
||||
mappings:
|
||||
properties:
|
||||
type: dense_vector
|
||||
element_type: bit
|
||||
dims: 40
|
||||
index: true
|
||||
similarity: max_inner_product
|
|
@ -0,0 +1,223 @@
|
|||
setup:
|
||||
- requires:
|
||||
cluster_features: "mapper.vectors.bit_vectors"
|
||||
reason: 'mapper.vectors.bit_vectors'
|
||||
|
||||
- do:
|
||||
indices.create:
|
||||
index: test
|
||||
body:
|
||||
mappings:
|
||||
properties:
|
||||
name:
|
||||
type: keyword
|
||||
vector:
|
||||
type: dense_vector
|
||||
element_type: bit
|
||||
dims: 40
|
||||
index: true
|
||||
similarity: l2_norm
|
||||
index_options:
|
||||
type: flat
|
||||
|
||||
- do:
|
||||
index:
|
||||
index: test
|
||||
id: "1"
|
||||
body:
|
||||
name: cow.jpg
|
||||
vector: [2, -1, 1, 4, -3]
|
||||
|
||||
- do:
|
||||
index:
|
||||
index: test
|
||||
id: "2"
|
||||
body:
|
||||
name: moose.jpg
|
||||
vector: [127.0, -128.0, 0.0, 1.0, -1.0]
|
||||
|
||||
- do:
|
||||
index:
|
||||
index: test
|
||||
id: "3"
|
||||
body:
|
||||
name: rabbit.jpg
|
||||
vector: [5, 4.0, 3, 2.0, 127]
|
||||
|
||||
- do:
|
||||
indices.refresh: {}
|
||||
|
||||
---
|
||||
"kNN search only":
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
fields: [ "name" ]
|
||||
knn:
|
||||
field: vector
|
||||
query_vector: [127, 127, -128, -128, 127]
|
||||
k: 2
|
||||
num_candidates: 3
|
||||
|
||||
- match: {hits.hits.0._id: "2"}
|
||||
- match: {hits.hits.0.fields.name.0: "moose.jpg"}
|
||||
|
||||
- match: {hits.hits.1._id: "1"}
|
||||
- match: {hits.hits.1.fields.name.0: "cow.jpg"}
|
||||
|
||||
---
|
||||
"kNN search plus query":
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
fields: [ "name" ]
|
||||
knn:
|
||||
field: vector
|
||||
query_vector: [127.0, -128.0, 0.0, 1.0, -1.0]
|
||||
k: 2
|
||||
num_candidates: 3
|
||||
query:
|
||||
term:
|
||||
name: rabbit.jpg
|
||||
|
||||
- match: {hits.hits.0._id: "3"}
|
||||
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
|
||||
|
||||
- match: {hits.hits.1._id: "2"}
|
||||
- match: {hits.hits.1.fields.name.0: "moose.jpg"}
|
||||
|
||||
---
|
||||
"kNN search with filter":
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
fields: [ "name" ]
|
||||
knn:
|
||||
field: vector
|
||||
query_vector: [5.0, 4, 3.0, 2, 127.0]
|
||||
k: 2
|
||||
num_candidates: 3
|
||||
filter:
|
||||
term:
|
||||
name: "rabbit.jpg"
|
||||
|
||||
- match: {hits.total.value: 1}
|
||||
- match: {hits.hits.0._id: "3"}
|
||||
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
|
||||
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
fields: [ "name" ]
|
||||
knn:
|
||||
field: vector
|
||||
query_vector: [2, -1, 1, 4, -3]
|
||||
k: 2
|
||||
num_candidates: 3
|
||||
filter:
|
||||
- term:
|
||||
name: "rabbit.jpg"
|
||||
- term:
|
||||
_id: 2
|
||||
|
||||
- match: {hits.total.value: 0}
|
||||
|
||||
---
|
||||
"Vector similarity search only":
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
fields: [ "name" ]
|
||||
knn:
|
||||
num_candidates: 3
|
||||
k: 3
|
||||
field: vector
|
||||
similarity: 0.98
|
||||
query_vector: [5, 4.0, 3, 2.0, 127]
|
||||
|
||||
- length: {hits.hits: 1}
|
||||
|
||||
- match: {hits.hits.0._id: "3"}
|
||||
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
|
||||
---
|
||||
"Vector similarity with filter only":
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
fields: [ "name" ]
|
||||
knn:
|
||||
num_candidates: 3
|
||||
k: 3
|
||||
field: vector
|
||||
similarity: 0.98
|
||||
query_vector: [5, 4.0, 3, 2.0, 127]
|
||||
filter: {"term": {"name": "rabbit.jpg"}}
|
||||
|
||||
- length: {hits.hits: 1}
|
||||
|
||||
- match: {hits.hits.0._id: "3"}
|
||||
- match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
|
||||
|
||||
- do:
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
fields: [ "name" ]
|
||||
knn:
|
||||
num_candidates: 3
|
||||
k: 3
|
||||
field: vector
|
||||
similarity: 0.98
|
||||
query_vector: [5, 4.0, 3, 2.0, 127]
|
||||
filter: {"term": {"name": "cow.jpg"}}
|
||||
|
||||
- length: {hits.hits: 0}
|
||||
---
|
||||
"dim mismatch":
|
||||
- do:
|
||||
catch: bad_request
|
||||
search:
|
||||
index: test
|
||||
body:
|
||||
fields: [ "name" ]
|
||||
knn:
|
||||
field: vector
|
||||
query_vector: [1, 2, 3, 4, 5, 6]
|
||||
k: 2
|
||||
num_candidates: 3
|
||||
---
|
||||
"disallow vector index type change to quantized type":
|
||||
- do:
|
||||
catch: bad_request
|
||||
indices.put_mapping:
|
||||
index: test
|
||||
body:
|
||||
properties:
|
||||
vector:
|
||||
type: dense_vector
|
||||
element_type: bit
|
||||
dims: 32
|
||||
index: true
|
||||
similarity: l2_norm
|
||||
index_options:
|
||||
type: int4_hnsw
|
||||
- do:
|
||||
catch: bad_request
|
||||
indices.put_mapping:
|
||||
index: test
|
||||
body:
|
||||
properties:
|
||||
vector:
|
||||
type: dense_vector
|
||||
element_type: bit
|
||||
dims: 32
|
||||
index: true
|
||||
similarity: l2_norm
|
||||
index_options:
|
||||
type: int8_hnsw
|
|
@ -449,7 +449,10 @@ module org.elasticsearch.server {
|
|||
with
|
||||
org.elasticsearch.index.codec.vectors.ES813FlatVectorFormat,
|
||||
org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat,
|
||||
org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat;
|
||||
org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat,
|
||||
org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat,
|
||||
org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat;
|
||||
|
||||
provides org.apache.lucene.codecs.Codec with Elasticsearch814Codec;
|
||||
|
||||
provides org.apache.logging.log4j.core.util.ContextDataProvider with org.elasticsearch.common.logging.DynamicContextDataProvider;
|
||||
|
|
|
@ -54,11 +54,11 @@ public class ES813FlatVectorFormat extends KnnVectorsFormat {
|
|||
return new ES813FlatVectorReader(format.fieldsReader(state));
|
||||
}
|
||||
|
||||
public static class ES813FlatVectorWriter extends KnnVectorsWriter {
|
||||
static class ES813FlatVectorWriter extends KnnVectorsWriter {
|
||||
|
||||
private final FlatVectorsWriter writer;
|
||||
|
||||
public ES813FlatVectorWriter(FlatVectorsWriter writer) {
|
||||
ES813FlatVectorWriter(FlatVectorsWriter writer) {
|
||||
super();
|
||||
this.writer = writer;
|
||||
}
|
||||
|
@ -94,11 +94,11 @@ public class ES813FlatVectorFormat extends KnnVectorsFormat {
|
|||
}
|
||||
}
|
||||
|
||||
public static class ES813FlatVectorReader extends KnnVectorsReader {
|
||||
static class ES813FlatVectorReader extends KnnVectorsReader {
|
||||
|
||||
private final FlatVectorsReader reader;
|
||||
|
||||
public ES813FlatVectorReader(FlatVectorsReader reader) {
|
||||
ES813FlatVectorReader(FlatVectorsReader reader) {
|
||||
super();
|
||||
this.reader = reader;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0 and the Server Side Public License, v 1; you may not use this file except
|
||||
* in compliance with, at your election, the Elastic License 2.0 or the Server
|
||||
* Side Public License, v 1.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.codec.vectors;
|
||||
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class ES815BitFlatVectorFormat extends KnnVectorsFormat {
|
||||
|
||||
static final String NAME = "ES815BitFlatVectorFormat";
|
||||
|
||||
private final FlatVectorsFormat format = new ES815BitFlatVectorsFormat();
|
||||
|
||||
/**
|
||||
* Sole constructor
|
||||
*/
|
||||
public ES815BitFlatVectorFormat() {
|
||||
super(NAME);
|
||||
}
|
||||
|
||||
@Override
|
||||
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
|
||||
return new ES813FlatVectorFormat.ES813FlatVectorWriter(format.fieldsWriter(state));
|
||||
}
|
||||
|
||||
@Override
|
||||
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
|
||||
return new ES813FlatVectorFormat.ES813FlatVectorReader(format.fieldsReader(state));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return NAME;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,143 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0 and the Server Side Public License, v 1; you may not use this file except
|
||||
* in compliance with, at your election, the Elastic License 2.0 or the Server
|
||||
* Side Public License, v 1.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.codec.vectors;
|
||||
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
|
||||
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
class ES815BitFlatVectorsFormat extends FlatVectorsFormat {
|
||||
|
||||
private final FlatVectorsFormat delegate = new Lucene99FlatVectorsFormat(FlatBitVectorScorer.INSTANCE);
|
||||
|
||||
@Override
|
||||
public FlatVectorsWriter fieldsWriter(SegmentWriteState segmentWriteState) throws IOException {
|
||||
return delegate.fieldsWriter(segmentWriteState);
|
||||
}
|
||||
|
||||
@Override
|
||||
public FlatVectorsReader fieldsReader(SegmentReadState segmentReadState) throws IOException {
|
||||
return delegate.fieldsReader(segmentReadState);
|
||||
}
|
||||
|
||||
static class FlatBitVectorScorer implements FlatVectorsScorer {
|
||||
|
||||
static final FlatBitVectorScorer INSTANCE = new FlatBitVectorScorer();
|
||||
|
||||
static void checkDimensions(int queryLen, int fieldLen) {
|
||||
if (queryLen != fieldLen) {
|
||||
throw new IllegalArgumentException("vector query dimension: " + queryLen + " differs from field dimension: " + fieldLen);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return super.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
|
||||
VectorSimilarityFunction vectorSimilarityFunction,
|
||||
RandomAccessVectorValues randomAccessVectorValues
|
||||
) throws IOException {
|
||||
assert randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes;
|
||||
assert vectorSimilarityFunction == VectorSimilarityFunction.EUCLIDEAN;
|
||||
if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes randomAccessVectorValuesBytes) {
|
||||
assert randomAccessVectorValues instanceof RandomAccessQuantizedByteVectorValues == false;
|
||||
return switch (vectorSimilarityFunction) {
|
||||
case DOT_PRODUCT, MAXIMUM_INNER_PRODUCT, COSINE, EUCLIDEAN -> new HammingScorerSupplier(randomAccessVectorValuesBytes);
|
||||
};
|
||||
}
|
||||
throw new IllegalArgumentException("Unsupported vector type or similarity function");
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorer getRandomVectorScorer(
|
||||
VectorSimilarityFunction vectorSimilarityFunction,
|
||||
RandomAccessVectorValues randomAccessVectorValues,
|
||||
byte[] bytes
|
||||
) {
|
||||
assert randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes;
|
||||
assert vectorSimilarityFunction == VectorSimilarityFunction.EUCLIDEAN;
|
||||
if (randomAccessVectorValues instanceof RandomAccessVectorValues.Bytes randomAccessVectorValuesBytes) {
|
||||
checkDimensions(bytes.length, randomAccessVectorValuesBytes.dimension());
|
||||
return switch (vectorSimilarityFunction) {
|
||||
case DOT_PRODUCT, MAXIMUM_INNER_PRODUCT, COSINE, EUCLIDEAN -> new HammingVectorScorer(
|
||||
randomAccessVectorValuesBytes,
|
||||
bytes
|
||||
);
|
||||
};
|
||||
}
|
||||
throw new IllegalArgumentException("Unsupported vector type or similarity function");
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorer getRandomVectorScorer(
|
||||
VectorSimilarityFunction vectorSimilarityFunction,
|
||||
RandomAccessVectorValues randomAccessVectorValues,
|
||||
float[] floats
|
||||
) {
|
||||
throw new IllegalArgumentException("Unsupported vector type");
|
||||
}
|
||||
}
|
||||
|
||||
static float hammingScore(byte[] a, byte[] b) {
|
||||
return ((a.length * Byte.SIZE) - VectorUtil.xorBitCount(a, b)) / (float) (a.length * Byte.SIZE);
|
||||
}
|
||||
|
||||
static class HammingVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer {
|
||||
private final byte[] query;
|
||||
private final RandomAccessVectorValues.Bytes byteValues;
|
||||
|
||||
HammingVectorScorer(RandomAccessVectorValues.Bytes byteValues, byte[] query) {
|
||||
super(byteValues);
|
||||
this.query = query;
|
||||
this.byteValues = byteValues;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score(int i) throws IOException {
|
||||
return hammingScore(byteValues.vectorValue(i), query);
|
||||
}
|
||||
}
|
||||
|
||||
static class HammingScorerSupplier implements RandomVectorScorerSupplier {
|
||||
private final RandomAccessVectorValues.Bytes byteValues, byteValues1, byteValues2;
|
||||
|
||||
HammingScorerSupplier(RandomAccessVectorValues.Bytes byteValues) throws IOException {
|
||||
this.byteValues = byteValues;
|
||||
this.byteValues1 = byteValues.copy();
|
||||
this.byteValues2 = byteValues.copy();
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorer scorer(int i) throws IOException {
|
||||
byte[] query = byteValues1.vectorValue(i);
|
||||
return new HammingVectorScorer(byteValues2, query);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorerSupplier copy() throws IOException {
|
||||
return new HammingScorerSupplier(byteValues);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,74 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0 and the Server Side Public License, v 1; you may not use this file except
|
||||
* in compliance with, at your election, the Elastic License 2.0 or the Server
|
||||
* Side Public License, v 1.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.codec.vectors;
|
||||
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
|
||||
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class ES815HnswBitVectorsFormat extends KnnVectorsFormat {
|
||||
|
||||
static final String NAME = "ES815HnswBitVectorsFormat";
|
||||
|
||||
static final int MAXIMUM_MAX_CONN = 512;
|
||||
static final int MAXIMUM_BEAM_WIDTH = 3200;
|
||||
|
||||
private final int maxConn;
|
||||
private final int beamWidth;
|
||||
|
||||
private final FlatVectorsFormat flatVectorsFormat = new ES815BitFlatVectorsFormat();
|
||||
|
||||
public ES815HnswBitVectorsFormat() {
|
||||
this(16, 100);
|
||||
}
|
||||
|
||||
public ES815HnswBitVectorsFormat(int maxConn, int beamWidth) {
|
||||
super(NAME);
|
||||
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
|
||||
throw new IllegalArgumentException(
|
||||
"maxConn must be positive and less than or equal to " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn
|
||||
);
|
||||
}
|
||||
if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) {
|
||||
throw new IllegalArgumentException(
|
||||
"beamWidth must be positive and less than or equal to " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth
|
||||
);
|
||||
}
|
||||
this.maxConn = maxConn;
|
||||
this.beamWidth = beamWidth;
|
||||
}
|
||||
|
||||
@Override
|
||||
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
|
||||
return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), 1, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
|
||||
return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "ES815HnswBitVectorsFormat(name=ES815HnswBitVectorsFormat, maxConn="
|
||||
+ maxConn
|
||||
+ ", beamWidth="
|
||||
+ beamWidth
|
||||
+ ", flatVectorFormat="
|
||||
+ flatVectorsFormat
|
||||
+ ")";
|
||||
}
|
||||
}
|
|
@ -25,7 +25,8 @@ public class MapperFeatures implements FeatureSpecification {
|
|||
PassThroughObjectMapper.PASS_THROUGH_PRIORITY,
|
||||
RangeFieldMapper.NULL_VALUES_OFF_BY_ONE_FIX,
|
||||
SourceFieldMapper.SYNTHETIC_SOURCE_FALLBACK,
|
||||
DenseVectorFieldMapper.INT4_QUANTIZATION
|
||||
DenseVectorFieldMapper.INT4_QUANTIZATION,
|
||||
DenseVectorFieldMapper.BIT_VECTORS
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,6 +31,7 @@ import org.apache.lucene.index.VectorSimilarityFunction;
|
|||
import org.apache.lucene.search.FieldExistsQuery;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.join.BitSetProducer;
|
||||
import org.apache.lucene.util.BitUtil;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.elasticsearch.common.ParsingException;
|
||||
|
@ -41,6 +42,8 @@ import org.elasticsearch.index.IndexVersions;
|
|||
import org.elasticsearch.index.codec.vectors.ES813FlatVectorFormat;
|
||||
import org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat;
|
||||
import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat;
|
||||
import org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat;
|
||||
import org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat;
|
||||
import org.elasticsearch.index.fielddata.FieldDataContext;
|
||||
import org.elasticsearch.index.fielddata.IndexFieldData;
|
||||
import org.elasticsearch.index.mapper.ArraySourceValueFetcher;
|
||||
|
@ -100,6 +103,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
}
|
||||
|
||||
public static final NodeFeature INT4_QUANTIZATION = new NodeFeature("mapper.vectors.int4_quantization");
|
||||
public static final NodeFeature BIT_VECTORS = new NodeFeature("mapper.vectors.bit_vectors");
|
||||
|
||||
public static final IndexVersion MAGNITUDE_STORED_INDEX_VERSION = IndexVersions.V_7_5_0;
|
||||
public static final IndexVersion INDEXED_BY_DEFAULT_INDEX_VERSION = IndexVersions.FIRST_DETACHED_INDEX_VERSION;
|
||||
|
@ -109,6 +113,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
|
||||
public static final String CONTENT_TYPE = "dense_vector";
|
||||
public static short MAX_DIMS_COUNT = 4096; // maximum allowed number of dimensions
|
||||
public static int MAX_DIMS_COUNT_BIT = 4096 * Byte.SIZE; // maximum allowed number of dimensions
|
||||
|
||||
public static short MIN_DIMS_FOR_DYNAMIC_FLOAT_MAPPING = 128; // minimum number of dims for floats to be dynamically mapped to vector
|
||||
public static final int MAGNITUDE_BYTES = 4;
|
||||
|
@ -134,17 +139,28 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
throw new MapperParsingException("Property [dims] on field [" + n + "] must be an integer but got [" + o + "]");
|
||||
}
|
||||
int dims = XContentMapValues.nodeIntegerValue(o);
|
||||
if (dims < 1 || dims > MAX_DIMS_COUNT) {
|
||||
int maxDims = elementType.getValue() == ElementType.BIT ? MAX_DIMS_COUNT_BIT : MAX_DIMS_COUNT;
|
||||
int minDims = elementType.getValue() == ElementType.BIT ? Byte.SIZE : 1;
|
||||
if (dims < minDims || dims > maxDims) {
|
||||
throw new MapperParsingException(
|
||||
"The number of dimensions for field ["
|
||||
+ n
|
||||
+ "] should be in the range [1, "
|
||||
+ MAX_DIMS_COUNT
|
||||
+ "] should be in the range ["
|
||||
+ minDims
|
||||
+ ", "
|
||||
+ maxDims
|
||||
+ "] but was ["
|
||||
+ dims
|
||||
+ "]"
|
||||
);
|
||||
}
|
||||
if (elementType.getValue() == ElementType.BIT) {
|
||||
if (dims % Byte.SIZE != 0) {
|
||||
throw new MapperParsingException(
|
||||
"The number of dimensions for field [" + n + "] should be a multiple of 8 but was [" + dims + "]"
|
||||
);
|
||||
}
|
||||
}
|
||||
return dims;
|
||||
}, m -> toType(m).fieldType().dims, XContentBuilder::field, Object::toString).setSerializerCheck((id, ic, v) -> v != null)
|
||||
.setMergeValidator((previous, current, c) -> previous == null || Objects.equals(previous, current));
|
||||
|
@ -171,13 +187,27 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
"similarity",
|
||||
false,
|
||||
m -> toType(m).fieldType().similarity,
|
||||
(Supplier<VectorSimilarity>) () -> indexedByDefault && indexed.getValue() ? VectorSimilarity.COSINE : null,
|
||||
(Supplier<VectorSimilarity>) () -> {
|
||||
if (indexedByDefault && indexed.getValue()) {
|
||||
return elementType.getValue() == ElementType.BIT ? VectorSimilarity.L2_NORM : VectorSimilarity.COSINE;
|
||||
}
|
||||
return null;
|
||||
},
|
||||
VectorSimilarity.class
|
||||
).acceptsNull().setSerializerCheck((id, ic, v) -> v != null);
|
||||
).acceptsNull().setSerializerCheck((id, ic, v) -> v != null).addValidator(vectorSim -> {
|
||||
if (vectorSim == null) {
|
||||
return;
|
||||
}
|
||||
if (elementType.getValue() == ElementType.BIT && vectorSim != VectorSimilarity.L2_NORM) {
|
||||
throw new IllegalArgumentException(
|
||||
"The [" + VectorSimilarity.L2_NORM + "] similarity is the only supported similarity for bit vectors"
|
||||
);
|
||||
}
|
||||
});
|
||||
this.indexOptions = new Parameter<>(
|
||||
"index_options",
|
||||
true,
|
||||
() -> defaultInt8Hnsw && elementType.getValue() != ElementType.BYTE && this.indexed.getValue()
|
||||
() -> defaultInt8Hnsw && elementType.getValue() == ElementType.FLOAT && this.indexed.getValue()
|
||||
? new Int8HnswIndexOptions(
|
||||
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
|
||||
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
|
||||
|
@ -266,7 +296,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
|
||||
public enum ElementType {
|
||||
|
||||
BYTE(1) {
|
||||
BYTE {
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
|
@ -371,7 +401,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
}
|
||||
|
||||
@Override
|
||||
public double computeDotProduct(VectorData vectorData) {
|
||||
public double computeSquaredMagnitude(VectorData vectorData) {
|
||||
return VectorUtil.dotProduct(vectorData.asByteVector(), vectorData.asByteVector());
|
||||
}
|
||||
|
||||
|
@ -428,7 +458,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
byte[] decodedVector = HexFormat.of().parseHex(context.parser().text());
|
||||
fieldMapper.checkDimensionMatches(decodedVector.length, context);
|
||||
VectorData vectorData = VectorData.fromBytes(decodedVector);
|
||||
double squaredMagnitude = computeDotProduct(vectorData);
|
||||
double squaredMagnitude = computeSquaredMagnitude(vectorData);
|
||||
checkVectorMagnitude(
|
||||
fieldMapper.fieldType().similarity,
|
||||
errorByteElementsAppender(decodedVector),
|
||||
|
@ -463,7 +493,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
|
||||
@Override
|
||||
int getNumBytes(int dimensions) {
|
||||
return dimensions * elementBytes;
|
||||
return dimensions;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -494,7 +524,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
}
|
||||
},
|
||||
|
||||
FLOAT(4) {
|
||||
FLOAT {
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
|
@ -596,7 +626,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
}
|
||||
|
||||
@Override
|
||||
public double computeDotProduct(VectorData vectorData) {
|
||||
public double computeSquaredMagnitude(VectorData vectorData) {
|
||||
return VectorUtil.dotProduct(vectorData.asFloatVector(), vectorData.asFloatVector());
|
||||
}
|
||||
|
||||
|
@ -656,7 +686,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
|
||||
@Override
|
||||
int getNumBytes(int dimensions) {
|
||||
return dimensions * elementBytes;
|
||||
return dimensions * Float.BYTES;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -665,14 +695,250 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
? ByteBuffer.wrap(new byte[numBytes]).order(ByteOrder.LITTLE_ENDIAN)
|
||||
: ByteBuffer.wrap(new byte[numBytes]);
|
||||
}
|
||||
},
|
||||
|
||||
BIT {
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "bit";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeValue(ByteBuffer byteBuffer, float value) {
|
||||
byteBuffer.put((byte) value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void readAndWriteValue(ByteBuffer byteBuffer, XContentBuilder b) throws IOException {
|
||||
b.value(byteBuffer.get());
|
||||
}
|
||||
|
||||
private KnnByteVectorField createKnnVectorField(String name, byte[] vector, VectorSimilarityFunction function) {
|
||||
if (vector == null) {
|
||||
throw new IllegalArgumentException("vector value must not be null");
|
||||
}
|
||||
FieldType denseVectorFieldType = new FieldType();
|
||||
denseVectorFieldType.setVectorAttributes(vector.length, VectorEncoding.BYTE, function);
|
||||
denseVectorFieldType.freeze();
|
||||
return new KnnByteVectorField(name, vector, denseVectorFieldType);
|
||||
}
|
||||
|
||||
@Override
|
||||
IndexFieldData.Builder fielddataBuilder(DenseVectorFieldType denseVectorFieldType, FieldDataContext fieldDataContext) {
|
||||
return new VectorIndexFieldData.Builder(
|
||||
denseVectorFieldType.name(),
|
||||
CoreValuesSourceType.KEYWORD,
|
||||
denseVectorFieldType.indexVersionCreated,
|
||||
this,
|
||||
denseVectorFieldType.dims,
|
||||
denseVectorFieldType.indexed,
|
||||
r -> r
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkVectorBounds(float[] vector) {
|
||||
checkNanAndInfinite(vector);
|
||||
|
||||
StringBuilder errorBuilder = null;
|
||||
|
||||
for (int index = 0; index < vector.length; ++index) {
|
||||
float value = vector[index];
|
||||
|
||||
if (value % 1.0f != 0.0f) {
|
||||
errorBuilder = new StringBuilder(
|
||||
"element_type ["
|
||||
+ this
|
||||
+ "] vectors only support non-decimal values but found decimal value ["
|
||||
+ value
|
||||
+ "] at dim ["
|
||||
+ index
|
||||
+ "];"
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
|
||||
errorBuilder = new StringBuilder(
|
||||
"element_type ["
|
||||
+ this
|
||||
+ "] vectors only support integers between ["
|
||||
+ Byte.MIN_VALUE
|
||||
+ ", "
|
||||
+ Byte.MAX_VALUE
|
||||
+ "] but found ["
|
||||
+ value
|
||||
+ "] at dim ["
|
||||
+ index
|
||||
+ "];"
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (errorBuilder != null) {
|
||||
throw new IllegalArgumentException(appendErrorElements(errorBuilder, vector).toString());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
void checkVectorMagnitude(
|
||||
VectorSimilarity similarity,
|
||||
Function<StringBuilder, StringBuilder> appender,
|
||||
float squaredMagnitude
|
||||
) {}
|
||||
|
||||
@Override
|
||||
public double computeSquaredMagnitude(VectorData vectorData) {
|
||||
int count = 0;
|
||||
int i = 0;
|
||||
byte[] byteBits = vectorData.asByteVector();
|
||||
for (int upperBound = byteBits.length & -8; i < upperBound; i += 8) {
|
||||
count += Long.bitCount((long) BitUtil.VH_NATIVE_LONG.get(byteBits, i));
|
||||
}
|
||||
|
||||
while (i < byteBits.length) {
|
||||
count += Integer.bitCount(byteBits[i] & 255);
|
||||
++i;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
private VectorData parseVectorArray(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
|
||||
int index = 0;
|
||||
byte[] vector = new byte[fieldMapper.fieldType().dims / Byte.SIZE];
|
||||
for (XContentParser.Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser()
|
||||
.nextToken()) {
|
||||
fieldMapper.checkDimensionExceeded(index, context);
|
||||
ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser());
|
||||
final int value;
|
||||
if (context.parser().numberType() != XContentParser.NumberType.INT) {
|
||||
float floatValue = context.parser().floatValue(true);
|
||||
if (floatValue % 1.0f != 0.0f) {
|
||||
throw new IllegalArgumentException(
|
||||
"element_type ["
|
||||
+ this
|
||||
+ "] vectors only support non-decimal values but found decimal value ["
|
||||
+ floatValue
|
||||
+ "] at dim ["
|
||||
+ index
|
||||
+ "];"
|
||||
);
|
||||
}
|
||||
value = (int) floatValue;
|
||||
} else {
|
||||
value = context.parser().intValue(true);
|
||||
}
|
||||
if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
|
||||
throw new IllegalArgumentException(
|
||||
"element_type ["
|
||||
+ this
|
||||
+ "] vectors only support integers between ["
|
||||
+ Byte.MIN_VALUE
|
||||
+ ", "
|
||||
+ Byte.MAX_VALUE
|
||||
+ "] but found ["
|
||||
+ value
|
||||
+ "] at dim ["
|
||||
+ index
|
||||
+ "];"
|
||||
);
|
||||
}
|
||||
if (index >= vector.length) {
|
||||
throw new IllegalArgumentException(
|
||||
"The number of dimensions for field ["
|
||||
+ fieldMapper.fieldType().name()
|
||||
+ "] should be ["
|
||||
+ fieldMapper.fieldType().dims
|
||||
+ "] but found ["
|
||||
+ (index + 1) * Byte.SIZE
|
||||
+ "]"
|
||||
);
|
||||
}
|
||||
vector[index++] = (byte) value;
|
||||
}
|
||||
fieldMapper.checkDimensionMatches(index * Byte.SIZE, context);
|
||||
return VectorData.fromBytes(vector);
|
||||
}
|
||||
|
||||
private VectorData parseHexEncodedVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
|
||||
byte[] decodedVector = HexFormat.of().parseHex(context.parser().text());
|
||||
fieldMapper.checkDimensionMatches(decodedVector.length * Byte.SIZE, context);
|
||||
return VectorData.fromBytes(decodedVector);
|
||||
}
|
||||
|
||||
@Override
|
||||
VectorData parseKnnVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
|
||||
XContentParser.Token token = context.parser().currentToken();
|
||||
return switch (token) {
|
||||
case START_ARRAY -> parseVectorArray(context, fieldMapper);
|
||||
case VALUE_STRING -> parseHexEncodedVector(context, fieldMapper);
|
||||
default -> throw new ParsingException(
|
||||
context.parser().getTokenLocation(),
|
||||
format("Unsupported type [%s] for provided value [%s]", token, context.parser().text())
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
|
||||
VectorData vectorData = parseKnnVector(context, fieldMapper);
|
||||
Field field = createKnnVectorField(
|
||||
fieldMapper.fieldType().name(),
|
||||
vectorData.asByteVector(),
|
||||
fieldMapper.fieldType().similarity.vectorSimilarityFunction(fieldMapper.indexCreatedVersion, this)
|
||||
);
|
||||
context.doc().addWithKey(fieldMapper.fieldType().name(), field);
|
||||
}
|
||||
|
||||
@Override
|
||||
int getNumBytes(int dimensions) {
|
||||
assert dimensions % Byte.SIZE == 0;
|
||||
return dimensions / Byte.SIZE;
|
||||
}
|
||||
|
||||
@Override
|
||||
ByteBuffer createByteBuffer(IndexVersion indexVersion, int numBytes) {
|
||||
return ByteBuffer.wrap(new byte[numBytes]);
|
||||
}
|
||||
|
||||
@Override
|
||||
int parseDimensionCount(DocumentParserContext context) throws IOException {
|
||||
XContentParser.Token currentToken = context.parser().currentToken();
|
||||
return switch (currentToken) {
|
||||
case START_ARRAY -> {
|
||||
int index = 0;
|
||||
for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) {
|
||||
index++;
|
||||
}
|
||||
yield index * Byte.SIZE;
|
||||
}
|
||||
case VALUE_STRING -> {
|
||||
byte[] decodedVector = HexFormat.of().parseHex(context.parser().text());
|
||||
yield decodedVector.length * Byte.SIZE;
|
||||
}
|
||||
default -> throw new ParsingException(
|
||||
context.parser().getTokenLocation(),
|
||||
format("Unsupported type [%s] for provided value [%s]", currentToken, context.parser().text())
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkDimensions(int dvDims, int qvDims) {
|
||||
if (dvDims != qvDims * Byte.SIZE) {
|
||||
throw new IllegalArgumentException(
|
||||
"The query vector has a different number of dimensions ["
|
||||
+ qvDims * Byte.SIZE
|
||||
+ "] than the document vectors ["
|
||||
+ dvDims
|
||||
+ "]."
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
final int elementBytes;
|
||||
|
||||
ElementType(int elementBytes) {
|
||||
this.elementBytes = elementBytes;
|
||||
}
|
||||
|
||||
public abstract void writeValue(ByteBuffer byteBuffer, float value);
|
||||
|
||||
public abstract void readAndWriteValue(ByteBuffer byteBuffer, XContentBuilder b) throws IOException;
|
||||
|
@ -695,6 +961,14 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
float squaredMagnitude
|
||||
);
|
||||
|
||||
public void checkDimensions(int dvDims, int qvDims) {
|
||||
if (dvDims != qvDims) {
|
||||
throw new IllegalArgumentException(
|
||||
"The query vector has a different number of dimensions [" + qvDims + "] than the document vectors [" + dvDims + "]."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
int parseDimensionCount(DocumentParserContext context) throws IOException {
|
||||
int index = 0;
|
||||
for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) {
|
||||
|
@ -775,7 +1049,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
return sb -> appendErrorElements(sb, vector);
|
||||
}
|
||||
|
||||
public abstract double computeDotProduct(VectorData vectorData);
|
||||
public abstract double computeSquaredMagnitude(VectorData vectorData);
|
||||
|
||||
public static ElementType fromString(String name) {
|
||||
return valueOf(name.trim().toUpperCase(Locale.ROOT));
|
||||
|
@ -786,7 +1060,9 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
ElementType.BYTE.toString(),
|
||||
ElementType.BYTE,
|
||||
ElementType.FLOAT.toString(),
|
||||
ElementType.FLOAT
|
||||
ElementType.FLOAT,
|
||||
ElementType.BIT.toString(),
|
||||
ElementType.BIT
|
||||
);
|
||||
|
||||
public enum VectorSimilarity {
|
||||
|
@ -795,6 +1071,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
float score(float similarity, ElementType elementType, int dim) {
|
||||
return switch (elementType) {
|
||||
case BYTE, FLOAT -> 1f / (1f + similarity * similarity);
|
||||
case BIT -> (dim - similarity) / dim;
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -806,8 +1083,10 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
COSINE {
|
||||
@Override
|
||||
float score(float similarity, ElementType elementType, int dim) {
|
||||
assert elementType != ElementType.BIT;
|
||||
return switch (elementType) {
|
||||
case BYTE, FLOAT -> (1 + similarity) / 2f;
|
||||
default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]");
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -824,6 +1103,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
return switch (elementType) {
|
||||
case BYTE -> 0.5f + similarity / (float) (dim * (1 << 15));
|
||||
case FLOAT -> (1 + similarity) / 2f;
|
||||
default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]");
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -837,6 +1117,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
float score(float similarity, ElementType elementType, int dim) {
|
||||
return switch (elementType) {
|
||||
case BYTE, FLOAT -> similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1;
|
||||
default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]");
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -863,7 +1144,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
this.type = type;
|
||||
}
|
||||
|
||||
abstract KnnVectorsFormat getVectorsFormat();
|
||||
abstract KnnVectorsFormat getVectorsFormat(ElementType elementType);
|
||||
|
||||
boolean supportsElementType(ElementType elementType) {
|
||||
return true;
|
||||
|
@ -1002,7 +1283,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
}
|
||||
|
||||
@Override
|
||||
KnnVectorsFormat getVectorsFormat() {
|
||||
KnnVectorsFormat getVectorsFormat(ElementType elementType) {
|
||||
assert elementType == ElementType.FLOAT;
|
||||
return new ES813Int8FlatVectorFormat(confidenceInterval, 7, false);
|
||||
}
|
||||
|
||||
|
@ -1021,7 +1303,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
|
||||
@Override
|
||||
boolean supportsElementType(ElementType elementType) {
|
||||
return elementType != ElementType.BYTE;
|
||||
return elementType == ElementType.FLOAT;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -1047,7 +1329,10 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
}
|
||||
|
||||
@Override
|
||||
KnnVectorsFormat getVectorsFormat() {
|
||||
KnnVectorsFormat getVectorsFormat(ElementType elementType) {
|
||||
if (elementType.equals(ElementType.BIT)) {
|
||||
return new ES815BitFlatVectorFormat();
|
||||
}
|
||||
return new ES813FlatVectorFormat();
|
||||
}
|
||||
|
||||
|
@ -1083,7 +1368,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
}
|
||||
|
||||
@Override
|
||||
public KnnVectorsFormat getVectorsFormat() {
|
||||
public KnnVectorsFormat getVectorsFormat(ElementType elementType) {
|
||||
assert elementType == ElementType.FLOAT;
|
||||
return new ES814HnswScalarQuantizedVectorsFormat(m, efConstruction, confidenceInterval, 4, true);
|
||||
}
|
||||
|
||||
|
@ -1126,7 +1412,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
|
||||
@Override
|
||||
boolean supportsElementType(ElementType elementType) {
|
||||
return elementType != ElementType.BYTE;
|
||||
return elementType == ElementType.FLOAT;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -1153,7 +1439,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
}
|
||||
|
||||
@Override
|
||||
public KnnVectorsFormat getVectorsFormat() {
|
||||
public KnnVectorsFormat getVectorsFormat(ElementType elementType) {
|
||||
assert elementType == ElementType.FLOAT;
|
||||
return new ES813Int8FlatVectorFormat(confidenceInterval, 4, true);
|
||||
}
|
||||
|
||||
|
@ -1186,7 +1473,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
|
||||
@Override
|
||||
boolean supportsElementType(ElementType elementType) {
|
||||
return elementType != ElementType.BYTE;
|
||||
return elementType == ElementType.FLOAT;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -1216,7 +1503,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
}
|
||||
|
||||
@Override
|
||||
public KnnVectorsFormat getVectorsFormat() {
|
||||
public KnnVectorsFormat getVectorsFormat(ElementType elementType) {
|
||||
assert elementType == ElementType.FLOAT;
|
||||
return new ES814HnswScalarQuantizedVectorsFormat(m, efConstruction, confidenceInterval, 7, false);
|
||||
}
|
||||
|
||||
|
@ -1261,7 +1549,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
|
||||
@Override
|
||||
boolean supportsElementType(ElementType elementType) {
|
||||
return elementType != ElementType.BYTE;
|
||||
return elementType == ElementType.FLOAT;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -1291,7 +1579,10 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
}
|
||||
|
||||
@Override
|
||||
public KnnVectorsFormat getVectorsFormat() {
|
||||
public KnnVectorsFormat getVectorsFormat(ElementType elementType) {
|
||||
if (elementType == ElementType.BIT) {
|
||||
return new ES815HnswBitVectorsFormat(m, efConstruction);
|
||||
}
|
||||
return new Lucene99HnswVectorsFormat(m, efConstruction, 1, null);
|
||||
}
|
||||
|
||||
|
@ -1412,48 +1703,6 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] doesn't support term queries");
|
||||
}
|
||||
|
||||
public Query createKnnQuery(
|
||||
byte[] queryVector,
|
||||
int numCands,
|
||||
Query filter,
|
||||
Float similarityThreshold,
|
||||
BitSetProducer parentFilter
|
||||
) {
|
||||
if (isIndexed() == false) {
|
||||
throw new IllegalArgumentException(
|
||||
"to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]"
|
||||
);
|
||||
}
|
||||
|
||||
if (queryVector.length != dims) {
|
||||
throw new IllegalArgumentException(
|
||||
"the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]"
|
||||
);
|
||||
}
|
||||
|
||||
if (elementType != ElementType.BYTE) {
|
||||
throw new IllegalArgumentException(
|
||||
"only [" + ElementType.BYTE + "] elements are supported when querying field [" + name() + "]"
|
||||
);
|
||||
}
|
||||
|
||||
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
|
||||
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
|
||||
elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude);
|
||||
}
|
||||
Query knnQuery = parentFilter != null
|
||||
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, numCands, parentFilter)
|
||||
: new ESKnnByteVectorQuery(name(), queryVector, numCands, filter);
|
||||
if (similarityThreshold != null) {
|
||||
knnQuery = new VectorSimilarityQuery(
|
||||
knnQuery,
|
||||
similarityThreshold,
|
||||
similarity.score(similarityThreshold, elementType, dims)
|
||||
);
|
||||
}
|
||||
return knnQuery;
|
||||
}
|
||||
|
||||
public Query createExactKnnQuery(VectorData queryVector) {
|
||||
if (isIndexed() == false) {
|
||||
throw new IllegalArgumentException(
|
||||
|
@ -1463,15 +1712,17 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
return switch (elementType) {
|
||||
case BYTE -> createExactKnnByteQuery(queryVector.asByteVector());
|
||||
case FLOAT -> createExactKnnFloatQuery(queryVector.asFloatVector());
|
||||
case BIT -> createExactKnnBitQuery(queryVector.asByteVector());
|
||||
};
|
||||
}
|
||||
|
||||
private Query createExactKnnBitQuery(byte[] queryVector) {
|
||||
elementType.checkDimensions(dims, queryVector.length);
|
||||
return new DenseVectorQuery.Bytes(queryVector, name());
|
||||
}
|
||||
|
||||
private Query createExactKnnByteQuery(byte[] queryVector) {
|
||||
if (queryVector.length != dims) {
|
||||
throw new IllegalArgumentException(
|
||||
"the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]"
|
||||
);
|
||||
}
|
||||
elementType.checkDimensions(dims, queryVector.length);
|
||||
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
|
||||
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
|
||||
elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude);
|
||||
|
@ -1480,11 +1731,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
}
|
||||
|
||||
private Query createExactKnnFloatQuery(float[] queryVector) {
|
||||
if (queryVector.length != dims) {
|
||||
throw new IllegalArgumentException(
|
||||
"the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]"
|
||||
);
|
||||
}
|
||||
elementType.checkDimensions(dims, queryVector.length);
|
||||
elementType.checkVectorBounds(queryVector);
|
||||
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
|
||||
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
|
||||
|
@ -1521,9 +1768,31 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
return switch (getElementType()) {
|
||||
case BYTE -> createKnnByteQuery(queryVector.asByteVector(), numCands, filter, similarityThreshold, parentFilter);
|
||||
case FLOAT -> createKnnFloatQuery(queryVector.asFloatVector(), numCands, filter, similarityThreshold, parentFilter);
|
||||
case BIT -> createKnnBitQuery(queryVector.asByteVector(), numCands, filter, similarityThreshold, parentFilter);
|
||||
};
|
||||
}
|
||||
|
||||
private Query createKnnBitQuery(
|
||||
byte[] queryVector,
|
||||
int numCands,
|
||||
Query filter,
|
||||
Float similarityThreshold,
|
||||
BitSetProducer parentFilter
|
||||
) {
|
||||
elementType.checkDimensions(dims, queryVector.length);
|
||||
Query knnQuery = parentFilter != null
|
||||
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, numCands, parentFilter)
|
||||
: new ESKnnByteVectorQuery(name(), queryVector, numCands, filter);
|
||||
if (similarityThreshold != null) {
|
||||
knnQuery = new VectorSimilarityQuery(
|
||||
knnQuery,
|
||||
similarityThreshold,
|
||||
similarity.score(similarityThreshold, elementType, dims)
|
||||
);
|
||||
}
|
||||
return knnQuery;
|
||||
}
|
||||
|
||||
private Query createKnnByteQuery(
|
||||
byte[] queryVector,
|
||||
int numCands,
|
||||
|
@ -1531,11 +1800,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
Float similarityThreshold,
|
||||
BitSetProducer parentFilter
|
||||
) {
|
||||
if (queryVector.length != dims) {
|
||||
throw new IllegalArgumentException(
|
||||
"the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]"
|
||||
);
|
||||
}
|
||||
elementType.checkDimensions(dims, queryVector.length);
|
||||
|
||||
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
|
||||
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
|
||||
|
@ -1561,11 +1826,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
Float similarityThreshold,
|
||||
BitSetProducer parentFilter
|
||||
) {
|
||||
if (queryVector.length != dims) {
|
||||
throw new IllegalArgumentException(
|
||||
"the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]"
|
||||
);
|
||||
}
|
||||
elementType.checkDimensions(dims, queryVector.length);
|
||||
elementType.checkVectorBounds(queryVector);
|
||||
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
|
||||
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
|
||||
|
@ -1701,7 +1962,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
vectorData.addToBuffer(byteBuffer);
|
||||
if (indexCreatedVersion.onOrAfter(MAGNITUDE_STORED_INDEX_VERSION)) {
|
||||
// encode vector magnitude at the end
|
||||
double dotProduct = elementType.computeDotProduct(vectorData);
|
||||
double dotProduct = elementType.computeSquaredMagnitude(vectorData);
|
||||
float vectorMagnitude = (float) Math.sqrt(dotProduct);
|
||||
byteBuffer.putFloat(vectorMagnitude);
|
||||
}
|
||||
|
@ -1780,9 +2041,9 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
public KnnVectorsFormat getKnnVectorsFormatForField(KnnVectorsFormat defaultFormat) {
|
||||
final KnnVectorsFormat format;
|
||||
if (indexOptions == null) {
|
||||
format = defaultFormat;
|
||||
format = fieldType().elementType == ElementType.BIT ? new ES815HnswBitVectorsFormat() : defaultFormat;
|
||||
} else {
|
||||
format = indexOptions.getVectorsFormat();
|
||||
format = indexOptions.getVectorsFormat(fieldType().elementType);
|
||||
}
|
||||
// It's legal to reuse the same format name as this is the same on-disk format.
|
||||
return new KnnVectorsFormat(format.getName()) {
|
||||
|
|
|
@ -17,6 +17,8 @@ import org.elasticsearch.index.fielddata.SortedBinaryDocValues;
|
|||
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
|
||||
import org.elasticsearch.script.field.DocValuesScriptFieldFactory;
|
||||
import org.elasticsearch.script.field.vectors.BinaryDenseVectorDocValuesField;
|
||||
import org.elasticsearch.script.field.vectors.BitBinaryDenseVectorDocValuesField;
|
||||
import org.elasticsearch.script.field.vectors.BitKnnDenseVectorDocValuesField;
|
||||
import org.elasticsearch.script.field.vectors.ByteBinaryDenseVectorDocValuesField;
|
||||
import org.elasticsearch.script.field.vectors.ByteKnnDenseVectorDocValuesField;
|
||||
import org.elasticsearch.script.field.vectors.KnnDenseVectorDocValuesField;
|
||||
|
@ -58,12 +60,14 @@ final class VectorDVLeafFieldData implements LeafFieldData {
|
|||
return switch (elementType) {
|
||||
case BYTE -> new ByteKnnDenseVectorDocValuesField(reader.getByteVectorValues(field), name, dims);
|
||||
case FLOAT -> new KnnDenseVectorDocValuesField(reader.getFloatVectorValues(field), name, dims);
|
||||
case BIT -> new BitKnnDenseVectorDocValuesField(reader.getByteVectorValues(field), name, dims);
|
||||
};
|
||||
} else {
|
||||
BinaryDocValues values = DocValues.getBinary(reader, field);
|
||||
return switch (elementType) {
|
||||
case BYTE -> new ByteBinaryDenseVectorDocValuesField(values, name, elementType, dims);
|
||||
case FLOAT -> new BinaryDenseVectorDocValuesField(values, name, elementType, dims, indexVersion);
|
||||
case BIT -> new BitBinaryDenseVectorDocValuesField(values, name, elementType, dims);
|
||||
};
|
||||
}
|
||||
} catch (IOException e) {
|
||||
|
|
|
@ -56,7 +56,7 @@ public class VectorScoreScriptUtils {
|
|||
*/
|
||||
public ByteDenseVectorFunction(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
|
||||
super(scoreScript, field);
|
||||
DenseVector.checkDimensions(field.get().getDims(), queryVector.size());
|
||||
field.getElementType().checkDimensions(field.get().getDims(), queryVector.size());
|
||||
this.queryVector = new byte[queryVector.size()];
|
||||
float[] validateValues = new float[queryVector.size()];
|
||||
int queryMagnitude = 0;
|
||||
|
@ -168,7 +168,7 @@ public class VectorScoreScriptUtils {
|
|||
public L1Norm(ScoreScript scoreScript, Object queryVector, String fieldName) {
|
||||
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
|
||||
function = switch (field.getElementType()) {
|
||||
case BYTE -> {
|
||||
case BYTE, BIT -> {
|
||||
if (queryVector instanceof List) {
|
||||
yield new ByteL1Norm(scoreScript, field, (List<Number>) queryVector);
|
||||
} else if (queryVector instanceof String s) {
|
||||
|
@ -219,8 +219,8 @@ public class VectorScoreScriptUtils {
|
|||
@SuppressWarnings("unchecked")
|
||||
public Hamming(ScoreScript scoreScript, Object queryVector, String fieldName) {
|
||||
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
|
||||
if (field.getElementType() != DenseVectorFieldMapper.ElementType.BYTE) {
|
||||
throw new IllegalArgumentException("hamming distance is only supported for byte vectors");
|
||||
if (field.getElementType() == DenseVectorFieldMapper.ElementType.FLOAT) {
|
||||
throw new IllegalArgumentException("hamming distance is only supported for byte or bit vectors");
|
||||
}
|
||||
if (queryVector instanceof List) {
|
||||
function = new ByteHammingDistance(scoreScript, field, (List<Number>) queryVector);
|
||||
|
@ -278,7 +278,7 @@ public class VectorScoreScriptUtils {
|
|||
public L2Norm(ScoreScript scoreScript, Object queryVector, String fieldName) {
|
||||
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
|
||||
function = switch (field.getElementType()) {
|
||||
case BYTE -> {
|
||||
case BYTE, BIT -> {
|
||||
if (queryVector instanceof List) {
|
||||
yield new ByteL2Norm(scoreScript, field, (List<Number>) queryVector);
|
||||
} else if (queryVector instanceof String s) {
|
||||
|
@ -342,7 +342,7 @@ public class VectorScoreScriptUtils {
|
|||
public DotProduct(ScoreScript scoreScript, Object queryVector, String fieldName) {
|
||||
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
|
||||
function = switch (field.getElementType()) {
|
||||
case BYTE -> {
|
||||
case BYTE, BIT -> {
|
||||
if (queryVector instanceof List) {
|
||||
yield new ByteDotProduct(scoreScript, field, (List<Number>) queryVector);
|
||||
} else if (queryVector instanceof String s) {
|
||||
|
@ -406,7 +406,7 @@ public class VectorScoreScriptUtils {
|
|||
public CosineSimilarity(ScoreScript scoreScript, Object queryVector, String fieldName) {
|
||||
DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
|
||||
function = switch (field.getElementType()) {
|
||||
case BYTE -> {
|
||||
case BYTE, BIT -> {
|
||||
if (queryVector instanceof List) {
|
||||
yield new ByteCosineSimilarity(scoreScript, field, (List<Number>) queryVector);
|
||||
} else if (queryVector instanceof String s) {
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0 and the Server Side Public License, v 1; you may not use this file except
|
||||
* in compliance with, at your election, the Elastic License 2.0 or the Server
|
||||
* Side Public License, v 1.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.script.field.vectors;
|
||||
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class BitBinaryDenseVector extends ByteBinaryDenseVector {
|
||||
|
||||
public BitBinaryDenseVector(byte[] vectorValue, BytesRef docVector, int dims) {
|
||||
super(vectorValue, docVector, dims);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkDimensions(int qvDims) {
|
||||
if (qvDims != dims) {
|
||||
throw new IllegalArgumentException(
|
||||
"The query vector has a different number of dimensions ["
|
||||
+ qvDims * Byte.SIZE
|
||||
+ "] than the document vectors ["
|
||||
+ dims * Byte.SIZE
|
||||
+ "]."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int l1Norm(byte[] queryVector) {
|
||||
return hamming(queryVector);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double l1Norm(List<Number> queryVector) {
|
||||
return hamming(queryVector);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double l2Norm(byte[] queryVector) {
|
||||
return Math.sqrt(hamming(queryVector));
|
||||
}
|
||||
|
||||
@Override
|
||||
public double l2Norm(List<Number> queryVector) {
|
||||
return Math.sqrt(hamming(queryVector));
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dotProduct(byte[] queryVector) {
|
||||
throw new UnsupportedOperationException("dotProduct is not supported for bit vectors.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public double cosineSimilarity(float[] queryVector, boolean normalizeQueryVector) {
|
||||
throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public double dotProduct(List<Number> queryVector) {
|
||||
throw new UnsupportedOperationException("dotProduct is not supported for bit vectors.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public double cosineSimilarity(byte[] queryVector, float qvMagnitude) {
|
||||
throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public double cosineSimilarity(List<Number> queryVector) {
|
||||
throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public double dotProduct(float[] queryVector) {
|
||||
throw new UnsupportedOperationException("dotProduct is not supported for bit vectors.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getDims() {
|
||||
return dims * Byte.SIZE;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0 and the Server Side Public License, v 1; you may not use this file except
|
||||
* in compliance with, at your election, the Elastic License 2.0 or the Server
|
||||
* Side Public License, v 1.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.script.field.vectors;
|
||||
|
||||
import org.apache.lucene.index.BinaryDocValues;
|
||||
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
|
||||
|
||||
public class BitBinaryDenseVectorDocValuesField extends ByteBinaryDenseVectorDocValuesField {
|
||||
|
||||
public BitBinaryDenseVectorDocValuesField(BinaryDocValues input, String name, ElementType elementType, int dims) {
|
||||
super(input, name, elementType, dims / 8);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected DenseVector getVector() {
|
||||
return new BitBinaryDenseVector(vectorValue, value, dims);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,95 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0 and the Server Side Public License, v 1; you may not use this file except
|
||||
* in compliance with, at your election, the Elastic License 2.0 or the Server
|
||||
* Side Public License, v 1.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.script.field.vectors;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class BitKnnDenseVector extends ByteKnnDenseVector {
|
||||
|
||||
public BitKnnDenseVector(byte[] vector) {
|
||||
super(vector);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkDimensions(int qvDims) {
|
||||
if (qvDims != docVector.length) {
|
||||
throw new IllegalArgumentException(
|
||||
"The query vector has a different number of dimensions ["
|
||||
+ qvDims * Byte.SIZE
|
||||
+ "] than the document vectors ["
|
||||
+ docVector.length * Byte.SIZE
|
||||
+ "]."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public float getMagnitude() {
|
||||
if (magnitudeCalculated == false) {
|
||||
magnitude = DenseVector.getBitMagnitude(docVector, docVector.length);
|
||||
magnitudeCalculated = true;
|
||||
}
|
||||
return magnitude;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int l1Norm(byte[] queryVector) {
|
||||
return hamming(queryVector);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double l1Norm(List<Number> queryVector) {
|
||||
return hamming(queryVector);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double l2Norm(byte[] queryVector) {
|
||||
return Math.sqrt(hamming(queryVector));
|
||||
}
|
||||
|
||||
@Override
|
||||
public double l2Norm(List<Number> queryVector) {
|
||||
return Math.sqrt(hamming(queryVector));
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dotProduct(byte[] queryVector) {
|
||||
throw new UnsupportedOperationException("dotProduct is not supported for bit vectors.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public double cosineSimilarity(float[] queryVector, boolean normalizeQueryVector) {
|
||||
throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public double dotProduct(List<Number> queryVector) {
|
||||
throw new UnsupportedOperationException("dotProduct is not supported for bit vectors.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public double cosineSimilarity(byte[] queryVector, float qvMagnitude) {
|
||||
throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public double cosineSimilarity(List<Number> queryVector) {
|
||||
throw new UnsupportedOperationException("cosineSimilarity is not supported for bit vectors.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public double dotProduct(float[] queryVector) {
|
||||
throw new UnsupportedOperationException("dotProduct is not supported for bit vectors.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getDims() {
|
||||
return docVector.length * Byte.SIZE;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0 and the Server Side Public License, v 1; you may not use this file except
|
||||
* in compliance with, at your election, the Elastic License 2.0 or the Server
|
||||
* Side Public License, v 1.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.script.field.vectors;
|
||||
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
|
||||
|
||||
public class BitKnnDenseVectorDocValuesField extends ByteKnnDenseVectorDocValuesField {
|
||||
|
||||
public BitKnnDenseVectorDocValuesField(@Nullable ByteVectorValues input, String name, int dims) {
|
||||
super(input, name, dims / 8, DenseVectorFieldMapper.ElementType.BIT);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected DenseVector getVector() {
|
||||
return new BitKnnDenseVector(vector);
|
||||
}
|
||||
|
||||
}
|
|
@ -21,7 +21,7 @@ public class ByteBinaryDenseVector implements DenseVector {
|
|||
|
||||
private final BytesRef docVector;
|
||||
private final byte[] vectorValue;
|
||||
private final int dims;
|
||||
protected final int dims;
|
||||
|
||||
private float[] floatDocVector;
|
||||
private boolean magnitudeDecoded;
|
||||
|
|
|
@ -17,11 +17,11 @@ import java.io.IOException;
|
|||
|
||||
public class ByteBinaryDenseVectorDocValuesField extends DenseVectorDocValuesField {
|
||||
|
||||
private final BinaryDocValues input;
|
||||
private final int dims;
|
||||
private final byte[] vectorValue;
|
||||
private boolean decoded;
|
||||
private BytesRef value;
|
||||
protected final BinaryDocValues input;
|
||||
protected final int dims;
|
||||
protected final byte[] vectorValue;
|
||||
protected boolean decoded;
|
||||
protected BytesRef value;
|
||||
|
||||
public ByteBinaryDenseVectorDocValuesField(BinaryDocValues input, String name, ElementType elementType, int dims) {
|
||||
super(name, elementType);
|
||||
|
@ -50,13 +50,17 @@ public class ByteBinaryDenseVectorDocValuesField extends DenseVectorDocValuesFie
|
|||
return value == null;
|
||||
}
|
||||
|
||||
protected DenseVector getVector() {
|
||||
return new ByteBinaryDenseVector(vectorValue, value, dims);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DenseVector get() {
|
||||
if (isEmpty()) {
|
||||
return DenseVector.EMPTY;
|
||||
}
|
||||
decodeVectorIfNecessary();
|
||||
return new ByteBinaryDenseVector(vectorValue, value, dims);
|
||||
return getVector();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -65,7 +69,7 @@ public class ByteBinaryDenseVectorDocValuesField extends DenseVectorDocValuesFie
|
|||
return defaultValue;
|
||||
}
|
||||
decodeVectorIfNecessary();
|
||||
return new ByteBinaryDenseVector(vectorValue, value, dims);
|
||||
return getVector();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -23,7 +23,11 @@ public class ByteKnnDenseVectorDocValuesField extends DenseVectorDocValuesField
|
|||
protected final int dims;
|
||||
|
||||
public ByteKnnDenseVectorDocValuesField(@Nullable ByteVectorValues input, String name, int dims) {
|
||||
super(name, ElementType.BYTE);
|
||||
this(input, name, dims, ElementType.BYTE);
|
||||
}
|
||||
|
||||
protected ByteKnnDenseVectorDocValuesField(@Nullable ByteVectorValues input, String name, int dims, ElementType elementType) {
|
||||
super(name, elementType);
|
||||
this.dims = dims;
|
||||
this.input = input;
|
||||
}
|
||||
|
@ -57,13 +61,17 @@ public class ByteKnnDenseVectorDocValuesField extends DenseVectorDocValuesField
|
|||
return vector == null;
|
||||
}
|
||||
|
||||
protected DenseVector getVector() {
|
||||
return new ByteKnnDenseVector(vector);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DenseVector get() {
|
||||
if (isEmpty()) {
|
||||
return DenseVector.EMPTY;
|
||||
}
|
||||
|
||||
return new ByteKnnDenseVector(vector);
|
||||
return getVector();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -72,7 +80,7 @@ public class ByteKnnDenseVectorDocValuesField extends DenseVectorDocValuesField
|
|||
return defaultValue;
|
||||
}
|
||||
|
||||
return new ByteKnnDenseVector(vector);
|
||||
return getVector();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
package org.elasticsearch.script.field.vectors;
|
||||
|
||||
import org.apache.lucene.util.BitUtil;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
|
||||
import java.util.List;
|
||||
|
@ -25,6 +26,10 @@ import java.util.List;
|
|||
*/
|
||||
public interface DenseVector {
|
||||
|
||||
default void checkDimensions(int qvDims) {
|
||||
checkDimensions(getDims(), qvDims);
|
||||
}
|
||||
|
||||
float[] getVector();
|
||||
|
||||
float getMagnitude();
|
||||
|
@ -38,13 +43,13 @@ public interface DenseVector {
|
|||
@SuppressWarnings("unchecked")
|
||||
default double dotProduct(Object queryVector) {
|
||||
if (queryVector instanceof float[] floats) {
|
||||
checkDimensions(getDims(), floats.length);
|
||||
checkDimensions(floats.length);
|
||||
return dotProduct(floats);
|
||||
} else if (queryVector instanceof List<?> list) {
|
||||
checkDimensions(getDims(), list.size());
|
||||
checkDimensions(list.size());
|
||||
return dotProduct((List<Number>) list);
|
||||
} else if (queryVector instanceof byte[] bytes) {
|
||||
checkDimensions(getDims(), bytes.length);
|
||||
checkDimensions(bytes.length);
|
||||
return dotProduct(bytes);
|
||||
}
|
||||
|
||||
|
@ -60,13 +65,13 @@ public interface DenseVector {
|
|||
@SuppressWarnings("unchecked")
|
||||
default double l1Norm(Object queryVector) {
|
||||
if (queryVector instanceof float[] floats) {
|
||||
checkDimensions(getDims(), floats.length);
|
||||
checkDimensions(floats.length);
|
||||
return l1Norm(floats);
|
||||
} else if (queryVector instanceof List<?> list) {
|
||||
checkDimensions(getDims(), list.size());
|
||||
checkDimensions(list.size());
|
||||
return l1Norm((List<Number>) list);
|
||||
} else if (queryVector instanceof byte[] bytes) {
|
||||
checkDimensions(getDims(), bytes.length);
|
||||
checkDimensions(bytes.length);
|
||||
return l1Norm(bytes);
|
||||
}
|
||||
|
||||
|
@ -80,11 +85,11 @@ public interface DenseVector {
|
|||
@SuppressWarnings("unchecked")
|
||||
default int hamming(Object queryVector) {
|
||||
if (queryVector instanceof List<?> list) {
|
||||
checkDimensions(getDims(), list.size());
|
||||
checkDimensions(list.size());
|
||||
return hamming((List<Number>) list);
|
||||
}
|
||||
if (queryVector instanceof byte[] bytes) {
|
||||
checkDimensions(getDims(), bytes.length);
|
||||
checkDimensions(bytes.length);
|
||||
return hamming(bytes);
|
||||
}
|
||||
|
||||
|
@ -100,13 +105,13 @@ public interface DenseVector {
|
|||
@SuppressWarnings("unchecked")
|
||||
default double l2Norm(Object queryVector) {
|
||||
if (queryVector instanceof float[] floats) {
|
||||
checkDimensions(getDims(), floats.length);
|
||||
checkDimensions(floats.length);
|
||||
return l2Norm(floats);
|
||||
} else if (queryVector instanceof List<?> list) {
|
||||
checkDimensions(getDims(), list.size());
|
||||
checkDimensions(list.size());
|
||||
return l2Norm((List<Number>) list);
|
||||
} else if (queryVector instanceof byte[] bytes) {
|
||||
checkDimensions(getDims(), bytes.length);
|
||||
checkDimensions(bytes.length);
|
||||
return l2Norm(bytes);
|
||||
}
|
||||
|
||||
|
@ -150,13 +155,13 @@ public interface DenseVector {
|
|||
@SuppressWarnings("unchecked")
|
||||
default double cosineSimilarity(Object queryVector) {
|
||||
if (queryVector instanceof float[] floats) {
|
||||
checkDimensions(getDims(), floats.length);
|
||||
checkDimensions(floats.length);
|
||||
return cosineSimilarity(floats);
|
||||
} else if (queryVector instanceof List<?> list) {
|
||||
checkDimensions(getDims(), list.size());
|
||||
checkDimensions(list.size());
|
||||
return cosineSimilarity((List<Number>) list);
|
||||
} else if (queryVector instanceof byte[] bytes) {
|
||||
checkDimensions(getDims(), bytes.length);
|
||||
checkDimensions(bytes.length);
|
||||
return cosineSimilarity(bytes);
|
||||
}
|
||||
|
||||
|
@ -184,6 +189,20 @@ public interface DenseVector {
|
|||
return (float) Math.sqrt(mag);
|
||||
}
|
||||
|
||||
static float getBitMagnitude(byte[] vector, int dims) {
|
||||
int count = 0;
|
||||
int i = 0;
|
||||
for (int upperBound = dims & -8; i < upperBound; i += 8) {
|
||||
count += Long.bitCount((long) BitUtil.VH_NATIVE_LONG.get(vector, i));
|
||||
}
|
||||
|
||||
while (i < dims) {
|
||||
count += Integer.bitCount(vector[i] & 255);
|
||||
++i;
|
||||
}
|
||||
return (float) Math.sqrt(count);
|
||||
}
|
||||
|
||||
static float getMagnitude(float[] vector) {
|
||||
return (float) Math.sqrt(VectorUtil.dotProduct(vector, vector));
|
||||
}
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
org.elasticsearch.index.codec.vectors.ES813FlatVectorFormat
|
||||
org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat
|
||||
org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat
|
||||
org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat
|
||||
org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat
|
||||
|
|
|
@ -0,0 +1,149 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0 and the Server Side Public License, v 1; you may not use this file except
|
||||
* in compliance with, at your election, the Elastic License 2.0 or the Server
|
||||
* Side Public License, v 1.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.codec.vectors;
|
||||
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.KnnByteVectorField;
|
||||
import org.apache.lucene.document.NumericDocValuesField;
|
||||
import org.apache.lucene.document.StringField;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.StoredFields;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.Sort;
|
||||
import org.apache.lucene.search.SortField;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.tests.index.BaseIndexFileFormatTestCase;
|
||||
import org.elasticsearch.common.logging.LogConfigurator;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
abstract class BaseKnnBitVectorsFormatTestCase extends BaseIndexFileFormatTestCase {
|
||||
|
||||
static {
|
||||
LogConfigurator.loadLog4jPlugins();
|
||||
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void addRandomFields(Document doc) {
|
||||
doc.add(new KnnByteVectorField("v2", randomVector(30), similarityFunction));
|
||||
}
|
||||
|
||||
protected VectorSimilarityFunction similarityFunction;
|
||||
|
||||
protected VectorSimilarityFunction randomSimilarity() {
|
||||
return VectorSimilarityFunction.values()[random().nextInt(VectorSimilarityFunction.values().length)];
|
||||
}
|
||||
|
||||
byte[] randomVector(int dims) {
|
||||
byte[] vector = new byte[dims];
|
||||
random().nextBytes(vector);
|
||||
return vector;
|
||||
}
|
||||
|
||||
public void testRandom() throws Exception {
|
||||
IndexWriterConfig iwc = newIndexWriterConfig();
|
||||
if (random().nextBoolean()) {
|
||||
iwc.setIndexSort(new Sort(new SortField("sortkey", SortField.Type.INT)));
|
||||
}
|
||||
String fieldName = "field";
|
||||
try (Directory dir = newDirectory(); IndexWriter iw = new IndexWriter(dir, iwc)) {
|
||||
int numDoc = atLeast(100);
|
||||
int dimension = atLeast(10);
|
||||
if (dimension % 2 != 0) {
|
||||
dimension++;
|
||||
}
|
||||
byte[] scratch = new byte[dimension];
|
||||
int numValues = 0;
|
||||
byte[][] values = new byte[numDoc][];
|
||||
for (int i = 0; i < numDoc; i++) {
|
||||
if (random().nextInt(7) != 3) {
|
||||
// usually index a vector value for a doc
|
||||
values[i] = randomVector(dimension);
|
||||
++numValues;
|
||||
}
|
||||
if (random().nextBoolean() && values[i] != null) {
|
||||
// sometimes use a shared scratch array
|
||||
System.arraycopy(values[i], 0, scratch, 0, scratch.length);
|
||||
add(iw, fieldName, i, scratch, similarityFunction);
|
||||
} else {
|
||||
add(iw, fieldName, i, values[i], similarityFunction);
|
||||
}
|
||||
if (random().nextInt(10) == 2) {
|
||||
// sometimes delete a random document
|
||||
int idToDelete = random().nextInt(i + 1);
|
||||
iw.deleteDocuments(new Term("id", Integer.toString(idToDelete)));
|
||||
// and remember that it was deleted
|
||||
if (values[idToDelete] != null) {
|
||||
values[idToDelete] = null;
|
||||
--numValues;
|
||||
}
|
||||
}
|
||||
if (random().nextInt(10) == 3) {
|
||||
iw.commit();
|
||||
}
|
||||
}
|
||||
int numDeletes = 0;
|
||||
try (IndexReader reader = DirectoryReader.open(iw)) {
|
||||
int valueCount = 0, totalSize = 0;
|
||||
for (LeafReaderContext ctx : reader.leaves()) {
|
||||
ByteVectorValues vectorValues = ctx.reader().getByteVectorValues(fieldName);
|
||||
if (vectorValues == null) {
|
||||
continue;
|
||||
}
|
||||
totalSize += vectorValues.size();
|
||||
StoredFields storedFields = ctx.reader().storedFields();
|
||||
int docId;
|
||||
while ((docId = vectorValues.nextDoc()) != NO_MORE_DOCS) {
|
||||
byte[] v = vectorValues.vectorValue();
|
||||
assertEquals(dimension, v.length);
|
||||
String idString = storedFields.document(docId).getField("id").stringValue();
|
||||
int id = Integer.parseInt(idString);
|
||||
if (ctx.reader().getLiveDocs() == null || ctx.reader().getLiveDocs().get(docId)) {
|
||||
assertArrayEquals(idString, values[id], v);
|
||||
++valueCount;
|
||||
} else {
|
||||
++numDeletes;
|
||||
assertNull(values[id]);
|
||||
}
|
||||
}
|
||||
}
|
||||
assertEquals(numValues, valueCount);
|
||||
assertEquals(numValues, totalSize - numDeletes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void add(IndexWriter iw, String field, int id, byte[] vector, VectorSimilarityFunction similarity) throws IOException {
|
||||
add(iw, field, id, random().nextInt(100), vector, similarity);
|
||||
}
|
||||
|
||||
private void add(IndexWriter iw, String field, int id, int sortKey, byte[] vector, VectorSimilarityFunction similarityFunction)
|
||||
throws IOException {
|
||||
Document doc = new Document();
|
||||
if (vector != null) {
|
||||
doc.add(new KnnByteVectorField(field, vector, similarityFunction));
|
||||
}
|
||||
doc.add(new NumericDocValuesField("sortkey", sortKey));
|
||||
String idString = Integer.toString(id);
|
||||
doc.add(new StringField("id", idString, Field.Store.YES));
|
||||
Term idTerm = new Term("id", idString);
|
||||
iw.updateDocument(idTerm, doc);
|
||||
}
|
||||
|
||||
}
|
|
@ -12,8 +12,15 @@ import org.apache.lucene.codecs.Codec;
|
|||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
|
||||
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
|
||||
import org.elasticsearch.common.logging.LogConfigurator;
|
||||
|
||||
public class ES813FlatVectorFormatTests extends BaseKnnVectorsFormatTestCase {
|
||||
|
||||
static {
|
||||
LogConfigurator.loadLog4jPlugins();
|
||||
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Codec getCodec() {
|
||||
return new Lucene99Codec() {
|
||||
|
|
|
@ -12,8 +12,15 @@ import org.apache.lucene.codecs.Codec;
|
|||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
|
||||
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
|
||||
import org.elasticsearch.common.logging.LogConfigurator;
|
||||
|
||||
public class ES813Int8FlatVectorFormatTests extends BaseKnnVectorsFormatTestCase {
|
||||
|
||||
static {
|
||||
LogConfigurator.loadLog4jPlugins();
|
||||
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Codec getCodec() {
|
||||
return new Lucene99Codec() {
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0 and the Server Side Public License, v 1; you may not use this file except
|
||||
* in compliance with, at your election, the Elastic License 2.0 or the Server
|
||||
* Side Public License, v 1.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.codec.vectors;
|
||||
|
||||
import org.apache.lucene.codecs.Codec;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.junit.Before;
|
||||
|
||||
public class ES815BitFlatVectorFormatTests extends BaseKnnBitVectorsFormatTestCase {
|
||||
|
||||
@Override
|
||||
protected Codec getCodec() {
|
||||
return new Lucene99Codec() {
|
||||
@Override
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
return new ES815BitFlatVectorFormat();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Before
|
||||
public void init() {
|
||||
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0 and the Server Side Public License, v 1; you may not use this file except
|
||||
* in compliance with, at your election, the Elastic License 2.0 or the Server
|
||||
* Side Public License, v 1.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.codec.vectors;
|
||||
|
||||
import org.apache.lucene.codecs.Codec;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.junit.Before;
|
||||
|
||||
public class ES815HnswBitVectorsFormatTests extends BaseKnnBitVectorsFormatTestCase {
|
||||
|
||||
@Override
|
||||
protected Codec getCodec() {
|
||||
return new Lucene99Codec() {
|
||||
@Override
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
return new ES815HnswBitVectorsFormat();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Before
|
||||
public void init() {
|
||||
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||
}
|
||||
}
|
|
@ -236,8 +236,8 @@ public class BinaryDenseVectorScriptDocValuesTests extends ESTestCase {
|
|||
|
||||
public static BytesRef mockEncodeDenseVector(float[] values, ElementType elementType, IndexVersion indexVersion) {
|
||||
int numBytes = indexVersion.onOrAfter(DenseVectorFieldMapper.MAGNITUDE_STORED_INDEX_VERSION)
|
||||
? elementType.elementBytes * values.length + DenseVectorFieldMapper.MAGNITUDE_BYTES
|
||||
: elementType.elementBytes * values.length;
|
||||
? elementType.getNumBytes(values.length) + DenseVectorFieldMapper.MAGNITUDE_BYTES
|
||||
: elementType.getNumBytes(values.length);
|
||||
double dotProduct = 0f;
|
||||
ByteBuffer byteBuffer = elementType.createByteBuffer(indexVersion, numBytes);
|
||||
for (float value : values) {
|
||||
|
|
|
@ -71,11 +71,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
|||
private final ElementType elementType;
|
||||
private final boolean indexed;
|
||||
private final boolean indexOptionsSet;
|
||||
private final int dims;
|
||||
|
||||
public DenseVectorFieldMapperTests() {
|
||||
this.elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT);
|
||||
this.elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT, ElementType.BIT);
|
||||
this.indexed = randomBoolean();
|
||||
this.indexOptionsSet = this.indexed && randomBoolean();
|
||||
this.dims = ElementType.BIT == elementType ? 4 * Byte.SIZE : 4;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -89,7 +91,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
|||
}
|
||||
|
||||
private void indexMapping(XContentBuilder b, IndexVersion indexVersion) throws IOException {
|
||||
b.field("type", "dense_vector").field("dims", 4);
|
||||
b.field("type", "dense_vector").field("dims", dims);
|
||||
if (elementType != ElementType.FLOAT) {
|
||||
b.field("element_type", elementType.toString());
|
||||
}
|
||||
|
@ -108,7 +110,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
|||
b.endObject();
|
||||
}
|
||||
if (indexed) {
|
||||
b.field("similarity", "dot_product");
|
||||
b.field("similarity", elementType == ElementType.BIT ? "l2_norm" : "dot_product");
|
||||
if (indexOptionsSet) {
|
||||
b.startObject("index_options");
|
||||
b.field("type", "hnsw");
|
||||
|
@ -121,52 +123,86 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
|||
|
||||
@Override
|
||||
protected Object getSampleValueForDocument() {
|
||||
return elementType == ElementType.BYTE ? List.of((byte) 1, (byte) 1, (byte) 1, (byte) 1) : List.of(0.5, 0.5, 0.5, 0.5);
|
||||
return elementType == ElementType.FLOAT ? List.of(0.5, 0.5, 0.5, 0.5) : List.of((byte) 1, (byte) 1, (byte) 1, (byte) 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void registerParameters(ParameterChecker checker) throws IOException {
|
||||
checker.registerConflictCheck(
|
||||
"dims",
|
||||
fieldMapping(b -> b.field("type", "dense_vector").field("dims", 4)),
|
||||
fieldMapping(b -> b.field("type", "dense_vector").field("dims", 5))
|
||||
fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims)),
|
||||
fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims + 8))
|
||||
);
|
||||
checker.registerConflictCheck(
|
||||
"similarity",
|
||||
fieldMapping(b -> b.field("type", "dense_vector").field("dims", 4).field("index", true).field("similarity", "dot_product")),
|
||||
fieldMapping(b -> b.field("type", "dense_vector").field("dims", 4).field("index", true).field("similarity", "l2_norm"))
|
||||
fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims).field("index", true).field("similarity", "dot_product")),
|
||||
fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims).field("index", true).field("similarity", "l2_norm"))
|
||||
);
|
||||
checker.registerConflictCheck(
|
||||
"index",
|
||||
fieldMapping(b -> b.field("type", "dense_vector").field("dims", 4).field("index", true).field("similarity", "dot_product")),
|
||||
fieldMapping(b -> b.field("type", "dense_vector").field("dims", 4).field("index", false))
|
||||
fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims).field("index", true).field("similarity", "dot_product")),
|
||||
fieldMapping(b -> b.field("type", "dense_vector").field("dims", dims).field("index", false))
|
||||
);
|
||||
checker.registerConflictCheck(
|
||||
"element_type",
|
||||
fieldMapping(
|
||||
b -> b.field("type", "dense_vector")
|
||||
.field("dims", 4)
|
||||
.field("dims", dims)
|
||||
.field("index", true)
|
||||
.field("similarity", "dot_product")
|
||||
.field("element_type", "byte")
|
||||
),
|
||||
fieldMapping(
|
||||
b -> b.field("type", "dense_vector")
|
||||
.field("dims", 4)
|
||||
.field("dims", dims)
|
||||
.field("index", true)
|
||||
.field("similarity", "dot_product")
|
||||
.field("element_type", "float")
|
||||
)
|
||||
);
|
||||
checker.registerConflictCheck(
|
||||
"element_type",
|
||||
fieldMapping(
|
||||
b -> b.field("type", "dense_vector")
|
||||
.field("dims", dims)
|
||||
.field("index", true)
|
||||
.field("similarity", "l2_norm")
|
||||
.field("element_type", "float")
|
||||
),
|
||||
fieldMapping(
|
||||
b -> b.field("type", "dense_vector")
|
||||
.field("dims", dims)
|
||||
.field("index", true)
|
||||
.field("similarity", "l2_norm")
|
||||
.field("element_type", "bit")
|
||||
)
|
||||
);
|
||||
checker.registerConflictCheck(
|
||||
"element_type",
|
||||
fieldMapping(
|
||||
b -> b.field("type", "dense_vector")
|
||||
.field("dims", dims)
|
||||
.field("index", true)
|
||||
.field("similarity", "l2_norm")
|
||||
.field("element_type", "byte")
|
||||
),
|
||||
fieldMapping(
|
||||
b -> b.field("type", "dense_vector")
|
||||
.field("dims", dims)
|
||||
.field("index", true)
|
||||
.field("similarity", "l2_norm")
|
||||
.field("element_type", "bit")
|
||||
)
|
||||
);
|
||||
checker.registerUpdateCheck(
|
||||
b -> b.field("type", "dense_vector")
|
||||
.field("dims", 4)
|
||||
.field("dims", dims)
|
||||
.field("index", true)
|
||||
.startObject("index_options")
|
||||
.field("type", "flat")
|
||||
.endObject(),
|
||||
b -> b.field("type", "dense_vector")
|
||||
.field("dims", 4)
|
||||
.field("dims", dims)
|
||||
.field("index", true)
|
||||
.startObject("index_options")
|
||||
.field("type", "int8_flat")
|
||||
|
@ -175,13 +211,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
|||
);
|
||||
checker.registerUpdateCheck(
|
||||
b -> b.field("type", "dense_vector")
|
||||
.field("dims", 4)
|
||||
.field("dims", dims)
|
||||
.field("index", true)
|
||||
.startObject("index_options")
|
||||
.field("type", "flat")
|
||||
.endObject(),
|
||||
b -> b.field("type", "dense_vector")
|
||||
.field("dims", 4)
|
||||
.field("dims", dims)
|
||||
.field("index", true)
|
||||
.startObject("index_options")
|
||||
.field("type", "hnsw")
|
||||
|
@ -190,13 +226,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
|||
);
|
||||
checker.registerUpdateCheck(
|
||||
b -> b.field("type", "dense_vector")
|
||||
.field("dims", 4)
|
||||
.field("dims", dims)
|
||||
.field("index", true)
|
||||
.startObject("index_options")
|
||||
.field("type", "flat")
|
||||
.endObject(),
|
||||
b -> b.field("type", "dense_vector")
|
||||
.field("dims", 4)
|
||||
.field("dims", dims)
|
||||
.field("index", true)
|
||||
.startObject("index_options")
|
||||
.field("type", "int8_hnsw")
|
||||
|
@ -205,13 +241,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
|||
);
|
||||
checker.registerUpdateCheck(
|
||||
b -> b.field("type", "dense_vector")
|
||||
.field("dims", 4)
|
||||
.field("dims", dims)
|
||||
.field("index", true)
|
||||
.startObject("index_options")
|
||||
.field("type", "int8_flat")
|
||||
.endObject(),
|
||||
b -> b.field("type", "dense_vector")
|
||||
.field("dims", 4)
|
||||
.field("dims", dims)
|
||||
.field("index", true)
|
||||
.startObject("index_options")
|
||||
.field("type", "hnsw")
|
||||
|
@ -220,13 +256,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
|||
);
|
||||
checker.registerUpdateCheck(
|
||||
b -> b.field("type", "dense_vector")
|
||||
.field("dims", 4)
|
||||
.field("dims", dims)
|
||||
.field("index", true)
|
||||
.startObject("index_options")
|
||||
.field("type", "int8_flat")
|
||||
.endObject(),
|
||||
b -> b.field("type", "dense_vector")
|
||||
.field("dims", 4)
|
||||
.field("dims", dims)
|
||||
.field("index", true)
|
||||
.startObject("index_options")
|
||||
.field("type", "int8_hnsw")
|
||||
|
@ -235,13 +271,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
|||
);
|
||||
checker.registerUpdateCheck(
|
||||
b -> b.field("type", "dense_vector")
|
||||
.field("dims", 4)
|
||||
.field("dims", dims)
|
||||
.field("index", true)
|
||||
.startObject("index_options")
|
||||
.field("type", "hnsw")
|
||||
.endObject(),
|
||||
b -> b.field("type", "dense_vector")
|
||||
.field("dims", 4)
|
||||
.field("dims", dims)
|
||||
.field("index", true)
|
||||
.startObject("index_options")
|
||||
.field("type", "int8_hnsw")
|
||||
|
@ -252,7 +288,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
|||
"index_options",
|
||||
fieldMapping(
|
||||
b -> b.field("type", "dense_vector")
|
||||
.field("dims", 4)
|
||||
.field("dims", dims)
|
||||
.field("index", true)
|
||||
.startObject("index_options")
|
||||
.field("type", "hnsw")
|
||||
|
@ -260,7 +296,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
|||
),
|
||||
fieldMapping(
|
||||
b -> b.field("type", "dense_vector")
|
||||
.field("dims", 4)
|
||||
.field("dims", dims)
|
||||
.field("index", true)
|
||||
.startObject("index_options")
|
||||
.field("type", "flat")
|
||||
|
@ -353,7 +389,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
|||
mapping = mapping(b -> {
|
||||
b.startObject("field");
|
||||
b.field("type", "dense_vector")
|
||||
.field("dims", 4)
|
||||
.field("dims", dims)
|
||||
.field("similarity", "cosine")
|
||||
.field("index", true)
|
||||
.startObject("index_options")
|
||||
|
@ -648,7 +684,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
|||
() -> createDocumentMapper(
|
||||
fieldMapping(
|
||||
b -> b.field("type", "dense_vector")
|
||||
.field("dims", 4)
|
||||
.field("dims", dims)
|
||||
.field("element_type", "byte")
|
||||
.field("similarity", "l2_norm")
|
||||
.field("index", true)
|
||||
|
@ -1020,6 +1056,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
|||
}
|
||||
yield floats;
|
||||
}
|
||||
case BIT -> randomByteArrayOfLength(vectorFieldType.getVectorDimensions() / 8);
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -1196,7 +1233,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
|||
boolean setEfConstruction = randomBoolean();
|
||||
MapperService mapperService = createMapperService(fieldMapping(b -> {
|
||||
b.field("type", "dense_vector");
|
||||
b.field("dims", 4);
|
||||
b.field("dims", dims);
|
||||
b.field("index", true);
|
||||
b.field("similarity", "dot_product");
|
||||
b.startObject("index_options");
|
||||
|
@ -1234,7 +1271,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
|||
for (String quantizedFlatFormat : new String[] { "int8_flat", "int4_flat" }) {
|
||||
MapperService mapperService = createMapperService(fieldMapping(b -> {
|
||||
b.field("type", "dense_vector");
|
||||
b.field("dims", 4);
|
||||
b.field("dims", dims);
|
||||
b.field("index", true);
|
||||
b.field("similarity", "dot_product");
|
||||
b.startObject("index_options");
|
||||
|
@ -1275,7 +1312,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
|||
float confidenceInterval = (float) randomDoubleBetween(0.90f, 1.0f, true);
|
||||
MapperService mapperService = createMapperService(fieldMapping(b -> {
|
||||
b.field("type", "dense_vector");
|
||||
b.field("dims", 4);
|
||||
b.field("dims", dims);
|
||||
b.field("index", true);
|
||||
b.field("similarity", "dot_product");
|
||||
b.startObject("index_options");
|
||||
|
@ -1316,7 +1353,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
|
|||
float confidenceInterval = (float) randomDoubleBetween(0.90f, 1.0f, true);
|
||||
MapperService mapperService = createMapperService(fieldMapping(b -> {
|
||||
b.field("type", "dense_vector");
|
||||
b.field("dims", 4);
|
||||
b.field("dims", dims);
|
||||
b.field("index", true);
|
||||
b.field("similarity", "dot_product");
|
||||
b.startObject("index_options");
|
||||
|
|
|
@ -185,10 +185,12 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
|
|||
queryVector[i] = randomByte();
|
||||
floatQueryVector[i] = queryVector[i];
|
||||
}
|
||||
Query query = field.createKnnQuery(queryVector, 10, null, null, producer);
|
||||
VectorData vectorData = new VectorData(null, queryVector);
|
||||
Query query = field.createKnnQuery(vectorData, 10, null, null, producer);
|
||||
assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class));
|
||||
|
||||
query = field.createKnnQuery(floatQueryVector, 10, null, null, producer);
|
||||
vectorData = new VectorData(floatQueryVector, null);
|
||||
query = field.createKnnQuery(vectorData, 10, null, null, producer);
|
||||
assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class));
|
||||
}
|
||||
}
|
||||
|
@ -321,7 +323,8 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
|
|||
for (int i = 0; i < 4096; i++) {
|
||||
queryVector[i] = randomByte();
|
||||
}
|
||||
Query query = fieldWith4096dims.createKnnQuery(queryVector, 10, null, null, null);
|
||||
VectorData vectorData = new VectorData(null, queryVector);
|
||||
Query query = fieldWith4096dims.createKnnQuery(vectorData, 10, null, null, null);
|
||||
assertThat(query, instanceOf(KnnByteVectorQuery.class));
|
||||
}
|
||||
}
|
||||
|
@ -359,7 +362,10 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
|
|||
);
|
||||
assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude."));
|
||||
|
||||
e = expectThrows(IllegalArgumentException.class, () -> cosineField.createKnnQuery(new byte[] { 0, 0, 0 }, 10, null, null, null));
|
||||
e = expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> cosineField.createKnnQuery(new VectorData(null, new byte[] { 0, 0, 0 }), 10, null, null, null)
|
||||
);
|
||||
assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude."));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -114,10 +114,10 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
|
|||
);
|
||||
|
||||
e = expectThrows(IllegalArgumentException.class, () -> new Hamming(scoreScript, queryVector, fieldName));
|
||||
assertThat(e.getMessage(), containsString("hamming distance is only supported for byte vectors"));
|
||||
assertThat(e.getMessage(), containsString("hamming distance is only supported for byte or bit vectors"));
|
||||
|
||||
e = expectThrows(IllegalArgumentException.class, () -> new Hamming(scoreScript, invalidQueryVector, fieldName));
|
||||
assertThat(e.getMessage(), containsString("hamming distance is only supported for byte vectors"));
|
||||
assertThat(e.getMessage(), containsString("hamming distance is only supported for byte or bit vectors"));
|
||||
|
||||
// Check scripting infrastructure integration
|
||||
DotProduct dotProduct = new DotProduct(scoreScript, queryVector, fieldName);
|
||||
|
|
|
@ -122,7 +122,7 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
|
|||
Query filterQuery = booleanQuery.clauses().isEmpty() ? null : booleanQuery;
|
||||
// The field should always be resolved to the concrete field
|
||||
Query knnVectorQueryBuilt = switch (elementType()) {
|
||||
case BYTE -> new ESKnnByteVectorQuery(
|
||||
case BYTE, BIT -> new ESKnnByteVectorQuery(
|
||||
VECTOR_FIELD,
|
||||
queryBuilder.queryVector().asByteVector(),
|
||||
queryBuilder.numCands(),
|
||||
|
@ -145,7 +145,10 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
|
|||
SearchExecutionContext context = createSearchExecutionContext();
|
||||
KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f }, 10, null);
|
||||
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context));
|
||||
assertThat(e.getMessage(), containsString("the query vector has a different dimension [2] than the index vectors [3]"));
|
||||
assertThat(
|
||||
e.getMessage(),
|
||||
containsString("The query vector has a different number of dimensions [2] than the document vectors [3]")
|
||||
);
|
||||
}
|
||||
|
||||
public void testNonexistentField() {
|
||||
|
|
|
@ -46,6 +46,7 @@ public class EmbeddingRequestChunker {
|
|||
return switch (elementType) {
|
||||
case BYTE -> EmbeddingType.BYTE;
|
||||
case FLOAT -> EmbeddingType.FLOAT;
|
||||
case BIT -> throw new IllegalArgumentException("Bit vectors are not supported");
|
||||
};
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue