diff --git a/docs/changelog/126035.yaml b/docs/changelog/126035.yaml new file mode 100644 index 000000000000..1a96205e476e --- /dev/null +++ b/docs/changelog/126035.yaml @@ -0,0 +1,5 @@ +pr: 126035 +summary: Fix top level knn search with scroll +area: Vector Search +type: bug +issues: [] diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/KnnSearchIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/KnnSearchIT.java new file mode 100644 index 000000000000..91409e5e7018 --- /dev/null +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/KnnSearchIT.java @@ -0,0 +1,137 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search; + +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.vectors.KnnSearchBuilder; +import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; + +import java.util.List; + +import static org.hamcrest.Matchers.notNullValue; + +@ESIntegTestCase.ClusterScope(minNumDataNodes = 2) +public class KnnSearchIT extends ESIntegTestCase { + + private static final String INDEX_NAME = "test_knn_index"; + private static final String VECTOR_FIELD = "vector"; + + private XContentBuilder createKnnMapping() throws Exception { + return XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(VECTOR_FIELD) + .field("type", "dense_vector") + .field("dims", 2) + .field("index", true) + .field("similarity", "l2_norm") + .startObject("index_options") + .field("type", "hnsw") + .endObject() + .endObject() + .startObject("category") + .field("type", "keyword") + .endObject() + .endObject() + .endObject(); + } + + public void testKnnSearchWithScroll() throws Exception { + final int numShards = randomIntBetween(1, 3); + Client client = client(); + client.admin() + .indices() + .prepareCreate(INDEX_NAME) + .setSettings(Settings.builder().put("index.number_of_shards", numShards)) + .setMapping(createKnnMapping()) + .get(); + + final int count = 100; + for (int i = 0; i < count; i++) { + XContentBuilder source = XContentFactory.jsonBuilder() + .startObject() + .field(VECTOR_FIELD, new float[] { i * 0.1f, i * 0.1f }) + .field("category", i >= 90 ? "last_ten" : null) + .endObject(); + client.prepareIndex(INDEX_NAME).setSource(source).get(); + } + refresh(INDEX_NAME); + + final int k = randomIntBetween(11, 15); + // test top level knn search + { + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.knnSearch(List.of(new KnnSearchBuilder(VECTOR_FIELD, new float[] { 0, 0 }, k, 100, null, null))); + executeScrollSearch(client, sourceBuilder, k); + } + // test top level knn search + another query + { + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.knnSearch(List.of(new KnnSearchBuilder(VECTOR_FIELD, new float[] { 0, 0 }, k, 100, null, null))); + sourceBuilder.query(QueryBuilders.existsQuery("category").boost(10)); + executeScrollSearch(client, sourceBuilder, k + 10); + } + + // test knn query + { + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 0, 0 }, k, 100, null, null)); + executeScrollSearch(client, sourceBuilder, k * numShards); + } + // test knn query + another query + { + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query( + QueryBuilders.boolQuery() + .should(new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 0, 0 }, k, 100, null, null)) + .should(QueryBuilders.existsQuery("category").boost(10)) + ); + executeScrollSearch(client, sourceBuilder, k * numShards + 10); + } + + } + + private static void executeScrollSearch(Client client, SearchSourceBuilder sourceBuilder, int expectedNumHits) { + SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.source(sourceBuilder).scroll(TimeValue.timeValueMinutes(1)); + + SearchResponse searchResponse = client.search(searchRequest).actionGet(); + int hitsCollected = 0; + float prevScore = Float.POSITIVE_INFINITY; + try { + do { + assertThat(searchResponse.getScrollId(), notNullValue()); + assertEquals(expectedNumHits, searchResponse.getHits().getTotalHits().value()); + // assert correct order of returned hits + for (var searchHit : searchResponse.getHits()) { + assert (searchHit.getScore() <= prevScore); + prevScore = searchHit.getScore(); + hitsCollected += 1; + } + searchResponse.decRef(); + searchResponse = client().prepareSearchScroll(searchResponse.getScrollId()).setScroll(TimeValue.timeValueMinutes(1)).get(); + } while (searchResponse.getHits().getHits().length > 0); + } finally { + assertEquals(expectedNumHits, hitsCollected); + clearScroll(searchResponse.getScrollId()); + searchResponse.decRef(); + } + } + +} diff --git a/server/src/main/java/org/elasticsearch/search/internal/LegacyReaderContext.java b/server/src/main/java/org/elasticsearch/search/internal/LegacyReaderContext.java index e028c3739053..3d0e63868db5 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/LegacyReaderContext.java +++ b/server/src/main/java/org/elasticsearch/search/internal/LegacyReaderContext.java @@ -72,6 +72,11 @@ public final class LegacyReaderContext extends ReaderContext { @Override public ShardSearchRequest getShardSearchRequest(ShardSearchRequest other) { + if (other != null) { + // The top level knn search modifies the source after the DFS phase. + // so we need to update the source stored in the context. + shardSearchRequest.source(other.source()); + } return shardSearchRequest; }