mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-04-24 15:17:30 -04:00
Merge remote-tracking branch 'upstream-main/main' into update-main-10-10-24
This commit is contained in:
commit
f981d1f9e2
486 changed files with 9344 additions and 2819 deletions
|
@ -8,7 +8,11 @@
|
|||
package org.elasticsearch.xpack.rank.rrf;
|
||||
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.search.SearchRequestBuilder;
|
||||
import org.elasticsearch.client.internal.Client;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.index.query.InnerHitBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
|
@ -24,16 +28,23 @@ import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
|
|||
import org.elasticsearch.search.retriever.TestRetrieverBuilder;
|
||||
import org.elasticsearch.search.sort.FieldSortBuilder;
|
||||
import org.elasticsearch.search.sort.SortOrder;
|
||||
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
|
||||
import org.elasticsearch.search.vectors.QueryVectorBuilder;
|
||||
import org.elasticsearch.test.ESIntegTestCase;
|
||||
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS;
|
||||
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.greaterThan;
|
||||
|
@ -652,4 +663,55 @@ public class RRFRetrieverBuilderIT extends ESIntegTestCase {
|
|||
source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD));
|
||||
expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get());
|
||||
}
|
||||
|
||||
public void testRewriteOnce() {
|
||||
final float[] vector = new float[] { 1 };
|
||||
AtomicInteger numAsyncCalls = new AtomicInteger();
|
||||
QueryVectorBuilder vectorBuilder = new QueryVectorBuilder() {
|
||||
@Override
|
||||
public void buildVector(Client client, ActionListener<float[]> listener) {
|
||||
numAsyncCalls.incrementAndGet();
|
||||
listener.onResponse(vector);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
throw new IllegalStateException("Should not be called");
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
throw new IllegalStateException("Should not be called");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
throw new IllegalStateException("Should not be called");
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
throw new IllegalStateException("Should not be called");
|
||||
}
|
||||
};
|
||||
var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, null);
|
||||
var standard = new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", vectorBuilder, 10, 10, null));
|
||||
var rrf = new RRFRetrieverBuilder(
|
||||
List.of(new CompoundRetrieverBuilder.RetrieverSource(knn, null), new CompoundRetrieverBuilder.RetrieverSource(standard, null)),
|
||||
10,
|
||||
10
|
||||
);
|
||||
assertResponse(
|
||||
client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf)),
|
||||
searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value, is(4L))
|
||||
);
|
||||
assertThat(numAsyncCalls.get(), equalTo(2));
|
||||
|
||||
// check that we use the rewritten vector to build the explain query
|
||||
assertResponse(
|
||||
client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf).explain(true)),
|
||||
searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value, is(4L))
|
||||
);
|
||||
assertThat(numAsyncCalls.get(), equalTo(4));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,194 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.rank.rrf;
|
||||
|
||||
import org.elasticsearch.action.admin.cluster.node.capabilities.NodesCapabilitiesRequest;
|
||||
import org.elasticsearch.action.admin.cluster.node.capabilities.NodesCapabilitiesResponse;
|
||||
import org.elasticsearch.action.admin.cluster.stats.SearchUsageStats;
|
||||
import org.elasticsearch.client.Request;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.plugins.Plugin;
|
||||
import org.elasticsearch.rest.RestRequest;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
|
||||
import org.elasticsearch.search.retriever.KnnRetrieverBuilder;
|
||||
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
|
||||
import org.elasticsearch.search.vectors.KnnSearchBuilder;
|
||||
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
|
||||
import org.elasticsearch.test.ESIntegTestCase;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xpack.core.XPackPlugin;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0)
|
||||
public class RRFRetrieverTelemetryIT extends ESIntegTestCase {
|
||||
|
||||
private static final String INDEX_NAME = "test_index";
|
||||
|
||||
@Override
|
||||
protected boolean addMockHttpTransport() {
|
||||
return false; // enable http
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Collection<Class<? extends Plugin>> nodePlugins() {
|
||||
return List.of(RRFRankPlugin.class, XPackPlugin.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
|
||||
return Settings.builder()
|
||||
.put(super.nodeSettings(nodeOrdinal, otherSettings))
|
||||
.put("xpack.license.self_generated.type", "trial")
|
||||
.build();
|
||||
}
|
||||
|
||||
@Before
|
||||
public void setup() throws IOException {
|
||||
XContentBuilder builder = XContentFactory.jsonBuilder()
|
||||
.startObject()
|
||||
.startObject("properties")
|
||||
.startObject("vector")
|
||||
.field("type", "dense_vector")
|
||||
.field("dims", 1)
|
||||
.field("index", true)
|
||||
.field("similarity", "l2_norm")
|
||||
.startObject("index_options")
|
||||
.field("type", "hnsw")
|
||||
.endObject()
|
||||
.endObject()
|
||||
.startObject("text")
|
||||
.field("type", "text")
|
||||
.endObject()
|
||||
.startObject("integer")
|
||||
.field("type", "integer")
|
||||
.endObject()
|
||||
.startObject("topic")
|
||||
.field("type", "keyword")
|
||||
.endObject()
|
||||
.endObject()
|
||||
.endObject();
|
||||
|
||||
assertAcked(prepareCreate(INDEX_NAME).setMapping(builder));
|
||||
ensureGreen(INDEX_NAME);
|
||||
}
|
||||
|
||||
private void performSearch(SearchSourceBuilder source) throws IOException {
|
||||
Request request = new Request("GET", INDEX_NAME + "/_search");
|
||||
request.setJsonEntity(Strings.toString(source));
|
||||
getRestClient().performRequest(request);
|
||||
}
|
||||
|
||||
public void testTelemetryForRRFRetriever() throws IOException {
|
||||
|
||||
if (false == isRetrieverTelemetryEnabled()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers`
|
||||
{
|
||||
performSearch(new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null)));
|
||||
}
|
||||
|
||||
// search#2 - this will record 1 entry for "retriever" in `sections`, 1 for "standard" under `retrievers`, and 1 for "range" under
|
||||
// `queries`
|
||||
{
|
||||
performSearch(new SearchSourceBuilder().retriever(new StandardRetrieverBuilder(QueryBuilders.rangeQuery("integer").gte(2))));
|
||||
}
|
||||
|
||||
// search#3 - this will record 1 entry for "retriever" in `sections`, and 1 for "standard" under `retrievers`, and 1 for "knn" under
|
||||
// `queries`
|
||||
{
|
||||
performSearch(
|
||||
new SearchSourceBuilder().retriever(
|
||||
new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null))
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// search#4 - this will record 1 entry for "retriever" in `sections`, and 1 for "standard" under `retrievers`, and 1 for "term"
|
||||
// under `queries`
|
||||
{
|
||||
performSearch(new SearchSourceBuilder().retriever(new StandardRetrieverBuilder(QueryBuilders.termQuery("topic", "foo"))));
|
||||
}
|
||||
|
||||
// search#5 - this will record 1 entry for "retriever" in `sections`, and 1 for "rrf" under `retrievers`, as well as
|
||||
// 1 for "knn" and 1 for "standard" under `retrievers`, and eventually 1 for "match" under `queries`
|
||||
{
|
||||
performSearch(
|
||||
new SearchSourceBuilder().retriever(
|
||||
new RRFRetrieverBuilder(
|
||||
Arrays.asList(
|
||||
new CompoundRetrieverBuilder.RetrieverSource(
|
||||
new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null),
|
||||
null
|
||||
),
|
||||
new CompoundRetrieverBuilder.RetrieverSource(
|
||||
new StandardRetrieverBuilder(QueryBuilders.matchQuery("text", "foo")),
|
||||
null
|
||||
)
|
||||
),
|
||||
10,
|
||||
10
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// search#6 - this will record 1 entry for "knn" in `sections`
|
||||
{
|
||||
performSearch(new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null))));
|
||||
}
|
||||
|
||||
// search#7 - this will record 1 entry for "query" in `sections`, and 1 for "match_all" under `queries`
|
||||
{
|
||||
performSearch(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery()));
|
||||
}
|
||||
|
||||
// cluster stats
|
||||
{
|
||||
SearchUsageStats stats = clusterAdmin().prepareClusterStats().get().getIndicesStats().getSearchUsageStats();
|
||||
assertEquals(7, stats.getTotalSearchCount());
|
||||
|
||||
assertThat(stats.getSectionsUsage().size(), equalTo(3));
|
||||
assertThat(stats.getSectionsUsage().get("retriever"), equalTo(5L));
|
||||
assertThat(stats.getSectionsUsage().get("query"), equalTo(1L));
|
||||
assertThat(stats.getSectionsUsage().get("knn"), equalTo(1L));
|
||||
|
||||
assertThat(stats.getRetrieversUsage().size(), equalTo(3));
|
||||
assertThat(stats.getRetrieversUsage().get("standard"), equalTo(4L));
|
||||
assertThat(stats.getRetrieversUsage().get("knn"), equalTo(2L));
|
||||
assertThat(stats.getRetrieversUsage().get("rrf"), equalTo(1L));
|
||||
|
||||
assertThat(stats.getQueryUsage().size(), equalTo(5));
|
||||
assertThat(stats.getQueryUsage().get("range"), equalTo(1L));
|
||||
assertThat(stats.getQueryUsage().get("term"), equalTo(1L));
|
||||
assertThat(stats.getQueryUsage().get("match"), equalTo(1L));
|
||||
assertThat(stats.getQueryUsage().get("match_all"), equalTo(1L));
|
||||
assertThat(stats.getQueryUsage().get("knn"), equalTo(1L));
|
||||
}
|
||||
}
|
||||
|
||||
private boolean isRetrieverTelemetryEnabled() throws IOException {
|
||||
NodesCapabilitiesResponse res = clusterAdmin().nodesCapabilities(
|
||||
new NodesCapabilitiesRequest().method(RestRequest.Method.GET).path("_cluster/stats").capabilities("retrievers-usage-stats")
|
||||
).actionGet();
|
||||
return res != null && res.isSupported().orElse(false);
|
||||
}
|
||||
}
|
|
@ -68,6 +68,7 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
|
|||
p.nextToken();
|
||||
String name = p.currentName();
|
||||
RetrieverBuilder retrieverBuilder = p.namedObject(RetrieverBuilder.class, name, c);
|
||||
c.trackRetrieverUsage(retrieverBuilder.getName());
|
||||
p.nextToken();
|
||||
return retrieverBuilder;
|
||||
}, RETRIEVERS_FIELD);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue