Merge remote-tracking branch 'upstream-main/main' into update-main-10-10-24

This commit is contained in:
Simon Cooper 2024-10-10 13:27:33 +01:00
commit f981d1f9e2
486 changed files with 9344 additions and 2819 deletions

View file

@ -8,7 +8,11 @@
package org.elasticsearch.xpack.rank.rrf;
import org.apache.lucene.search.TotalHits;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.query.InnerHitBuilder;
import org.elasticsearch.index.query.QueryBuilder;
@ -24,16 +28,23 @@ import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
import org.elasticsearch.search.retriever.TestRetrieverBuilder;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.search.vectors.QueryVectorBuilder;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentType;
import org.junit.Before;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
@ -652,4 +663,55 @@ public class RRFRetrieverBuilderIT extends ESIntegTestCase {
source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD));
expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get());
}
public void testRewriteOnce() {
final float[] vector = new float[] { 1 };
AtomicInteger numAsyncCalls = new AtomicInteger();
QueryVectorBuilder vectorBuilder = new QueryVectorBuilder() {
@Override
public void buildVector(Client client, ActionListener<float[]> listener) {
numAsyncCalls.incrementAndGet();
listener.onResponse(vector);
}
@Override
public String getWriteableName() {
throw new IllegalStateException("Should not be called");
}
@Override
public TransportVersion getMinimalSupportedVersion() {
throw new IllegalStateException("Should not be called");
}
@Override
public void writeTo(StreamOutput out) throws IOException {
throw new IllegalStateException("Should not be called");
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
throw new IllegalStateException("Should not be called");
}
};
var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, null);
var standard = new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", vectorBuilder, 10, 10, null));
var rrf = new RRFRetrieverBuilder(
List.of(new CompoundRetrieverBuilder.RetrieverSource(knn, null), new CompoundRetrieverBuilder.RetrieverSource(standard, null)),
10,
10
);
assertResponse(
client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf)),
searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value, is(4L))
);
assertThat(numAsyncCalls.get(), equalTo(2));
// check that we use the rewritten vector to build the explain query
assertResponse(
client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf).explain(true)),
searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value, is(4L))
);
assertThat(numAsyncCalls.get(), equalTo(4));
}
}

View file

@ -0,0 +1,194 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.rank.rrf;
import org.elasticsearch.action.admin.cluster.node.capabilities.NodesCapabilitiesRequest;
import org.elasticsearch.action.admin.cluster.node.capabilities.NodesCapabilitiesResponse;
import org.elasticsearch.action.admin.cluster.stats.SearchUsageStats;
import org.elasticsearch.client.Request;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
import org.elasticsearch.search.retriever.KnnRetrieverBuilder;
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
import org.elasticsearch.search.vectors.KnnSearchBuilder;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xpack.core.XPackPlugin;
import org.junit.Before;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.hamcrest.Matchers.equalTo;
@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0)
public class RRFRetrieverTelemetryIT extends ESIntegTestCase {
private static final String INDEX_NAME = "test_index";
@Override
protected boolean addMockHttpTransport() {
return false; // enable http
}
@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return List.of(RRFRankPlugin.class, XPackPlugin.class);
}
@Override
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
return Settings.builder()
.put(super.nodeSettings(nodeOrdinal, otherSettings))
.put("xpack.license.self_generated.type", "trial")
.build();
}
@Before
public void setup() throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject("vector")
.field("type", "dense_vector")
.field("dims", 1)
.field("index", true)
.field("similarity", "l2_norm")
.startObject("index_options")
.field("type", "hnsw")
.endObject()
.endObject()
.startObject("text")
.field("type", "text")
.endObject()
.startObject("integer")
.field("type", "integer")
.endObject()
.startObject("topic")
.field("type", "keyword")
.endObject()
.endObject()
.endObject();
assertAcked(prepareCreate(INDEX_NAME).setMapping(builder));
ensureGreen(INDEX_NAME);
}
private void performSearch(SearchSourceBuilder source) throws IOException {
Request request = new Request("GET", INDEX_NAME + "/_search");
request.setJsonEntity(Strings.toString(source));
getRestClient().performRequest(request);
}
public void testTelemetryForRRFRetriever() throws IOException {
if (false == isRetrieverTelemetryEnabled()) {
return;
}
// search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers`
{
performSearch(new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null)));
}
// search#2 - this will record 1 entry for "retriever" in `sections`, 1 for "standard" under `retrievers`, and 1 for "range" under
// `queries`
{
performSearch(new SearchSourceBuilder().retriever(new StandardRetrieverBuilder(QueryBuilders.rangeQuery("integer").gte(2))));
}
// search#3 - this will record 1 entry for "retriever" in `sections`, and 1 for "standard" under `retrievers`, and 1 for "knn" under
// `queries`
{
performSearch(
new SearchSourceBuilder().retriever(
new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null))
)
);
}
// search#4 - this will record 1 entry for "retriever" in `sections`, and 1 for "standard" under `retrievers`, and 1 for "term"
// under `queries`
{
performSearch(new SearchSourceBuilder().retriever(new StandardRetrieverBuilder(QueryBuilders.termQuery("topic", "foo"))));
}
// search#5 - this will record 1 entry for "retriever" in `sections`, and 1 for "rrf" under `retrievers`, as well as
// 1 for "knn" and 1 for "standard" under `retrievers`, and eventually 1 for "match" under `queries`
{
performSearch(
new SearchSourceBuilder().retriever(
new RRFRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(
new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null),
null
),
new CompoundRetrieverBuilder.RetrieverSource(
new StandardRetrieverBuilder(QueryBuilders.matchQuery("text", "foo")),
null
)
),
10,
10
)
)
);
}
// search#6 - this will record 1 entry for "knn" in `sections`
{
performSearch(new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null))));
}
// search#7 - this will record 1 entry for "query" in `sections`, and 1 for "match_all" under `queries`
{
performSearch(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery()));
}
// cluster stats
{
SearchUsageStats stats = clusterAdmin().prepareClusterStats().get().getIndicesStats().getSearchUsageStats();
assertEquals(7, stats.getTotalSearchCount());
assertThat(stats.getSectionsUsage().size(), equalTo(3));
assertThat(stats.getSectionsUsage().get("retriever"), equalTo(5L));
assertThat(stats.getSectionsUsage().get("query"), equalTo(1L));
assertThat(stats.getSectionsUsage().get("knn"), equalTo(1L));
assertThat(stats.getRetrieversUsage().size(), equalTo(3));
assertThat(stats.getRetrieversUsage().get("standard"), equalTo(4L));
assertThat(stats.getRetrieversUsage().get("knn"), equalTo(2L));
assertThat(stats.getRetrieversUsage().get("rrf"), equalTo(1L));
assertThat(stats.getQueryUsage().size(), equalTo(5));
assertThat(stats.getQueryUsage().get("range"), equalTo(1L));
assertThat(stats.getQueryUsage().get("term"), equalTo(1L));
assertThat(stats.getQueryUsage().get("match"), equalTo(1L));
assertThat(stats.getQueryUsage().get("match_all"), equalTo(1L));
assertThat(stats.getQueryUsage().get("knn"), equalTo(1L));
}
}
private boolean isRetrieverTelemetryEnabled() throws IOException {
NodesCapabilitiesResponse res = clusterAdmin().nodesCapabilities(
new NodesCapabilitiesRequest().method(RestRequest.Method.GET).path("_cluster/stats").capabilities("retrievers-usage-stats")
).actionGet();
return res != null && res.isSupported().orElse(false);
}
}

View file

@ -68,6 +68,7 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
p.nextToken();
String name = p.currentName();
RetrieverBuilder retrieverBuilder = p.namedObject(RetrieverBuilder.class, name, c);
c.trackRetrieverUsage(retrieverBuilder.getName());
p.nextToken();
return retrieverBuilder;
}, RETRIEVERS_FIELD);