ES|QL - kNN function initial support (#127322)

This commit is contained in:
Carlos Delgado 2025-06-10 11:11:42 +02:00 committed by GitHub
parent aa87f46681
commit 366e00f5c5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 1251 additions and 23 deletions

View file

@ -0,0 +1 @@
<svg version="1.1" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg" width="568" height="61" viewbox="0 0 568 61"><defs><style type="text/css">.c{fill:none;stroke:#222222;}.k{fill:#000000;font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;font-size:20px;}.s{fill:#e4f4ff;stroke:#222222;}.syn{fill:#8D8D8D;font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;font-size:20px;}</style></defs><path class="c" d="M0 31h5m56 0h10m32 0h10m80 0h10m32 0h10m80 0h10m32 0h30m104 0h20m-139 0q5 0 5 5v10q0 5 5 5h114q5 0 5-5v-10q0-5 5-5m5 0h10m32 0h5"/><rect class="s" x="5" y="5" width="56" height="36"/><text class="k" x="15" y="31">KNN</text><rect class="s" x="71" y="5" width="32" height="36" rx="7"/><text class="syn" x="81" y="31">(</text><rect class="s" x="113" y="5" width="80" height="36" rx="7"/><text class="k" x="123" y="31">field</text><rect class="s" x="203" y="5" width="32" height="36" rx="7"/><text class="syn" x="213" y="31">,</text><rect class="s" x="245" y="5" width="80" height="36" rx="7"/><text class="k" x="255" y="31">query</text><rect class="s" x="335" y="5" width="32" height="36" rx="7"/><text class="syn" x="345" y="31">,</text><rect class="s" x="397" y="5" width="104" height="36" rx="7"/><text class="k" x="407" y="31">options</text><rect class="s" x="531" y="5" width="32" height="36" rx="7"/><text class="syn" x="541" y="31">)</text></svg>

After

Width:  |  Height:  |  Size: 1.5 KiB

View file

@ -0,0 +1,13 @@
{
"comment" : "This is generated by ESQLs AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.",
"type" : "scalar",
"name" : "knn",
"description" : "Finds the k nearest vectors to a query vector, as measured by a similarity metric. knn function finds nearest vectors through approximate search on indexed dense_vectors.",
"signatures" : [ ],
"examples" : [
"from colors metadata _score\n| where knn(rgb_vector, [0, 120, 0])\n| sort _score desc",
"from colors metadata _score\n| where knn(rgb_vector, [0,255,255], {\"k\": 4})\n| sort _score desc"
],
"preview" : true,
"snapshot_only" : true
}

View file

@ -0,0 +1,10 @@
% This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.
### KNN
Finds the k nearest vectors to a query vector, as measured by a similarity metric. knn function finds nearest vectors through approximate search on indexed dense_vectors.
```esql
from colors metadata _score
| where knn(rgb_vector, [0, 120, 0])
| sort _score desc
```

View file

@ -2589,6 +2589,11 @@ public class DenseVectorFieldMapper extends FieldMapper {
return null;
}
if (dims == null) {
// No data has been indexed yet
return BlockLoader.CONSTANT_NULLS;
}
if (indexed) {
return new BlockDocValuesReader.DenseVectorBlockLoader(name(), dims);
}

View file

@ -112,7 +112,7 @@ public abstract class LuceneQueryEvaluator<T extends Vector.Builder> implements
int min = docs.docs().getInt(0);
int max = docs.docs().getInt(docs.getPositionCount() - 1);
int length = max - min + 1;
try (T scoreBuilder = createVectorBuilder(blockFactory, length)) {
try (T scoreBuilder = createVectorBuilder(blockFactory, docs.getPositionCount())) {
if (length == docs.getPositionCount() && length > 1) {
return segmentState.scoreDense(scoreBuilder, min, max);
}

View file

@ -1022,7 +1022,7 @@ public abstract class RestEsqlTestCase extends ESRestTestCase {
var query = requestObjectBuilder().query(format(null, "from * | lookup join {} on integer {}", testIndexName(), sort));
Map<String, Object> result = runEsql(query);
var columns = as(result.get("columns"), List.class);
assertEquals(21, columns.size());
assertEquals(22, columns.size());
var values = as(result.get("values"), List.class);
assertEquals(10, values.size());
}

View file

@ -148,6 +148,7 @@ public class CsvTestsDataLoader {
private static final TestDataset LOGS = new TestDataset("logs");
private static final TestDataset MV_TEXT = new TestDataset("mv_text");
private static final TestDataset DENSE_VECTOR = new TestDataset("dense_vector");
private static final TestDataset COLORS = new TestDataset("colors");
public static final Map<String, TestDataset> CSV_DATASET_MAP = Map.ofEntries(
Map.entry(EMPLOYEES.indexName, EMPLOYEES),
@ -210,7 +211,8 @@ public class CsvTestsDataLoader {
Map.entry(SEMANTIC_TEXT.indexName, SEMANTIC_TEXT),
Map.entry(LOGS.indexName, LOGS),
Map.entry(MV_TEXT.indexName, MV_TEXT),
Map.entry(DENSE_VECTOR.indexName, DENSE_VECTOR)
Map.entry(DENSE_VECTOR.indexName, DENSE_VECTOR),
Map.entry(COLORS.indexName, COLORS)
);
private static final EnrichConfig LANGUAGES_ENRICH = new EnrichConfig("languages_policy", "enrich-policy-languages.json");

View file

@ -0,0 +1,60 @@
color:text,hex_code:keyword,rgb_vector:dense_vector,primary:boolean
maroon, #800000, [128,0,0], false
brown, #A52A2A, [165,42,42], false
firebrick, #B22222, [178,34,34], false
crimson, #DC143C, [220,20,60], false
red, #FF0000, [255,0,0], true
tomato, #FF6347, [255,99,71], false
coral, #FF7F50, [255,127,80], false
salmon, #FA8072, [250,128,114], false
orange, #FFA500, [255,165,0], false
gold, #FFD700, [255,215,0], false
golden rod, #DAA520, [218,165,32], false
khaki, #F0E68C, [240,230,140], false
olive, #808000, [128,128,0], false
yellow, #FFFF00, [255,255,0], true
chartreuse, #7FFF00, [127,255,0], false
green, #008000, [0,128,0], true
lime, #00FF00, [0,255,0], false
teal, #008080, [0,128,128], false
cyan, #00FFFF, [0,255,255], true
turquoise, #40E0D0, [64,224,208], false
aqua marine, #7FFFD4, [127,255,212], false
navy, #000080, [0,0,128], false
blue, #0000FF, [0,0,255], true
indigo, #4B0082, [75,0,130], false
purple, #800080, [128,0,128], false
thistle, #D8BFD8, [216,191,216], false
plum, #DDA0DD, [221,160,221], false
violet, #EE82EE, [238,130,238], false
magenta, #FF00FF, [255,0,255], true
orchid, #DA70D6, [218,112,214], false
pink, #FFC0CB, [255,192,203], false
beige, #F5F5DC, [245,245,220], false
bisque, #FFE4C4, [255,228,196], false
wheat, #F5DEB3, [245,222,179], false
corn silk, #FFF8DC, [255,248,220], false
lemon chiffon, #FFFACD, [255,250,205], false
sienna, #A0522D, [160,82,45], false
chocolate, #D2691E, [210,105,30], false
peru, #CD853F, [205,133,63], false
burly wood, #DEB887, [222,184,135], false
tan, #D2B48C, [210,180,140], false
moccasin, #FFE4B5, [255,228,181], false
peach puff, #FFDAB9, [255,218,185], false
misty rose, #FFE4E1, [255,228,225], false
linen, #FAF0E6, [250,240,230], false
old lace, #FDF5E6, [253,245,230], false
papaya whip, #FFEFD5, [255,239,213], false
sea shell, #FFF5EE, [255,245,238], false
mint cream, #F5FFFA, [245,255,250], false
lavender, #E6E6FA, [230,230,250], false
honeydew, #F0FFF0, [240,255,240], false
ivory, #FFFFF0, [255,255,240], false
azure, #F0FFFF, [240,255,255], false
snow, #FFFAFA, [255,250,250], false
black, #000000, [0,0,0], true
gray, #808080, [128,128,128], true
silver, #C0C0C0, [192,192,192], false
gainsboro, #DCDCDC, [220,220,220], false
white, #FFFFFF, [255,255,255], true
1 color:text,hex_code:keyword,rgb_vector:dense_vector,primary:boolean
2 maroon, #800000, [128,0,0], false
3 brown, #A52A2A, [165,42,42], false
4 firebrick, #B22222, [178,34,34], false
5 crimson, #DC143C, [220,20,60], false
6 red, #FF0000, [255,0,0], true
7 tomato, #FF6347, [255,99,71], false
8 coral, #FF7F50, [255,127,80], false
9 salmon, #FA8072, [250,128,114], false
10 orange, #FFA500, [255,165,0], false
11 gold, #FFD700, [255,215,0], false
12 golden rod, #DAA520, [218,165,32], false
13 khaki, #F0E68C, [240,230,140], false
14 olive, #808000, [128,128,0], false
15 yellow, #FFFF00, [255,255,0], true
16 chartreuse, #7FFF00, [127,255,0], false
17 green, #008000, [0,128,0], true
18 lime, #00FF00, [0,255,0], false
19 teal, #008080, [0,128,128], false
20 cyan, #00FFFF, [0,255,255], true
21 turquoise, #40E0D0, [64,224,208], false
22 aqua marine, #7FFFD4, [127,255,212], false
23 navy, #000080, [0,0,128], false
24 blue, #0000FF, [0,0,255], true
25 indigo, #4B0082, [75,0,130], false
26 purple, #800080, [128,0,128], false
27 thistle, #D8BFD8, [216,191,216], false
28 plum, #DDA0DD, [221,160,221], false
29 violet, #EE82EE, [238,130,238], false
30 magenta, #FF00FF, [255,0,255], true
31 orchid, #DA70D6, [218,112,214], false
32 pink, #FFC0CB, [255,192,203], false
33 beige, #F5F5DC, [245,245,220], false
34 bisque, #FFE4C4, [255,228,196], false
35 wheat, #F5DEB3, [245,222,179], false
36 corn silk, #FFF8DC, [255,248,220], false
37 lemon chiffon, #FFFACD, [255,250,205], false
38 sienna, #A0522D, [160,82,45], false
39 chocolate, #D2691E, [210,105,30], false
40 peru, #CD853F, [205,133,63], false
41 burly wood, #DEB887, [222,184,135], false
42 tan, #D2B48C, [210,180,140], false
43 moccasin, #FFE4B5, [255,228,181], false
44 peach puff, #FFDAB9, [255,218,185], false
45 misty rose, #FFE4E1, [255,228,225], false
46 linen, #FAF0E6, [250,240,230], false
47 old lace, #FDF5E6, [253,245,230], false
48 papaya whip, #FFEFD5, [255,239,213], false
49 sea shell, #FFF5EE, [255,245,238], false
50 mint cream, #F5FFFA, [245,255,250], false
51 lavender, #E6E6FA, [230,230,250], false
52 honeydew, #F0FFF0, [240,255,240], false
53 ivory, #FFFFF0, [255,255,240], false
54 azure, #F0FFFF, [240,255,255], false
55 snow, #FFFAFA, [255,250,250], false
56 black, #000000, [0,0,0], true
57 gray, #808080, [128,128,128], true
58 silver, #C0C0C0, [192,192,192], false
59 gainsboro, #DCDCDC, [220,220,220], false
60 white, #FFFFFF, [255,255,255], true

View file

@ -0,0 +1,285 @@
# TODO Most tests explicitly set k. Until knn function uses LIMIT as k, we need to explicitly set it to all values
# in the dataset to avoid test failures due to docs allocation in different shards, which can impact results for a
# top-n query at the shard level
knnSearch
required_capability: knn_function
// tag::knn-function[]
from colors metadata _score
| where knn(rgb_vector, [0, 120, 0])
| sort _score desc, color asc
// end::knn-function[]
| keep color, rgb_vector
| limit 10
;
// tag::knn-function-result[]
color:text | rgb_vector:dense_vector
green | [0.0, 128.0, 0.0]
black | [0.0, 0.0, 0.0]
olive | [128.0, 128.0, 0.0]
teal | [0.0, 128.0, 128.0]
lime | [0.0, 255.0, 0.0]
sienna | [160.0, 82.0, 45.0]
maroon | [128.0, 0.0, 0.0]
navy | [0.0, 0.0, 128.0]
gray | [128.0, 128.0, 128.0]
chartreuse | [127.0, 255.0, 0.0]
// end::knn-function-result[]
;
knnSearchWithKOption
required_capability: knn_function
// tag::knn-function-options[]
from colors metadata _score
| where knn(rgb_vector, [0,255,255], {"k": 4})
| sort _score desc, color asc
// end::knn-function-options[]
| keep color, rgb_vector
| limit 4
;
color:text | rgb_vector:dense_vector
cyan | [0.0, 255.0, 255.0]
turquoise | [64.0, 224.0, 208.0]
aqua marine | [127.0, 255.0, 212.0]
teal | [0.0, 128.0, 128.0]
;
knnSearchWithSimilarityOption
required_capability: knn_function
from colors metadata _score
| where knn(rgb_vector, [255,192,203], {"k": 140, "similarity": 40})
| sort _score desc, color asc
| keep color, rgb_vector
;
color:text | rgb_vector:dense_vector
pink | [255.0, 192.0, 203.0]
peach puff | [255.0, 218.0, 185.0]
bisque | [255.0, 228.0, 196.0]
wheat | [245.0, 222.0, 179.0]
;
knnHybridSearch
required_capability: knn_function
from colors metadata _score
| where match(color, "blue") or knn(rgb_vector, [65,105,225], {"k": 140})
| where primary == true
| sort _score desc, color asc
| keep color, rgb_vector
| limit 10
;
color:text | rgb_vector:dense_vector
blue | [0.0, 0.0, 255.0]
gray | [128.0, 128.0, 128.0]
cyan | [0.0, 255.0, 255.0]
magenta | [255.0, 0.0, 255.0]
green | [0.0, 128.0, 0.0]
white | [255.0, 255.0, 255.0]
black | [0.0, 0.0, 0.0]
red | [255.0, 0.0, 0.0]
yellow | [255.0, 255.0, 0.0]
;
knnWithMultipleFunctions
required_capability: knn_function
from colors metadata _score
| where knn(rgb_vector, [128,128,0], {"k": 140}) and match(color, "olive")
| sort _score desc, color asc
| keep color, rgb_vector
;
color:text | rgb_vector:dense_vector
olive | [128.0, 128.0, 0.0]
;
knnAfterKeep
required_capability: knn_function
from colors metadata _score
| keep rgb_vector, color, _score
| where knn(rgb_vector, [128,255,0], {"k": 140})
| sort _score desc, color asc
| keep rgb_vector
| limit 5
;
rgb_vector:dense_vector
[127.0, 255.0, 0.0]
[128.0, 128.0, 0.0]
[255.0, 255.0, 0.0]
[0.0, 255.0, 0.0]
[218.0, 165.0, 32.0]
;
knnAfterDrop
required_capability: knn_function
from colors metadata _score
| drop primary
| where knn(rgb_vector, [128,250,0], {"k": 140})
| sort _score desc, color asc
| keep color, rgb_vector
| limit 5
;
color:text | rgb_vector: dense_vector
chartreuse | [127.0, 255.0, 0.0]
olive | [128.0, 128.0, 0.0]
yellow | [255.0, 255.0, 0.0]
golden rod | [218.0, 165.0, 32.0]
lime | [0.0, 255.0, 0.0]
;
knnAfterEval
required_capability: knn_function
from colors metadata _score
| eval composed_name = locate(color, " ") > 0
| where knn(rgb_vector, [128,128,0], {"k": 140})
| sort _score desc, color asc
| keep color, composed_name
| limit 5
;
color:text | composed_name:boolean
olive | false
sienna | false
chocolate | false
peru | false
golden rod | true
;
knnWithConjunction
required_capability: knn_function
# TODO We need kNN prefiltering here so we get more candidates that pass the filter
from colors metadata _score
| where knn(rgb_vector, [255,255,238], {"k": 140}) and hex_code like "#FFF*"
| sort _score desc, color asc
| keep color, hex_code, rgb_vector
| limit 10
;
color:text | hex_code:keyword | rgb_vector:dense_vector
ivory | #FFFFF0 | [255.0, 255.0, 240.0]
sea shell | #FFF5EE | [255.0, 245.0, 238.0]
snow | #FFFAFA | [255.0, 250.0, 250.0]
white | #FFFFFF | [255.0, 255.0, 255.0]
corn silk | #FFF8DC | [255.0, 248.0, 220.0]
lemon chiffon | #FFFACD | [255.0, 250.0, 205.0]
yellow | #FFFF00 | [255.0, 255.0, 0.0]
;
knnWithDisjunctionAndFiltersConjunction
required_capability: knn_function
# TODO We need kNN prefiltering here so we get more candidates that pass the filter
from colors metadata _score
| where (knn(rgb_vector, [0,255,255], {"k": 140}) or knn(rgb_vector, [128, 0, 255], {"k": 140})) and primary == true
| keep color, rgb_vector, _score
| sort _score desc, color asc
| drop _score
| limit 10
;
color:text | rgb_vector:dense_vector
cyan | [0.0, 255.0, 255.0]
blue | [0.0, 0.0, 255.0]
magenta | [255.0, 0.0, 255.0]
gray | [128.0, 128.0, 128.0]
white | [255.0, 255.0, 255.0]
green | [0.0, 128.0, 0.0]
black | [0.0, 0.0, 0.0]
red | [255.0, 0.0, 0.0]
yellow | [255.0, 255.0, 0.0]
;
knnWithNonPushableConjunction
required_capability: knn_function
from colors metadata _score
| eval composed_name = locate(color, " ") > 0
| where knn(rgb_vector, [128,128,0], {"k": 140}) and composed_name == false
| sort _score desc, color asc
| keep color, composed_name
| limit 10
;
color:text | composed_name:boolean
olive | false
sienna | false
chocolate | false
peru | false
brown | false
firebrick | false
chartreuse | false
gray | false
green | false
maroon | false
;
testKnnWithNonPushableDisjunctions
required_capability: knn_function
from colors metadata _score
| where knn(rgb_vector, [128,128,0], {"k": 140, "similarity": 30}) or length(color) > 10
| sort _score desc, color asc
| keep color
;
color:text
olive
aqua marine
lemon chiffon
papaya whip
;
testKnnWithNonPushableDisjunctionsOnComplexExpressions
required_capability: knn_function
from colors metadata _score
| where (knn(rgb_vector, [128,128,0], {"k": 140, "similarity": 70}) and length(color) < 10) or (knn(rgb_vector, [128,0,128], {"k": 140, "similarity": 60}) and primary == false)
| sort _score desc, color asc
| keep color, primary
;
color:text | primary:boolean
olive | false
purple | false
indigo | false
;
testKnnInStatsNonPushable
required_capability: knn_function
from colors
| where length(color) < 10
| stats c = count(*) where knn(rgb_vector, [128,128,255], {"k": 140})
;
c: long
50
;
testKnnInStatsWithGrouping
required_capability: knn_function
required_capability: full_text_functions_in_stats_where
from colors
| where length(color) < 10
| stats c = count(*) where knn(rgb_vector, [128,128,255], {"k": 140}) by primary
;
c: long | primary: boolean
41 | false
9 | true
;

View file

@ -63,6 +63,9 @@
"semantic_text": {
"type": "semantic_text",
"inference_id": "foo_inference_id"
},
"dense_vector": {
"type": "dense_vector"
}
}
}

View file

@ -0,0 +1,20 @@
{
"properties": {
"color": {
"type": "text"
},
"hex_code": {
"type": "keyword"
},
"primary": {
"type": "boolean"
},
"rgb_vector": {
"type": "dense_vector",
"similarity": "l2_norm",
"index_options": {
"type": "hnsw"
}
}
}
}

View file

@ -128,10 +128,57 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase {
}
}
public void testNonIndexedDenseVectorField() throws IOException {
createIndexWithDenseVector("no_dense_vectors");
int numDocs = randomIntBetween(10, 100);
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
for (int i = 0; i < numDocs; i++) {
docs[i] = prepareIndex("no_dense_vectors").setId("" + i).setSource("id", String.valueOf(i));
}
indexRandom(true, docs);
var query = """
FROM no_dense_vectors
| KEEP id, vector
""";
try (var resp = run(query)) {
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
assertEquals(numDocs, valuesList.size());
valuesList.forEach(value -> {
assertEquals(2, value.size());
Integer id = (Integer) value.get(0);
assertNotNull(id);
Object vector = value.get(1);
assertNull(vector);
});
}
}
@Before
public void setup() throws IOException {
assumeTrue("Dense vector type is disabled", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled());
var indexName = "test";
createIndexWithDenseVector("test");
int numDims = randomIntBetween(32, 64) * 2; // min 64, even number
int numDocs = randomIntBetween(10, 100);
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
for (int i = 0; i < numDocs; i++) {
List<Float> vector = new ArrayList<>(numDims);
for (int j = 0; j < numDims; j++) {
vector.add(randomFloat());
}
docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "vector", vector);
indexedVectors.put(i, vector);
}
indexRandom(true, docs);
}
private void createIndexWithDenseVector(String indexName) throws IOException {
var client = client().admin().indices();
XContentBuilder mapping = XContentFactory.jsonBuilder()
.startObject()
@ -161,19 +208,5 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase {
.setMapping(mapping)
.setSettings(settingsBuilder.build());
assertAcked(CreateRequest);
int numDims = randomIntBetween(32, 64) * 2; // min 64, even number
int numDocs = randomIntBetween(10, 100);
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
for (int i = 0; i < numDocs; i++) {
List<Float> vector = new ArrayList<>(numDims);
for (int j = 0; j < numDims; j++) {
vector.add(randomFloat());
}
docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "vector", vector);
indexedVectors.put(i, vector);
}
indexRandom(true, docs);
}
}

View file

@ -0,0 +1,156 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.plugin;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xpack.esql.EsqlTestUtils;
import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase;
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
import org.junit.Before;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
public class KnnFunctionIT extends AbstractEsqlIntegTestCase {
private final Map<Integer, List<Float>> indexedVectors = new HashMap<>();
private int numDocs;
private int numDims;
public void testKnnDefaults() {
float[] queryVector = new float[numDims];
Arrays.fill(queryVector, 1.0f);
var query = String.format(Locale.ROOT, """
FROM test METADATA _score
| WHERE knn(vector, %s)
| KEEP id, floats, _score, vector
| SORT _score DESC
""", Arrays.toString(queryVector));
try (var resp = run(query)) {
assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector"));
assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector"));
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
assertEquals(Math.min(indexedVectors.size(), 10), valuesList.size());
for (int i = 0; i < valuesList.size(); i++) {
List<Object> row = valuesList.get(i);
// Vectors should be in order of ID, as they're less similar than the query vector as the ID increases
assertEquals(i, row.getFirst());
@SuppressWarnings("unchecked")
// Vectors should be the same
List<Double> floats = (List<Double>) row.get(1);
for (int j = 0; j < floats.size(); j++) {
assertEquals(floats.get(j).floatValue(), indexedVectors.get(i).get(j), 0f);
}
var score = (Double) row.get(2);
assertNotNull(score);
assertTrue(score > 0.0);
}
}
}
public void testKnnOptions() {
float[] queryVector = new float[numDims];
Arrays.fill(queryVector, 1.0f);
var query = String.format(Locale.ROOT, """
FROM test METADATA _score
| WHERE knn(vector, %s, {"k": 5})
| KEEP id, floats, _score, vector
| SORT _score DESC
""", Arrays.toString(queryVector));
try (var resp = run(query)) {
assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector"));
assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector"));
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
assertEquals(5, valuesList.size());
}
}
public void testKnnNonPushedDown() {
float[] queryVector = new float[numDims];
Arrays.fill(queryVector, 1.0f);
// TODO we need to decide what to do when / if user uses k for limit, as no more than k results will be returned from knn query
var query = String.format(Locale.ROOT, """
FROM test METADATA _score
| WHERE knn(vector, %s, {"k": 5}) OR id > 10
| KEEP id, floats, _score, vector
| SORT _score DESC
""", Arrays.toString(queryVector));
try (var resp = run(query)) {
assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector"));
assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector"));
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
// K = 5, 1 more for every id > 10
assertEquals(5 + Math.max(0, numDocs - 10 - 1), valuesList.size());
}
}
@Before
public void setup() throws IOException {
assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled());
var indexName = "test";
var client = client().admin().indices();
XContentBuilder mapping = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject("id")
.field("type", "integer")
.endObject()
.startObject("vector")
.field("type", "dense_vector")
.field("similarity", "l2_norm")
.endObject()
.startObject("floats")
.field("type", "float")
.endObject()
.endObject()
.endObject();
Settings.Builder settingsBuilder = Settings.builder()
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1);
var createRequest = client.prepareCreate(indexName).setMapping(mapping).setSettings(settingsBuilder.build());
assertAcked(createRequest);
numDocs = randomIntBetween(10, 20);
numDims = randomIntBetween(3, 10);
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
float value = 0.0f;
for (int i = 0; i < numDocs; i++) {
List<Float> vector = new ArrayList<>(numDims);
for (int j = 0; j < numDims; j++) {
vector.add(value++);
}
docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "floats", vector, "vector", vector);
indexedVectors.put(i, vector);
}
indexRandom(true, docs);
}
}

View file

@ -1185,7 +1185,12 @@ public class EsqlCapabilities {
/**
* MATCH PHRASE function
*/
MATCH_PHRASE_FUNCTION;
MATCH_PHRASE_FUNCTION,
/**
* Support knn function
*/
KNN_FUNCTION(Build.current().isSnapshot());
private final boolean enabled;

View file

@ -66,6 +66,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToUnsignedLong;
import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce;
import org.elasticsearch.xpack.esql.expression.function.vector.VectorFunction;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.DateTimeArithmeticOperation;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
@ -1396,6 +1397,9 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
if (f instanceof EsqlArithmeticOperation || f instanceof BinaryComparison) {
return processBinaryOperator((BinaryOperator) f);
}
if (f instanceof VectorFunction vectorFunction) {
return processVectorFunction(f);
}
return f;
}
@ -1595,6 +1599,25 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
return unresolvedAttribute(from, target.toString(), e);
}
}
private static Expression processVectorFunction(org.elasticsearch.xpack.esql.core.expression.function.Function vectorFunction) {
List<Expression> args = vectorFunction.arguments();
List<Expression> newArgs = new ArrayList<>();
for (Expression arg : args) {
if (arg.resolved() && arg.dataType().isNumeric() && arg.foldable()) {
Object folded = arg.fold(FoldContext.small() /* TODO remove me */);
if (folded instanceof List) {
Literal denseVector = new Literal(arg.source(), folded, DataType.DENSE_VECTOR);
newArgs.add(denseVector);
continue;
}
}
newArgs.add(arg);
}
return vectorFunction.replaceChildren(newArgs);
}
}
/**

View file

@ -8,6 +8,7 @@
package org.elasticsearch.xpack.esql.expression;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
import org.elasticsearch.xpack.esql.core.expression.ExpressionCoreWritables;
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables;
@ -83,6 +84,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.string.Trim;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.RLike;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.WildcardLike;
import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay;
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
import org.elasticsearch.xpack.esql.expression.predicate.logical.Not;
import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNotNull;
import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNull;
@ -115,6 +117,7 @@ public class ExpressionWritables {
entries.addAll(binaryComparisons());
entries.addAll(fullText());
entries.addAll(unaryScalars());
entries.addAll(vector());
return entries;
}
@ -252,4 +255,11 @@ public class ExpressionWritables {
private static List<NamedWriteableRegistry.Entry> fullText() {
return FullTextWritables.getNamedWriteables();
}
private static List<NamedWriteableRegistry.Entry> vector() {
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
return List.of(Knn.ENTRY);
}
return List.of();
}
}

View file

@ -179,6 +179,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.string.ToLower;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.ToUpper;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Trim;
import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay;
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
import org.elasticsearch.xpack.esql.parser.ParsingException;
import org.elasticsearch.xpack.esql.session.Configuration;
@ -485,7 +486,8 @@ public class EsqlFunctionRegistry {
def(LastOverTime.class, LastOverTime::withUnresolvedTimestamp, "last_over_time"),
def(FirstOverTime.class, FirstOverTime::withUnresolvedTimestamp, "first_over_time"),
def(Term.class, bi(Term::new), "term"),
def(MatchPhrase.class, tri(MatchPhrase::new), "match_phrase") } };
def(MatchPhrase.class, tri(MatchPhrase::new), "match_phrase"),
def(Knn.class, tri(Knn::new), "knn") } };
}
public EsqlFunctionRegistry snapshotRegistry() {

View file

@ -148,7 +148,7 @@ public abstract class FullTextFunction extends Function
@Override
public int hashCode() {
return Objects.hash(super.hashCode(), queryBuilder);
return Objects.hash(super.hashCode(), query, queryBuilder);
}
@Override
@ -157,7 +157,7 @@ public abstract class FullTextFunction extends Function
return false;
}
return Objects.equals(queryBuilder, ((FullTextFunction) obj).queryBuilder);
return Objects.equals(queryBuilder, ((FullTextFunction) obj).queryBuilder) && Objects.equals(query, ((FullTextFunction) obj).query);
}
@Override

View file

@ -0,0 +1,292 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.expression.function.vector;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.xpack.esql.core.InvalidArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.MapExpression;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.querydsl.query.Query;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.util.Check;
import org.elasticsearch.xpack.esql.expression.function.Example;
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo;
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.MapParam;
import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction;
import org.elasticsearch.xpack.esql.expression.function.fulltext.Match;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
import org.elasticsearch.xpack.esql.querydsl.query.KnnQuery;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import static java.util.Map.entry;
import static org.elasticsearch.index.query.AbstractQueryBuilder.BOOST_FIELD;
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD;
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD;
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isMapExpression;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT;
import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER;
import static org.elasticsearch.xpack.esql.expression.function.fulltext.Match.getNameFromFieldAttribute;
public class Knn extends FullTextFunction implements OptionalArgument, VectorFunction {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", Knn::readFrom);
private final Expression field;
private final Expression options;
public static final Map<String, DataType> ALLOWED_OPTIONS = Map.ofEntries(
entry(K_FIELD.getPreferredName(), INTEGER),
entry(NUM_CANDS_FIELD.getPreferredName(), INTEGER),
entry(VECTOR_SIMILARITY_FIELD.getPreferredName(), FLOAT),
entry(BOOST_FIELD.getPreferredName(), FLOAT),
entry(KnnQuery.RESCORE_OVERSAMPLE_FIELD, FLOAT)
);
@FunctionInfo(
returnType = "boolean",
preview = true,
description = "Finds the k nearest vectors to a query vector, as measured by a similarity metric. "
+ "knn function finds nearest vectors through approximate search on indexed dense_vectors.",
examples = {
@Example(file = "knn-function", tag = "knn-function"),
@Example(file = "knn-function", tag = "knn-function-options"), },
appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) }
)
public Knn(
Source source,
@Param(name = "field", type = { "dense_vector" }, description = "Field that the query will target.") Expression field,
@Param(
name = "query",
type = { "dense_vector" },
description = "Vector value to find top nearest neighbours for."
) Expression query,
@MapParam(
name = "options",
params = {
@MapParam.MapParamEntry(
name = "boost",
type = "float",
valueHint = { "2.5" },
description = "Floating point number used to decrease or increase the relevance scores of the query."
+ "Defaults to 1.0."
),
@MapParam.MapParamEntry(
name = "k",
type = "integer",
valueHint = { "10" },
description = "The number of nearest neighbors to return from each shard. "
+ "Elasticsearch collects k results from each shard, then merges them to find the global top results. "
+ "This value must be less than or equal to num_candidates. Defaults to 10."
),
@MapParam.MapParamEntry(
name = "num_candidates",
type = "integer",
valueHint = { "10" },
description = "The number of nearest neighbor candidates to consider per shard while doing knn search. "
+ "Cannot exceed 10,000. Increasing num_candidates tends to improve the accuracy of the final results. "
+ "Defaults to 1.5 * k"
),
@MapParam.MapParamEntry(
name = "similarity",
type = "double",
valueHint = { "0.01" },
description = "The minimum similarity required for a document to be considered a match. "
+ "The similarity value calculated relates to the raw similarity used, not the document score."
),
@MapParam.MapParamEntry(
name = "rescore_oversample",
type = "double",
valueHint = { "3.5" },
description = "Applies the specified oversampling for rescoring quantized vectors. "
+ "See [oversampling and rescoring quantized vectors]"
+ "(docs-content://solutions/search/vector/knn.md#dense-vector-knn-search-rescoring) for details."
), },
description = "(Optional) kNN additional options as <<esql-function-named-params,function named parameters>>."
+ " See <<query-dsl-knn-query,knn query>> for more information.",
optional = true
) Expression options
) {
this(source, field, query, options, null);
}
private Knn(Source source, Expression field, Expression query, Expression options, QueryBuilder queryBuilder) {
super(source, query, options == null ? List.of(field, query) : List.of(field, query, options), queryBuilder);
this.field = field;
this.options = options;
}
public Expression field() {
return field;
}
public Expression options() {
return options;
}
@Override
public DataType dataType() {
return DataType.BOOLEAN;
}
@Override
protected TypeResolution resolveParams() {
return resolveField().and(resolveQuery()).and(resolveOptions());
}
private TypeResolution resolveField() {
return isNotNull(field(), sourceText(), FIRST).and(isType(field(), dt -> dt == DENSE_VECTOR, sourceText(), FIRST, "dense_vector"));
}
private TypeResolution resolveQuery() {
return isType(query(), dt -> dt == DENSE_VECTOR, sourceText(), TypeResolutions.ParamOrdinal.SECOND, "dense_vector").and(
isNotNullAndFoldable(query(), sourceText(), SECOND)
);
}
private TypeResolution resolveOptions() {
if (options() != null) {
TypeResolution resolution = isNotNull(options(), sourceText(), THIRD);
if (resolution.unresolved()) {
return resolution;
}
// MapExpression does not have a DataType associated with it
resolution = isMapExpression(options(), sourceText(), THIRD);
if (resolution.unresolved()) {
return resolution;
}
try {
knnQueryOptions();
} catch (InvalidArgumentException e) {
return new TypeResolution(e.getMessage());
}
}
return TypeResolution.TYPE_RESOLVED;
}
private Map<String, Object> knnQueryOptions() throws InvalidArgumentException {
if (options() == null) {
return Map.of();
}
Map<String, Object> matchOptions = new HashMap<>();
populateOptionsMap((MapExpression) options(), matchOptions, THIRD, sourceText(), ALLOWED_OPTIONS);
return matchOptions;
}
@Override
protected Query translate(TranslatorHandler handler) {
var fieldAttribute = Match.fieldAsFieldAttribute(field());
Check.notNull(fieldAttribute, "Match must have a field attribute as the first argument");
String fieldName = getNameFromFieldAttribute(fieldAttribute);
@SuppressWarnings("unchecked")
List<Number> queryFolded = (List<Number>) query().fold(FoldContext.small() /* TODO remove me */);
float[] queryAsFloats = new float[queryFolded.size()];
for (int i = 0; i < queryFolded.size(); i++) {
queryAsFloats[i] = queryFolded.get(i).floatValue();
}
return new KnnQuery(source(), fieldName, queryAsFloats, queryOptions());
}
@Override
public Expression replaceQueryBuilder(QueryBuilder queryBuilder) {
return new Knn(source(), field(), query(), options(), queryBuilder);
}
private Map<String, Object> queryOptions() throws InvalidArgumentException {
if (options() == null) {
return Map.of();
}
Map<String, Object> options = new HashMap<>();
populateOptionsMap((MapExpression) options(), options, THIRD, sourceText(), ALLOWED_OPTIONS);
return options;
}
@Override
public Expression replaceChildren(List<Expression> newChildren) {
return new Knn(
source(),
newChildren.get(0),
newChildren.get(1),
newChildren.size() > 2 ? newChildren.get(2) : null,
queryBuilder()
);
}
@Override
protected NodeInfo<? extends Expression> info() {
return NodeInfo.create(this, Knn::new, field(), query(), options());
}
@Override
public String getWriteableName() {
return ENTRY.name;
}
private static Knn readFrom(StreamInput in) throws IOException {
Source source = Source.readFrom((PlanStreamInput) in);
Expression field = in.readNamedWriteable(Expression.class);
Expression query = in.readNamedWriteable(Expression.class);
QueryBuilder queryBuilder = in.readOptionalNamedWriteable(QueryBuilder.class);
return new Knn(source, field, query, null, queryBuilder);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
source().writeTo(out);
out.writeNamedWriteable(field());
out.writeNamedWriteable(query());
out.writeOptionalNamedWriteable(queryBuilder());
}
@Override
public boolean equals(Object o) {
// Knn does not serialize options, as they get included in the query builder. We need to override equals and hashcode to
// ignore options when comparing two Knn functions
if (o == null || getClass() != o.getClass()) return false;
Knn knn = (Knn) o;
return Objects.equals(field(), knn.field())
&& Objects.equals(query(), knn.query())
&& Objects.equals(queryBuilder(), knn.queryBuilder());
}
@Override
public int hashCode() {
return Objects.hash(field(), query(), queryBuilder());
}
}

View file

@ -0,0 +1,15 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.expression.function.vector;
/**
* Marker interface for vector functions. Makes possible to do implicit casting
* from multi values to dense_vector field types, so parameters are actually
* processed as dense_vectors in vector functions
*/
public interface VectorFunction {}

View file

@ -0,0 +1,84 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.querydsl.query;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.search.vectors.RescoreVectorBuilder;
import org.elasticsearch.xpack.esql.core.querydsl.query.Query;
import org.elasticsearch.xpack.esql.core.tree.Source;
import java.util.Arrays;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.index.query.AbstractQueryBuilder.BOOST_FIELD;
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD;
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD;
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD;
public class KnnQuery extends Query {
private final String field;
private final float[] query;
private final Map<String, Object> options;
public static final String RESCORE_OVERSAMPLE_FIELD = "rescore_oversample";
public KnnQuery(Source source, String field, float[] query, Map<String, Object> options) {
super(source);
assert options != null;
this.field = field;
this.query = query;
this.options = options;
}
@Override
protected QueryBuilder asBuilder() {
Integer k = (Integer) options.get(K_FIELD.getPreferredName());
Integer numCands = (Integer) options.get(NUM_CANDS_FIELD.getPreferredName());
RescoreVectorBuilder rescoreVectorBuilder = null;
Float oversample = (Float) options.get(RESCORE_OVERSAMPLE_FIELD);
if (oversample != null) {
rescoreVectorBuilder = new RescoreVectorBuilder(oversample);
}
Float vectorSimilarity = (Float) options.get(VECTOR_SIMILARITY_FIELD.getPreferredName());
KnnVectorQueryBuilder queryBuilder = new KnnVectorQueryBuilder(field, query, k, numCands, rescoreVectorBuilder, vectorSimilarity);
Number boost = (Number) options.get(BOOST_FIELD.getPreferredName());
if (boost != null) {
queryBuilder.boost(boost.floatValue());
}
return queryBuilder;
}
@Override
protected String innerToString() {
return "knn(" + field + ", " + Arrays.toString(query) + " options={" + options + "}))";
}
@Override
public boolean equals(Object o) {
if (super.equals(o) == false) return false;
KnnQuery knnQuery = (KnnQuery) o;
return Objects.equals(field, knnQuery.field)
&& Objects.deepEquals(query, knnQuery.query)
&& Objects.equals(options, knnQuery.options);
}
@Override
public int hashCode() {
return Objects.hash(super.hashCode(), field, Arrays.hashCode(query), options);
}
@Override
public boolean scorable() {
return true;
}
}

View file

@ -296,6 +296,10 @@ public class CsvTests extends ESTestCase {
"can't use KQL function in csv tests",
testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KQL_FUNCTION.capabilityName())
);
assumeFalse(
"can't use KNN function in csv tests",
testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION.capabilityName())
);
assumeFalse(
"lookup join disabled for csv tests",
testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.JOIN_LOOKUP_V12.capabilityName())

View file

@ -24,6 +24,7 @@ import org.elasticsearch.index.query.RegexpQueryBuilder;
import org.elasticsearch.index.query.TermQueryBuilder;
import org.elasticsearch.index.query.TermsQueryBuilder;
import org.elasticsearch.index.query.WildcardQueryBuilder;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.test.EqualsHashCodeTestUtils;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.expression.ExpressionWritables;
@ -111,6 +112,7 @@ public class SerializationTestUtils {
entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, WildcardQueryBuilder.NAME, WildcardQueryBuilder::new));
entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, RegexpQueryBuilder.NAME, RegexpQueryBuilder::new));
entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, ExistsQueryBuilder.NAME, ExistsQueryBuilder::new));
entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, KnnVectorQueryBuilder.NAME, KnnVectorQueryBuilder::new));
entries.add(SingleValueQuery.ENTRY);
entries.addAll(ExpressionWritables.getNamedWriteables());
entries.addAll(PlanWritables.getNamedWriteables());

View file

@ -55,6 +55,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat;
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Substring;
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan;
@ -2363,6 +2364,22 @@ public class AnalyzerTests extends ESTestCase {
assertThat(e.getMessage(), containsString("[+] has arguments with incompatible types [datetime] and [datetime]"));
}
public void testDenseVectorImplicitCasting() {
Analyzer analyzer = analyzer(loadMapping("mapping-dense_vector.json", "vectors"));
var plan = analyze("""
from test | where knn(vector, [0.342, 0.164, 0.234])
""", "mapping-dense_vector.json");
var limit = as(plan, Limit.class);
var filter = as(limit.child(), Filter.class);
var knn = as(filter.condition(), Knn.class);
var field = knn.field();
var queryVector = as(knn.query(), Literal.class);
assertEquals(DataType.DENSE_VECTOR, queryVector.dataType());
assertThat(queryVector.value(), equalTo(List.of(0.342, 0.164, 0.234)));
}
public void testRateRequiresCounterTypes() {
assumeTrue("rate requires snapshot builds", Build.current().isSnapshot());
Analyzer analyzer = analyzer(tsdbIndexResolution());

View file

@ -22,6 +22,7 @@ import org.elasticsearch.xpack.esql.expression.function.fulltext.Match;
import org.elasticsearch.xpack.esql.expression.function.fulltext.MatchPhrase;
import org.elasticsearch.xpack.esql.expression.function.fulltext.MultiMatch;
import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString;
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
import org.elasticsearch.xpack.esql.index.EsIndex;
import org.elasticsearch.xpack.esql.index.IndexResolution;
import org.elasticsearch.xpack.esql.parser.EsqlParser;
@ -1234,6 +1235,9 @@ public class VerifierTests extends ESTestCase {
checkFieldBasedWithNonIndexedColumn("Term", "term(text, \"cat\")", "function");
checkFieldBasedFunctionNotAllowedAfterCommands("Term", "function", "term(title, \"Meditation\")");
}
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
checkFieldBasedFunctionNotAllowedAfterCommands("KNN", "function", "knn(vector, [1, 2, 3])");
}
}
private void checkFieldBasedFunctionNotAllowedAfterCommands(String functionName, String functionType, String functionInvocation) {
@ -1364,6 +1368,9 @@ public class VerifierTests extends ESTestCase {
if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) {
checkFullTextFunctionsOnlyAllowedInWhere("MultiMatch", "multi_match(\"Meditation\", title, body)", "function");
}
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
checkFullTextFunctionsOnlyAllowedInWhere("KNN", "knn(vector, [0, 1, 2])", "function");
}
}
private void checkFullTextFunctionsOnlyAllowedInWhere(String functionName, String functionInvocation, String functionType)
@ -1400,6 +1407,9 @@ public class VerifierTests extends ESTestCase {
if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) {
checkWithFullTextFunctionsDisjunctions("term(title, \"Meditation\")");
}
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
checkWithFullTextFunctionsDisjunctions("knn(vector, [1, 2, 3])");
}
}
private void checkWithFullTextFunctionsDisjunctions(String functionInvocation) {
@ -1462,6 +1472,9 @@ public class VerifierTests extends ESTestCase {
if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) {
checkFullTextFunctionsWithNonBooleanFunctions("Term", "term(title, \"Meditation\")", "function");
}
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
checkFullTextFunctionsWithNonBooleanFunctions("KNN", "knn(vector, [1, 2, 3])", "function");
}
}
private void checkFullTextFunctionsWithNonBooleanFunctions(String functionName, String functionInvocation, String functionType) {
@ -1530,6 +1543,9 @@ public class VerifierTests extends ESTestCase {
if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) {
testFullTextFunctionTargetsExistingField("term(fist_name, \"Meditation\")");
}
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
testFullTextFunctionTargetsExistingField("knn(vector, [0, 1, 2])");
}
}
private void testFullTextFunctionTargetsExistingField(String functionInvocation) throws Exception {
@ -2055,6 +2071,9 @@ public class VerifierTests extends ESTestCase {
if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) {
checkOptionDataTypes(MultiMatch.OPTIONS, "FROM test | WHERE MULTI_MATCH(\"Jean\", title, body, {\"%s\": %s})");
}
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
checkOptionDataTypes(Knn.ALLOWED_OPTIONS, "FROM test | WHERE KNN(vector, [0.1, 0.2, 0.3], {\"%s\": %s})");
}
}
/**
@ -2140,6 +2159,10 @@ public class VerifierTests extends ESTestCase {
checkFullTextFunctionNullArgs("term(null, \"query\")", "first");
checkFullTextFunctionNullArgs("term(title, null)", "second");
}
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
checkFullTextFunctionNullArgs("knn(null, [0, 1, 2])", "first");
checkFullTextFunctionNullArgs("knn(vector, null)", "second");
}
}
private void checkFullTextFunctionNullArgs(String functionInvocation, String argOrdinal) throws Exception {
@ -2161,6 +2184,9 @@ public class VerifierTests extends ESTestCase {
if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) {
checkFullTextFunctionsConstantQuery("term(title, tags)", "second");
}
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
checkFullTextFunctionsConstantQuery("knn(vector, vector)", "second");
}
}
private void checkFullTextFunctionsConstantQuery(String functionInvocation, String argOrdinal) throws Exception {
@ -2188,6 +2214,9 @@ public class VerifierTests extends ESTestCase {
if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) {
checkFullTextFunctionsInStats("multi_match(\"Meditation\", title, body)");
}
if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) {
checkFullTextFunctionsInStats("knn(vector, [0, 1, 2])");
}
}
private void checkFullTextFunctionsInStats(String functionInvocation) {

View file

@ -0,0 +1,132 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.expression.function.fulltext;
import com.carrotsearch.randomizedtesting.annotations.Name;
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.MapExpression;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase;
import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
import org.junit.Before;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Supplier;
import static org.elasticsearch.xpack.esql.SerializationTestUtils.serializeDeserialize;
import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN;
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
import static org.elasticsearch.xpack.esql.core.type.DataType.KEYWORD;
import static org.elasticsearch.xpack.esql.core.type.DataType.UNSUPPORTED;
import static org.elasticsearch.xpack.esql.planner.TranslatorHandler.TRANSLATOR_HANDLER;
import static org.hamcrest.Matchers.equalTo;
public class KnnTests extends AbstractFunctionTestCase {
public KnnTests(@Name("TestCase") Supplier<TestCaseSupplier.TestCase> testCaseSupplier) {
this.testCase = testCaseSupplier.get();
}
@ParametersFactory
public static Iterable<Object[]> parameters() {
return parameterSuppliersFromTypedData(addFunctionNamedParams(testCaseSuppliers()));
}
@Before
public void checkCapability() {
assumeTrue("KNN is not enabled", EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled());
}
private static List<TestCaseSupplier> testCaseSuppliers() {
List<TestCaseSupplier> suppliers = new ArrayList<>();
suppliers.add(
TestCaseSupplier.testCaseSupplier(
new TestCaseSupplier.TypedDataSupplier("dense_vector field", KnnTests::randomDenseVector, DENSE_VECTOR),
new TestCaseSupplier.TypedDataSupplier("query", KnnTests::randomDenseVector, DENSE_VECTOR, true),
(d1, d2) -> equalTo("string"),
BOOLEAN,
(o1, o2) -> true
)
);
return suppliers;
}
private static List<Float> randomDenseVector() {
int dimensions = randomIntBetween(64, 128);
List<Float> vector = new ArrayList<>();
for (int i = 0; i < dimensions; i++) {
vector.add(randomFloat());
}
return vector;
}
/**
* Adds function named parameters to all the test case suppliers provided
*/
private static List<TestCaseSupplier> addFunctionNamedParams(List<TestCaseSupplier> suppliers) {
// TODO get to a common class with MatchTests
List<TestCaseSupplier> result = new ArrayList<>();
for (TestCaseSupplier supplier : suppliers) {
List<DataType> dataTypes = new ArrayList<>(supplier.types());
dataTypes.add(UNSUPPORTED);
result.add(new TestCaseSupplier(supplier.name() + ", options", dataTypes, () -> {
List<TestCaseSupplier.TypedData> values = new ArrayList<>(supplier.get().getData());
values.add(
new TestCaseSupplier.TypedData(
new MapExpression(Source.EMPTY, List.of(new Literal(Source.EMPTY, randomAlphaOfLength(10), KEYWORD))),
UNSUPPORTED,
"options"
).forceLiteral()
);
return new TestCaseSupplier.TestCase(values, equalTo("KnnEvaluator"), BOOLEAN, equalTo(true));
}));
}
return result;
}
@Override
protected Expression build(Source source, List<Expression> args) {
Knn knn = new Knn(source, args.get(0), args.get(1), args.size() > 2 ? args.get(2) : null);
// We need to add the QueryBuilder to the match expression, as it is used to implement equals() and hashCode() and
// thus test the serialization methods. But we can only do this if the parameters make sense .
if (args.get(0) instanceof FieldAttribute && args.get(1).foldable()) {
QueryBuilder queryBuilder = TRANSLATOR_HANDLER.asQuery(LucenePushdownPredicates.DEFAULT, knn).toQueryBuilder();
knn = (Knn) knn.replaceQueryBuilder(queryBuilder);
}
return knn;
}
/**
* Copy of the overridden method that doesn't check for children size, as the {@code options} child isn't serialized in Match.
*/
@Override
protected Expression serializeDeserializeExpression(Expression expression) {
Expression newExpression = serializeDeserialize(
expression,
PlanStreamOutput::writeNamedWriteable,
in -> in.readNamedWriteable(Expression.class),
testCase.getConfiguration() // The configuration query should be == to the source text of the function for this to work
);
// Fields use synthetic sources, which can't be serialized. So we use the originals instead.
return newExpression.replaceChildren(expression.children());
}
}

View file

@ -82,7 +82,7 @@ public class MatchTests extends AbstractMatchFullTextFunctionTests {
// thus test the serialization methods. But we can only do this if the parameters make sense .
if (args.get(0) instanceof FieldAttribute && args.get(1).foldable()) {
QueryBuilder queryBuilder = TRANSLATOR_HANDLER.asQuery(LucenePushdownPredicates.DEFAULT, match).toQueryBuilder();
match.replaceQueryBuilder(queryBuilder);
match = (Match) match.replaceQueryBuilder(queryBuilder);
}
return match;
}

View file

@ -28,6 +28,8 @@ import org.elasticsearch.index.query.QueryStringQueryBuilder;
import org.elasticsearch.index.query.RangeQueryBuilder;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.search.vectors.RescoreVectorBuilder;
import org.elasticsearch.test.VersionUtils;
import org.elasticsearch.xpack.core.enrich.EnrichPolicy;
import org.elasticsearch.xpack.esql.EsqlTestUtils;
@ -1359,6 +1361,29 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase {
assertThat(expectedQuery.toString(), is(planStr.get()));
}
public void testKnnOptionsPushDown() {
String query = """
from test
| where KNN(dense_vector, [0.1, 0.2, 0.3],
{ "k": 5, "similarity": 0.001, "num_candidates": 10, "rescore_oversample": 7, "boost": 3.5 })
""";
var analyzer = makeAnalyzer("mapping-all-types.json");
var plan = plannerOptimizer.plan(query, IS_SV_STATS, analyzer);
AtomicReference<String> planStr = new AtomicReference<>();
plan.forEachDown(EsQueryExec.class, result -> planStr.set(result.query().toString()));
var expectedQuery = new KnnVectorQueryBuilder(
"dense_vector",
new float[] { 0.1f, 0.2f, 0.3f },
5,
10,
new RescoreVectorBuilder(7),
0.001f
).boost(3.5f);
assertThat(expectedQuery.toString(), is(planStr.get()));
}
/**
* Expecting
* LimitExec[1000[INTEGER]]