mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 09:28:55 -04:00
ES|QL - kNN function initial support (#127322)
This commit is contained in:
parent
aa87f46681
commit
366e00f5c5
28 changed files with 1251 additions and 23 deletions
1
docs/reference/query-languages/esql/images/functions/knn.svg
generated
Normal file
1
docs/reference/query-languages/esql/images/functions/knn.svg
generated
Normal file
|
@ -0,0 +1 @@
|
|||
<svg version="1.1" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg" width="568" height="61" viewbox="0 0 568 61"><defs><style type="text/css">.c{fill:none;stroke:#222222;}.k{fill:#000000;font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;font-size:20px;}.s{fill:#e4f4ff;stroke:#222222;}.syn{fill:#8D8D8D;font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;font-size:20px;}</style></defs><path class="c" d="M0 31h5m56 0h10m32 0h10m80 0h10m32 0h10m80 0h10m32 0h30m104 0h20m-139 0q5 0 5 5v10q0 5 5 5h114q5 0 5-5v-10q0-5 5-5m5 0h10m32 0h5"/><rect class="s" x="5" y="5" width="56" height="36"/><text class="k" x="15" y="31">KNN</text><rect class="s" x="71" y="5" width="32" height="36" rx="7"/><text class="syn" x="81" y="31">(</text><rect class="s" x="113" y="5" width="80" height="36" rx="7"/><text class="k" x="123" y="31">field</text><rect class="s" x="203" y="5" width="32" height="36" rx="7"/><text class="syn" x="213" y="31">,</text><rect class="s" x="245" y="5" width="80" height="36" rx="7"/><text class="k" x="255" y="31">query</text><rect class="s" x="335" y="5" width="32" height="36" rx="7"/><text class="syn" x="345" y="31">,</text><rect class="s" x="397" y="5" width="104" height="36" rx="7"/><text class="k" x="407" y="31">options</text><rect class="s" x="531" y="5" width="32" height="36" rx="7"/><text class="syn" x="541" y="31">)</text></svg>
|
After Width: | Height: | Size: 1.5 KiB |
13
docs/reference/query-languages/esql/kibana/definition/functions/knn.json
generated
Normal file
13
docs/reference/query-languages/esql/kibana/definition/functions/knn.json
generated
Normal file
|
@ -0,0 +1,13 @@
|
|||
{
|
||||
"comment" : "This is generated by ESQL’s AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.",
|
||||
"type" : "scalar",
|
||||
"name" : "knn",
|
||||
"description" : "Finds the k nearest vectors to a query vector, as measured by a similarity metric. knn function finds nearest vectors through approximate search on indexed dense_vectors.",
|
||||
"signatures" : [ ],
|
||||
"examples" : [
|
||||
"from colors metadata _score\n| where knn(rgb_vector, [0, 120, 0])\n| sort _score desc",
|
||||
"from colors metadata _score\n| where knn(rgb_vector, [0,255,255], {\"k\": 4})\n| sort _score desc"
|
||||
],
|
||||
"preview" : true,
|
||||
"snapshot_only" : true
|
||||
}
|
10
docs/reference/query-languages/esql/kibana/docs/functions/knn.md
generated
Normal file
10
docs/reference/query-languages/esql/kibana/docs/functions/knn.md
generated
Normal file
|
@ -0,0 +1,10 @@
|
|||
% This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.
|
||||
|
||||
### KNN
|
||||
Finds the k nearest vectors to a query vector, as measured by a similarity metric. knn function finds nearest vectors through approximate search on indexed dense_vectors.
|
||||
|
||||
```esql
|
||||
from colors metadata _score
|
||||
| where knn(rgb_vector, [0, 120, 0])
|
||||
| sort _score desc
|
||||
```
|
|
@ -2589,6 +2589,11 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|||
return null;
|
||||
}
|
||||
|
||||
if (dims == null) {
|
||||
// No data has been indexed yet
|
||||
return BlockLoader.CONSTANT_NULLS;
|
||||
}
|
||||
|
||||
if (indexed) {
|
||||
return new BlockDocValuesReader.DenseVectorBlockLoader(name(), dims);
|
||||
}
|
||||
|
|
|
@ -112,7 +112,7 @@ public abstract class LuceneQueryEvaluator<T extends Vector.Builder> implements
|
|||
int min = docs.docs().getInt(0);
|
||||
int max = docs.docs().getInt(docs.getPositionCount() - 1);
|
||||
int length = max - min + 1;
|
||||
try (T scoreBuilder = createVectorBuilder(blockFactory, length)) {
|
||||
try (T scoreBuilder = createVectorBuilder(blockFactory, docs.getPositionCount())) {
|
||||
if (length == docs.getPositionCount() && length > 1) {
|
||||
return segmentState.scoreDense(scoreBuilder, min, max);
|
||||
}
|
||||
|
|
|
@ -1022,7 +1022,7 @@ public abstract class RestEsqlTestCase extends ESRestTestCase {
|
|||
var query = requestObjectBuilder().query(format(null, "from * | lookup join {} on integer {}", testIndexName(), sort));
|
||||
Map<String, Object> result = runEsql(query);
|
||||
var columns = as(result.get("columns"), List.class);
|
||||
assertEquals(21, columns.size());
|
||||
assertEquals(22, columns.size());
|
||||
var values = as(result.get("values"), List.class);
|
||||
assertEquals(10, values.size());
|
||||
}
|
||||
|
|
|
@ -148,6 +148,7 @@ public class CsvTestsDataLoader {
|
|||
private static final TestDataset LOGS = new TestDataset("logs");
|
||||
private static final TestDataset MV_TEXT = new TestDataset("mv_text");
|
||||
private static final TestDataset DENSE_VECTOR = new TestDataset("dense_vector");
|
||||
private static final TestDataset COLORS = new TestDataset("colors");
|
||||
|
||||
public static final Map<String, TestDataset> CSV_DATASET_MAP = Map.ofEntries(
|
||||
Map.entry(EMPLOYEES.indexName, EMPLOYEES),
|
||||
|
@ -210,7 +211,8 @@ public class CsvTestsDataLoader {
|
|||
Map.entry(SEMANTIC_TEXT.indexName, SEMANTIC_TEXT),
|
||||
Map.entry(LOGS.indexName, LOGS),
|
||||
Map.entry(MV_TEXT.indexName, MV_TEXT),
|
||||
Map.entry(DENSE_VECTOR.indexName, DENSE_VECTOR)
|
||||
Map.entry(DENSE_VECTOR.indexName, DENSE_VECTOR),
|
||||
Map.entry(COLORS.indexName, COLORS)
|
||||
);
|
||||
|
||||
private static final EnrichConfig LANGUAGES_ENRICH = new EnrichConfig("languages_policy", "enrich-policy-languages.json");
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
color:text,hex_code:keyword,rgb_vector:dense_vector,primary:boolean
|
||||
maroon, #800000, [128,0,0], false
|
||||
brown, #A52A2A, [165,42,42], false
|
||||
firebrick, #B22222, [178,34,34], false
|
||||
crimson, #DC143C, [220,20,60], false
|
||||
red, #FF0000, [255,0,0], true
|
||||
tomato, #FF6347, [255,99,71], false
|
||||
coral, #FF7F50, [255,127,80], false
|
||||
salmon, #FA8072, [250,128,114], false
|
||||
orange, #FFA500, [255,165,0], false
|
||||
gold, #FFD700, [255,215,0], false
|
||||
golden rod, #DAA520, [218,165,32], false
|
||||
khaki, #F0E68C, [240,230,140], false
|
||||
olive, #808000, [128,128,0], false
|
||||
yellow, #FFFF00, [255,255,0], true
|
||||
chartreuse, #7FFF00, [127,255,0], false
|
||||
green, #008000, [0,128,0], true
|
||||
lime, #00FF00, [0,255,0], false
|
||||
teal, #008080, [0,128,128], false
|
||||
cyan, #00FFFF, [0,255,255], true
|
||||
turquoise, #40E0D0, [64,224,208], false
|
||||
aqua marine, #7FFFD4, [127,255,212], false
|
||||
navy, #000080, [0,0,128], false
|
||||
blue, #0000FF, [0,0,255], true
|
||||
indigo, #4B0082, [75,0,130], false
|
||||
purple, #800080, [128,0,128], false
|
||||
thistle, #D8BFD8, [216,191,216], false
|
||||
plum, #DDA0DD, [221,160,221], false
|
||||
violet, #EE82EE, [238,130,238], false
|
||||
magenta, #FF00FF, [255,0,255], true
|
||||
orchid, #DA70D6, [218,112,214], false
|
||||
pink, #FFC0CB, [255,192,203], false
|
||||
beige, #F5F5DC, [245,245,220], false
|
||||
bisque, #FFE4C4, [255,228,196], false
|
||||
wheat, #F5DEB3, [245,222,179], false
|
||||
corn silk, #FFF8DC, [255,248,220], false
|
||||
lemon chiffon, #FFFACD, [255,250,205], false
|
||||
sienna, #A0522D, [160,82,45], false
|
||||
chocolate, #D2691E, [210,105,30], false
|
||||
peru, #CD853F, [205,133,63], false
|
||||
burly wood, #DEB887, [222,184,135], false
|
||||
tan, #D2B48C, [210,180,140], false
|
||||
moccasin, #FFE4B5, [255,228,181], false
|
||||
peach puff, #FFDAB9, [255,218,185], false
|
||||
misty rose, #FFE4E1, [255,228,225], false
|
||||
linen, #FAF0E6, [250,240,230], false
|
||||
old lace, #FDF5E6, [253,245,230], false
|
||||
papaya whip, #FFEFD5, [255,239,213], false
|
||||
sea shell, #FFF5EE, [255,245,238], false
|
||||
mint cream, #F5FFFA, [245,255,250], false
|
||||
lavender, #E6E6FA, [230,230,250], false
|
||||
honeydew, #F0FFF0, [240,255,240], false
|
||||
ivory, #FFFFF0, [255,255,240], false
|
||||
azure, #F0FFFF, [240,255,255], false
|
||||
snow, #FFFAFA, [255,250,250], false
|
||||
black, #000000, [0,0,0], true
|
||||
gray, #808080, [128,128,128], true
|
||||
silver, #C0C0C0, [192,192,192], false
|
||||
gainsboro, #DCDCDC, [220,220,220], false
|
||||
white, #FFFFFF, [255,255,255], true
|
|
|
@ -0,0 +1,285 @@
|
|||
# TODO Most tests explicitly set k. Until knn function uses LIMIT as k, we need to explicitly set it to all values
|
||||
# in the dataset to avoid test failures due to docs allocation in different shards, which can impact results for a
|
||||
# top-n query at the shard level
|
||||
|
||||
knnSearch
|
||||
required_capability: knn_function
|
||||
|
||||
// tag::knn-function[]
|
||||
from colors metadata _score
|
||||
| where knn(rgb_vector, [0, 120, 0])
|
||||
| sort _score desc, color asc
|
||||
// end::knn-function[]
|
||||
| keep color, rgb_vector
|
||||
| limit 10
|
||||
;
|
||||
|
||||
// tag::knn-function-result[]
|
||||
color:text | rgb_vector:dense_vector
|
||||
green | [0.0, 128.0, 0.0]
|
||||
black | [0.0, 0.0, 0.0]
|
||||
olive | [128.0, 128.0, 0.0]
|
||||
teal | [0.0, 128.0, 128.0]
|
||||
lime | [0.0, 255.0, 0.0]
|
||||
sienna | [160.0, 82.0, 45.0]
|
||||
maroon | [128.0, 0.0, 0.0]
|
||||
navy | [0.0, 0.0, 128.0]
|
||||
gray | [128.0, 128.0, 128.0]
|
||||
chartreuse | [127.0, 255.0, 0.0]
|
||||
// end::knn-function-result[]
|
||||
;
|
||||
|
||||
knnSearchWithKOption
|
||||
required_capability: knn_function
|
||||
|
||||
// tag::knn-function-options[]
|
||||
from colors metadata _score
|
||||
| where knn(rgb_vector, [0,255,255], {"k": 4})
|
||||
| sort _score desc, color asc
|
||||
// end::knn-function-options[]
|
||||
| keep color, rgb_vector
|
||||
| limit 4
|
||||
;
|
||||
|
||||
color:text | rgb_vector:dense_vector
|
||||
cyan | [0.0, 255.0, 255.0]
|
||||
turquoise | [64.0, 224.0, 208.0]
|
||||
aqua marine | [127.0, 255.0, 212.0]
|
||||
teal | [0.0, 128.0, 128.0]
|
||||
;
|
||||
|
||||
knnSearchWithSimilarityOption
|
||||
required_capability: knn_function
|
||||
|
||||
from colors metadata _score
|
||||
| where knn(rgb_vector, [255,192,203], {"k": 140, "similarity": 40})
|
||||
| sort _score desc, color asc
|
||||
| keep color, rgb_vector
|
||||
;
|
||||
|
||||
color:text | rgb_vector:dense_vector
|
||||
pink | [255.0, 192.0, 203.0]
|
||||
peach puff | [255.0, 218.0, 185.0]
|
||||
bisque | [255.0, 228.0, 196.0]
|
||||
wheat | [245.0, 222.0, 179.0]
|
||||
|
||||
;
|
||||
|
||||
knnHybridSearch
|
||||
required_capability: knn_function
|
||||
|
||||
from colors metadata _score
|
||||
| where match(color, "blue") or knn(rgb_vector, [65,105,225], {"k": 140})
|
||||
| where primary == true
|
||||
| sort _score desc, color asc
|
||||
| keep color, rgb_vector
|
||||
| limit 10
|
||||
;
|
||||
|
||||
color:text | rgb_vector:dense_vector
|
||||
blue | [0.0, 0.0, 255.0]
|
||||
gray | [128.0, 128.0, 128.0]
|
||||
cyan | [0.0, 255.0, 255.0]
|
||||
magenta | [255.0, 0.0, 255.0]
|
||||
green | [0.0, 128.0, 0.0]
|
||||
white | [255.0, 255.0, 255.0]
|
||||
black | [0.0, 0.0, 0.0]
|
||||
red | [255.0, 0.0, 0.0]
|
||||
yellow | [255.0, 255.0, 0.0]
|
||||
;
|
||||
|
||||
knnWithMultipleFunctions
|
||||
required_capability: knn_function
|
||||
|
||||
from colors metadata _score
|
||||
| where knn(rgb_vector, [128,128,0], {"k": 140}) and match(color, "olive")
|
||||
| sort _score desc, color asc
|
||||
| keep color, rgb_vector
|
||||
;
|
||||
|
||||
color:text | rgb_vector:dense_vector
|
||||
olive | [128.0, 128.0, 0.0]
|
||||
;
|
||||
|
||||
knnAfterKeep
|
||||
required_capability: knn_function
|
||||
|
||||
from colors metadata _score
|
||||
| keep rgb_vector, color, _score
|
||||
| where knn(rgb_vector, [128,255,0], {"k": 140})
|
||||
| sort _score desc, color asc
|
||||
| keep rgb_vector
|
||||
| limit 5
|
||||
;
|
||||
|
||||
rgb_vector:dense_vector
|
||||
[127.0, 255.0, 0.0]
|
||||
[128.0, 128.0, 0.0]
|
||||
[255.0, 255.0, 0.0]
|
||||
[0.0, 255.0, 0.0]
|
||||
[218.0, 165.0, 32.0]
|
||||
;
|
||||
|
||||
knnAfterDrop
|
||||
required_capability: knn_function
|
||||
|
||||
from colors metadata _score
|
||||
| drop primary
|
||||
| where knn(rgb_vector, [128,250,0], {"k": 140})
|
||||
| sort _score desc, color asc
|
||||
| keep color, rgb_vector
|
||||
| limit 5
|
||||
;
|
||||
|
||||
color:text | rgb_vector: dense_vector
|
||||
chartreuse | [127.0, 255.0, 0.0]
|
||||
olive | [128.0, 128.0, 0.0]
|
||||
yellow | [255.0, 255.0, 0.0]
|
||||
golden rod | [218.0, 165.0, 32.0]
|
||||
lime | [0.0, 255.0, 0.0]
|
||||
;
|
||||
|
||||
knnAfterEval
|
||||
required_capability: knn_function
|
||||
|
||||
from colors metadata _score
|
||||
| eval composed_name = locate(color, " ") > 0
|
||||
| where knn(rgb_vector, [128,128,0], {"k": 140})
|
||||
| sort _score desc, color asc
|
||||
| keep color, composed_name
|
||||
| limit 5
|
||||
;
|
||||
|
||||
color:text | composed_name:boolean
|
||||
olive | false
|
||||
sienna | false
|
||||
chocolate | false
|
||||
peru | false
|
||||
golden rod | true
|
||||
;
|
||||
|
||||
knnWithConjunction
|
||||
required_capability: knn_function
|
||||
|
||||
# TODO We need kNN prefiltering here so we get more candidates that pass the filter
|
||||
from colors metadata _score
|
||||
| where knn(rgb_vector, [255,255,238], {"k": 140}) and hex_code like "#FFF*"
|
||||
| sort _score desc, color asc
|
||||
| keep color, hex_code, rgb_vector
|
||||
| limit 10
|
||||
;
|
||||
|
||||
color:text | hex_code:keyword | rgb_vector:dense_vector
|
||||
ivory | #FFFFF0 | [255.0, 255.0, 240.0]
|
||||
sea shell | #FFF5EE | [255.0, 245.0, 238.0]
|
||||
snow | #FFFAFA | [255.0, 250.0, 250.0]
|
||||
white | #FFFFFF | [255.0, 255.0, 255.0]
|
||||
corn silk | #FFF8DC | [255.0, 248.0, 220.0]
|
||||
lemon chiffon | #FFFACD | [255.0, 250.0, 205.0]
|
||||
yellow | #FFFF00 | [255.0, 255.0, 0.0]
|
||||
;
|
||||
|
||||
knnWithDisjunctionAndFiltersConjunction
|
||||
required_capability: knn_function
|
||||
|
||||
# TODO We need kNN prefiltering here so we get more candidates that pass the filter
|
||||
from colors metadata _score
|
||||
| where (knn(rgb_vector, [0,255,255], {"k": 140}) or knn(rgb_vector, [128, 0, 255], {"k": 140})) and primary == true
|
||||
| keep color, rgb_vector, _score
|
||||
| sort _score desc, color asc
|
||||
| drop _score
|
||||
| limit 10
|
||||
;
|
||||
|
||||
color:text | rgb_vector:dense_vector
|
||||
cyan | [0.0, 255.0, 255.0]
|
||||
blue | [0.0, 0.0, 255.0]
|
||||
magenta | [255.0, 0.0, 255.0]
|
||||
gray | [128.0, 128.0, 128.0]
|
||||
white | [255.0, 255.0, 255.0]
|
||||
green | [0.0, 128.0, 0.0]
|
||||
black | [0.0, 0.0, 0.0]
|
||||
red | [255.0, 0.0, 0.0]
|
||||
yellow | [255.0, 255.0, 0.0]
|
||||
;
|
||||
|
||||
knnWithNonPushableConjunction
|
||||
required_capability: knn_function
|
||||
|
||||
from colors metadata _score
|
||||
| eval composed_name = locate(color, " ") > 0
|
||||
| where knn(rgb_vector, [128,128,0], {"k": 140}) and composed_name == false
|
||||
| sort _score desc, color asc
|
||||
| keep color, composed_name
|
||||
| limit 10
|
||||
;
|
||||
|
||||
color:text | composed_name:boolean
|
||||
olive | false
|
||||
sienna | false
|
||||
chocolate | false
|
||||
peru | false
|
||||
brown | false
|
||||
firebrick | false
|
||||
chartreuse | false
|
||||
gray | false
|
||||
green | false
|
||||
maroon | false
|
||||
;
|
||||
|
||||
testKnnWithNonPushableDisjunctions
|
||||
required_capability: knn_function
|
||||
|
||||
from colors metadata _score
|
||||
| where knn(rgb_vector, [128,128,0], {"k": 140, "similarity": 30}) or length(color) > 10
|
||||
| sort _score desc, color asc
|
||||
| keep color
|
||||
;
|
||||
|
||||
color:text
|
||||
olive
|
||||
aqua marine
|
||||
lemon chiffon
|
||||
papaya whip
|
||||
;
|
||||
|
||||
testKnnWithNonPushableDisjunctionsOnComplexExpressions
|
||||
required_capability: knn_function
|
||||
|
||||
from colors metadata _score
|
||||
| where (knn(rgb_vector, [128,128,0], {"k": 140, "similarity": 70}) and length(color) < 10) or (knn(rgb_vector, [128,0,128], {"k": 140, "similarity": 60}) and primary == false)
|
||||
| sort _score desc, color asc
|
||||
| keep color, primary
|
||||
;
|
||||
|
||||
color:text | primary:boolean
|
||||
olive | false
|
||||
purple | false
|
||||
indigo | false
|
||||
;
|
||||
|
||||
testKnnInStatsNonPushable
|
||||
required_capability: knn_function
|
||||
|
||||
from colors
|
||||
| where length(color) < 10
|
||||
| stats c = count(*) where knn(rgb_vector, [128,128,255], {"k": 140})
|
||||
;
|
||||
|
||||
c: long
|
||||
50
|
||||
;
|
||||
|
||||
testKnnInStatsWithGrouping
|
||||
required_capability: knn_function
|
||||
required_capability: full_text_functions_in_stats_where
|
||||
|
||||
from colors
|
||||
| where length(color) < 10
|
||||
| stats c = count(*) where knn(rgb_vector, [128,128,255], {"k": 140}) by primary
|
||||
;
|
||||
|
||||
c: long | primary: boolean
|
||||
41 | false
|
||||
9 | true
|
||||
;
|
|
@ -63,6 +63,9 @@
|
|||
"semantic_text": {
|
||||
"type": "semantic_text",
|
||||
"inference_id": "foo_inference_id"
|
||||
},
|
||||
"dense_vector": {
|
||||
"type": "dense_vector"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
{
|
||||
"properties": {
|
||||
"color": {
|
||||
"type": "text"
|
||||
},
|
||||
"hex_code": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"primary": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"rgb_vector": {
|
||||
"type": "dense_vector",
|
||||
"similarity": "l2_norm",
|
||||
"index_options": {
|
||||
"type": "hnsw"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -128,10 +128,57 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testNonIndexedDenseVectorField() throws IOException {
|
||||
createIndexWithDenseVector("no_dense_vectors");
|
||||
|
||||
int numDocs = randomIntBetween(10, 100);
|
||||
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
|
||||
for (int i = 0; i < numDocs; i++) {
|
||||
docs[i] = prepareIndex("no_dense_vectors").setId("" + i).setSource("id", String.valueOf(i));
|
||||
}
|
||||
|
||||
indexRandom(true, docs);
|
||||
|
||||
var query = """
|
||||
FROM no_dense_vectors
|
||||
| KEEP id, vector
|
||||
""";
|
||||
|
||||
try (var resp = run(query)) {
|
||||
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
|
||||
assertEquals(numDocs, valuesList.size());
|
||||
valuesList.forEach(value -> {
|
||||
assertEquals(2, value.size());
|
||||
Integer id = (Integer) value.get(0);
|
||||
assertNotNull(id);
|
||||
Object vector = value.get(1);
|
||||
assertNull(vector);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@Before
|
||||
public void setup() throws IOException {
|
||||
assumeTrue("Dense vector type is disabled", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled());
|
||||
var indexName = "test";
|
||||
|
||||
createIndexWithDenseVector("test");
|
||||
|
||||
int numDims = randomIntBetween(32, 64) * 2; // min 64, even number
|
||||
int numDocs = randomIntBetween(10, 100);
|
||||
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
|
||||
for (int i = 0; i < numDocs; i++) {
|
||||
List<Float> vector = new ArrayList<>(numDims);
|
||||
for (int j = 0; j < numDims; j++) {
|
||||
vector.add(randomFloat());
|
||||
}
|
||||
docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "vector", vector);
|
||||
indexedVectors.put(i, vector);
|
||||
}
|
||||
|
||||
indexRandom(true, docs);
|
||||
}
|
||||
|
||||
private void createIndexWithDenseVector(String indexName) throws IOException {
|
||||
var client = client().admin().indices();
|
||||
XContentBuilder mapping = XContentFactory.jsonBuilder()
|
||||
.startObject()
|
||||
|
@ -161,19 +208,5 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase {
|
|||
.setMapping(mapping)
|
||||
.setSettings(settingsBuilder.build());
|
||||
assertAcked(CreateRequest);
|
||||
|
||||
int numDims = randomIntBetween(32, 64) * 2; // min 64, even number
|
||||
int numDocs = randomIntBetween(10, 100);
|
||||
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
|
||||
for (int i = 0; i < numDocs; i++) {
|
||||
List<Float> vector = new ArrayList<>(numDims);
|
||||
for (int j = 0; j < numDims; j++) {
|
||||
vector.add(randomFloat());
|
||||
}
|
||||
docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "vector", vector);
|
||||
indexedVectors.put(i, vector);
|
||||
}
|
||||
|
||||
indexRandom(true, docs);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,156 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.plugin;
|
||||
|
||||
import org.elasticsearch.action.index.IndexRequestBuilder;
|
||||
import org.elasticsearch.cluster.metadata.IndexMetadata;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xpack.esql.EsqlTestUtils;
|
||||
import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase;
|
||||
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
|
||||
|
||||
public class KnnFunctionIT extends AbstractEsqlIntegTestCase {
|
||||
|
||||
private final Map<Integer, List<Float>> indexedVectors = new HashMap<>();
|
||||
private int numDocs;
|
||||
private int numDims;
|
||||
|
||||
public void testKnnDefaults() {
|
||||
float[] queryVector = new float[numDims];
|
||||
Arrays.fill(queryVector, 1.0f);
|
||||
|
||||
var query = String.format(Locale.ROOT, """
|
||||
FROM test METADATA _score
|
||||
| WHERE knn(vector, %s)
|
||||
| KEEP id, floats, _score, vector
|
||||
| SORT _score DESC
|
||||
""", Arrays.toString(queryVector));
|
||||
|
||||
try (var resp = run(query)) {
|
||||
assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector"));
|
||||
assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector"));
|
||||
|
||||
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
|
||||
assertEquals(Math.min(indexedVectors.size(), 10), valuesList.size());
|
||||
for (int i = 0; i < valuesList.size(); i++) {
|
||||
List<Object> row = valuesList.get(i);
|
||||
// Vectors should be in order of ID, as they're less similar than the query vector as the ID increases
|
||||
assertEquals(i, row.getFirst());
|
||||
@SuppressWarnings("unchecked")
|
||||
// Vectors should be the same
|
||||
List<Double> floats = (List<Double>) row.get(1);
|
||||
for (int j = 0; j < floats.size(); j++) {
|
||||
assertEquals(floats.get(j).floatValue(), indexedVectors.get(i).get(j), 0f);
|
||||
}
|
||||
var score = (Double) row.get(2);
|
||||
assertNotNull(score);
|
||||
assertTrue(score > 0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testKnnOptions() {
|
||||
float[] queryVector = new float[numDims];
|
||||
Arrays.fill(queryVector, 1.0f);
|
||||
|
||||
var query = String.format(Locale.ROOT, """
|
||||
FROM test METADATA _score
|
||||
| WHERE knn(vector, %s, {"k": 5})
|
||||
| KEEP id, floats, _score, vector
|
||||
| SORT _score DESC
|
||||
""", Arrays.toString(queryVector));
|
||||
|
||||
try (var resp = run(query)) {
|
||||
assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector"));
|
||||
assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector"));
|
||||
|
||||
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
|
||||
assertEquals(5, valuesList.size());
|
||||
}
|
||||
}
|
||||
|
||||
public void testKnnNonPushedDown() {
|
||||
float[] queryVector = new float[numDims];
|
||||
Arrays.fill(queryVector, 1.0f);
|
||||
|
||||
// TODO we need to decide what to do when / if user uses k for limit, as no more than k results will be returned from knn query
|
||||
var query = String.format(Locale.ROOT, """
|
||||
FROM test METADATA _score
|
||||
| WHERE knn(vector, %s, {"k": 5}) OR id > 10
|
||||
| KEEP id, floats, _score, vector
|
||||
| SORT _score DESC
|
||||
""", Arrays.toString(queryVector));
|
||||
|
||||
try (var resp = run(query)) {
|
||||
assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector"));
|
||||
assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector"));
|
||||
|
||||
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
|
||||
// K = 5, 1 more for every id > 10
|
||||
assertEquals(5 + Math.max(0, numDocs - 10 - 1), valuesList.size());
|
||||
}
|
||||
}
|
||||
|
||||
@Before
|
||||
public void setup() throws IOException {
|
||||
assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled());
|
||||
|
||||
var indexName = "test";
|
||||
var client = client().admin().indices();
|
||||
XContentBuilder mapping = XContentFactory.jsonBuilder()
|
||||
.startObject()
|
||||
.startObject("properties")
|
||||
.startObject("id")
|
||||
.field("type", "integer")
|
||||
.endObject()
|
||||
.startObject("vector")
|
||||
.field("type", "dense_vector")
|
||||
.field("similarity", "l2_norm")
|
||||
.endObject()
|
||||
.startObject("floats")
|
||||
.field("type", "float")
|
||||
.endObject()
|
||||
.endObject()
|
||||
.endObject();
|
||||
|
||||
Settings.Builder settingsBuilder = Settings.builder()
|
||||
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
|
||||
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1);
|
||||
|
||||
var createRequest = client.prepareCreate(indexName).setMapping(mapping).setSettings(settingsBuilder.build());
|
||||
assertAcked(createRequest);
|
||||
|
||||
numDocs = randomIntBetween(10, 20);
|
||||
numDims = randomIntBetween(3, 10);
|
||||
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
|
||||
float value = 0.0f;
|
||||
for (int i = 0; i < numDocs; i++) {
|
||||
List<Float> vector = new ArrayList<>(numDims);
|
||||
for (int j = 0; j < numDims; j++) {
|
||||
vector.add(value++);
|
||||
}
|
||||
docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "floats", vector, "vector", vector);
|
||||
indexedVectors.put(i, vector);
|
||||
}
|
||||
|
||||
indexRandom(true, docs);
|
||||
}
|
||||
}
|
|
@ -1185,7 +1185,12 @@ public class EsqlCapabilities {
|
|||
/**
|
||||
* MATCH PHRASE function
|
||||
*/
|
||||
MATCH_PHRASE_FUNCTION;
|
||||
MATCH_PHRASE_FUNCTION,
|
||||
|
||||
/**
|
||||
* Support knn function
|
||||
*/
|
||||
KNN_FUNCTION(Build.current().isSnapshot());
|
||||
|
||||
private final boolean enabled;
|
||||
|
||||
|
|
|
@ -66,6 +66,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger
|
|||
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
|
||||
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToUnsignedLong;
|
||||
import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce;
|
||||
import org.elasticsearch.xpack.esql.expression.function.vector.VectorFunction;
|
||||
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.DateTimeArithmeticOperation;
|
||||
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation;
|
||||
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
|
||||
|
@ -1396,6 +1397,9 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
|
|||
if (f instanceof EsqlArithmeticOperation || f instanceof BinaryComparison) {
|
||||
return processBinaryOperator((BinaryOperator) f);
|
||||
}
|
||||
if (f instanceof VectorFunction vectorFunction) {
|
||||
return processVectorFunction(f);
|
||||
}
|
||||
return f;
|
||||
}
|
||||
|
||||
|
@ -1595,6 +1599,25 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
|
|||
return unresolvedAttribute(from, target.toString(), e);
|
||||
}
|
||||
}
|
||||
|
||||
private static Expression processVectorFunction(org.elasticsearch.xpack.esql.core.expression.function.Function vectorFunction) {
|
||||
List<Expression> args = vectorFunction.arguments();
|
||||
List<Expression> newArgs = new ArrayList<>();
|
||||
for (Expression arg : args) {
|
||||
if (arg.resolved() && arg.dataType().isNumeric() && arg.foldable()) {
|
||||
Object folded = arg.fold(FoldContext.small() /* TODO remove me */);
|
||||
if (folded instanceof List) {
|
||||
Literal denseVector = new Literal(arg.source(), folded, DataType.DENSE_VECTOR);
|
||||
newArgs.add(denseVector);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
newArgs.add(arg);
|
||||
}
|
||||
|
||||
return vectorFunction.replaceChildren(newArgs);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
package org.elasticsearch.xpack.esql.expression;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
|
||||
import org.elasticsearch.xpack.esql.core.expression.ExpressionCoreWritables;
|
||||
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
|
||||
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables;
|
||||
|
@ -83,6 +84,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.string.Trim;
|
|||
import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.RLike;
|
||||
import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.WildcardLike;
|
||||
import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay;
|
||||
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
|
||||
import org.elasticsearch.xpack.esql.expression.predicate.logical.Not;
|
||||
import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNotNull;
|
||||
import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNull;
|
||||
|
@ -115,6 +117,7 @@ public class ExpressionWritables {
|
|||
entries.addAll(binaryComparisons());
|
||||
entries.addAll(fullText());
|
||||
entries.addAll(unaryScalars());
|
||||
entries.addAll(vector());
|
||||
return entries;
|
||||
}
|
||||
|
||||
|
@ -252,4 +255,11 @@ public class ExpressionWritables {
|
|||
private static List<NamedWriteableRegistry.Entry> fullText() {
|
||||
return FullTextWritables.getNamedWriteables();
|
||||
}
|
||||
|
||||
private static List<NamedWriteableRegistry.Entry> vector() {
|
||||
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
|
||||
return List.of(Knn.ENTRY);
|
||||
}
|
||||
return List.of();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -179,6 +179,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.string.ToLower;
|
|||
import org.elasticsearch.xpack.esql.expression.function.scalar.string.ToUpper;
|
||||
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Trim;
|
||||
import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay;
|
||||
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
|
||||
import org.elasticsearch.xpack.esql.parser.ParsingException;
|
||||
import org.elasticsearch.xpack.esql.session.Configuration;
|
||||
|
||||
|
@ -485,7 +486,8 @@ public class EsqlFunctionRegistry {
|
|||
def(LastOverTime.class, LastOverTime::withUnresolvedTimestamp, "last_over_time"),
|
||||
def(FirstOverTime.class, FirstOverTime::withUnresolvedTimestamp, "first_over_time"),
|
||||
def(Term.class, bi(Term::new), "term"),
|
||||
def(MatchPhrase.class, tri(MatchPhrase::new), "match_phrase") } };
|
||||
def(MatchPhrase.class, tri(MatchPhrase::new), "match_phrase"),
|
||||
def(Knn.class, tri(Knn::new), "knn") } };
|
||||
}
|
||||
|
||||
public EsqlFunctionRegistry snapshotRegistry() {
|
||||
|
|
|
@ -148,7 +148,7 @@ public abstract class FullTextFunction extends Function
|
|||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(super.hashCode(), queryBuilder);
|
||||
return Objects.hash(super.hashCode(), query, queryBuilder);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -157,7 +157,7 @@ public abstract class FullTextFunction extends Function
|
|||
return false;
|
||||
}
|
||||
|
||||
return Objects.equals(queryBuilder, ((FullTextFunction) obj).queryBuilder);
|
||||
return Objects.equals(queryBuilder, ((FullTextFunction) obj).queryBuilder) && Objects.equals(query, ((FullTextFunction) obj).query);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -0,0 +1,292 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.expression.function.vector;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.xpack.esql.core.InvalidArgumentException;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Expression;
|
||||
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
|
||||
import org.elasticsearch.xpack.esql.core.expression.MapExpression;
|
||||
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
|
||||
import org.elasticsearch.xpack.esql.core.querydsl.query.Query;
|
||||
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
|
||||
import org.elasticsearch.xpack.esql.core.tree.Source;
|
||||
import org.elasticsearch.xpack.esql.core.type.DataType;
|
||||
import org.elasticsearch.xpack.esql.core.util.Check;
|
||||
import org.elasticsearch.xpack.esql.expression.function.Example;
|
||||
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo;
|
||||
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
|
||||
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
|
||||
import org.elasticsearch.xpack.esql.expression.function.MapParam;
|
||||
import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
|
||||
import org.elasticsearch.xpack.esql.expression.function.Param;
|
||||
import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction;
|
||||
import org.elasticsearch.xpack.esql.expression.function.fulltext.Match;
|
||||
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
|
||||
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
|
||||
import org.elasticsearch.xpack.esql.querydsl.query.KnnQuery;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static java.util.Map.entry;
|
||||
import static org.elasticsearch.index.query.AbstractQueryBuilder.BOOST_FIELD;
|
||||
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD;
|
||||
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD;
|
||||
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD;
|
||||
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
|
||||
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
|
||||
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD;
|
||||
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isMapExpression;
|
||||
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull;
|
||||
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable;
|
||||
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
|
||||
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
|
||||
import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT;
|
||||
import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER;
|
||||
import static org.elasticsearch.xpack.esql.expression.function.fulltext.Match.getNameFromFieldAttribute;
|
||||
|
||||
public class Knn extends FullTextFunction implements OptionalArgument, VectorFunction {
|
||||
|
||||
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", Knn::readFrom);
|
||||
|
||||
private final Expression field;
|
||||
private final Expression options;
|
||||
|
||||
public static final Map<String, DataType> ALLOWED_OPTIONS = Map.ofEntries(
|
||||
entry(K_FIELD.getPreferredName(), INTEGER),
|
||||
entry(NUM_CANDS_FIELD.getPreferredName(), INTEGER),
|
||||
entry(VECTOR_SIMILARITY_FIELD.getPreferredName(), FLOAT),
|
||||
entry(BOOST_FIELD.getPreferredName(), FLOAT),
|
||||
entry(KnnQuery.RESCORE_OVERSAMPLE_FIELD, FLOAT)
|
||||
);
|
||||
|
||||
@FunctionInfo(
|
||||
returnType = "boolean",
|
||||
preview = true,
|
||||
description = "Finds the k nearest vectors to a query vector, as measured by a similarity metric. "
|
||||
+ "knn function finds nearest vectors through approximate search on indexed dense_vectors.",
|
||||
examples = {
|
||||
@Example(file = "knn-function", tag = "knn-function"),
|
||||
@Example(file = "knn-function", tag = "knn-function-options"), },
|
||||
appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) }
|
||||
)
|
||||
public Knn(
|
||||
Source source,
|
||||
@Param(name = "field", type = { "dense_vector" }, description = "Field that the query will target.") Expression field,
|
||||
@Param(
|
||||
name = "query",
|
||||
type = { "dense_vector" },
|
||||
description = "Vector value to find top nearest neighbours for."
|
||||
) Expression query,
|
||||
@MapParam(
|
||||
name = "options",
|
||||
params = {
|
||||
@MapParam.MapParamEntry(
|
||||
name = "boost",
|
||||
type = "float",
|
||||
valueHint = { "2.5" },
|
||||
description = "Floating point number used to decrease or increase the relevance scores of the query."
|
||||
+ "Defaults to 1.0."
|
||||
),
|
||||
@MapParam.MapParamEntry(
|
||||
name = "k",
|
||||
type = "integer",
|
||||
valueHint = { "10" },
|
||||
description = "The number of nearest neighbors to return from each shard. "
|
||||
+ "Elasticsearch collects k results from each shard, then merges them to find the global top results. "
|
||||
+ "This value must be less than or equal to num_candidates. Defaults to 10."
|
||||
),
|
||||
@MapParam.MapParamEntry(
|
||||
name = "num_candidates",
|
||||
type = "integer",
|
||||
valueHint = { "10" },
|
||||
description = "The number of nearest neighbor candidates to consider per shard while doing knn search. "
|
||||
+ "Cannot exceed 10,000. Increasing num_candidates tends to improve the accuracy of the final results. "
|
||||
+ "Defaults to 1.5 * k"
|
||||
),
|
||||
@MapParam.MapParamEntry(
|
||||
name = "similarity",
|
||||
type = "double",
|
||||
valueHint = { "0.01" },
|
||||
description = "The minimum similarity required for a document to be considered a match. "
|
||||
+ "The similarity value calculated relates to the raw similarity used, not the document score."
|
||||
),
|
||||
@MapParam.MapParamEntry(
|
||||
name = "rescore_oversample",
|
||||
type = "double",
|
||||
valueHint = { "3.5" },
|
||||
description = "Applies the specified oversampling for rescoring quantized vectors. "
|
||||
+ "See [oversampling and rescoring quantized vectors]"
|
||||
+ "(docs-content://solutions/search/vector/knn.md#dense-vector-knn-search-rescoring) for details."
|
||||
), },
|
||||
description = "(Optional) kNN additional options as <<esql-function-named-params,function named parameters>>."
|
||||
+ " See <<query-dsl-knn-query,knn query>> for more information.",
|
||||
optional = true
|
||||
) Expression options
|
||||
) {
|
||||
this(source, field, query, options, null);
|
||||
}
|
||||
|
||||
private Knn(Source source, Expression field, Expression query, Expression options, QueryBuilder queryBuilder) {
|
||||
super(source, query, options == null ? List.of(field, query) : List.of(field, query, options), queryBuilder);
|
||||
this.field = field;
|
||||
this.options = options;
|
||||
}
|
||||
|
||||
public Expression field() {
|
||||
return field;
|
||||
}
|
||||
|
||||
public Expression options() {
|
||||
return options;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType dataType() {
|
||||
return DataType.BOOLEAN;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TypeResolution resolveParams() {
|
||||
return resolveField().and(resolveQuery()).and(resolveOptions());
|
||||
}
|
||||
|
||||
private TypeResolution resolveField() {
|
||||
return isNotNull(field(), sourceText(), FIRST).and(isType(field(), dt -> dt == DENSE_VECTOR, sourceText(), FIRST, "dense_vector"));
|
||||
}
|
||||
|
||||
private TypeResolution resolveQuery() {
|
||||
return isType(query(), dt -> dt == DENSE_VECTOR, sourceText(), TypeResolutions.ParamOrdinal.SECOND, "dense_vector").and(
|
||||
isNotNullAndFoldable(query(), sourceText(), SECOND)
|
||||
);
|
||||
}
|
||||
|
||||
private TypeResolution resolveOptions() {
|
||||
if (options() != null) {
|
||||
TypeResolution resolution = isNotNull(options(), sourceText(), THIRD);
|
||||
if (resolution.unresolved()) {
|
||||
return resolution;
|
||||
}
|
||||
// MapExpression does not have a DataType associated with it
|
||||
resolution = isMapExpression(options(), sourceText(), THIRD);
|
||||
if (resolution.unresolved()) {
|
||||
return resolution;
|
||||
}
|
||||
|
||||
try {
|
||||
knnQueryOptions();
|
||||
} catch (InvalidArgumentException e) {
|
||||
return new TypeResolution(e.getMessage());
|
||||
}
|
||||
}
|
||||
return TypeResolution.TYPE_RESOLVED;
|
||||
}
|
||||
|
||||
private Map<String, Object> knnQueryOptions() throws InvalidArgumentException {
|
||||
if (options() == null) {
|
||||
return Map.of();
|
||||
}
|
||||
|
||||
Map<String, Object> matchOptions = new HashMap<>();
|
||||
populateOptionsMap((MapExpression) options(), matchOptions, THIRD, sourceText(), ALLOWED_OPTIONS);
|
||||
return matchOptions;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Query translate(TranslatorHandler handler) {
|
||||
var fieldAttribute = Match.fieldAsFieldAttribute(field());
|
||||
|
||||
Check.notNull(fieldAttribute, "Match must have a field attribute as the first argument");
|
||||
String fieldName = getNameFromFieldAttribute(fieldAttribute);
|
||||
@SuppressWarnings("unchecked")
|
||||
List<Number> queryFolded = (List<Number>) query().fold(FoldContext.small() /* TODO remove me */);
|
||||
float[] queryAsFloats = new float[queryFolded.size()];
|
||||
for (int i = 0; i < queryFolded.size(); i++) {
|
||||
queryAsFloats[i] = queryFolded.get(i).floatValue();
|
||||
}
|
||||
|
||||
return new KnnQuery(source(), fieldName, queryAsFloats, queryOptions());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression replaceQueryBuilder(QueryBuilder queryBuilder) {
|
||||
return new Knn(source(), field(), query(), options(), queryBuilder);
|
||||
}
|
||||
|
||||
private Map<String, Object> queryOptions() throws InvalidArgumentException {
|
||||
if (options() == null) {
|
||||
return Map.of();
|
||||
}
|
||||
|
||||
Map<String, Object> options = new HashMap<>();
|
||||
populateOptionsMap((MapExpression) options(), options, THIRD, sourceText(), ALLOWED_OPTIONS);
|
||||
return options;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression replaceChildren(List<Expression> newChildren) {
|
||||
return new Knn(
|
||||
source(),
|
||||
newChildren.get(0),
|
||||
newChildren.get(1),
|
||||
newChildren.size() > 2 ? newChildren.get(2) : null,
|
||||
queryBuilder()
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NodeInfo<? extends Expression> info() {
|
||||
return NodeInfo.create(this, Knn::new, field(), query(), options());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return ENTRY.name;
|
||||
}
|
||||
|
||||
private static Knn readFrom(StreamInput in) throws IOException {
|
||||
Source source = Source.readFrom((PlanStreamInput) in);
|
||||
Expression field = in.readNamedWriteable(Expression.class);
|
||||
Expression query = in.readNamedWriteable(Expression.class);
|
||||
QueryBuilder queryBuilder = in.readOptionalNamedWriteable(QueryBuilder.class);
|
||||
|
||||
return new Knn(source, field, query, null, queryBuilder);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
source().writeTo(out);
|
||||
out.writeNamedWriteable(field());
|
||||
out.writeNamedWriteable(query());
|
||||
out.writeOptionalNamedWriteable(queryBuilder());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
// Knn does not serialize options, as they get included in the query builder. We need to override equals and hashcode to
|
||||
// ignore options when comparing two Knn functions
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
Knn knn = (Knn) o;
|
||||
return Objects.equals(field(), knn.field())
|
||||
&& Objects.equals(query(), knn.query())
|
||||
&& Objects.equals(queryBuilder(), knn.queryBuilder());
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(field(), query(), queryBuilder());
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.expression.function.vector;
|
||||
|
||||
/**
|
||||
* Marker interface for vector functions. Makes possible to do implicit casting
|
||||
* from multi values to dense_vector field types, so parameters are actually
|
||||
* processed as dense_vectors in vector functions
|
||||
*/
|
||||
public interface VectorFunction {}
|
|
@ -0,0 +1,84 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.querydsl.query;
|
||||
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
|
||||
import org.elasticsearch.search.vectors.RescoreVectorBuilder;
|
||||
import org.elasticsearch.xpack.esql.core.querydsl.query.Query;
|
||||
import org.elasticsearch.xpack.esql.core.tree.Source;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.index.query.AbstractQueryBuilder.BOOST_FIELD;
|
||||
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD;
|
||||
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD;
|
||||
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD;
|
||||
|
||||
public class KnnQuery extends Query {
|
||||
|
||||
private final String field;
|
||||
private final float[] query;
|
||||
private final Map<String, Object> options;
|
||||
|
||||
public static final String RESCORE_OVERSAMPLE_FIELD = "rescore_oversample";
|
||||
|
||||
public KnnQuery(Source source, String field, float[] query, Map<String, Object> options) {
|
||||
super(source);
|
||||
assert options != null;
|
||||
this.field = field;
|
||||
this.query = query;
|
||||
this.options = options;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected QueryBuilder asBuilder() {
|
||||
Integer k = (Integer) options.get(K_FIELD.getPreferredName());
|
||||
Integer numCands = (Integer) options.get(NUM_CANDS_FIELD.getPreferredName());
|
||||
RescoreVectorBuilder rescoreVectorBuilder = null;
|
||||
Float oversample = (Float) options.get(RESCORE_OVERSAMPLE_FIELD);
|
||||
if (oversample != null) {
|
||||
rescoreVectorBuilder = new RescoreVectorBuilder(oversample);
|
||||
}
|
||||
Float vectorSimilarity = (Float) options.get(VECTOR_SIMILARITY_FIELD.getPreferredName());
|
||||
|
||||
KnnVectorQueryBuilder queryBuilder = new KnnVectorQueryBuilder(field, query, k, numCands, rescoreVectorBuilder, vectorSimilarity);
|
||||
Number boost = (Number) options.get(BOOST_FIELD.getPreferredName());
|
||||
if (boost != null) {
|
||||
queryBuilder.boost(boost.floatValue());
|
||||
}
|
||||
return queryBuilder;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String innerToString() {
|
||||
return "knn(" + field + ", " + Arrays.toString(query) + " options={" + options + "}))";
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (super.equals(o) == false) return false;
|
||||
|
||||
KnnQuery knnQuery = (KnnQuery) o;
|
||||
return Objects.equals(field, knnQuery.field)
|
||||
&& Objects.deepEquals(query, knnQuery.query)
|
||||
&& Objects.equals(options, knnQuery.options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(super.hashCode(), field, Arrays.hashCode(query), options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean scorable() {
|
||||
return true;
|
||||
}
|
||||
}
|
|
@ -296,6 +296,10 @@ public class CsvTests extends ESTestCase {
|
|||
"can't use KQL function in csv tests",
|
||||
testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KQL_FUNCTION.capabilityName())
|
||||
);
|
||||
assumeFalse(
|
||||
"can't use KNN function in csv tests",
|
||||
testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION.capabilityName())
|
||||
);
|
||||
assumeFalse(
|
||||
"lookup join disabled for csv tests",
|
||||
testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.JOIN_LOOKUP_V12.capabilityName())
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.elasticsearch.index.query.RegexpQueryBuilder;
|
|||
import org.elasticsearch.index.query.TermQueryBuilder;
|
||||
import org.elasticsearch.index.query.TermsQueryBuilder;
|
||||
import org.elasticsearch.index.query.WildcardQueryBuilder;
|
||||
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
|
||||
import org.elasticsearch.test.EqualsHashCodeTestUtils;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Expression;
|
||||
import org.elasticsearch.xpack.esql.expression.ExpressionWritables;
|
||||
|
@ -111,6 +112,7 @@ public class SerializationTestUtils {
|
|||
entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, WildcardQueryBuilder.NAME, WildcardQueryBuilder::new));
|
||||
entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, RegexpQueryBuilder.NAME, RegexpQueryBuilder::new));
|
||||
entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, ExistsQueryBuilder.NAME, ExistsQueryBuilder::new));
|
||||
entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, KnnVectorQueryBuilder.NAME, KnnVectorQueryBuilder::new));
|
||||
entries.add(SingleValueQuery.ENTRY);
|
||||
entries.addAll(ExpressionWritables.getNamedWriteables());
|
||||
entries.addAll(PlanWritables.getNamedWriteables());
|
||||
|
|
|
@ -55,6 +55,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger
|
|||
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
|
||||
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat;
|
||||
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Substring;
|
||||
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
|
||||
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add;
|
||||
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
|
||||
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan;
|
||||
|
@ -2363,6 +2364,22 @@ public class AnalyzerTests extends ESTestCase {
|
|||
assertThat(e.getMessage(), containsString("[+] has arguments with incompatible types [datetime] and [datetime]"));
|
||||
}
|
||||
|
||||
public void testDenseVectorImplicitCasting() {
|
||||
Analyzer analyzer = analyzer(loadMapping("mapping-dense_vector.json", "vectors"));
|
||||
|
||||
var plan = analyze("""
|
||||
from test | where knn(vector, [0.342, 0.164, 0.234])
|
||||
""", "mapping-dense_vector.json");
|
||||
|
||||
var limit = as(plan, Limit.class);
|
||||
var filter = as(limit.child(), Filter.class);
|
||||
var knn = as(filter.condition(), Knn.class);
|
||||
var field = knn.field();
|
||||
var queryVector = as(knn.query(), Literal.class);
|
||||
assertEquals(DataType.DENSE_VECTOR, queryVector.dataType());
|
||||
assertThat(queryVector.value(), equalTo(List.of(0.342, 0.164, 0.234)));
|
||||
}
|
||||
|
||||
public void testRateRequiresCounterTypes() {
|
||||
assumeTrue("rate requires snapshot builds", Build.current().isSnapshot());
|
||||
Analyzer analyzer = analyzer(tsdbIndexResolution());
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.elasticsearch.xpack.esql.expression.function.fulltext.Match;
|
|||
import org.elasticsearch.xpack.esql.expression.function.fulltext.MatchPhrase;
|
||||
import org.elasticsearch.xpack.esql.expression.function.fulltext.MultiMatch;
|
||||
import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString;
|
||||
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
|
||||
import org.elasticsearch.xpack.esql.index.EsIndex;
|
||||
import org.elasticsearch.xpack.esql.index.IndexResolution;
|
||||
import org.elasticsearch.xpack.esql.parser.EsqlParser;
|
||||
|
@ -1234,6 +1235,9 @@ public class VerifierTests extends ESTestCase {
|
|||
checkFieldBasedWithNonIndexedColumn("Term", "term(text, \"cat\")", "function");
|
||||
checkFieldBasedFunctionNotAllowedAfterCommands("Term", "function", "term(title, \"Meditation\")");
|
||||
}
|
||||
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
|
||||
checkFieldBasedFunctionNotAllowedAfterCommands("KNN", "function", "knn(vector, [1, 2, 3])");
|
||||
}
|
||||
}
|
||||
|
||||
private void checkFieldBasedFunctionNotAllowedAfterCommands(String functionName, String functionType, String functionInvocation) {
|
||||
|
@ -1364,6 +1368,9 @@ public class VerifierTests extends ESTestCase {
|
|||
if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) {
|
||||
checkFullTextFunctionsOnlyAllowedInWhere("MultiMatch", "multi_match(\"Meditation\", title, body)", "function");
|
||||
}
|
||||
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
|
||||
checkFullTextFunctionsOnlyAllowedInWhere("KNN", "knn(vector, [0, 1, 2])", "function");
|
||||
}
|
||||
}
|
||||
|
||||
private void checkFullTextFunctionsOnlyAllowedInWhere(String functionName, String functionInvocation, String functionType)
|
||||
|
@ -1400,6 +1407,9 @@ public class VerifierTests extends ESTestCase {
|
|||
if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) {
|
||||
checkWithFullTextFunctionsDisjunctions("term(title, \"Meditation\")");
|
||||
}
|
||||
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
|
||||
checkWithFullTextFunctionsDisjunctions("knn(vector, [1, 2, 3])");
|
||||
}
|
||||
}
|
||||
|
||||
private void checkWithFullTextFunctionsDisjunctions(String functionInvocation) {
|
||||
|
@ -1462,6 +1472,9 @@ public class VerifierTests extends ESTestCase {
|
|||
if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) {
|
||||
checkFullTextFunctionsWithNonBooleanFunctions("Term", "term(title, \"Meditation\")", "function");
|
||||
}
|
||||
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
|
||||
checkFullTextFunctionsWithNonBooleanFunctions("KNN", "knn(vector, [1, 2, 3])", "function");
|
||||
}
|
||||
}
|
||||
|
||||
private void checkFullTextFunctionsWithNonBooleanFunctions(String functionName, String functionInvocation, String functionType) {
|
||||
|
@ -1530,6 +1543,9 @@ public class VerifierTests extends ESTestCase {
|
|||
if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) {
|
||||
testFullTextFunctionTargetsExistingField("term(fist_name, \"Meditation\")");
|
||||
}
|
||||
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
|
||||
testFullTextFunctionTargetsExistingField("knn(vector, [0, 1, 2])");
|
||||
}
|
||||
}
|
||||
|
||||
private void testFullTextFunctionTargetsExistingField(String functionInvocation) throws Exception {
|
||||
|
@ -2055,6 +2071,9 @@ public class VerifierTests extends ESTestCase {
|
|||
if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) {
|
||||
checkOptionDataTypes(MultiMatch.OPTIONS, "FROM test | WHERE MULTI_MATCH(\"Jean\", title, body, {\"%s\": %s})");
|
||||
}
|
||||
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
|
||||
checkOptionDataTypes(Knn.ALLOWED_OPTIONS, "FROM test | WHERE KNN(vector, [0.1, 0.2, 0.3], {\"%s\": %s})");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -2140,6 +2159,10 @@ public class VerifierTests extends ESTestCase {
|
|||
checkFullTextFunctionNullArgs("term(null, \"query\")", "first");
|
||||
checkFullTextFunctionNullArgs("term(title, null)", "second");
|
||||
}
|
||||
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
|
||||
checkFullTextFunctionNullArgs("knn(null, [0, 1, 2])", "first");
|
||||
checkFullTextFunctionNullArgs("knn(vector, null)", "second");
|
||||
}
|
||||
}
|
||||
|
||||
private void checkFullTextFunctionNullArgs(String functionInvocation, String argOrdinal) throws Exception {
|
||||
|
@ -2161,6 +2184,9 @@ public class VerifierTests extends ESTestCase {
|
|||
if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) {
|
||||
checkFullTextFunctionsConstantQuery("term(title, tags)", "second");
|
||||
}
|
||||
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
|
||||
checkFullTextFunctionsConstantQuery("knn(vector, vector)", "second");
|
||||
}
|
||||
}
|
||||
|
||||
private void checkFullTextFunctionsConstantQuery(String functionInvocation, String argOrdinal) throws Exception {
|
||||
|
@ -2188,6 +2214,9 @@ public class VerifierTests extends ESTestCase {
|
|||
if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) {
|
||||
checkFullTextFunctionsInStats("multi_match(\"Meditation\", title, body)");
|
||||
}
|
||||
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
|
||||
checkFullTextFunctionsInStats("knn(vector, [0, 1, 2])");
|
||||
}
|
||||
}
|
||||
|
||||
private void checkFullTextFunctionsInStats(String functionInvocation) {
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.esql.expression.function.fulltext;
|
||||
|
||||
import com.carrotsearch.randomizedtesting.annotations.Name;
|
||||
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
|
||||
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Expression;
|
||||
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
|
||||
import org.elasticsearch.xpack.esql.core.expression.Literal;
|
||||
import org.elasticsearch.xpack.esql.core.expression.MapExpression;
|
||||
import org.elasticsearch.xpack.esql.core.tree.Source;
|
||||
import org.elasticsearch.xpack.esql.core.type.DataType;
|
||||
import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase;
|
||||
import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
|
||||
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
|
||||
import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput;
|
||||
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
import static org.elasticsearch.xpack.esql.SerializationTestUtils.serializeDeserialize;
|
||||
import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN;
|
||||
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
|
||||
import static org.elasticsearch.xpack.esql.core.type.DataType.KEYWORD;
|
||||
import static org.elasticsearch.xpack.esql.core.type.DataType.UNSUPPORTED;
|
||||
import static org.elasticsearch.xpack.esql.planner.TranslatorHandler.TRANSLATOR_HANDLER;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class KnnTests extends AbstractFunctionTestCase {
|
||||
|
||||
public KnnTests(@Name("TestCase") Supplier<TestCaseSupplier.TestCase> testCaseSupplier) {
|
||||
this.testCase = testCaseSupplier.get();
|
||||
}
|
||||
|
||||
@ParametersFactory
|
||||
public static Iterable<Object[]> parameters() {
|
||||
return parameterSuppliersFromTypedData(addFunctionNamedParams(testCaseSuppliers()));
|
||||
}
|
||||
|
||||
@Before
|
||||
public void checkCapability() {
|
||||
assumeTrue("KNN is not enabled", EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled());
|
||||
}
|
||||
|
||||
private static List<TestCaseSupplier> testCaseSuppliers() {
|
||||
List<TestCaseSupplier> suppliers = new ArrayList<>();
|
||||
|
||||
suppliers.add(
|
||||
TestCaseSupplier.testCaseSupplier(
|
||||
new TestCaseSupplier.TypedDataSupplier("dense_vector field", KnnTests::randomDenseVector, DENSE_VECTOR),
|
||||
new TestCaseSupplier.TypedDataSupplier("query", KnnTests::randomDenseVector, DENSE_VECTOR, true),
|
||||
(d1, d2) -> equalTo("string"),
|
||||
BOOLEAN,
|
||||
(o1, o2) -> true
|
||||
)
|
||||
);
|
||||
|
||||
return suppliers;
|
||||
}
|
||||
|
||||
private static List<Float> randomDenseVector() {
|
||||
int dimensions = randomIntBetween(64, 128);
|
||||
List<Float> vector = new ArrayList<>();
|
||||
for (int i = 0; i < dimensions; i++) {
|
||||
vector.add(randomFloat());
|
||||
}
|
||||
return vector;
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds function named parameters to all the test case suppliers provided
|
||||
*/
|
||||
private static List<TestCaseSupplier> addFunctionNamedParams(List<TestCaseSupplier> suppliers) {
|
||||
// TODO get to a common class with MatchTests
|
||||
List<TestCaseSupplier> result = new ArrayList<>();
|
||||
for (TestCaseSupplier supplier : suppliers) {
|
||||
List<DataType> dataTypes = new ArrayList<>(supplier.types());
|
||||
dataTypes.add(UNSUPPORTED);
|
||||
result.add(new TestCaseSupplier(supplier.name() + ", options", dataTypes, () -> {
|
||||
List<TestCaseSupplier.TypedData> values = new ArrayList<>(supplier.get().getData());
|
||||
values.add(
|
||||
new TestCaseSupplier.TypedData(
|
||||
new MapExpression(Source.EMPTY, List.of(new Literal(Source.EMPTY, randomAlphaOfLength(10), KEYWORD))),
|
||||
UNSUPPORTED,
|
||||
"options"
|
||||
).forceLiteral()
|
||||
);
|
||||
|
||||
return new TestCaseSupplier.TestCase(values, equalTo("KnnEvaluator"), BOOLEAN, equalTo(true));
|
||||
}));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Expression build(Source source, List<Expression> args) {
|
||||
Knn knn = new Knn(source, args.get(0), args.get(1), args.size() > 2 ? args.get(2) : null);
|
||||
// We need to add the QueryBuilder to the match expression, as it is used to implement equals() and hashCode() and
|
||||
// thus test the serialization methods. But we can only do this if the parameters make sense .
|
||||
if (args.get(0) instanceof FieldAttribute && args.get(1).foldable()) {
|
||||
QueryBuilder queryBuilder = TRANSLATOR_HANDLER.asQuery(LucenePushdownPredicates.DEFAULT, knn).toQueryBuilder();
|
||||
knn = (Knn) knn.replaceQueryBuilder(queryBuilder);
|
||||
}
|
||||
return knn;
|
||||
}
|
||||
|
||||
/**
|
||||
* Copy of the overridden method that doesn't check for children size, as the {@code options} child isn't serialized in Match.
|
||||
*/
|
||||
@Override
|
||||
protected Expression serializeDeserializeExpression(Expression expression) {
|
||||
Expression newExpression = serializeDeserialize(
|
||||
expression,
|
||||
PlanStreamOutput::writeNamedWriteable,
|
||||
in -> in.readNamedWriteable(Expression.class),
|
||||
testCase.getConfiguration() // The configuration query should be == to the source text of the function for this to work
|
||||
);
|
||||
// Fields use synthetic sources, which can't be serialized. So we use the originals instead.
|
||||
return newExpression.replaceChildren(expression.children());
|
||||
}
|
||||
}
|
|
@ -82,7 +82,7 @@ public class MatchTests extends AbstractMatchFullTextFunctionTests {
|
|||
// thus test the serialization methods. But we can only do this if the parameters make sense .
|
||||
if (args.get(0) instanceof FieldAttribute && args.get(1).foldable()) {
|
||||
QueryBuilder queryBuilder = TRANSLATOR_HANDLER.asQuery(LucenePushdownPredicates.DEFAULT, match).toQueryBuilder();
|
||||
match.replaceQueryBuilder(queryBuilder);
|
||||
match = (Match) match.replaceQueryBuilder(queryBuilder);
|
||||
}
|
||||
return match;
|
||||
}
|
||||
|
|
|
@ -28,6 +28,8 @@ import org.elasticsearch.index.query.QueryStringQueryBuilder;
|
|||
import org.elasticsearch.index.query.RangeQueryBuilder;
|
||||
import org.elasticsearch.index.query.SearchExecutionContext;
|
||||
import org.elasticsearch.license.XPackLicenseState;
|
||||
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
|
||||
import org.elasticsearch.search.vectors.RescoreVectorBuilder;
|
||||
import org.elasticsearch.test.VersionUtils;
|
||||
import org.elasticsearch.xpack.core.enrich.EnrichPolicy;
|
||||
import org.elasticsearch.xpack.esql.EsqlTestUtils;
|
||||
|
@ -1359,6 +1361,29 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
|
|||
assertThat(expectedQuery.toString(), is(planStr.get()));
|
||||
}
|
||||
|
||||
public void testKnnOptionsPushDown() {
|
||||
String query = """
|
||||
from test
|
||||
| where KNN(dense_vector, [0.1, 0.2, 0.3],
|
||||
{ "k": 5, "similarity": 0.001, "num_candidates": 10, "rescore_oversample": 7, "boost": 3.5 })
|
||||
""";
|
||||
var analyzer = makeAnalyzer("mapping-all-types.json");
|
||||
var plan = plannerOptimizer.plan(query, IS_SV_STATS, analyzer);
|
||||
|
||||
AtomicReference<String> planStr = new AtomicReference<>();
|
||||
plan.forEachDown(EsQueryExec.class, result -> planStr.set(result.query().toString()));
|
||||
|
||||
var expectedQuery = new KnnVectorQueryBuilder(
|
||||
"dense_vector",
|
||||
new float[] { 0.1f, 0.2f, 0.3f },
|
||||
5,
|
||||
10,
|
||||
new RescoreVectorBuilder(7),
|
||||
0.001f
|
||||
).boost(3.5f);
|
||||
assertThat(expectedQuery.toString(), is(planStr.get()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Expecting
|
||||
* LimitExec[1000[INTEGER]]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue