diff --git a/docs/reference/query-languages/esql/images/functions/knn.svg b/docs/reference/query-languages/esql/images/functions/knn.svg new file mode 100644 index 000000000000..75a104a7cdcf --- /dev/null +++ b/docs/reference/query-languages/esql/images/functions/knn.svg @@ -0,0 +1 @@ +KNN(field,query,options) \ No newline at end of file diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/knn.json b/docs/reference/query-languages/esql/kibana/definition/functions/knn.json new file mode 100644 index 000000000000..48d3e582eec5 --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/definition/functions/knn.json @@ -0,0 +1,13 @@ +{ + "comment" : "This is generated by ESQL’s AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.", + "type" : "scalar", + "name" : "knn", + "description" : "Finds the k nearest vectors to a query vector, as measured by a similarity metric. knn function finds nearest vectors through approximate search on indexed dense_vectors.", + "signatures" : [ ], + "examples" : [ + "from colors metadata _score\n| where knn(rgb_vector, [0, 120, 0])\n| sort _score desc", + "from colors metadata _score\n| where knn(rgb_vector, [0,255,255], {\"k\": 4})\n| sort _score desc" + ], + "preview" : true, + "snapshot_only" : true +} diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/knn.md b/docs/reference/query-languages/esql/kibana/docs/functions/knn.md new file mode 100644 index 000000000000..45d1f294ea0a --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/docs/functions/knn.md @@ -0,0 +1,10 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +### KNN +Finds the k nearest vectors to a query vector, as measured by a similarity metric. knn function finds nearest vectors through approximate search on indexed dense_vectors. + +```esql +from colors metadata _score +| where knn(rgb_vector, [0, 120, 0]) +| sort _score desc +``` diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 40c69ec6e7fd..9360d13c7833 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -2589,6 +2589,11 @@ public class DenseVectorFieldMapper extends FieldMapper { return null; } + if (dims == null) { + // No data has been indexed yet + return BlockLoader.CONSTANT_NULLS; + } + if (indexed) { return new BlockDocValuesReader.DenseVectorBlockLoader(name(), dims); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java index 4644dd31f204..d91df60621fc 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java @@ -112,7 +112,7 @@ public abstract class LuceneQueryEvaluator implements int min = docs.docs().getInt(0); int max = docs.docs().getInt(docs.getPositionCount() - 1); int length = max - min + 1; - try (T scoreBuilder = createVectorBuilder(blockFactory, length)) { + try (T scoreBuilder = createVectorBuilder(blockFactory, docs.getPositionCount())) { if (length == docs.getPositionCount() && length > 1) { return segmentState.scoreDense(scoreBuilder, min, max); } diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestEsqlTestCase.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestEsqlTestCase.java index ce13a655966c..293200b22791 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestEsqlTestCase.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestEsqlTestCase.java @@ -1022,7 +1022,7 @@ public abstract class RestEsqlTestCase extends ESRestTestCase { var query = requestObjectBuilder().query(format(null, "from * | lookup join {} on integer {}", testIndexName(), sort)); Map result = runEsql(query); var columns = as(result.get("columns"), List.class); - assertEquals(21, columns.size()); + assertEquals(22, columns.size()); var values = as(result.get("values"), List.class); assertEquals(10, values.size()); } diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java index 2d06431806ab..f1eb32afbdfe 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java @@ -148,6 +148,7 @@ public class CsvTestsDataLoader { private static final TestDataset LOGS = new TestDataset("logs"); private static final TestDataset MV_TEXT = new TestDataset("mv_text"); private static final TestDataset DENSE_VECTOR = new TestDataset("dense_vector"); + private static final TestDataset COLORS = new TestDataset("colors"); public static final Map CSV_DATASET_MAP = Map.ofEntries( Map.entry(EMPLOYEES.indexName, EMPLOYEES), @@ -210,7 +211,8 @@ public class CsvTestsDataLoader { Map.entry(SEMANTIC_TEXT.indexName, SEMANTIC_TEXT), Map.entry(LOGS.indexName, LOGS), Map.entry(MV_TEXT.indexName, MV_TEXT), - Map.entry(DENSE_VECTOR.indexName, DENSE_VECTOR) + Map.entry(DENSE_VECTOR.indexName, DENSE_VECTOR), + Map.entry(COLORS.indexName, COLORS) ); private static final EnrichConfig LANGUAGES_ENRICH = new EnrichConfig("languages_policy", "enrich-policy-languages.json"); diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/data/colors.csv b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/data/colors.csv new file mode 100644 index 000000000000..b82ef7087a54 --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/data/colors.csv @@ -0,0 +1,60 @@ +color:text,hex_code:keyword,rgb_vector:dense_vector,primary:boolean +maroon, #800000, [128,0,0], false +brown, #A52A2A, [165,42,42], false +firebrick, #B22222, [178,34,34], false +crimson, #DC143C, [220,20,60], false +red, #FF0000, [255,0,0], true +tomato, #FF6347, [255,99,71], false +coral, #FF7F50, [255,127,80], false +salmon, #FA8072, [250,128,114], false +orange, #FFA500, [255,165,0], false +gold, #FFD700, [255,215,0], false +golden rod, #DAA520, [218,165,32], false +khaki, #F0E68C, [240,230,140], false +olive, #808000, [128,128,0], false +yellow, #FFFF00, [255,255,0], true +chartreuse, #7FFF00, [127,255,0], false +green, #008000, [0,128,0], true +lime, #00FF00, [0,255,0], false +teal, #008080, [0,128,128], false +cyan, #00FFFF, [0,255,255], true +turquoise, #40E0D0, [64,224,208], false +aqua marine, #7FFFD4, [127,255,212], false +navy, #000080, [0,0,128], false +blue, #0000FF, [0,0,255], true +indigo, #4B0082, [75,0,130], false +purple, #800080, [128,0,128], false +thistle, #D8BFD8, [216,191,216], false +plum, #DDA0DD, [221,160,221], false +violet, #EE82EE, [238,130,238], false +magenta, #FF00FF, [255,0,255], true +orchid, #DA70D6, [218,112,214], false +pink, #FFC0CB, [255,192,203], false +beige, #F5F5DC, [245,245,220], false +bisque, #FFE4C4, [255,228,196], false +wheat, #F5DEB3, [245,222,179], false +corn silk, #FFF8DC, [255,248,220], false +lemon chiffon, #FFFACD, [255,250,205], false +sienna, #A0522D, [160,82,45], false +chocolate, #D2691E, [210,105,30], false +peru, #CD853F, [205,133,63], false +burly wood, #DEB887, [222,184,135], false +tan, #D2B48C, [210,180,140], false +moccasin, #FFE4B5, [255,228,181], false +peach puff, #FFDAB9, [255,218,185], false +misty rose, #FFE4E1, [255,228,225], false +linen, #FAF0E6, [250,240,230], false +old lace, #FDF5E6, [253,245,230], false +papaya whip, #FFEFD5, [255,239,213], false +sea shell, #FFF5EE, [255,245,238], false +mint cream, #F5FFFA, [245,255,250], false +lavender, #E6E6FA, [230,230,250], false +honeydew, #F0FFF0, [240,255,240], false +ivory, #FFFFF0, [255,255,240], false +azure, #F0FFFF, [240,255,255], false +snow, #FFFAFA, [255,250,250], false +black, #000000, [0,0,0], true +gray, #808080, [128,128,128], true +silver, #C0C0C0, [192,192,192], false +gainsboro, #DCDCDC, [220,220,220], false +white, #FFFFFF, [255,255,255], true diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec new file mode 100644 index 000000000000..5e65e6269e65 --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec @@ -0,0 +1,285 @@ +# TODO Most tests explicitly set k. Until knn function uses LIMIT as k, we need to explicitly set it to all values +# in the dataset to avoid test failures due to docs allocation in different shards, which can impact results for a +# top-n query at the shard level + +knnSearch +required_capability: knn_function + +// tag::knn-function[] +from colors metadata _score +| where knn(rgb_vector, [0, 120, 0]) +| sort _score desc, color asc +// end::knn-function[] +| keep color, rgb_vector +| limit 10 +; + +// tag::knn-function-result[] +color:text | rgb_vector:dense_vector +green | [0.0, 128.0, 0.0] +black | [0.0, 0.0, 0.0] +olive | [128.0, 128.0, 0.0] +teal | [0.0, 128.0, 128.0] +lime | [0.0, 255.0, 0.0] +sienna | [160.0, 82.0, 45.0] +maroon | [128.0, 0.0, 0.0] +navy | [0.0, 0.0, 128.0] +gray | [128.0, 128.0, 128.0] +chartreuse | [127.0, 255.0, 0.0] +// end::knn-function-result[] +; + +knnSearchWithKOption +required_capability: knn_function + +// tag::knn-function-options[] +from colors metadata _score +| where knn(rgb_vector, [0,255,255], {"k": 4}) +| sort _score desc, color asc +// end::knn-function-options[] +| keep color, rgb_vector +| limit 4 +; + +color:text | rgb_vector:dense_vector +cyan | [0.0, 255.0, 255.0] +turquoise | [64.0, 224.0, 208.0] +aqua marine | [127.0, 255.0, 212.0] +teal | [0.0, 128.0, 128.0] +; + +knnSearchWithSimilarityOption +required_capability: knn_function + +from colors metadata _score +| where knn(rgb_vector, [255,192,203], {"k": 140, "similarity": 40}) +| sort _score desc, color asc +| keep color, rgb_vector +; + +color:text | rgb_vector:dense_vector +pink | [255.0, 192.0, 203.0] +peach puff | [255.0, 218.0, 185.0] +bisque | [255.0, 228.0, 196.0] +wheat | [245.0, 222.0, 179.0] + +; + +knnHybridSearch +required_capability: knn_function + +from colors metadata _score +| where match(color, "blue") or knn(rgb_vector, [65,105,225], {"k": 140}) +| where primary == true +| sort _score desc, color asc +| keep color, rgb_vector +| limit 10 +; + +color:text | rgb_vector:dense_vector +blue | [0.0, 0.0, 255.0] +gray | [128.0, 128.0, 128.0] +cyan | [0.0, 255.0, 255.0] +magenta | [255.0, 0.0, 255.0] +green | [0.0, 128.0, 0.0] +white | [255.0, 255.0, 255.0] +black | [0.0, 0.0, 0.0] +red | [255.0, 0.0, 0.0] +yellow | [255.0, 255.0, 0.0] +; + +knnWithMultipleFunctions +required_capability: knn_function + +from colors metadata _score +| where knn(rgb_vector, [128,128,0], {"k": 140}) and match(color, "olive") +| sort _score desc, color asc +| keep color, rgb_vector +; + +color:text | rgb_vector:dense_vector +olive | [128.0, 128.0, 0.0] +; + +knnAfterKeep +required_capability: knn_function + +from colors metadata _score +| keep rgb_vector, color, _score +| where knn(rgb_vector, [128,255,0], {"k": 140}) +| sort _score desc, color asc +| keep rgb_vector +| limit 5 +; + +rgb_vector:dense_vector +[127.0, 255.0, 0.0] +[128.0, 128.0, 0.0] +[255.0, 255.0, 0.0] +[0.0, 255.0, 0.0] +[218.0, 165.0, 32.0] +; + +knnAfterDrop +required_capability: knn_function + +from colors metadata _score +| drop primary +| where knn(rgb_vector, [128,250,0], {"k": 140}) +| sort _score desc, color asc +| keep color, rgb_vector +| limit 5 +; + +color:text | rgb_vector: dense_vector +chartreuse | [127.0, 255.0, 0.0] +olive | [128.0, 128.0, 0.0] +yellow | [255.0, 255.0, 0.0] +golden rod | [218.0, 165.0, 32.0] +lime | [0.0, 255.0, 0.0] +; + +knnAfterEval +required_capability: knn_function + +from colors metadata _score +| eval composed_name = locate(color, " ") > 0 +| where knn(rgb_vector, [128,128,0], {"k": 140}) +| sort _score desc, color asc +| keep color, composed_name +| limit 5 +; + +color:text | composed_name:boolean +olive | false +sienna | false +chocolate | false +peru | false +golden rod | true +; + +knnWithConjunction +required_capability: knn_function + +# TODO We need kNN prefiltering here so we get more candidates that pass the filter +from colors metadata _score +| where knn(rgb_vector, [255,255,238], {"k": 140}) and hex_code like "#FFF*" +| sort _score desc, color asc +| keep color, hex_code, rgb_vector +| limit 10 +; + +color:text | hex_code:keyword | rgb_vector:dense_vector +ivory | #FFFFF0 | [255.0, 255.0, 240.0] +sea shell | #FFF5EE | [255.0, 245.0, 238.0] +snow | #FFFAFA | [255.0, 250.0, 250.0] +white | #FFFFFF | [255.0, 255.0, 255.0] +corn silk | #FFF8DC | [255.0, 248.0, 220.0] +lemon chiffon | #FFFACD | [255.0, 250.0, 205.0] +yellow | #FFFF00 | [255.0, 255.0, 0.0] +; + +knnWithDisjunctionAndFiltersConjunction +required_capability: knn_function + +# TODO We need kNN prefiltering here so we get more candidates that pass the filter +from colors metadata _score +| where (knn(rgb_vector, [0,255,255], {"k": 140}) or knn(rgb_vector, [128, 0, 255], {"k": 140})) and primary == true +| keep color, rgb_vector, _score +| sort _score desc, color asc +| drop _score +| limit 10 +; + +color:text | rgb_vector:dense_vector +cyan | [0.0, 255.0, 255.0] +blue | [0.0, 0.0, 255.0] +magenta | [255.0, 0.0, 255.0] +gray | [128.0, 128.0, 128.0] +white | [255.0, 255.0, 255.0] +green | [0.0, 128.0, 0.0] +black | [0.0, 0.0, 0.0] +red | [255.0, 0.0, 0.0] +yellow | [255.0, 255.0, 0.0] +; + +knnWithNonPushableConjunction +required_capability: knn_function + +from colors metadata _score +| eval composed_name = locate(color, " ") > 0 +| where knn(rgb_vector, [128,128,0], {"k": 140}) and composed_name == false +| sort _score desc, color asc +| keep color, composed_name +| limit 10 +; + +color:text | composed_name:boolean +olive | false +sienna | false +chocolate | false +peru | false +brown | false +firebrick | false +chartreuse | false +gray | false +green | false +maroon | false +; + +testKnnWithNonPushableDisjunctions +required_capability: knn_function + +from colors metadata _score +| where knn(rgb_vector, [128,128,0], {"k": 140, "similarity": 30}) or length(color) > 10 +| sort _score desc, color asc +| keep color +; + +color:text +olive +aqua marine +lemon chiffon +papaya whip +; + +testKnnWithNonPushableDisjunctionsOnComplexExpressions +required_capability: knn_function + +from colors metadata _score +| where (knn(rgb_vector, [128,128,0], {"k": 140, "similarity": 70}) and length(color) < 10) or (knn(rgb_vector, [128,0,128], {"k": 140, "similarity": 60}) and primary == false) +| sort _score desc, color asc +| keep color, primary +; + +color:text | primary:boolean +olive | false +purple | false +indigo | false +; + +testKnnInStatsNonPushable +required_capability: knn_function + +from colors +| where length(color) < 10 +| stats c = count(*) where knn(rgb_vector, [128,128,255], {"k": 140}) +; + +c: long +50 +; + +testKnnInStatsWithGrouping +required_capability: knn_function +required_capability: full_text_functions_in_stats_where + +from colors +| where length(color) < 10 +| stats c = count(*) where knn(rgb_vector, [128,128,255], {"k": 140}) by primary +; + +c: long | primary: boolean +41 | false +9 | true +; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-all-types.json b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-all-types.json index 17348adb6af4..a7ef2f484070 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-all-types.json +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-all-types.json @@ -63,6 +63,9 @@ "semantic_text": { "type": "semantic_text", "inference_id": "foo_inference_id" + }, + "dense_vector": { + "type": "dense_vector" } } } diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-colors.json b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-colors.json new file mode 100644 index 000000000000..24c4102e428f --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-colors.json @@ -0,0 +1,20 @@ +{ + "properties": { + "color": { + "type": "text" + }, + "hex_code": { + "type": "keyword" + }, + "primary": { + "type": "boolean" + }, + "rgb_vector": { + "type": "dense_vector", + "similarity": "l2_norm", + "index_options": { + "type": "hnsw" + } + } + } +} diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java index 905cf413fb48..a130b026cd88 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java @@ -128,10 +128,57 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase { } } + public void testNonIndexedDenseVectorField() throws IOException { + createIndexWithDenseVector("no_dense_vectors"); + + int numDocs = randomIntBetween(10, 100); + IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; + for (int i = 0; i < numDocs; i++) { + docs[i] = prepareIndex("no_dense_vectors").setId("" + i).setSource("id", String.valueOf(i)); + } + + indexRandom(true, docs); + + var query = """ + FROM no_dense_vectors + | KEEP id, vector + """; + + try (var resp = run(query)) { + List> valuesList = EsqlTestUtils.getValuesList(resp); + assertEquals(numDocs, valuesList.size()); + valuesList.forEach(value -> { + assertEquals(2, value.size()); + Integer id = (Integer) value.get(0); + assertNotNull(id); + Object vector = value.get(1); + assertNull(vector); + }); + } + } + @Before public void setup() throws IOException { assumeTrue("Dense vector type is disabled", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled()); - var indexName = "test"; + + createIndexWithDenseVector("test"); + + int numDims = randomIntBetween(32, 64) * 2; // min 64, even number + int numDocs = randomIntBetween(10, 100); + IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; + for (int i = 0; i < numDocs; i++) { + List vector = new ArrayList<>(numDims); + for (int j = 0; j < numDims; j++) { + vector.add(randomFloat()); + } + docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "vector", vector); + indexedVectors.put(i, vector); + } + + indexRandom(true, docs); + } + + private void createIndexWithDenseVector(String indexName) throws IOException { var client = client().admin().indices(); XContentBuilder mapping = XContentFactory.jsonBuilder() .startObject() @@ -161,19 +208,5 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase { .setMapping(mapping) .setSettings(settingsBuilder.build()); assertAcked(CreateRequest); - - int numDims = randomIntBetween(32, 64) * 2; // min 64, even number - int numDocs = randomIntBetween(10, 100); - IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; - for (int i = 0; i < numDocs; i++) { - List vector = new ArrayList<>(numDims); - for (int j = 0; j < numDims; j++) { - vector.add(randomFloat()); - } - docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "vector", vector); - indexedVectors.put(i, vector); - } - - indexRandom(true, docs); } } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java new file mode 100644 index 000000000000..a26294390993 --- /dev/null +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java @@ -0,0 +1,156 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plugin; + +import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xpack.esql.EsqlTestUtils; +import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase; +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.junit.Before; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; + +public class KnnFunctionIT extends AbstractEsqlIntegTestCase { + + private final Map> indexedVectors = new HashMap<>(); + private int numDocs; + private int numDims; + + public void testKnnDefaults() { + float[] queryVector = new float[numDims]; + Arrays.fill(queryVector, 1.0f); + + var query = String.format(Locale.ROOT, """ + FROM test METADATA _score + | WHERE knn(vector, %s) + | KEEP id, floats, _score, vector + | SORT _score DESC + """, Arrays.toString(queryVector)); + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector")); + + List> valuesList = EsqlTestUtils.getValuesList(resp); + assertEquals(Math.min(indexedVectors.size(), 10), valuesList.size()); + for (int i = 0; i < valuesList.size(); i++) { + List row = valuesList.get(i); + // Vectors should be in order of ID, as they're less similar than the query vector as the ID increases + assertEquals(i, row.getFirst()); + @SuppressWarnings("unchecked") + // Vectors should be the same + List floats = (List) row.get(1); + for (int j = 0; j < floats.size(); j++) { + assertEquals(floats.get(j).floatValue(), indexedVectors.get(i).get(j), 0f); + } + var score = (Double) row.get(2); + assertNotNull(score); + assertTrue(score > 0.0); + } + } + } + + public void testKnnOptions() { + float[] queryVector = new float[numDims]; + Arrays.fill(queryVector, 1.0f); + + var query = String.format(Locale.ROOT, """ + FROM test METADATA _score + | WHERE knn(vector, %s, {"k": 5}) + | KEEP id, floats, _score, vector + | SORT _score DESC + """, Arrays.toString(queryVector)); + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector")); + + List> valuesList = EsqlTestUtils.getValuesList(resp); + assertEquals(5, valuesList.size()); + } + } + + public void testKnnNonPushedDown() { + float[] queryVector = new float[numDims]; + Arrays.fill(queryVector, 1.0f); + + // TODO we need to decide what to do when / if user uses k for limit, as no more than k results will be returned from knn query + var query = String.format(Locale.ROOT, """ + FROM test METADATA _score + | WHERE knn(vector, %s, {"k": 5}) OR id > 10 + | KEEP id, floats, _score, vector + | SORT _score DESC + """, Arrays.toString(queryVector)); + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector")); + assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector")); + + List> valuesList = EsqlTestUtils.getValuesList(resp); + // K = 5, 1 more for every id > 10 + assertEquals(5 + Math.max(0, numDocs - 10 - 1), valuesList.size()); + } + } + + @Before + public void setup() throws IOException { + assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()); + + var indexName = "test"; + var client = client().admin().indices(); + XContentBuilder mapping = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject("id") + .field("type", "integer") + .endObject() + .startObject("vector") + .field("type", "dense_vector") + .field("similarity", "l2_norm") + .endObject() + .startObject("floats") + .field("type", "float") + .endObject() + .endObject() + .endObject(); + + Settings.Builder settingsBuilder = Settings.builder() + .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0) + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1); + + var createRequest = client.prepareCreate(indexName).setMapping(mapping).setSettings(settingsBuilder.build()); + assertAcked(createRequest); + + numDocs = randomIntBetween(10, 20); + numDims = randomIntBetween(3, 10); + IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs]; + float value = 0.0f; + for (int i = 0; i < numDocs; i++) { + List vector = new ArrayList<>(numDims); + for (int j = 0; j < numDims; j++) { + vector.add(value++); + } + docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "floats", vector, "vector", vector); + indexedVectors.put(i, vector); + } + + indexRandom(true, docs); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 77fb1b077b3a..d25f5b9e81d5 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1185,7 +1185,12 @@ public class EsqlCapabilities { /** * MATCH PHRASE function */ - MATCH_PHRASE_FUNCTION; + MATCH_PHRASE_FUNCTION, + + /** + * Support knn function + */ + KNN_FUNCTION(Build.current().isSnapshot()); private final boolean enabled; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index 259062cd14f5..494e801eb1f7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -66,6 +66,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToUnsignedLong; import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce; +import org.elasticsearch.xpack.esql.expression.function.vector.VectorFunction; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.DateTimeArithmeticOperation; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In; @@ -1396,6 +1397,9 @@ public class Analyzer extends ParameterizedRuleExecutor args = vectorFunction.arguments(); + List newArgs = new ArrayList<>(); + for (Expression arg : args) { + if (arg.resolved() && arg.dataType().isNumeric() && arg.foldable()) { + Object folded = arg.fold(FoldContext.small() /* TODO remove me */); + if (folded instanceof List) { + Literal denseVector = new Literal(arg.source(), folded, DataType.DENSE_VECTOR); + newArgs.add(denseVector); + continue; + } + } + newArgs.add(arg); + } + + return vectorFunction.replaceChildren(newArgs); + } + } /** diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java index 2608ed96eb10..b0856a3090ba 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.expression; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; import org.elasticsearch.xpack.esql.core.expression.ExpressionCoreWritables; import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables; @@ -83,6 +84,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.string.Trim; import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.RLike; import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.WildcardLike; import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay; +import org.elasticsearch.xpack.esql.expression.function.vector.Knn; import org.elasticsearch.xpack.esql.expression.predicate.logical.Not; import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNotNull; import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNull; @@ -115,6 +117,7 @@ public class ExpressionWritables { entries.addAll(binaryComparisons()); entries.addAll(fullText()); entries.addAll(unaryScalars()); + entries.addAll(vector()); return entries; } @@ -252,4 +255,11 @@ public class ExpressionWritables { private static List fullText() { return FullTextWritables.getNamedWriteables(); } + + private static List vector() { + if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { + return List.of(Knn.ENTRY); + } + return List.of(); + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index 5d6bccc77ca5..ecc9dbd27132 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -179,6 +179,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.string.ToLower; import org.elasticsearch.xpack.esql.expression.function.scalar.string.ToUpper; import org.elasticsearch.xpack.esql.expression.function.scalar.string.Trim; import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay; +import org.elasticsearch.xpack.esql.expression.function.vector.Knn; import org.elasticsearch.xpack.esql.parser.ParsingException; import org.elasticsearch.xpack.esql.session.Configuration; @@ -485,7 +486,8 @@ public class EsqlFunctionRegistry { def(LastOverTime.class, LastOverTime::withUnresolvedTimestamp, "last_over_time"), def(FirstOverTime.class, FirstOverTime::withUnresolvedTimestamp, "first_over_time"), def(Term.class, bi(Term::new), "term"), - def(MatchPhrase.class, tri(MatchPhrase::new), "match_phrase") } }; + def(MatchPhrase.class, tri(MatchPhrase::new), "match_phrase"), + def(Knn.class, tri(Knn::new), "knn") } }; } public EsqlFunctionRegistry snapshotRegistry() { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java index bf7564ee40db..b7bf29276511 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java @@ -148,7 +148,7 @@ public abstract class FullTextFunction extends Function @Override public int hashCode() { - return Objects.hash(super.hashCode(), queryBuilder); + return Objects.hash(super.hashCode(), query, queryBuilder); } @Override @@ -157,7 +157,7 @@ public abstract class FullTextFunction extends Function return false; } - return Objects.equals(queryBuilder, ((FullTextFunction) obj).queryBuilder); + return Objects.equals(queryBuilder, ((FullTextFunction) obj).queryBuilder) && Objects.equals(query, ((FullTextFunction) obj).query); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java new file mode 100644 index 000000000000..ecce0b069693 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java @@ -0,0 +1,292 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.vector; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.xpack.esql.core.InvalidArgumentException; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.MapExpression; +import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; +import org.elasticsearch.xpack.esql.core.querydsl.query.Query; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.util.Check; +import org.elasticsearch.xpack.esql.expression.function.Example; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.MapParam; +import org.elasticsearch.xpack.esql.expression.function.OptionalArgument; +import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction; +import org.elasticsearch.xpack.esql.expression.function.fulltext.Match; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.planner.TranslatorHandler; +import org.elasticsearch.xpack.esql.querydsl.query.KnnQuery; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static java.util.Map.entry; +import static org.elasticsearch.index.query.AbstractQueryBuilder.BOOST_FIELD; +import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD; +import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD; +import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.THIRD; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isMapExpression; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; +import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT; +import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; +import static org.elasticsearch.xpack.esql.expression.function.fulltext.Match.getNameFromFieldAttribute; + +public class Knn extends FullTextFunction implements OptionalArgument, VectorFunction { + + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", Knn::readFrom); + + private final Expression field; + private final Expression options; + + public static final Map ALLOWED_OPTIONS = Map.ofEntries( + entry(K_FIELD.getPreferredName(), INTEGER), + entry(NUM_CANDS_FIELD.getPreferredName(), INTEGER), + entry(VECTOR_SIMILARITY_FIELD.getPreferredName(), FLOAT), + entry(BOOST_FIELD.getPreferredName(), FLOAT), + entry(KnnQuery.RESCORE_OVERSAMPLE_FIELD, FLOAT) + ); + + @FunctionInfo( + returnType = "boolean", + preview = true, + description = "Finds the k nearest vectors to a query vector, as measured by a similarity metric. " + + "knn function finds nearest vectors through approximate search on indexed dense_vectors.", + examples = { + @Example(file = "knn-function", tag = "knn-function"), + @Example(file = "knn-function", tag = "knn-function-options"), }, + appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) } + ) + public Knn( + Source source, + @Param(name = "field", type = { "dense_vector" }, description = "Field that the query will target.") Expression field, + @Param( + name = "query", + type = { "dense_vector" }, + description = "Vector value to find top nearest neighbours for." + ) Expression query, + @MapParam( + name = "options", + params = { + @MapParam.MapParamEntry( + name = "boost", + type = "float", + valueHint = { "2.5" }, + description = "Floating point number used to decrease or increase the relevance scores of the query." + + "Defaults to 1.0." + ), + @MapParam.MapParamEntry( + name = "k", + type = "integer", + valueHint = { "10" }, + description = "The number of nearest neighbors to return from each shard. " + + "Elasticsearch collects k results from each shard, then merges them to find the global top results. " + + "This value must be less than or equal to num_candidates. Defaults to 10." + ), + @MapParam.MapParamEntry( + name = "num_candidates", + type = "integer", + valueHint = { "10" }, + description = "The number of nearest neighbor candidates to consider per shard while doing knn search. " + + "Cannot exceed 10,000. Increasing num_candidates tends to improve the accuracy of the final results. " + + "Defaults to 1.5 * k" + ), + @MapParam.MapParamEntry( + name = "similarity", + type = "double", + valueHint = { "0.01" }, + description = "The minimum similarity required for a document to be considered a match. " + + "The similarity value calculated relates to the raw similarity used, not the document score." + ), + @MapParam.MapParamEntry( + name = "rescore_oversample", + type = "double", + valueHint = { "3.5" }, + description = "Applies the specified oversampling for rescoring quantized vectors. " + + "See [oversampling and rescoring quantized vectors]" + + "(docs-content://solutions/search/vector/knn.md#dense-vector-knn-search-rescoring) for details." + ), }, + description = "(Optional) kNN additional options as <>." + + " See <> for more information.", + optional = true + ) Expression options + ) { + this(source, field, query, options, null); + } + + private Knn(Source source, Expression field, Expression query, Expression options, QueryBuilder queryBuilder) { + super(source, query, options == null ? List.of(field, query) : List.of(field, query, options), queryBuilder); + this.field = field; + this.options = options; + } + + public Expression field() { + return field; + } + + public Expression options() { + return options; + } + + @Override + public DataType dataType() { + return DataType.BOOLEAN; + } + + @Override + protected TypeResolution resolveParams() { + return resolveField().and(resolveQuery()).and(resolveOptions()); + } + + private TypeResolution resolveField() { + return isNotNull(field(), sourceText(), FIRST).and(isType(field(), dt -> dt == DENSE_VECTOR, sourceText(), FIRST, "dense_vector")); + } + + private TypeResolution resolveQuery() { + return isType(query(), dt -> dt == DENSE_VECTOR, sourceText(), TypeResolutions.ParamOrdinal.SECOND, "dense_vector").and( + isNotNullAndFoldable(query(), sourceText(), SECOND) + ); + } + + private TypeResolution resolveOptions() { + if (options() != null) { + TypeResolution resolution = isNotNull(options(), sourceText(), THIRD); + if (resolution.unresolved()) { + return resolution; + } + // MapExpression does not have a DataType associated with it + resolution = isMapExpression(options(), sourceText(), THIRD); + if (resolution.unresolved()) { + return resolution; + } + + try { + knnQueryOptions(); + } catch (InvalidArgumentException e) { + return new TypeResolution(e.getMessage()); + } + } + return TypeResolution.TYPE_RESOLVED; + } + + private Map knnQueryOptions() throws InvalidArgumentException { + if (options() == null) { + return Map.of(); + } + + Map matchOptions = new HashMap<>(); + populateOptionsMap((MapExpression) options(), matchOptions, THIRD, sourceText(), ALLOWED_OPTIONS); + return matchOptions; + } + + @Override + protected Query translate(TranslatorHandler handler) { + var fieldAttribute = Match.fieldAsFieldAttribute(field()); + + Check.notNull(fieldAttribute, "Match must have a field attribute as the first argument"); + String fieldName = getNameFromFieldAttribute(fieldAttribute); + @SuppressWarnings("unchecked") + List queryFolded = (List) query().fold(FoldContext.small() /* TODO remove me */); + float[] queryAsFloats = new float[queryFolded.size()]; + for (int i = 0; i < queryFolded.size(); i++) { + queryAsFloats[i] = queryFolded.get(i).floatValue(); + } + + return new KnnQuery(source(), fieldName, queryAsFloats, queryOptions()); + } + + @Override + public Expression replaceQueryBuilder(QueryBuilder queryBuilder) { + return new Knn(source(), field(), query(), options(), queryBuilder); + } + + private Map queryOptions() throws InvalidArgumentException { + if (options() == null) { + return Map.of(); + } + + Map options = new HashMap<>(); + populateOptionsMap((MapExpression) options(), options, THIRD, sourceText(), ALLOWED_OPTIONS); + return options; + } + + @Override + public Expression replaceChildren(List newChildren) { + return new Knn( + source(), + newChildren.get(0), + newChildren.get(1), + newChildren.size() > 2 ? newChildren.get(2) : null, + queryBuilder() + ); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, Knn::new, field(), query(), options()); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + private static Knn readFrom(StreamInput in) throws IOException { + Source source = Source.readFrom((PlanStreamInput) in); + Expression field = in.readNamedWriteable(Expression.class); + Expression query = in.readNamedWriteable(Expression.class); + QueryBuilder queryBuilder = in.readOptionalNamedWriteable(QueryBuilder.class); + + return new Knn(source, field, query, null, queryBuilder); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + source().writeTo(out); + out.writeNamedWriteable(field()); + out.writeNamedWriteable(query()); + out.writeOptionalNamedWriteable(queryBuilder()); + } + + @Override + public boolean equals(Object o) { + // Knn does not serialize options, as they get included in the query builder. We need to override equals and hashcode to + // ignore options when comparing two Knn functions + if (o == null || getClass() != o.getClass()) return false; + Knn knn = (Knn) o; + return Objects.equals(field(), knn.field()) + && Objects.equals(query(), knn.query()) + && Objects.equals(queryBuilder(), knn.queryBuilder()); + } + + @Override + public int hashCode() { + return Objects.hash(field(), query(), queryBuilder()); + } + +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorFunction.java new file mode 100644 index 000000000000..dc0be7a29fee --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorFunction.java @@ -0,0 +1,15 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.vector; + +/** + * Marker interface for vector functions. Makes possible to do implicit casting + * from multi values to dense_vector field types, so parameters are actually + * processed as dense_vectors in vector functions + */ +public interface VectorFunction {} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java new file mode 100644 index 000000000000..aa0e896dfc01 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java @@ -0,0 +1,84 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.querydsl.query; + +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; +import org.elasticsearch.search.vectors.RescoreVectorBuilder; +import org.elasticsearch.xpack.esql.core.querydsl.query.Query; +import org.elasticsearch.xpack.esql.core.tree.Source; + +import java.util.Arrays; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.index.query.AbstractQueryBuilder.BOOST_FIELD; +import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD; +import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD; +import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.VECTOR_SIMILARITY_FIELD; + +public class KnnQuery extends Query { + + private final String field; + private final float[] query; + private final Map options; + + public static final String RESCORE_OVERSAMPLE_FIELD = "rescore_oversample"; + + public KnnQuery(Source source, String field, float[] query, Map options) { + super(source); + assert options != null; + this.field = field; + this.query = query; + this.options = options; + } + + @Override + protected QueryBuilder asBuilder() { + Integer k = (Integer) options.get(K_FIELD.getPreferredName()); + Integer numCands = (Integer) options.get(NUM_CANDS_FIELD.getPreferredName()); + RescoreVectorBuilder rescoreVectorBuilder = null; + Float oversample = (Float) options.get(RESCORE_OVERSAMPLE_FIELD); + if (oversample != null) { + rescoreVectorBuilder = new RescoreVectorBuilder(oversample); + } + Float vectorSimilarity = (Float) options.get(VECTOR_SIMILARITY_FIELD.getPreferredName()); + + KnnVectorQueryBuilder queryBuilder = new KnnVectorQueryBuilder(field, query, k, numCands, rescoreVectorBuilder, vectorSimilarity); + Number boost = (Number) options.get(BOOST_FIELD.getPreferredName()); + if (boost != null) { + queryBuilder.boost(boost.floatValue()); + } + return queryBuilder; + } + + @Override + protected String innerToString() { + return "knn(" + field + ", " + Arrays.toString(query) + " options={" + options + "}))"; + } + + @Override + public boolean equals(Object o) { + if (super.equals(o) == false) return false; + + KnnQuery knnQuery = (KnnQuery) o; + return Objects.equals(field, knnQuery.field) + && Objects.deepEquals(query, knnQuery.query) + && Objects.equals(options, knnQuery.options); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), field, Arrays.hashCode(query), options); + } + + @Override + public boolean scorable() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java index 58860a163d94..0ab1b9c41896 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java @@ -296,6 +296,10 @@ public class CsvTests extends ESTestCase { "can't use KQL function in csv tests", testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KQL_FUNCTION.capabilityName()) ); + assumeFalse( + "can't use KNN function in csv tests", + testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.KNN_FUNCTION.capabilityName()) + ); assumeFalse( "lookup join disabled for csv tests", testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.JOIN_LOOKUP_V12.capabilityName()) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/SerializationTestUtils.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/SerializationTestUtils.java index 8e396e4753f0..e55a1b039258 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/SerializationTestUtils.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/SerializationTestUtils.java @@ -24,6 +24,7 @@ import org.elasticsearch.index.query.RegexpQueryBuilder; import org.elasticsearch.index.query.TermQueryBuilder; import org.elasticsearch.index.query.TermsQueryBuilder; import org.elasticsearch.index.query.WildcardQueryBuilder; +import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; import org.elasticsearch.test.EqualsHashCodeTestUtils; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.expression.ExpressionWritables; @@ -111,6 +112,7 @@ public class SerializationTestUtils { entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, WildcardQueryBuilder.NAME, WildcardQueryBuilder::new)); entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, RegexpQueryBuilder.NAME, RegexpQueryBuilder::new)); entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, ExistsQueryBuilder.NAME, ExistsQueryBuilder::new)); + entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, KnnVectorQueryBuilder.NAME, KnnVectorQueryBuilder::new)); entries.add(SingleValueQuery.ENTRY); entries.addAll(ExpressionWritables.getNamedWriteables()); entries.addAll(PlanWritables.getNamedWriteables()); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index d53a392650f8..19fb563488e6 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -55,6 +55,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong; import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat; import org.elasticsearch.xpack.esql.expression.function.scalar.string.Substring; +import org.elasticsearch.xpack.esql.expression.function.vector.Knn; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; @@ -2363,6 +2364,22 @@ public class AnalyzerTests extends ESTestCase { assertThat(e.getMessage(), containsString("[+] has arguments with incompatible types [datetime] and [datetime]")); } + public void testDenseVectorImplicitCasting() { + Analyzer analyzer = analyzer(loadMapping("mapping-dense_vector.json", "vectors")); + + var plan = analyze(""" + from test | where knn(vector, [0.342, 0.164, 0.234]) + """, "mapping-dense_vector.json"); + + var limit = as(plan, Limit.class); + var filter = as(limit.child(), Filter.class); + var knn = as(filter.condition(), Knn.class); + var field = knn.field(); + var queryVector = as(knn.query(), Literal.class); + assertEquals(DataType.DENSE_VECTOR, queryVector.dataType()); + assertThat(queryVector.value(), equalTo(List.of(0.342, 0.164, 0.234))); + } + public void testRateRequiresCounterTypes() { assumeTrue("rate requires snapshot builds", Build.current().isSnapshot()); Analyzer analyzer = analyzer(tsdbIndexResolution()); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index b32a74cd66f4..14e5c0615e2b 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.xpack.esql.expression.function.fulltext.Match; import org.elasticsearch.xpack.esql.expression.function.fulltext.MatchPhrase; import org.elasticsearch.xpack.esql.expression.function.fulltext.MultiMatch; import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString; +import org.elasticsearch.xpack.esql.expression.function.vector.Knn; import org.elasticsearch.xpack.esql.index.EsIndex; import org.elasticsearch.xpack.esql.index.IndexResolution; import org.elasticsearch.xpack.esql.parser.EsqlParser; @@ -1234,6 +1235,9 @@ public class VerifierTests extends ESTestCase { checkFieldBasedWithNonIndexedColumn("Term", "term(text, \"cat\")", "function"); checkFieldBasedFunctionNotAllowedAfterCommands("Term", "function", "term(title, \"Meditation\")"); } + if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { + checkFieldBasedFunctionNotAllowedAfterCommands("KNN", "function", "knn(vector, [1, 2, 3])"); + } } private void checkFieldBasedFunctionNotAllowedAfterCommands(String functionName, String functionType, String functionInvocation) { @@ -1364,6 +1368,9 @@ public class VerifierTests extends ESTestCase { if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) { checkFullTextFunctionsOnlyAllowedInWhere("MultiMatch", "multi_match(\"Meditation\", title, body)", "function"); } + if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { + checkFullTextFunctionsOnlyAllowedInWhere("KNN", "knn(vector, [0, 1, 2])", "function"); + } } private void checkFullTextFunctionsOnlyAllowedInWhere(String functionName, String functionInvocation, String functionType) @@ -1400,6 +1407,9 @@ public class VerifierTests extends ESTestCase { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { checkWithFullTextFunctionsDisjunctions("term(title, \"Meditation\")"); } + if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { + checkWithFullTextFunctionsDisjunctions("knn(vector, [1, 2, 3])"); + } } private void checkWithFullTextFunctionsDisjunctions(String functionInvocation) { @@ -1462,6 +1472,9 @@ public class VerifierTests extends ESTestCase { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { checkFullTextFunctionsWithNonBooleanFunctions("Term", "term(title, \"Meditation\")", "function"); } + if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { + checkFullTextFunctionsWithNonBooleanFunctions("KNN", "knn(vector, [1, 2, 3])", "function"); + } } private void checkFullTextFunctionsWithNonBooleanFunctions(String functionName, String functionInvocation, String functionType) { @@ -1530,6 +1543,9 @@ public class VerifierTests extends ESTestCase { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { testFullTextFunctionTargetsExistingField("term(fist_name, \"Meditation\")"); } + if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { + testFullTextFunctionTargetsExistingField("knn(vector, [0, 1, 2])"); + } } private void testFullTextFunctionTargetsExistingField(String functionInvocation) throws Exception { @@ -2055,6 +2071,9 @@ public class VerifierTests extends ESTestCase { if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) { checkOptionDataTypes(MultiMatch.OPTIONS, "FROM test | WHERE MULTI_MATCH(\"Jean\", title, body, {\"%s\": %s})"); } + if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { + checkOptionDataTypes(Knn.ALLOWED_OPTIONS, "FROM test | WHERE KNN(vector, [0.1, 0.2, 0.3], {\"%s\": %s})"); + } } /** @@ -2140,6 +2159,10 @@ public class VerifierTests extends ESTestCase { checkFullTextFunctionNullArgs("term(null, \"query\")", "first"); checkFullTextFunctionNullArgs("term(title, null)", "second"); } + if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { + checkFullTextFunctionNullArgs("knn(null, [0, 1, 2])", "first"); + checkFullTextFunctionNullArgs("knn(vector, null)", "second"); + } } private void checkFullTextFunctionNullArgs(String functionInvocation, String argOrdinal) throws Exception { @@ -2161,6 +2184,9 @@ public class VerifierTests extends ESTestCase { if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { checkFullTextFunctionsConstantQuery("term(title, tags)", "second"); } + if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { + checkFullTextFunctionsConstantQuery("knn(vector, vector)", "second"); + } } private void checkFullTextFunctionsConstantQuery(String functionInvocation, String argOrdinal) throws Exception { @@ -2188,6 +2214,9 @@ public class VerifierTests extends ESTestCase { if (EsqlCapabilities.Cap.MULTI_MATCH_FUNCTION.isEnabled()) { checkFullTextFunctionsInStats("multi_match(\"Meditation\", title, body)"); } + if (EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()) { + checkFullTextFunctionsInStats("knn(vector, [0, 1, 2])"); + } } private void checkFullTextFunctionsInStats(String functionInvocation) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java new file mode 100644 index 000000000000..c2bc381e2663 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/KnnTests.java @@ -0,0 +1,132 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.fulltext; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.MapExpression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.elasticsearch.xpack.esql.expression.function.vector.Knn; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; +import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.esql.SerializationTestUtils.serializeDeserialize; +import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN; +import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; +import static org.elasticsearch.xpack.esql.core.type.DataType.KEYWORD; +import static org.elasticsearch.xpack.esql.core.type.DataType.UNSUPPORTED; +import static org.elasticsearch.xpack.esql.planner.TranslatorHandler.TRANSLATOR_HANDLER; +import static org.hamcrest.Matchers.equalTo; + +public class KnnTests extends AbstractFunctionTestCase { + + public KnnTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + return parameterSuppliersFromTypedData(addFunctionNamedParams(testCaseSuppliers())); + } + + @Before + public void checkCapability() { + assumeTrue("KNN is not enabled", EsqlCapabilities.Cap.KNN_FUNCTION.isEnabled()); + } + + private static List testCaseSuppliers() { + List suppliers = new ArrayList<>(); + + suppliers.add( + TestCaseSupplier.testCaseSupplier( + new TestCaseSupplier.TypedDataSupplier("dense_vector field", KnnTests::randomDenseVector, DENSE_VECTOR), + new TestCaseSupplier.TypedDataSupplier("query", KnnTests::randomDenseVector, DENSE_VECTOR, true), + (d1, d2) -> equalTo("string"), + BOOLEAN, + (o1, o2) -> true + ) + ); + + return suppliers; + } + + private static List randomDenseVector() { + int dimensions = randomIntBetween(64, 128); + List vector = new ArrayList<>(); + for (int i = 0; i < dimensions; i++) { + vector.add(randomFloat()); + } + return vector; + } + + /** + * Adds function named parameters to all the test case suppliers provided + */ + private static List addFunctionNamedParams(List suppliers) { + // TODO get to a common class with MatchTests + List result = new ArrayList<>(); + for (TestCaseSupplier supplier : suppliers) { + List dataTypes = new ArrayList<>(supplier.types()); + dataTypes.add(UNSUPPORTED); + result.add(new TestCaseSupplier(supplier.name() + ", options", dataTypes, () -> { + List values = new ArrayList<>(supplier.get().getData()); + values.add( + new TestCaseSupplier.TypedData( + new MapExpression(Source.EMPTY, List.of(new Literal(Source.EMPTY, randomAlphaOfLength(10), KEYWORD))), + UNSUPPORTED, + "options" + ).forceLiteral() + ); + + return new TestCaseSupplier.TestCase(values, equalTo("KnnEvaluator"), BOOLEAN, equalTo(true)); + })); + } + return result; + } + + @Override + protected Expression build(Source source, List args) { + Knn knn = new Knn(source, args.get(0), args.get(1), args.size() > 2 ? args.get(2) : null); + // We need to add the QueryBuilder to the match expression, as it is used to implement equals() and hashCode() and + // thus test the serialization methods. But we can only do this if the parameters make sense . + if (args.get(0) instanceof FieldAttribute && args.get(1).foldable()) { + QueryBuilder queryBuilder = TRANSLATOR_HANDLER.asQuery(LucenePushdownPredicates.DEFAULT, knn).toQueryBuilder(); + knn = (Knn) knn.replaceQueryBuilder(queryBuilder); + } + return knn; + } + + /** + * Copy of the overridden method that doesn't check for children size, as the {@code options} child isn't serialized in Match. + */ + @Override + protected Expression serializeDeserializeExpression(Expression expression) { + Expression newExpression = serializeDeserialize( + expression, + PlanStreamOutput::writeNamedWriteable, + in -> in.readNamedWriteable(Expression.class), + testCase.getConfiguration() // The configuration query should be == to the source text of the function for this to work + ); + // Fields use synthetic sources, which can't be serialized. So we use the originals instead. + return newExpression.replaceChildren(expression.children()); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MatchTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MatchTests.java index 6993f7583dd0..301cbd6844af 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MatchTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MatchTests.java @@ -82,7 +82,7 @@ public class MatchTests extends AbstractMatchFullTextFunctionTests { // thus test the serialization methods. But we can only do this if the parameters make sense . if (args.get(0) instanceof FieldAttribute && args.get(1).foldable()) { QueryBuilder queryBuilder = TRANSLATOR_HANDLER.asQuery(LucenePushdownPredicates.DEFAULT, match).toQueryBuilder(); - match.replaceQueryBuilder(queryBuilder); + match = (Match) match.replaceQueryBuilder(queryBuilder); } return match; } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java index 08c9a2f8b4fe..b0f16b384a48 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java @@ -28,6 +28,8 @@ import org.elasticsearch.index.query.QueryStringQueryBuilder; import org.elasticsearch.index.query.RangeQueryBuilder; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; +import org.elasticsearch.search.vectors.RescoreVectorBuilder; import org.elasticsearch.test.VersionUtils; import org.elasticsearch.xpack.core.enrich.EnrichPolicy; import org.elasticsearch.xpack.esql.EsqlTestUtils; @@ -1359,6 +1361,29 @@ public class LocalPhysicalPlanOptimizerTests extends MapperServiceTestCase { assertThat(expectedQuery.toString(), is(planStr.get())); } + public void testKnnOptionsPushDown() { + String query = """ + from test + | where KNN(dense_vector, [0.1, 0.2, 0.3], + { "k": 5, "similarity": 0.001, "num_candidates": 10, "rescore_oversample": 7, "boost": 3.5 }) + """; + var analyzer = makeAnalyzer("mapping-all-types.json"); + var plan = plannerOptimizer.plan(query, IS_SV_STATS, analyzer); + + AtomicReference planStr = new AtomicReference<>(); + plan.forEachDown(EsQueryExec.class, result -> planStr.set(result.query().toString())); + + var expectedQuery = new KnnVectorQueryBuilder( + "dense_vector", + new float[] { 0.1f, 0.2f, 0.3f }, + 5, + 10, + new RescoreVectorBuilder(7), + 0.001f + ).boost(3.5f); + assertThat(expectedQuery.toString(), is(planStr.get())); + } + /** * Expecting * LimitExec[1000[INTEGER]]