Add telemetry for retrievers (#114109)

This commit is contained in:
Panagiotis Bailis 2024-10-10 09:57:42 +03:00 committed by GitHub
parent fd40f8b1cb
commit 4eab631e5f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 677 additions and 47 deletions

View file

@ -0,0 +1,5 @@
pr: 114109
summary: Update cluster stats for retrievers
area: Search
type: enhancement
issues: []

View file

@ -762,6 +762,10 @@ Queries are counted once per search request, meaning that if the same query type
(object) Search sections used in selected nodes.
For each section, name and number of times it's been used is reported.
`retrievers`::
(object) Retriever types that were used in selected nodes.
For each retriever, name and number of times it's been used is reported.
=====
`dense_vector`::

View file

@ -0,0 +1,151 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.search.retriever;
import org.elasticsearch.action.admin.cluster.node.capabilities.NodesCapabilitiesRequest;
import org.elasticsearch.action.admin.cluster.node.capabilities.NodesCapabilitiesResponse;
import org.elasticsearch.action.admin.cluster.stats.SearchUsageStats;
import org.elasticsearch.client.Request;
import org.elasticsearch.common.Strings;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.vectors.KnnSearchBuilder;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.junit.Before;
import java.io.IOException;
import java.util.List;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.hamcrest.Matchers.equalTo;
@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0)
public class RetrieverTelemetryIT extends ESIntegTestCase {
private static final String INDEX_NAME = "test_index";
@Override
protected boolean addMockHttpTransport() {
return false; // enable http
}
@Before
public void setup() throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject("vector")
.field("type", "dense_vector")
.field("dims", 1)
.field("index", true)
.field("similarity", "l2_norm")
.startObject("index_options")
.field("type", "hnsw")
.endObject()
.endObject()
.startObject("text")
.field("type", "text")
.endObject()
.startObject("integer")
.field("type", "integer")
.endObject()
.startObject("topic")
.field("type", "keyword")
.endObject()
.endObject()
.endObject();
assertAcked(prepareCreate(INDEX_NAME).setMapping(builder));
ensureGreen(INDEX_NAME);
}
private void performSearch(SearchSourceBuilder source) throws IOException {
Request request = new Request("GET", INDEX_NAME + "/_search");
request.setJsonEntity(Strings.toString(source));
getRestClient().performRequest(request);
}
public void testTelemetryForRetrievers() throws IOException {
if (false == isRetrieverTelemetryEnabled()) {
return;
}
// search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers`
{
performSearch(new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null)));
}
// search#2 - this will record 1 entry for "retriever" in `sections`, 1 for "standard" under `retrievers`, and 1 for "range" under
// `queries`
{
performSearch(new SearchSourceBuilder().retriever(new StandardRetrieverBuilder(QueryBuilders.rangeQuery("integer").gte(2))));
}
// search#3 - this will record 1 entry for "retriever" in `sections`, and 1 for "standard" under `retrievers`, and 1 for "knn" under
// `queries`
{
performSearch(
new SearchSourceBuilder().retriever(
new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null))
)
);
}
// search#4 - this will record 1 entry for "retriever" in `sections`, and 1 for "standard" under `retrievers`, and 1 for "term"
// under `queries`
{
performSearch(new SearchSourceBuilder().retriever(new StandardRetrieverBuilder(QueryBuilders.termQuery("topic", "foo"))));
}
// search#5 - t
// his will record 1 entry for "knn" in `sections`
{
performSearch(new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null))));
}
// search#6 - this will record 1 entry for "query" in `sections`, and 1 for "match_all" under `queries`
{
performSearch(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery()));
}
// cluster stats
{
SearchUsageStats stats = clusterAdmin().prepareClusterStats().get().getIndicesStats().getSearchUsageStats();
assertEquals(6, stats.getTotalSearchCount());
assertThat(stats.getSectionsUsage().size(), equalTo(3));
assertThat(stats.getSectionsUsage().get("retriever"), equalTo(4L));
assertThat(stats.getSectionsUsage().get("query"), equalTo(1L));
assertThat(stats.getSectionsUsage().get("knn"), equalTo(1L));
assertThat(stats.getRetrieversUsage().size(), equalTo(2));
assertThat(stats.getRetrieversUsage().get("standard"), equalTo(3L));
assertThat(stats.getRetrieversUsage().get("knn"), equalTo(1L));
assertThat(stats.getQueryUsage().size(), equalTo(4));
assertThat(stats.getQueryUsage().get("range"), equalTo(1L));
assertThat(stats.getQueryUsage().get("term"), equalTo(1L));
assertThat(stats.getQueryUsage().get("match_all"), equalTo(1L));
assertThat(stats.getQueryUsage().get("knn"), equalTo(1L));
}
}
private boolean isRetrieverTelemetryEnabled() throws IOException {
NodesCapabilitiesResponse res = clusterAdmin().nodesCapabilities(
new NodesCapabilitiesRequest().method(RestRequest.Method.GET).path("_cluster/stats").capabilities("retrievers-usage-stats")
).actionGet();
return res != null && res.isSupported().orElse(false);
}
}

View file

@ -238,6 +238,7 @@ public class TransportVersions {
public static final TransportVersion FAST_REFRESH_RCO = def(8_762_00_0);
public static final TransportVersion TEXT_SIMILARITY_RERANKER_QUERY_REWRITE = def(8_763_00_0);
public static final TransportVersion SIMULATE_INDEX_TEMPLATES_SUBSTITUTIONS = def(8_764_00_0);
public static final TransportVersion RETRIEVERS_TELEMETRY_ADDED = def(8_765_00_0);
/*
* STOP! READ THIS FIRST! No, really,

View file

@ -22,6 +22,7 @@ import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.TransportVersions.RETRIEVERS_TELEMETRY_ADDED;
import static org.elasticsearch.TransportVersions.V_8_12_0;
/**
@ -34,6 +35,7 @@ public final class SearchUsageStats implements Writeable, ToXContentFragment {
private final Map<String, Long> queries;
private final Map<String, Long> rescorers;
private final Map<String, Long> sections;
private final Map<String, Long> retrievers;
/**
* Creates a new empty stats instance, that will get additional stats added through {@link #add(SearchUsageStats)}
@ -43,17 +45,25 @@ public final class SearchUsageStats implements Writeable, ToXContentFragment {
this.queries = new HashMap<>();
this.sections = new HashMap<>();
this.rescorers = new HashMap<>();
this.retrievers = new HashMap<>();
}
/**
* Creates a new stats instance with the provided info. The expectation is that when a new instance is created using
* this constructor, the provided stats are final and won't be modified further.
*/
public SearchUsageStats(Map<String, Long> queries, Map<String, Long> rescorers, Map<String, Long> sections, long totalSearchCount) {
public SearchUsageStats(
Map<String, Long> queries,
Map<String, Long> rescorers,
Map<String, Long> sections,
Map<String, Long> retrievers,
long totalSearchCount
) {
this.totalSearchCount = totalSearchCount;
this.queries = queries;
this.sections = sections;
this.rescorers = rescorers;
this.retrievers = retrievers;
}
public SearchUsageStats(StreamInput in) throws IOException {
@ -61,6 +71,7 @@ public final class SearchUsageStats implements Writeable, ToXContentFragment {
this.sections = in.readMap(StreamInput::readLong);
this.totalSearchCount = in.readVLong();
this.rescorers = in.getTransportVersion().onOrAfter(V_8_12_0) ? in.readMap(StreamInput::readLong) : Map.of();
this.retrievers = in.getTransportVersion().onOrAfter(RETRIEVERS_TELEMETRY_ADDED) ? in.readMap(StreamInput::readLong) : Map.of();
}
@Override
@ -72,6 +83,9 @@ public final class SearchUsageStats implements Writeable, ToXContentFragment {
if (out.getTransportVersion().onOrAfter(V_8_12_0)) {
out.writeMap(rescorers, StreamOutput::writeLong);
}
if (out.getTransportVersion().onOrAfter(RETRIEVERS_TELEMETRY_ADDED)) {
out.writeMap(retrievers, StreamOutput::writeLong);
}
}
/**
@ -81,6 +95,7 @@ public final class SearchUsageStats implements Writeable, ToXContentFragment {
stats.queries.forEach((query, count) -> queries.merge(query, count, Long::sum));
stats.rescorers.forEach((rescorer, count) -> rescorers.merge(rescorer, count, Long::sum));
stats.sections.forEach((query, count) -> sections.merge(query, count, Long::sum));
stats.retrievers.forEach((query, count) -> retrievers.merge(query, count, Long::sum));
this.totalSearchCount += stats.totalSearchCount;
}
@ -95,6 +110,8 @@ public final class SearchUsageStats implements Writeable, ToXContentFragment {
builder.map(rescorers);
builder.field("sections");
builder.map(sections);
builder.field("retrievers");
builder.map(retrievers);
}
builder.endObject();
return builder;
@ -112,6 +129,10 @@ public final class SearchUsageStats implements Writeable, ToXContentFragment {
return Collections.unmodifiableMap(sections);
}
public Map<String, Long> getRetrieversUsage() {
return Collections.unmodifiableMap(retrievers);
}
public long getTotalSearchCount() {
return totalSearchCount;
}
@ -128,12 +149,13 @@ public final class SearchUsageStats implements Writeable, ToXContentFragment {
return totalSearchCount == that.totalSearchCount
&& queries.equals(that.queries)
&& rescorers.equals(that.rescorers)
&& sections.equals(that.sections);
&& sections.equals(that.sections)
&& retrievers.equals(that.retrievers);
}
@Override
public int hashCode() {
return Objects.hash(totalSearchCount, queries, rescorers, sections);
return Objects.hash(totalSearchCount, queries, rescorers, sections, retrievers);
}
@Override

View file

@ -32,7 +32,8 @@ public class RestClusterStatsAction extends BaseRestHandler {
private static final Set<String> SUPPORTED_CAPABILITIES = Set.of(
"human-readable-total-docs-size",
"verbose-dense-vector-mapping-stats",
"ccs-stats"
"ccs-stats",
"retrievers-usage-stats"
);
private static final Set<String> SUPPORTED_QUERY_PARAMETERS = Set.of("include_remotes", "nodeId", REST_TIMEOUT_PARAM);

View file

@ -1409,6 +1409,7 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
parser,
new RetrieverParserContext(searchUsage, clusterSupportsFeature)
);
searchUsage.trackSectionUsage(RETRIEVER.getPreferredName());
} else if (QUERY_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
if (subSearchSourceBuilders.isEmpty() == false) {
throw new IllegalArgumentException(

View file

@ -62,11 +62,11 @@ public abstract class RetrieverBuilder implements Rewriteable<RetrieverBuilder>,
String name,
AbstractObjectParser<? extends RetrieverBuilder, RetrieverParserContext> parser
) {
parser.declareObjectArray((r, v) -> r.preFilterQueryBuilders = v, (p, c) -> {
QueryBuilder preFilterQueryBuilder = AbstractQueryBuilder.parseTopLevelQuery(p, c::trackQueryUsage);
c.trackSectionUsage(name + ":" + PRE_FILTER_FIELD.getPreferredName());
return preFilterQueryBuilder;
}, PRE_FILTER_FIELD);
parser.declareObjectArray(
(r, v) -> r.preFilterQueryBuilders = v,
(p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p, c::trackQueryUsage),
PRE_FILTER_FIELD
);
parser.declareString(RetrieverBuilder::retrieverName, NAME_FIELD);
parser.declareFloat(RetrieverBuilder::minScore, MIN_SCORE_FIELD);
}
@ -138,7 +138,7 @@ public abstract class RetrieverBuilder implements Rewriteable<RetrieverBuilder>,
throw new ParsingException(new XContentLocation(nonfe.getLineNumber(), nonfe.getColumnNumber()), message, nonfe);
}
context.trackSectionUsage(retrieverName);
context.trackRetrieverUsage(retrieverName);
if (parser.currentToken() != XContentParser.Token.END_OBJECT) {
throw new ParsingException(

View file

@ -37,6 +37,10 @@ public class RetrieverParserContext {
searchUsage.trackRescorerUsage(name);
}
public void trackRetrieverUsage(String name) {
searchUsage.trackRetrieverUsage(name);
}
public boolean clusterSupportsFeature(NodeFeature nodeFeature) {
return clusterSupportsFeature != null && clusterSupportsFeature.test(nodeFeature);
}

View file

@ -55,36 +55,28 @@ public final class StandardRetrieverBuilder extends RetrieverBuilder implements
static {
PARSER.declareObject((r, v) -> r.queryBuilder = v, (p, c) -> {
QueryBuilder queryBuilder = AbstractQueryBuilder.parseTopLevelQuery(p, c::trackQueryUsage);
c.trackSectionUsage(NAME + ":" + QUERY_FIELD.getPreferredName());
return queryBuilder;
}, QUERY_FIELD);
PARSER.declareField((r, v) -> r.searchAfterBuilder = v, (p, c) -> {
SearchAfterBuilder searchAfterBuilder = SearchAfterBuilder.fromXContent(p);
c.trackSectionUsage(NAME + ":" + SEARCH_AFTER_FIELD.getPreferredName());
return searchAfterBuilder;
}, SEARCH_AFTER_FIELD, ObjectParser.ValueType.OBJECT_ARRAY);
PARSER.declareField((r, v) -> r.terminateAfter = v, (p, c) -> {
int terminateAfter = p.intValue();
c.trackSectionUsage(NAME + ":" + TERMINATE_AFTER_FIELD.getPreferredName());
return terminateAfter;
}, TERMINATE_AFTER_FIELD, ObjectParser.ValueType.INT);
PARSER.declareField((r, v) -> r.sortBuilders = v, (p, c) -> {
List<SortBuilder<?>> sortBuilders = SortBuilder.fromXContent(p);
c.trackSectionUsage(NAME + ":" + SORT_FIELD.getPreferredName());
return sortBuilders;
}, SORT_FIELD, ObjectParser.ValueType.OBJECT_ARRAY);
PARSER.declareField((r, v) -> r.collapseBuilder = v, (p, c) -> {
CollapseBuilder collapseBuilder = CollapseBuilder.fromXContent(p);
if (collapseBuilder.getField() != null) {
c.trackSectionUsage(COLLAPSE_FIELD.getPreferredName());
}
return collapseBuilder;
}, COLLAPSE_FIELD, ObjectParser.ValueType.OBJECT);
PARSER.declareField(
(r, v) -> r.searchAfterBuilder = v,
(p, c) -> SearchAfterBuilder.fromXContent(p),
SEARCH_AFTER_FIELD,
ObjectParser.ValueType.OBJECT_ARRAY
);
PARSER.declareField((r, v) -> r.terminateAfter = v, (p, c) -> p.intValue(), TERMINATE_AFTER_FIELD, ObjectParser.ValueType.INT);
PARSER.declareField(
(r, v) -> r.sortBuilders = v,
(p, c) -> SortBuilder.fromXContent(p),
SORT_FIELD,
ObjectParser.ValueType.OBJECT_ARRAY
);
PARSER.declareField(
(r, v) -> r.collapseBuilder = v,
(p, c) -> CollapseBuilder.fromXContent(p),
COLLAPSE_FIELD,
ObjectParser.ValueType.OBJECT
);
RetrieverBuilder.declareBaseParserFields(NAME, PARSER);
}

View file

@ -20,6 +20,7 @@ public final class SearchUsage {
private final Set<String> queries = new HashSet<>();
private final Set<String> rescorers = new HashSet<>();
private final Set<String> sections = new HashSet<>();
private final Set<String> retrievers = new HashSet<>();
/**
* Track the usage of the provided query
@ -42,6 +43,13 @@ public final class SearchUsage {
rescorers.add(name);
}
/**
* Track retrieve usage
*/
public void trackRetrieverUsage(String retriever) {
retrievers.add(retriever);
}
/**
* Returns the query types that have been used at least once in the tracked search request
*/
@ -62,4 +70,11 @@ public final class SearchUsage {
public Set<String> getSectionsUsage() {
return Collections.unmodifiableSet(sections);
}
/**
* Returns the retriever names that have been used at least once in the tracked search request
*/
public Set<String> getRetrieverUsage() {
return Collections.unmodifiableSet(retrievers);
}
}

View file

@ -27,6 +27,7 @@ public final class SearchUsageHolder {
private final Map<String, LongAdder> queriesUsage = new ConcurrentHashMap<>();
private final Map<String, LongAdder> rescorersUsage = new ConcurrentHashMap<>();
private final Map<String, LongAdder> sectionsUsage = new ConcurrentHashMap<>();
private final Map<String, LongAdder> retrieversUsage = new ConcurrentHashMap<>();
SearchUsageHolder() {}
@ -44,6 +45,9 @@ public final class SearchUsageHolder {
for (String rescorer : searchUsage.getRescorerUsage()) {
rescorersUsage.computeIfAbsent(rescorer, q -> new LongAdder()).increment();
}
for (String retriever : searchUsage.getRetrieverUsage()) {
retrieversUsage.computeIfAbsent(retriever, q -> new LongAdder()).increment();
}
}
/**
@ -56,10 +60,13 @@ public final class SearchUsageHolder {
sectionsUsage.forEach((query, adder) -> sectionsUsageMap.put(query, adder.longValue()));
Map<String, Long> rescorersUsageMap = Maps.newMapWithExpectedSize(rescorersUsage.size());
rescorersUsage.forEach((query, adder) -> rescorersUsageMap.put(query, adder.longValue()));
Map<String, Long> retrieversUsageMap = Maps.newMapWithExpectedSize(retrieversUsage.size());
retrieversUsage.forEach((retriever, adder) -> retrieversUsageMap.put(retriever, adder.longValue()));
return new SearchUsageStats(
Collections.unmodifiableMap(queriesUsageMap),
Collections.unmodifiableMap(rescorersUsageMap),
Collections.unmodifiableMap(sectionsUsageMap),
Collections.unmodifiableMap(retrieversUsageMap),
totalSearchCount.longValue()
);
}

View file

@ -10,6 +10,7 @@
package org.elasticsearch.action.admin.cluster.stats;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable.Reader;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
@ -43,9 +44,12 @@ public class SearchUsageStatsTests extends AbstractWireSerializingTestCase<Searc
"terminate_after",
"indices_boost",
"range",
"script_score"
"script_score",
"retrievers"
);
private static final List<String> RETRIEVERS = List.of("standard", "knn", "rrf", "random", "text_similarity_reranker");
@Override
protected Reader<SearchUsageStats> instanceReader() {
return SearchUsageStats::new;
@ -75,6 +79,14 @@ public class SearchUsageStatsTests extends AbstractWireSerializingTestCase<Searc
return rescorerUsage;
}
private static Map<String, Long> randomRetrieversUsage(int size) {
Map<String, Long> retrieversUsage = new HashMap<>();
while (retrieversUsage.size() < size) {
retrieversUsage.put(randomFrom(RETRIEVERS), randomLongBetween(1, Long.MAX_VALUE));
}
return retrieversUsage;
}
@Override
protected SearchUsageStats createTestInstance() {
if (randomBoolean()) {
@ -84,6 +96,7 @@ public class SearchUsageStatsTests extends AbstractWireSerializingTestCase<Searc
randomQueryUsage(randomIntBetween(0, QUERY_TYPES.size())),
randomRescorerUsage(randomIntBetween(0, RESCORER_TYPES.size())),
randomSectionsUsage(randomIntBetween(0, SECTIONS.size())),
randomRetrieversUsage(randomIntBetween(0, RETRIEVERS.size())),
randomLongBetween(10, Long.MAX_VALUE)
);
}
@ -96,26 +109,38 @@ public class SearchUsageStatsTests extends AbstractWireSerializingTestCase<Searc
randomValueOtherThan(instance.getQueryUsage(), () -> randomQueryUsage(randomIntBetween(0, QUERY_TYPES.size()))),
instance.getRescorerUsage(),
instance.getSectionsUsage(),
instance.getRetrieversUsage(),
instance.getTotalSearchCount()
);
case 1 -> new SearchUsageStats(
instance.getQueryUsage(),
randomValueOtherThan(instance.getRescorerUsage(), () -> randomRescorerUsage(randomIntBetween(0, RESCORER_TYPES.size()))),
instance.getSectionsUsage(),
instance.getRetrieversUsage(),
instance.getTotalSearchCount()
);
case 2 -> new SearchUsageStats(
instance.getQueryUsage(),
instance.getRescorerUsage(),
randomValueOtherThan(instance.getSectionsUsage(), () -> randomSectionsUsage(randomIntBetween(0, SECTIONS.size()))),
instance.getRetrieversUsage(),
instance.getTotalSearchCount()
);
default -> new SearchUsageStats(
case 3 -> new SearchUsageStats(
instance.getQueryUsage(),
instance.getRescorerUsage(),
instance.getSectionsUsage(),
randomLongBetween(10, Long.MAX_VALUE)
randomValueOtherThan(instance.getRetrieversUsage(), () -> randomSectionsUsage(randomIntBetween(0, SECTIONS.size()))),
instance.getTotalSearchCount()
);
case 4 -> new SearchUsageStats(
instance.getQueryUsage(),
instance.getRescorerUsage(),
instance.getSectionsUsage(),
instance.getRetrieversUsage(),
randomValueOtherThan(instance.getTotalSearchCount(), () -> randomLongBetween(10, Long.MAX_VALUE))
);
default -> throw new IllegalStateException("Unexpected value: " + i);
};
}
@ -126,7 +151,9 @@ public class SearchUsageStatsTests extends AbstractWireSerializingTestCase<Searc
assertEquals(Map.of(), searchUsageStats.getSectionsUsage());
assertEquals(0, searchUsageStats.getTotalSearchCount());
searchUsageStats.add(new SearchUsageStats(Map.of("match", 10L), Map.of("query", 5L), Map.of("query", 10L), 10L));
searchUsageStats.add(
new SearchUsageStats(Map.of("match", 10L), Map.of("query", 5L), Map.of("query", 10L), Map.of("knn", 10L), 10L)
);
assertEquals(Map.of("match", 10L), searchUsageStats.getQueryUsage());
assertEquals(Map.of("query", 10L), searchUsageStats.getSectionsUsage());
assertEquals(Map.of("query", 5L), searchUsageStats.getRescorerUsage());
@ -137,19 +164,28 @@ public class SearchUsageStatsTests extends AbstractWireSerializingTestCase<Searc
Map.of("term", 1L, "match", 1L),
Map.of("query", 5L, "learning_to_rank", 2L),
Map.of("query", 10L, "knn", 1L),
Map.of("knn", 10L, "rrf", 2L),
10L
)
);
assertEquals(Map.of("match", 11L, "term", 1L), searchUsageStats.getQueryUsage());
assertEquals(Map.of("query", 20L, "knn", 1L), searchUsageStats.getSectionsUsage());
assertEquals(Map.of("query", 10L, "learning_to_rank", 2L), searchUsageStats.getRescorerUsage());
assertEquals(Map.of("knn", 20L, "rrf", 2L), searchUsageStats.getRetrieversUsage());
assertEquals(20L, searchUsageStats.getTotalSearchCount());
}
public void testToXContent() throws IOException {
SearchUsageStats searchUsageStats = new SearchUsageStats(Map.of("term", 1L), Map.of("query", 2L), Map.of("query", 10L), 10L);
SearchUsageStats searchUsageStats = new SearchUsageStats(
Map.of("term", 1L),
Map.of("query", 2L),
Map.of("query", 10L),
Map.of("knn", 10L),
10L
);
assertEquals(
"{\"search\":{\"total\":10,\"queries\":{\"term\":1},\"rescorers\":{\"query\":2},\"sections\":{\"query\":10}}}",
"{\"search\":{\"total\":10,\"queries\":{\"term\":1},\"rescorers\":{\"query\":2},"
+ "\"sections\":{\"query\":10},\"retrievers\":{\"knn\":10}}}",
Strings.toString(searchUsageStats)
);
}
@ -161,8 +197,9 @@ public class SearchUsageStatsTests extends AbstractWireSerializingTestCase<Searc
for (TransportVersion version : TransportVersionUtils.allReleasedVersions()) {
SearchUsageStats testInstance = new SearchUsageStats(
randomQueryUsage(QUERY_TYPES.size()),
Map.of(),
version.onOrAfter(TransportVersions.V_8_12_0) ? randomRescorerUsage(RESCORER_TYPES.size()) : Map.of(),
randomSectionsUsage(SECTIONS.size()),
version.onOrAfter(TransportVersions.RETRIEVERS_TELEMETRY_ADDED) ? randomRetrieversUsage(RETRIEVERS.size()) : Map.of(),
randomLongBetween(0, Long.MAX_VALUE)
);
assertSerialization(testInstance, version);

View file

@ -64,7 +64,11 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
});
static {
PARSER.declareNamedObject(constructorArg(), (p, c, n) -> p.namedObject(RetrieverBuilder.class, n, c), RETRIEVER_FIELD);
PARSER.declareNamedObject(constructorArg(), (p, c, n) -> {
RetrieverBuilder innerRetriever = p.namedObject(RetrieverBuilder.class, n, c);
c.trackRetrieverUsage(innerRetriever.getName());
return innerRetriever;
}, RETRIEVER_FIELD);
PARSER.declareString(constructorArg(), INFERENCE_ID_FIELD);
PARSER.declareString(constructorArg(), INFERENCE_TEXT_FIELD);
PARSER.declareString(constructorArg(), FIELD_FIELD);

View file

@ -117,7 +117,10 @@ public class TextSimilarityRankRetrieverBuilderTests extends AbstractXContentTes
}""";
try (XContentParser parser = createParser(JsonXContent.jsonXContent, json)) {
TextSimilarityRankRetrieverBuilder parsed = TextSimilarityRankRetrieverBuilder.PARSER.parse(parser, null);
TextSimilarityRankRetrieverBuilder parsed = TextSimilarityRankRetrieverBuilder.PARSER.parse(
parser,
new RetrieverParserContext(new SearchUsage(), nf -> true)
);
assertEquals(DEFAULT_RANK_WINDOW_SIZE, parsed.rankWindowSize());
}
}

View file

@ -0,0 +1,187 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.rank.textsimilarity;
import org.elasticsearch.action.admin.cluster.node.capabilities.NodesCapabilitiesRequest;
import org.elasticsearch.action.admin.cluster.node.capabilities.NodesCapabilitiesResponse;
import org.elasticsearch.action.admin.cluster.stats.SearchUsageStats;
import org.elasticsearch.client.Request;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.retriever.KnnRetrieverBuilder;
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
import org.elasticsearch.search.vectors.KnnSearchBuilder;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xpack.core.XPackPlugin;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.junit.Before;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.hamcrest.Matchers.equalTo;
@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0)
public class TextSimilarityRankRetrieverTelemetryTests extends ESIntegTestCase {
private static final String INDEX_NAME = "test_index";
@Override
protected boolean addMockHttpTransport() {
return false; // enable http
}
@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return List.of(InferencePlugin.class, XPackPlugin.class, TextSimilarityTestPlugin.class);
}
@Override
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
return Settings.builder()
.put(super.nodeSettings(nodeOrdinal, otherSettings))
.put("xpack.license.self_generated.type", "trial")
.build();
}
@Before
public void setup() throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject("vector")
.field("type", "dense_vector")
.field("dims", 1)
.field("index", true)
.field("similarity", "l2_norm")
.startObject("index_options")
.field("type", "hnsw")
.endObject()
.endObject()
.startObject("text")
.field("type", "text")
.endObject()
.startObject("integer")
.field("type", "integer")
.endObject()
.startObject("topic")
.field("type", "keyword")
.endObject()
.endObject()
.endObject();
assertAcked(prepareCreate(INDEX_NAME).setMapping(builder));
ensureGreen(INDEX_NAME);
}
private void performSearch(SearchSourceBuilder source) throws IOException {
Request request = new Request("GET", INDEX_NAME + "/_search");
request.setJsonEntity(Strings.toString(source));
getRestClient().performRequest(request);
}
public void testTelemetryForRRFRetriever() throws IOException {
if (false == isRetrieverTelemetryEnabled()) {
return;
}
// search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers`
{
performSearch(new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null)));
}
// search#2 - this will record 1 entry for "retriever" in `sections`, 1 for "standard" under `retrievers`, and 1 for "range" under
// `queries`
{
performSearch(new SearchSourceBuilder().retriever(new StandardRetrieverBuilder(QueryBuilders.rangeQuery("integer").gte(2))));
}
// search#3 - this will record 1 entry for "retriever" in `sections`, and 1 for "standard" under `retrievers`, and 1 for "knn" under
// `queries`
{
performSearch(
new SearchSourceBuilder().retriever(
new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null))
)
);
}
// search#4 - this will record 1 entry for "retriever" in `sections`, and 1 for "standard" under `retrievers`, and 1 for "term"
// under `queries`
{
performSearch(new SearchSourceBuilder().retriever(new StandardRetrieverBuilder(QueryBuilders.termQuery("topic", "foo"))));
}
// search#5 - this will record 1 entry for "retriever" in `sections`, and 1 for "text_similarity_reranker" under `retrievers`, as
// well as
// 1 "standard" under `retrievers`, and eventually 1 for "match" under `queries`
{
performSearch(
new SearchSourceBuilder().retriever(
new TextSimilarityRankRetrieverBuilder(
new StandardRetrieverBuilder(QueryBuilders.matchQuery("text", "foo")),
"some_inference_id",
"some_inference_text",
"some_field",
10
)
)
);
}
// search#6 - this will record 1 entry for "knn" in `sections`
{
performSearch(new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null))));
}
// search#7 - this will record 1 entry for "query" in `sections`, and 1 for "match_all" under `queries`
{
performSearch(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery()));
}
// cluster stats
{
SearchUsageStats stats = clusterAdmin().prepareClusterStats().get().getIndicesStats().getSearchUsageStats();
assertEquals(7, stats.getTotalSearchCount());
assertThat(stats.getSectionsUsage().size(), equalTo(3));
assertThat(stats.getSectionsUsage().get("retriever"), equalTo(5L));
assertThat(stats.getSectionsUsage().get("query"), equalTo(1L));
assertThat(stats.getSectionsUsage().get("knn"), equalTo(1L));
assertThat(stats.getRetrieversUsage().size(), equalTo(3));
assertThat(stats.getRetrieversUsage().get("standard"), equalTo(4L));
assertThat(stats.getRetrieversUsage().get("knn"), equalTo(1L));
assertThat(stats.getRetrieversUsage().get("text_similarity_reranker"), equalTo(1L));
assertThat(stats.getQueryUsage().size(), equalTo(5));
assertThat(stats.getQueryUsage().get("range"), equalTo(1L));
assertThat(stats.getQueryUsage().get("term"), equalTo(1L));
assertThat(stats.getQueryUsage().get("match"), equalTo(1L));
assertThat(stats.getQueryUsage().get("match_all"), equalTo(1L));
assertThat(stats.getQueryUsage().get("knn"), equalTo(1L));
}
}
private boolean isRetrieverTelemetryEnabled() throws IOException {
NodesCapabilitiesResponse res = clusterAdmin().nodesCapabilities(
new NodesCapabilitiesRequest().method(RestRequest.Method.GET).path("_cluster/stats").capabilities("retrievers-usage-stats")
).actionGet();
return res != null && res.isSupported().orElse(false);
}
}

View file

@ -590,7 +590,8 @@ public class ClusterStatsMonitoringDocTests extends BaseMonitoringDocTestCase<Cl
"total": 0,
"queries": {},
"rescorers": {},
"sections": {}
"sections": {},
"retrievers": {}
},
"dense_vector": {
"value_count": 0

View file

@ -0,0 +1,194 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.rank.rrf;
import org.elasticsearch.action.admin.cluster.node.capabilities.NodesCapabilitiesRequest;
import org.elasticsearch.action.admin.cluster.node.capabilities.NodesCapabilitiesResponse;
import org.elasticsearch.action.admin.cluster.stats.SearchUsageStats;
import org.elasticsearch.client.Request;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
import org.elasticsearch.search.retriever.KnnRetrieverBuilder;
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
import org.elasticsearch.search.vectors.KnnSearchBuilder;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xpack.core.XPackPlugin;
import org.junit.Before;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.hamcrest.Matchers.equalTo;
@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0)
public class RRFRetrieverTelemetryIT extends ESIntegTestCase {
private static final String INDEX_NAME = "test_index";
@Override
protected boolean addMockHttpTransport() {
return false; // enable http
}
@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return List.of(RRFRankPlugin.class, XPackPlugin.class);
}
@Override
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
return Settings.builder()
.put(super.nodeSettings(nodeOrdinal, otherSettings))
.put("xpack.license.self_generated.type", "trial")
.build();
}
@Before
public void setup() throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject("vector")
.field("type", "dense_vector")
.field("dims", 1)
.field("index", true)
.field("similarity", "l2_norm")
.startObject("index_options")
.field("type", "hnsw")
.endObject()
.endObject()
.startObject("text")
.field("type", "text")
.endObject()
.startObject("integer")
.field("type", "integer")
.endObject()
.startObject("topic")
.field("type", "keyword")
.endObject()
.endObject()
.endObject();
assertAcked(prepareCreate(INDEX_NAME).setMapping(builder));
ensureGreen(INDEX_NAME);
}
private void performSearch(SearchSourceBuilder source) throws IOException {
Request request = new Request("GET", INDEX_NAME + "/_search");
request.setJsonEntity(Strings.toString(source));
getRestClient().performRequest(request);
}
public void testTelemetryForRRFRetriever() throws IOException {
if (false == isRetrieverTelemetryEnabled()) {
return;
}
// search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers`
{
performSearch(new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null)));
}
// search#2 - this will record 1 entry for "retriever" in `sections`, 1 for "standard" under `retrievers`, and 1 for "range" under
// `queries`
{
performSearch(new SearchSourceBuilder().retriever(new StandardRetrieverBuilder(QueryBuilders.rangeQuery("integer").gte(2))));
}
// search#3 - this will record 1 entry for "retriever" in `sections`, and 1 for "standard" under `retrievers`, and 1 for "knn" under
// `queries`
{
performSearch(
new SearchSourceBuilder().retriever(
new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[] { 1.0f }, 10, 15, null))
)
);
}
// search#4 - this will record 1 entry for "retriever" in `sections`, and 1 for "standard" under `retrievers`, and 1 for "term"
// under `queries`
{
performSearch(new SearchSourceBuilder().retriever(new StandardRetrieverBuilder(QueryBuilders.termQuery("topic", "foo"))));
}
// search#5 - this will record 1 entry for "retriever" in `sections`, and 1 for "rrf" under `retrievers`, as well as
// 1 for "knn" and 1 for "standard" under `retrievers`, and eventually 1 for "match" under `queries`
{
performSearch(
new SearchSourceBuilder().retriever(
new RRFRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(
new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null),
null
),
new CompoundRetrieverBuilder.RetrieverSource(
new StandardRetrieverBuilder(QueryBuilders.matchQuery("text", "foo")),
null
)
),
10,
10
)
)
);
}
// search#6 - this will record 1 entry for "knn" in `sections`
{
performSearch(new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null))));
}
// search#7 - this will record 1 entry for "query" in `sections`, and 1 for "match_all" under `queries`
{
performSearch(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery()));
}
// cluster stats
{
SearchUsageStats stats = clusterAdmin().prepareClusterStats().get().getIndicesStats().getSearchUsageStats();
assertEquals(7, stats.getTotalSearchCount());
assertThat(stats.getSectionsUsage().size(), equalTo(3));
assertThat(stats.getSectionsUsage().get("retriever"), equalTo(5L));
assertThat(stats.getSectionsUsage().get("query"), equalTo(1L));
assertThat(stats.getSectionsUsage().get("knn"), equalTo(1L));
assertThat(stats.getRetrieversUsage().size(), equalTo(3));
assertThat(stats.getRetrieversUsage().get("standard"), equalTo(4L));
assertThat(stats.getRetrieversUsage().get("knn"), equalTo(2L));
assertThat(stats.getRetrieversUsage().get("rrf"), equalTo(1L));
assertThat(stats.getQueryUsage().size(), equalTo(5));
assertThat(stats.getQueryUsage().get("range"), equalTo(1L));
assertThat(stats.getQueryUsage().get("term"), equalTo(1L));
assertThat(stats.getQueryUsage().get("match"), equalTo(1L));
assertThat(stats.getQueryUsage().get("match_all"), equalTo(1L));
assertThat(stats.getQueryUsage().get("knn"), equalTo(1L));
}
}
private boolean isRetrieverTelemetryEnabled() throws IOException {
NodesCapabilitiesResponse res = clusterAdmin().nodesCapabilities(
new NodesCapabilitiesRequest().method(RestRequest.Method.GET).path("_cluster/stats").capabilities("retrievers-usage-stats")
).actionGet();
return res != null && res.isSupported().orElse(false);
}
}

View file

@ -68,6 +68,7 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
p.nextToken();
String name = p.currentName();
RetrieverBuilder retrieverBuilder = p.namedObject(RetrieverBuilder.class, name, c);
c.trackRetrieverUsage(retrieverBuilder.getName());
p.nextToken();
return retrieverBuilder;
}, RETRIEVERS_FIELD);