Add match support for semantic_text fields (#117839)

* Added query name to inference field metadata

* Fix build error

* Added query builder service

* Add query builder service to query rewrite context

* Updated match query to support querying semantic text fields

* Fix build error

* Fix NPE

* Update the POC to rewrite to a bool query when combined inference and non-inference fields

* Separate clause for each inference index (to avoid inference ID clashes)

* Simplify query builder service concept to a single default inference query

* Rename QueryBuilderService, remove query name from inference metadata

* Fix too many rewrite rounds error by injecting booleans in constructors for match query builder and semantic text

* Fix test compilation errors

* Fix tests

* Add yaml test for semantic match

* Add NodeFeature

* Fix license headers

* Spotless

* Updated getClass comparison in MatchQueryBuilder

* Cleanup

* Add Mock Inference Query Builder Service

* Spotless

* Cleanup

* Update docs/changelog/117839.yaml

* Update changelog

* Replace the default inference query builder with a query rewrite interceptor

* Cleanup

* Some more cleanup/renames

* Some more cleanup/renames

* Spotless

* Checkstyle

* Convert List<QueryRewriteInterceptor> to Map keyed on query name, error on query name collisions

* PR feedback - remove check on QueryRewriteContext class only

* PR feedback

* Remove intercept flag from MatchQueryBuilder and replace with wrapper

* Move feature to test feature

* Ensure interception happens only once

* Rename InterceptedQueryBuilderWrapper to AbstractQueryBuilderWrapper

* Add lenient field to SemanticQueryBuilder

* Clean up yaml test

* Add TODO comment

* Add comment

* Spotless

* Rename AbstractQueryBuilderWrapper back to InterceptedQueryBuilderWrapper

* Spotless

* Didn't mean to commit that

* Remove static class wrapping the InterceptedQueryBuilderWrapper

* Make InterceptedQueryBuilderWrapper part of QueryRewriteInterceptor

* Refactor the interceptor to be an internal plugin that cannot be used outside inference plugin

* Fix tests

* Spotless

* Minor cleanup

* C'mon spotless

* Test spotless

* Cleanup InternalQueryRewriter

* Change if statement to assert

* Simplify template of InterceptedQueryBuilderWrapper

* Change constructor of InterceptedQueryBuilderWrapper

* Refactor InterceptedQueryBuilderWrapper to extend QueryBuilder

* Cleanup

* Add test

* Spotless

* Rename rewrite to interceptAndRewrite in QueryRewriteInterceptor

* DOESN'T WORK - for testing

* Add comment

* Getting closer - match on single typed fields works now

* Deleted line by mistake

* Checkstyle

* Fix over-aggressive IntelliJ Refactor/Rename

* And another one

* Move SemanticMatchQueryRewriteInterceptor.SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED to Test feature

* PR feedback

* Require query name with no default

* PR feedback & update test

* Add rewrite test

* Update server/src/main/java/org/elasticsearch/index/query/InnerHitContextBuilder.java

Co-authored-by: Mike Pellegrini <mike.pellegrini@elastic.co>

---------

Co-authored-by: Mike Pellegrini <mike.pellegrini@elastic.co>
This commit is contained in:
Kathleen DeRusso 2024-12-12 10:55:00 -05:00 committed by GitHub
parent eac4731512
commit c9a6a2c841
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
31 changed files with 890 additions and 32 deletions

View file

@ -0,0 +1,5 @@
pr: 117839
summary: Add match support for `semantic_text` fields
area: "Search"
type: enhancement
issues: []

View file

@ -479,5 +479,5 @@ module org.elasticsearch.server {
exports org.elasticsearch.lucene.spatial; exports org.elasticsearch.lucene.spatial;
exports org.elasticsearch.inference.configuration; exports org.elasticsearch.inference.configuration;
exports org.elasticsearch.monitor.metrics; exports org.elasticsearch.monitor.metrics;
exports org.elasticsearch.plugins.internal.rewriter to org.elasticsearch.inference;
} }

View file

@ -137,6 +137,7 @@ public class TransportVersions {
public static final TransportVersion RETRIES_AND_OPERATIONS_IN_BLOBSTORE_STATS = def(8_804_00_0); public static final TransportVersion RETRIES_AND_OPERATIONS_IN_BLOBSTORE_STATS = def(8_804_00_0);
public static final TransportVersion ADD_DATA_STREAM_OPTIONS_TO_TEMPLATES = def(8_805_00_0); public static final TransportVersion ADD_DATA_STREAM_OPTIONS_TO_TEMPLATES = def(8_805_00_0);
public static final TransportVersion KNN_QUERY_RESCORE_OVERSAMPLE = def(8_806_00_0); public static final TransportVersion KNN_QUERY_RESCORE_OVERSAMPLE = def(8_806_00_0);
public static final TransportVersion SEMANTIC_QUERY_LENIENT = def(8_807_00_0);
/* /*
* STOP! READ THIS FIRST! No, really, * STOP! READ THIS FIRST! No, really,

View file

@ -58,6 +58,7 @@ import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.indices.fielddata.cache.IndicesFieldDataCache; import org.elasticsearch.indices.fielddata.cache.IndicesFieldDataCache;
import org.elasticsearch.indices.recovery.RecoveryState; import org.elasticsearch.indices.recovery.RecoveryState;
import org.elasticsearch.plugins.IndexStorePlugin; import org.elasticsearch.plugins.IndexStorePlugin;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
import org.elasticsearch.script.ScriptService; import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry; import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
@ -478,7 +479,8 @@ public final class IndexModule {
IdFieldMapper idFieldMapper, IdFieldMapper idFieldMapper,
ValuesSourceRegistry valuesSourceRegistry, ValuesSourceRegistry valuesSourceRegistry,
IndexStorePlugin.IndexFoldersDeletionListener indexFoldersDeletionListener, IndexStorePlugin.IndexFoldersDeletionListener indexFoldersDeletionListener,
Map<String, IndexStorePlugin.SnapshotCommitSupplier> snapshotCommitSuppliers Map<String, IndexStorePlugin.SnapshotCommitSupplier> snapshotCommitSuppliers,
QueryRewriteInterceptor queryRewriteInterceptor
) throws IOException { ) throws IOException {
final IndexEventListener eventListener = freeze(); final IndexEventListener eventListener = freeze();
Function<IndexService, CheckedFunction<DirectoryReader, DirectoryReader, IOException>> readerWrapperFactory = indexReaderWrapper Function<IndexService, CheckedFunction<DirectoryReader, DirectoryReader, IOException>> readerWrapperFactory = indexReaderWrapper
@ -540,7 +542,8 @@ public final class IndexModule {
indexFoldersDeletionListener, indexFoldersDeletionListener,
snapshotCommitSupplier, snapshotCommitSupplier,
indexCommitListener.get(), indexCommitListener.get(),
mapperMetrics mapperMetrics,
queryRewriteInterceptor
); );
success = true; success = true;
return indexService; return indexService;

View file

@ -85,6 +85,7 @@ import org.elasticsearch.indices.cluster.IndicesClusterStateService;
import org.elasticsearch.indices.fielddata.cache.IndicesFieldDataCache; import org.elasticsearch.indices.fielddata.cache.IndicesFieldDataCache;
import org.elasticsearch.indices.recovery.RecoveryState; import org.elasticsearch.indices.recovery.RecoveryState;
import org.elasticsearch.plugins.IndexStorePlugin; import org.elasticsearch.plugins.IndexStorePlugin;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
import org.elasticsearch.script.ScriptService; import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry; import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
@ -162,6 +163,7 @@ public class IndexService extends AbstractIndexComponent implements IndicesClust
private final Supplier<Sort> indexSortSupplier; private final Supplier<Sort> indexSortSupplier;
private final ValuesSourceRegistry valuesSourceRegistry; private final ValuesSourceRegistry valuesSourceRegistry;
private final MapperMetrics mapperMetrics; private final MapperMetrics mapperMetrics;
private final QueryRewriteInterceptor queryRewriteInterceptor;
@SuppressWarnings("this-escape") @SuppressWarnings("this-escape")
public IndexService( public IndexService(
@ -196,7 +198,8 @@ public class IndexService extends AbstractIndexComponent implements IndicesClust
IndexStorePlugin.IndexFoldersDeletionListener indexFoldersDeletionListener, IndexStorePlugin.IndexFoldersDeletionListener indexFoldersDeletionListener,
IndexStorePlugin.SnapshotCommitSupplier snapshotCommitSupplier, IndexStorePlugin.SnapshotCommitSupplier snapshotCommitSupplier,
Engine.IndexCommitListener indexCommitListener, Engine.IndexCommitListener indexCommitListener,
MapperMetrics mapperMetrics MapperMetrics mapperMetrics,
QueryRewriteInterceptor queryRewriteInterceptor
) { ) {
super(indexSettings); super(indexSettings);
assert indexCreationContext != IndexCreationContext.RELOAD_ANALYZERS assert indexCreationContext != IndexCreationContext.RELOAD_ANALYZERS
@ -271,6 +274,7 @@ public class IndexService extends AbstractIndexComponent implements IndicesClust
this.indexingOperationListeners = Collections.unmodifiableList(indexingOperationListeners); this.indexingOperationListeners = Collections.unmodifiableList(indexingOperationListeners);
this.indexCommitListener = indexCommitListener; this.indexCommitListener = indexCommitListener;
this.mapperMetrics = mapperMetrics; this.mapperMetrics = mapperMetrics;
this.queryRewriteInterceptor = queryRewriteInterceptor;
try (var ignored = threadPool.getThreadContext().clearTraceContext()) { try (var ignored = threadPool.getThreadContext().clearTraceContext()) {
// kick off async ops for the first shard in this index // kick off async ops for the first shard in this index
this.refreshTask = new AsyncRefreshTask(this); this.refreshTask = new AsyncRefreshTask(this);
@ -802,6 +806,7 @@ public class IndexService extends AbstractIndexComponent implements IndicesClust
allowExpensiveQueries, allowExpensiveQueries,
scriptService, scriptService,
null, null,
null,
null null
); );
} }

View file

@ -21,6 +21,7 @@ import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.lucene.BytesRefs; import org.elasticsearch.common.lucene.BytesRefs;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.SuggestingErrorOnUnknown; import org.elasticsearch.common.xcontent.SuggestingErrorOnUnknown;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
import org.elasticsearch.xcontent.AbstractObjectParser; import org.elasticsearch.xcontent.AbstractObjectParser;
import org.elasticsearch.xcontent.FilterXContentParser; import org.elasticsearch.xcontent.FilterXContentParser;
import org.elasticsearch.xcontent.FilterXContentParserWrapper; import org.elasticsearch.xcontent.FilterXContentParserWrapper;
@ -278,6 +279,14 @@ public abstract class AbstractQueryBuilder<QB extends AbstractQueryBuilder<QB>>
@Override @Override
public final QueryBuilder rewrite(QueryRewriteContext queryRewriteContext) throws IOException { public final QueryBuilder rewrite(QueryRewriteContext queryRewriteContext) throws IOException {
QueryRewriteInterceptor queryRewriteInterceptor = queryRewriteContext.getQueryRewriteInterceptor();
if (queryRewriteInterceptor != null) {
var rewritten = queryRewriteInterceptor.interceptAndRewrite(queryRewriteContext, this);
if (rewritten != this) {
return new InterceptedQueryBuilderWrapper(rewritten);
}
}
QueryBuilder rewritten = doRewrite(queryRewriteContext); QueryBuilder rewritten = doRewrite(queryRewriteContext);
if (rewritten == this) { if (rewritten == this) {
return rewritten; return rewritten;

View file

@ -104,6 +104,7 @@ public class CoordinatorRewriteContext extends QueryRewriteContext {
null, null,
null, null,
null, null,
null,
null null
); );
this.dateFieldRangeInfo = dateFieldRangeInfo; this.dateFieldRangeInfo = dateFieldRangeInfo;

View file

@ -66,6 +66,9 @@ public abstract class InnerHitContextBuilder {
public static void extractInnerHits(QueryBuilder query, Map<String, InnerHitContextBuilder> innerHitBuilders) { public static void extractInnerHits(QueryBuilder query, Map<String, InnerHitContextBuilder> innerHitBuilders) {
if (query instanceof AbstractQueryBuilder) { if (query instanceof AbstractQueryBuilder) {
((AbstractQueryBuilder<?>) query).extractInnerHitBuilders(innerHitBuilders); ((AbstractQueryBuilder<?>) query).extractInnerHitBuilders(innerHitBuilders);
} else if (query instanceof InterceptedQueryBuilderWrapper interceptedQuery) {
// Unwrap an intercepted query here
extractInnerHits(interceptedQuery.queryBuilder, innerHitBuilders);
} else { } else {
throw new IllegalStateException( throw new IllegalStateException(
"provided query builder [" + query.getClass() + "] class should inherit from AbstractQueryBuilder, but it doesn't" "provided query builder [" + query.getClass() + "] class should inherit from AbstractQueryBuilder, but it doesn't"

View file

@ -0,0 +1,109 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.index.query;
import org.apache.lucene.search.Query;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
import org.elasticsearch.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.Objects;
/**
* Wrapper for instances of {@link QueryBuilder} that have been intercepted using the {@link QueryRewriteInterceptor} to
* break out of the rewrite phase. These instances are unwrapped on serialization.
*/
class InterceptedQueryBuilderWrapper implements QueryBuilder {
protected final QueryBuilder queryBuilder;
InterceptedQueryBuilderWrapper(QueryBuilder queryBuilder) {
super();
this.queryBuilder = queryBuilder;
}
@Override
public QueryBuilder rewrite(QueryRewriteContext queryRewriteContext) throws IOException {
QueryRewriteInterceptor queryRewriteInterceptor = queryRewriteContext.getQueryRewriteInterceptor();
try {
queryRewriteContext.setQueryRewriteInterceptor(null);
QueryBuilder rewritten = queryBuilder.rewrite(queryRewriteContext);
return rewritten != queryBuilder ? new InterceptedQueryBuilderWrapper(rewritten) : this;
} finally {
queryRewriteContext.setQueryRewriteInterceptor(queryRewriteInterceptor);
}
}
@Override
public String getWriteableName() {
return queryBuilder.getWriteableName();
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return queryBuilder.getMinimalSupportedVersion();
}
@Override
public Query toQuery(SearchExecutionContext context) throws IOException {
return queryBuilder.toQuery(context);
}
@Override
public QueryBuilder queryName(String queryName) {
queryBuilder.queryName(queryName);
return this;
}
@Override
public String queryName() {
return queryBuilder.queryName();
}
@Override
public float boost() {
return queryBuilder.boost();
}
@Override
public QueryBuilder boost(float boost) {
queryBuilder.boost(boost);
return this;
}
@Override
public String getName() {
return queryBuilder.getName();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
queryBuilder.writeTo(out);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return queryBuilder.toXContent(builder, params);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o instanceof InterceptedQueryBuilderWrapper == false) return false;
return Objects.equals(queryBuilder, ((InterceptedQueryBuilderWrapper) o).queryBuilder);
}
@Override
public int hashCode() {
return Objects.hashCode(queryBuilder);
}
}

View file

@ -28,6 +28,7 @@ import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.mapper.MappingLookup; import org.elasticsearch.index.mapper.MappingLookup;
import org.elasticsearch.index.mapper.SourceFieldMapper; import org.elasticsearch.index.mapper.SourceFieldMapper;
import org.elasticsearch.index.mapper.TextFieldMapper; import org.elasticsearch.index.mapper.TextFieldMapper;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
import org.elasticsearch.script.ScriptCompiler; import org.elasticsearch.script.ScriptCompiler;
import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry; import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry;
import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.builder.PointInTimeBuilder;
@ -70,6 +71,7 @@ public class QueryRewriteContext {
protected Predicate<String> allowedFields; protected Predicate<String> allowedFields;
private final ResolvedIndices resolvedIndices; private final ResolvedIndices resolvedIndices;
private final PointInTimeBuilder pit; private final PointInTimeBuilder pit;
private QueryRewriteInterceptor queryRewriteInterceptor;
public QueryRewriteContext( public QueryRewriteContext(
final XContentParserConfiguration parserConfiguration, final XContentParserConfiguration parserConfiguration,
@ -86,7 +88,8 @@ public class QueryRewriteContext {
final BooleanSupplier allowExpensiveQueries, final BooleanSupplier allowExpensiveQueries,
final ScriptCompiler scriptService, final ScriptCompiler scriptService,
final ResolvedIndices resolvedIndices, final ResolvedIndices resolvedIndices,
final PointInTimeBuilder pit final PointInTimeBuilder pit,
final QueryRewriteInterceptor queryRewriteInterceptor
) { ) {
this.parserConfiguration = parserConfiguration; this.parserConfiguration = parserConfiguration;
@ -105,6 +108,7 @@ public class QueryRewriteContext {
this.scriptService = scriptService; this.scriptService = scriptService;
this.resolvedIndices = resolvedIndices; this.resolvedIndices = resolvedIndices;
this.pit = pit; this.pit = pit;
this.queryRewriteInterceptor = queryRewriteInterceptor;
} }
public QueryRewriteContext(final XContentParserConfiguration parserConfiguration, final Client client, final LongSupplier nowInMillis) { public QueryRewriteContext(final XContentParserConfiguration parserConfiguration, final Client client, final LongSupplier nowInMillis) {
@ -123,6 +127,7 @@ public class QueryRewriteContext {
null, null,
null, null,
null, null,
null,
null null
); );
} }
@ -132,7 +137,8 @@ public class QueryRewriteContext {
final Client client, final Client client,
final LongSupplier nowInMillis, final LongSupplier nowInMillis,
final ResolvedIndices resolvedIndices, final ResolvedIndices resolvedIndices,
final PointInTimeBuilder pit final PointInTimeBuilder pit,
final QueryRewriteInterceptor queryRewriteInterceptor
) { ) {
this( this(
parserConfiguration, parserConfiguration,
@ -149,7 +155,8 @@ public class QueryRewriteContext {
null, null,
null, null,
resolvedIndices, resolvedIndices,
pit pit,
queryRewriteInterceptor
); );
} }
@ -428,4 +435,13 @@ public class QueryRewriteContext {
// It was decided we should only test the first of these potentially multiple preferences. // It was decided we should only test the first of these potentially multiple preferences.
return value.split(",")[0].trim(); return value.split(",")[0].trim();
} }
public QueryRewriteInterceptor getQueryRewriteInterceptor() {
return queryRewriteInterceptor;
}
public void setQueryRewriteInterceptor(QueryRewriteInterceptor queryRewriteInterceptor) {
this.queryRewriteInterceptor = queryRewriteInterceptor;
}
} }

View file

@ -271,6 +271,7 @@ public class SearchExecutionContext extends QueryRewriteContext {
allowExpensiveQueries, allowExpensiveQueries,
scriptService, scriptService,
null, null,
null,
null null
); );
this.shardId = shardId; this.shardId = shardId;

View file

@ -135,6 +135,7 @@ import org.elasticsearch.node.Node;
import org.elasticsearch.plugins.FieldPredicate; import org.elasticsearch.plugins.FieldPredicate;
import org.elasticsearch.plugins.IndexStorePlugin; import org.elasticsearch.plugins.IndexStorePlugin;
import org.elasticsearch.plugins.PluginsService; import org.elasticsearch.plugins.PluginsService;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
import org.elasticsearch.repositories.RepositoriesService; import org.elasticsearch.repositories.RepositoriesService;
import org.elasticsearch.script.ScriptService; import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry; import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry;
@ -262,6 +263,7 @@ public class IndicesService extends AbstractLifecycleComponent
private final MapperMetrics mapperMetrics; private final MapperMetrics mapperMetrics;
private final PostRecoveryMerger postRecoveryMerger; private final PostRecoveryMerger postRecoveryMerger;
private final List<SearchOperationListener> searchOperationListeners; private final List<SearchOperationListener> searchOperationListeners;
private final QueryRewriteInterceptor queryRewriteInterceptor;
@Override @Override
protected void doStart() { protected void doStart() {
@ -330,6 +332,7 @@ public class IndicesService extends AbstractLifecycleComponent
this.indexFoldersDeletionListeners = new CompositeIndexFoldersDeletionListener(builder.indexFoldersDeletionListeners); this.indexFoldersDeletionListeners = new CompositeIndexFoldersDeletionListener(builder.indexFoldersDeletionListeners);
this.snapshotCommitSuppliers = builder.snapshotCommitSuppliers; this.snapshotCommitSuppliers = builder.snapshotCommitSuppliers;
this.requestCacheKeyDifferentiator = builder.requestCacheKeyDifferentiator; this.requestCacheKeyDifferentiator = builder.requestCacheKeyDifferentiator;
this.queryRewriteInterceptor = builder.queryRewriteInterceptor;
this.mapperMetrics = builder.mapperMetrics; this.mapperMetrics = builder.mapperMetrics;
// doClose() is called when shutting down a node, yet there might still be ongoing requests // doClose() is called when shutting down a node, yet there might still be ongoing requests
// that we need to wait for before closing some resources such as the caches. In order to // that we need to wait for before closing some resources such as the caches. In order to
@ -779,7 +782,8 @@ public class IndicesService extends AbstractLifecycleComponent
idFieldMappers.apply(idxSettings.getMode()), idFieldMappers.apply(idxSettings.getMode()),
valuesSourceRegistry, valuesSourceRegistry,
indexFoldersDeletionListeners, indexFoldersDeletionListeners,
snapshotCommitSuppliers snapshotCommitSuppliers,
queryRewriteInterceptor
); );
} }
@ -1766,7 +1770,7 @@ public class IndicesService extends AbstractLifecycleComponent
* Returns a new {@link QueryRewriteContext} with the given {@code now} provider * Returns a new {@link QueryRewriteContext} with the given {@code now} provider
*/ */
public QueryRewriteContext getRewriteContext(LongSupplier nowInMillis, ResolvedIndices resolvedIndices, PointInTimeBuilder pit) { public QueryRewriteContext getRewriteContext(LongSupplier nowInMillis, ResolvedIndices resolvedIndices, PointInTimeBuilder pit) {
return new QueryRewriteContext(parserConfig, client, nowInMillis, resolvedIndices, pit); return new QueryRewriteContext(parserConfig, client, nowInMillis, resolvedIndices, pit, queryRewriteInterceptor);
} }
public DataRewriteContext getDataRewriteContext(LongSupplier nowInMillis) { public DataRewriteContext getDataRewriteContext(LongSupplier nowInMillis) {

View file

@ -32,6 +32,8 @@ import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.plugins.EnginePlugin; import org.elasticsearch.plugins.EnginePlugin;
import org.elasticsearch.plugins.IndexStorePlugin; import org.elasticsearch.plugins.IndexStorePlugin;
import org.elasticsearch.plugins.PluginsService; import org.elasticsearch.plugins.PluginsService;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
import org.elasticsearch.script.ScriptService; import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry; import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry;
import org.elasticsearch.search.internal.ShardSearchRequest; import org.elasticsearch.search.internal.ShardSearchRequest;
@ -76,6 +78,7 @@ public class IndicesServiceBuilder {
CheckedBiConsumer<ShardSearchRequest, StreamOutput, IOException> requestCacheKeyDifferentiator; CheckedBiConsumer<ShardSearchRequest, StreamOutput, IOException> requestCacheKeyDifferentiator;
MapperMetrics mapperMetrics; MapperMetrics mapperMetrics;
List<SearchOperationListener> searchOperationListener = List.of(); List<SearchOperationListener> searchOperationListener = List.of();
QueryRewriteInterceptor queryRewriteInterceptor = null;
public IndicesServiceBuilder settings(Settings settings) { public IndicesServiceBuilder settings(Settings settings) {
this.settings = settings; this.settings = settings;
@ -239,6 +242,27 @@ public class IndicesServiceBuilder {
.flatMap(m -> m.entrySet().stream()) .flatMap(m -> m.entrySet().stream())
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
var queryRewriteInterceptors = pluginsService.filterPlugins(SearchPlugin.class)
.map(SearchPlugin::getQueryRewriteInterceptors)
.flatMap(List::stream)
.collect(Collectors.toMap(QueryRewriteInterceptor::getQueryName, interceptor -> {
if (interceptor.getQueryName() == null) {
throw new IllegalArgumentException("QueryRewriteInterceptor [" + interceptor.getClass().getName() + "] requires name");
}
return interceptor;
}, (a, b) -> {
throw new IllegalStateException(
"Conflicting rewrite interceptors ["
+ a.getQueryName()
+ "] found in ["
+ a.getClass().getName()
+ "] and ["
+ b.getClass().getName()
+ "]"
);
}));
queryRewriteInterceptor = QueryRewriteInterceptor.multi(queryRewriteInterceptors);
return new IndicesService(this); return new IndicesService(this);
} }
} }

View file

@ -23,6 +23,7 @@ import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryParser; import org.elasticsearch.index.query.QueryParser;
import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder; import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder;
import org.elasticsearch.index.query.functionscore.ScoreFunctionParser; import org.elasticsearch.index.query.functionscore.ScoreFunctionParser;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
import org.elasticsearch.search.SearchExtBuilder; import org.elasticsearch.search.SearchExtBuilder;
import org.elasticsearch.search.aggregations.Aggregation; import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregationBuilder;
@ -128,6 +129,14 @@ public interface SearchPlugin {
return emptyList(); return emptyList();
} }
/**
* @return Applicable {@link QueryRewriteInterceptor}s configured for this plugin.
* Note: This is internal to Elasticsearch's API and not extensible by external plugins.
*/
default List<QueryRewriteInterceptor> getQueryRewriteInterceptors() {
return emptyList();
}
/** /**
* The new {@link Aggregation}s added by this plugin. * The new {@link Aggregation}s added by this plugin.
*/ */

View file

@ -0,0 +1,75 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.plugins.internal.rewriter;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import java.util.Map;
/**
* Enables modules and plugins to intercept and rewrite queries during the query rewrite phase on the coordinator node.
*/
public interface QueryRewriteInterceptor {
/**
* Intercepts and returns a rewritten query if modifications are required; otherwise,
* returns the same provided {@link QueryBuilder} instance unchanged.
*
* @param context the {@link QueryRewriteContext} providing the context for the rewrite operation
* @param queryBuilder the original {@link QueryBuilder} to potentially rewrite
* @return the rewritten {@link QueryBuilder}, or the original instance if no rewrite was needed
*/
QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder);
/**
* Name of the query to be intercepted and rewritten.
*/
String getQueryName();
static QueryRewriteInterceptor multi(Map<String, QueryRewriteInterceptor> interceptors) {
return interceptors.isEmpty() ? new NoOpQueryRewriteInterceptor() : new CompositeQueryRewriteInterceptor(interceptors);
}
class CompositeQueryRewriteInterceptor implements QueryRewriteInterceptor {
final String NAME = "composite";
private final Map<String, QueryRewriteInterceptor> interceptors;
private CompositeQueryRewriteInterceptor(Map<String, QueryRewriteInterceptor> interceptors) {
this.interceptors = interceptors;
}
@Override
public String getQueryName() {
return NAME;
}
@Override
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
QueryRewriteInterceptor interceptor = interceptors.get(queryBuilder.getName());
if (interceptor != null) {
return interceptor.interceptAndRewrite(context, queryBuilder);
}
return queryBuilder;
}
}
class NoOpQueryRewriteInterceptor implements QueryRewriteInterceptor {
@Override
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
return queryBuilder;
}
@Override
public String getQueryName() {
return null;
}
}
}

View file

@ -1744,7 +1744,9 @@ public class TransportSearchActionTests extends ESTestCase {
NodeClient client = new NodeClient(settings, threadPool); NodeClient client = new NodeClient(settings, threadPool);
SearchService searchService = mock(SearchService.class); SearchService searchService = mock(SearchService.class);
when(searchService.getRewriteContext(any(), any(), any())).thenReturn(new QueryRewriteContext(null, null, null, null, null)); when(searchService.getRewriteContext(any(), any(), any())).thenReturn(
new QueryRewriteContext(null, null, null, null, null, null)
);
ClusterService clusterService = new ClusterService( ClusterService clusterService = new ClusterService(
settings, settings,
new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS), new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS),

View file

@ -690,7 +690,12 @@ public class IndexMetadataTests extends ESTestCase {
} }
private static InferenceFieldMetadata randomInferenceFieldMetadata(String name) { private static InferenceFieldMetadata randomInferenceFieldMetadata(String name) {
return new InferenceFieldMetadata(name, randomIdentifier(), randomSet(1, 5, ESTestCase::randomIdentifier).toArray(String[]::new)); return new InferenceFieldMetadata(
name,
randomIdentifier(),
randomIdentifier(),
randomSet(1, 5, ESTestCase::randomIdentifier).toArray(String[]::new)
);
} }
private IndexMetadataStats randomIndexStats(int numberOfShards) { private IndexMetadataStats randomIndexStats(int numberOfShards) {

View file

@ -86,6 +86,7 @@ import org.elasticsearch.indices.cluster.IndicesClusterStateService.AllocatedInd
import org.elasticsearch.indices.fielddata.cache.IndicesFieldDataCache; import org.elasticsearch.indices.fielddata.cache.IndicesFieldDataCache;
import org.elasticsearch.indices.recovery.RecoveryState; import org.elasticsearch.indices.recovery.RecoveryState;
import org.elasticsearch.plugins.IndexStorePlugin; import org.elasticsearch.plugins.IndexStorePlugin;
import org.elasticsearch.plugins.internal.rewriter.MockQueryRewriteInterceptor;
import org.elasticsearch.script.ScriptService; import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.internal.ReaderContext; import org.elasticsearch.search.internal.ReaderContext;
import org.elasticsearch.test.ClusterServiceUtils; import org.elasticsearch.test.ClusterServiceUtils;
@ -223,7 +224,8 @@ public class IndexModuleTests extends ESTestCase {
module.indexSettings().getMode().idFieldMapperWithoutFieldData(), module.indexSettings().getMode().idFieldMapperWithoutFieldData(),
null, null,
indexDeletionListener, indexDeletionListener,
emptyMap() emptyMap(),
new MockQueryRewriteInterceptor()
); );
} }

View file

@ -15,7 +15,6 @@ import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.MapperPlugin;
import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.Plugin;
import java.io.IOException;
import java.util.Collection; import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -94,6 +93,7 @@ public class MappingLookupInferenceFieldMapperTests extends MapperServiceTestCas
public static final TypeParser PARSER = new TypeParser((n, c) -> new Builder(n)); public static final TypeParser PARSER = new TypeParser((n, c) -> new Builder(n));
public static final String INFERENCE_ID = "test_inference_id"; public static final String INFERENCE_ID = "test_inference_id";
public static final String SEARCH_INFERENCE_ID = "test_search_inference_id";
public static final String CONTENT_TYPE = "test_inference_field"; public static final String CONTENT_TYPE = "test_inference_field";
TestInferenceFieldMapper(String simpleName) { TestInferenceFieldMapper(String simpleName) {
@ -102,7 +102,7 @@ public class MappingLookupInferenceFieldMapperTests extends MapperServiceTestCas
@Override @Override
public InferenceFieldMetadata getMetadata(Set<String> sourcePaths) { public InferenceFieldMetadata getMetadata(Set<String> sourcePaths) {
return new InferenceFieldMetadata(fullPath(), INFERENCE_ID, sourcePaths.toArray(new String[0])); return new InferenceFieldMetadata(fullPath(), INFERENCE_ID, SEARCH_INFERENCE_ID, sourcePaths.toArray(new String[0]));
} }
@Override @Override
@ -111,7 +111,7 @@ public class MappingLookupInferenceFieldMapperTests extends MapperServiceTestCas
} }
@Override @Override
protected void parseCreateField(DocumentParserContext context) throws IOException {} protected void parseCreateField(DocumentParserContext context) {}
@Override @Override
public Builder getMergeBuilder() { public Builder getMergeBuilder() {

View file

@ -0,0 +1,92 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.index.query;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.client.NoOpClient;
import org.elasticsearch.threadpool.TestThreadPool;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
public class InterceptedQueryBuilderWrapperTests extends ESTestCase {
private TestThreadPool threadPool;
private NoOpClient client;
@Before
public void setup() {
threadPool = createThreadPool();
client = new NoOpClient(threadPool);
}
@After
public void cleanup() {
threadPool.close();
}
public void testQueryNameReturnsWrappedQueryBuilder() {
MatchAllQueryBuilder matchAllQueryBuilder = new MatchAllQueryBuilder();
InterceptedQueryBuilderWrapper interceptedQueryBuilderWrapper = new InterceptedQueryBuilderWrapper(matchAllQueryBuilder);
String queryName = randomAlphaOfLengthBetween(5, 10);
QueryBuilder namedQuery = interceptedQueryBuilderWrapper.queryName(queryName);
assertTrue(namedQuery instanceof InterceptedQueryBuilderWrapper);
assertEquals(queryName, namedQuery.queryName());
}
public void testQueryBoostReturnsWrappedQueryBuilder() {
MatchAllQueryBuilder matchAllQueryBuilder = new MatchAllQueryBuilder();
InterceptedQueryBuilderWrapper interceptedQueryBuilderWrapper = new InterceptedQueryBuilderWrapper(matchAllQueryBuilder);
float boost = randomFloat();
QueryBuilder boostedQuery = interceptedQueryBuilderWrapper.boost(boost);
assertTrue(boostedQuery instanceof InterceptedQueryBuilderWrapper);
assertEquals(boost, boostedQuery.boost(), 0.0001f);
}
public void testRewrite() throws IOException {
QueryRewriteContext context = new QueryRewriteContext(null, client, null);
context.setQueryRewriteInterceptor(myMatchInterceptor);
// Queries that are not intercepted behave normally
TermQueryBuilder termQueryBuilder = new TermQueryBuilder("field", "value");
QueryBuilder rewritten = termQueryBuilder.rewrite(context);
assertTrue(rewritten instanceof TermQueryBuilder);
// Queries that should be intercepted are and the right thing happens
MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("field", "value");
rewritten = matchQueryBuilder.rewrite(context);
assertTrue(rewritten instanceof InterceptedQueryBuilderWrapper);
assertTrue(((InterceptedQueryBuilderWrapper) rewritten).queryBuilder instanceof MatchQueryBuilder);
MatchQueryBuilder rewrittenMatchQueryBuilder = (MatchQueryBuilder) ((InterceptedQueryBuilderWrapper) rewritten).queryBuilder;
assertEquals("intercepted", rewrittenMatchQueryBuilder.value());
// An additional rewrite on an already intercepted query returns the same query
QueryBuilder rewrittenAgain = rewritten.rewrite(context);
assertTrue(rewrittenAgain instanceof InterceptedQueryBuilderWrapper);
assertEquals(rewritten, rewrittenAgain);
}
private final QueryRewriteInterceptor myMatchInterceptor = new QueryRewriteInterceptor() {
@Override
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
if (queryBuilder instanceof MatchQueryBuilder matchQueryBuilder) {
return new MatchQueryBuilder(matchQueryBuilder.fieldName(), "intercepted");
}
return queryBuilder;
}
@Override
public String getQueryName() {
return MatchQueryBuilder.NAME;
}
};
}

View file

@ -52,6 +52,7 @@ public class QueryRewriteContextTests extends ESTestCase {
null, null,
null, null,
null, null,
null,
null null
); );
@ -79,6 +80,7 @@ public class QueryRewriteContextTests extends ESTestCase {
null, null,
null, null,
null, null,
null,
null null
); );

View file

@ -0,0 +1,26 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.plugins.internal.rewriter;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
public class MockQueryRewriteInterceptor implements QueryRewriteInterceptor {
@Override
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
return queryBuilder;
}
@Override
public String getQueryName() {
return this.getClass().getSimpleName();
}
}

View file

@ -71,6 +71,8 @@ import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.PluginsService; import org.elasticsearch.plugins.PluginsService;
import org.elasticsearch.plugins.ScriptPlugin; import org.elasticsearch.plugins.ScriptPlugin;
import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.plugins.internal.rewriter.MockQueryRewriteInterceptor;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
import org.elasticsearch.plugins.scanners.StablePluginsRegistry; import org.elasticsearch.plugins.scanners.StablePluginsRegistry;
import org.elasticsearch.script.MockScriptEngine; import org.elasticsearch.script.MockScriptEngine;
import org.elasticsearch.script.MockScriptService; import org.elasticsearch.script.MockScriptService;
@ -629,7 +631,8 @@ public abstract class AbstractBuilderTestCase extends ESTestCase {
() -> true, () -> true,
scriptService, scriptService,
createMockResolvedIndices(), createMockResolvedIndices(),
null null,
createMockQueryRewriteInterceptor()
); );
} }
@ -670,5 +673,9 @@ public abstract class AbstractBuilderTestCase extends ESTestCase {
Map.of(index, indexMetadata) Map.of(index, indexMetadata)
); );
} }
private QueryRewriteInterceptor createMockQueryRewriteInterceptor() {
return new MockQueryRewriteInterceptor();
}
} }
} }

View file

@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference;
import org.elasticsearch.features.FeatureSpecification; import org.elasticsearch.features.FeatureSpecification;
import org.elasticsearch.features.NodeFeature; import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder; import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder; import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;
@ -43,7 +44,8 @@ public class InferenceFeatures implements FeatureSpecification {
SemanticTextFieldMapper.SEMANTIC_TEXT_DELETE_FIX, SemanticTextFieldMapper.SEMANTIC_TEXT_DELETE_FIX,
SemanticTextFieldMapper.SEMANTIC_TEXT_ZERO_SIZE_FIX, SemanticTextFieldMapper.SEMANTIC_TEXT_ZERO_SIZE_FIX,
SemanticTextFieldMapper.SEMANTIC_TEXT_ALWAYS_EMIT_INFERENCE_ID_FIX, SemanticTextFieldMapper.SEMANTIC_TEXT_ALWAYS_EMIT_INFERENCE_ID_FIX,
SEMANTIC_TEXT_HIGHLIGHTER SEMANTIC_TEXT_HIGHLIGHTER,
SemanticMatchQueryRewriteInterceptor.SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED
); );
} }
} }

View file

@ -36,6 +36,7 @@ import org.elasticsearch.plugins.MapperPlugin;
import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.plugins.SystemIndexPlugin; import org.elasticsearch.plugins.SystemIndexPlugin;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestHandler; import org.elasticsearch.rest.RestHandler;
import org.elasticsearch.search.fetch.subphase.highlight.Highlighter; import org.elasticsearch.search.fetch.subphase.highlight.Highlighter;
@ -77,6 +78,7 @@ import org.elasticsearch.xpack.inference.highlight.SemanticTextHighlighter;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.mapper.OffsetSourceFieldMapper; import org.elasticsearch.xpack.inference.mapper.OffsetSourceFieldMapper;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder; import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder;
import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder; import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
@ -436,6 +438,11 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
return List.of(new QuerySpec<>(SemanticQueryBuilder.NAME, SemanticQueryBuilder::new, SemanticQueryBuilder::fromXContent)); return List.of(new QuerySpec<>(SemanticQueryBuilder.NAME, SemanticQueryBuilder::new, SemanticQueryBuilder::fromXContent));
} }
@Override
public List<QueryRewriteInterceptor> getQueryRewriteInterceptors() {
return List.of(new SemanticMatchQueryRewriteInterceptor());
}
@Override @Override
public List<RetrieverSpec<?>> getRetrievers() { public List<RetrieverSpec<?>> getRetrievers() {
return List.of( return List.of(

View file

@ -0,0 +1,95 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.queries;
import org.elasticsearch.action.ResolvedIndices;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.mapper.IndexFieldMapper;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.MatchQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.TermQueryBuilder;
import org.elasticsearch.index.query.TermsQueryBuilder;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
public class SemanticMatchQueryRewriteInterceptor implements QueryRewriteInterceptor {
public static final NodeFeature SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature(
"search.semantic_match_query_rewrite_interception_supported"
);
public SemanticMatchQueryRewriteInterceptor() {}
@Override
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
assert (queryBuilder instanceof MatchQueryBuilder);
MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
QueryBuilder rewritten = queryBuilder;
ResolvedIndices resolvedIndices = context.getResolvedIndices();
if (resolvedIndices != null) {
Collection<IndexMetadata> indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values();
List<String> inferenceIndices = new ArrayList<>();
List<String> nonInferenceIndices = new ArrayList<>();
for (IndexMetadata indexMetadata : indexMetadataCollection) {
String indexName = indexMetadata.getIndex().getName();
InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(matchQueryBuilder.fieldName());
if (inferenceFieldMetadata != null) {
inferenceIndices.add(indexName);
} else {
nonInferenceIndices.add(indexName);
}
}
if (inferenceIndices.isEmpty()) {
return rewritten;
} else if (nonInferenceIndices.isEmpty() == false) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
for (String inferenceIndexName : inferenceIndices) {
// Add a separate clause for each semantic query, because they may be using different inference endpoints
// TODO - consolidate this to a single clause once the semantic query supports multiple inference endpoints
boolQueryBuilder.should(
createSemanticSubQuery(inferenceIndexName, matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value())
);
}
boolQueryBuilder.should(createMatchSubQuery(nonInferenceIndices, matchQueryBuilder));
rewritten = boolQueryBuilder;
} else {
rewritten = new SemanticQueryBuilder(matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value(), false);
}
}
return rewritten;
}
@Override
public String getQueryName() {
return MatchQueryBuilder.NAME;
}
private QueryBuilder createSemanticSubQuery(String indexName, String fieldName, String value) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.must(new SemanticQueryBuilder(fieldName, value, true));
boolQueryBuilder.filter(new TermQueryBuilder(IndexFieldMapper.NAME, indexName));
return boolQueryBuilder;
}
private QueryBuilder createMatchSubQuery(List<String> indices, MatchQueryBuilder matchQueryBuilder) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.must(matchQueryBuilder);
boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices));
return boolQueryBuilder;
}
}

View file

@ -46,6 +46,7 @@ import java.util.Map;
import java.util.Objects; import java.util.Objects;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
@ -57,16 +58,18 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
private static final ParseField FIELD_FIELD = new ParseField("field"); private static final ParseField FIELD_FIELD = new ParseField("field");
private static final ParseField QUERY_FIELD = new ParseField("query"); private static final ParseField QUERY_FIELD = new ParseField("query");
private static final ParseField LENIENT_FIELD = new ParseField("lenient");
private static final ConstructingObjectParser<SemanticQueryBuilder, Void> PARSER = new ConstructingObjectParser<>( private static final ConstructingObjectParser<SemanticQueryBuilder, Void> PARSER = new ConstructingObjectParser<>(
NAME, NAME,
false, false,
args -> new SemanticQueryBuilder((String) args[0], (String) args[1]) args -> new SemanticQueryBuilder((String) args[0], (String) args[1], (Boolean) args[2])
); );
static { static {
PARSER.declareString(constructorArg(), FIELD_FIELD); PARSER.declareString(constructorArg(), FIELD_FIELD);
PARSER.declareString(constructorArg(), QUERY_FIELD); PARSER.declareString(constructorArg(), QUERY_FIELD);
PARSER.declareBoolean(optionalConstructorArg(), LENIENT_FIELD);
declareStandardFields(PARSER); declareStandardFields(PARSER);
} }
@ -75,8 +78,13 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
private final SetOnce<InferenceServiceResults> inferenceResultsSupplier; private final SetOnce<InferenceServiceResults> inferenceResultsSupplier;
private final InferenceResults inferenceResults; private final InferenceResults inferenceResults;
private final boolean noInferenceResults; private final boolean noInferenceResults;
private final Boolean lenient;
public SemanticQueryBuilder(String fieldName, String query) { public SemanticQueryBuilder(String fieldName, String query) {
this(fieldName, query, null);
}
public SemanticQueryBuilder(String fieldName, String query, Boolean lenient) {
if (fieldName == null) { if (fieldName == null) {
throw new IllegalArgumentException("[" + NAME + "] requires a " + FIELD_FIELD.getPreferredName() + " value"); throw new IllegalArgumentException("[" + NAME + "] requires a " + FIELD_FIELD.getPreferredName() + " value");
} }
@ -88,6 +96,7 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
this.inferenceResults = null; this.inferenceResults = null;
this.inferenceResultsSupplier = null; this.inferenceResultsSupplier = null;
this.noInferenceResults = false; this.noInferenceResults = false;
this.lenient = lenient;
} }
public SemanticQueryBuilder(StreamInput in) throws IOException { public SemanticQueryBuilder(StreamInput in) throws IOException {
@ -97,6 +106,11 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
this.inferenceResults = in.readOptionalNamedWriteable(InferenceResults.class); this.inferenceResults = in.readOptionalNamedWriteable(InferenceResults.class);
this.noInferenceResults = in.readBoolean(); this.noInferenceResults = in.readBoolean();
this.inferenceResultsSupplier = null; this.inferenceResultsSupplier = null;
if (in.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_QUERY_LENIENT)) {
this.lenient = in.readOptionalBoolean();
} else {
this.lenient = null;
}
} }
@Override @Override
@ -108,6 +122,9 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
out.writeString(query); out.writeString(query);
out.writeOptionalNamedWriteable(inferenceResults); out.writeOptionalNamedWriteable(inferenceResults);
out.writeBoolean(noInferenceResults); out.writeBoolean(noInferenceResults);
if (out.getTransportVersion().onOrAfter(TransportVersions.SEMANTIC_QUERY_LENIENT)) {
out.writeOptionalBoolean(lenient);
}
} }
private SemanticQueryBuilder( private SemanticQueryBuilder(
@ -123,6 +140,7 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
this.inferenceResultsSupplier = inferenceResultsSupplier; this.inferenceResultsSupplier = inferenceResultsSupplier;
this.inferenceResults = inferenceResults; this.inferenceResults = inferenceResults;
this.noInferenceResults = noInferenceResults; this.noInferenceResults = noInferenceResults;
this.lenient = other.lenient;
} }
@Override @Override
@ -144,6 +162,9 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
builder.startObject(NAME); builder.startObject(NAME);
builder.field(FIELD_FIELD.getPreferredName(), fieldName); builder.field(FIELD_FIELD.getPreferredName(), fieldName);
builder.field(QUERY_FIELD.getPreferredName(), query); builder.field(QUERY_FIELD.getPreferredName(), query);
if (lenient != null) {
builder.field(LENIENT_FIELD.getPreferredName(), lenient);
}
boostAndQueryNameToXContent(builder); boostAndQueryNameToXContent(builder);
builder.endObject(); builder.endObject();
} }
@ -171,6 +192,8 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
} }
return semanticTextFieldType.semanticQuery(inferenceResults, searchExecutionContext.requestSize(), boost(), queryName()); return semanticTextFieldType.semanticQuery(inferenceResults, searchExecutionContext.requestSize(), boost(), queryName());
} else if (lenient != null && lenient) {
return new MatchNoneQueryBuilder();
} else { } else {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Field [" + fieldName + "] of type [" + fieldType.typeName() + "] does not support " + NAME + " queries" "Field [" + fieldName + "] of type [" + fieldType.typeName() + "] does not support " + NAME + " queries"

View file

@ -102,7 +102,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
new BulkItemRequest[0] new BulkItemRequest[0]
); );
request.setInferenceFieldMap( request.setInferenceFieldMap(
Map.of("foo", new InferenceFieldMetadata("foo", "bar", generateRandomStringArray(5, 10, false, false))) Map.of("foo", new InferenceFieldMetadata("foo", "bar", "baz", generateRandomStringArray(5, 10, false, false)))
); );
filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain);
awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);

View file

@ -101,7 +101,7 @@ setup:
index: test-sparse-index index: test-sparse-index
id: doc_1 id: doc_1
body: body:
inference_field: ["inference test", "another inference test"] inference_field: [ "inference test", "another inference test" ]
non_inference_field: "non inference test" non_inference_field: "non inference test"
refresh: true refresh: true
@ -132,7 +132,7 @@ setup:
index: test-sparse-index index: test-sparse-index
id: doc_1 id: doc_1
body: body:
inference_field: [40, 49.678] inference_field: [ 40, 49.678 ]
refresh: true refresh: true
- do: - do:
@ -229,7 +229,7 @@ setup:
index: test-dense-index index: test-dense-index
id: doc_1 id: doc_1
body: body:
inference_field: ["inference test", "another inference test"] inference_field: [ "inference test", "another inference test" ]
non_inference_field: "non inference test" non_inference_field: "non inference test"
refresh: true refresh: true
@ -260,7 +260,7 @@ setup:
index: test-dense-index index: test-dense-index
id: doc_1 id: doc_1
body: body:
inference_field: [45.1, 100] inference_field: [ 45.1, 100 ]
refresh: true refresh: true
- do: - do:
@ -387,7 +387,7 @@ setup:
index: test-dense-index index: test-dense-index
id: doc_1 id: doc_1
body: body:
inference_field: ["inference test", "another inference test"] inference_field: [ "inference test", "another inference test" ]
non_inference_field: "non inference test" non_inference_field: "non inference test"
refresh: true refresh: true
@ -418,7 +418,7 @@ setup:
index: test-sparse-index index: test-sparse-index
id: doc_1 id: doc_1
body: body:
inference_field: ["inference test", "another inference test"] inference_field: [ "inference test", "another inference test" ]
non_inference_field: "non inference test" non_inference_field: "non inference test"
refresh: true refresh: true
@ -440,7 +440,7 @@ setup:
- match: { hits.hits.0._id: "doc_1" } - match: { hits.hits.0._id: "doc_1" }
- close_to: { hits.hits.0._score: { value: 3.783733e19, error: 1e13 } } - close_to: { hits.hits.0._score: { value: 3.783733e19, error: 1e13 } }
- length: { hits.hits.0._source.inference_field.inference.chunks: 2 } - length: { hits.hits.0._source.inference_field.inference.chunks: 2 }
- match: { hits.hits.0.matched_queries: ["i-like-naming-my-queries"] } - match: { hits.hits.0.matched_queries: [ "i-like-naming-my-queries" ] }
--- ---
"Query an index alias": "Query an index alias":
@ -452,7 +452,7 @@ setup:
index: test-sparse-index index: test-sparse-index
id: doc_1 id: doc_1
body: body:
inference_field: ["inference test", "another inference test"] inference_field: [ "inference test", "another inference test" ]
non_inference_field: "non inference test" non_inference_field: "non inference test"
refresh: true refresh: true
@ -503,6 +503,48 @@ setup:
- match: { error.root_cause.0.type: "illegal_argument_exception" } - match: { error.root_cause.0.type: "illegal_argument_exception" }
- match: { error.root_cause.0.reason: "Field [non_inference_field] of type [text] does not support semantic queries" } - match: { error.root_cause.0.reason: "Field [non_inference_field] of type [text] does not support semantic queries" }
---
"Query the wrong field type with lenient: true":
- requires:
cluster_features: "search.semantic_match_query_rewrite_interception_supported"
reason: lenient introduced in 8.18.0
- do:
index:
index: test-sparse-index
id: doc_1
body:
inference_field: "inference test"
non_inference_field: "non inference test"
refresh: true
- do:
catch: bad_request
search:
index: test-sparse-index
body:
query:
semantic:
field: "non_inference_field"
query: "inference test"
- match: { error.type: "search_phase_execution_exception" }
- match: { error.root_cause.0.type: "illegal_argument_exception" }
- match: { error.root_cause.0.reason: "Field [non_inference_field] of type [text] does not support semantic queries" }
- do:
search:
index: test-sparse-index
body:
query:
semantic:
field: "non_inference_field"
query: "inference test"
lenient: true
- match: { hits.total.value: 0 }
--- ---
"Query a missing field": "Query a missing field":
- do: - do:
@ -783,7 +825,7 @@ setup:
index: test-dense-index index: test-dense-index
id: doc_1 id: doc_1
body: body:
inference_field: ["inference test", "another inference test"] inference_field: [ "inference test", "another inference test" ]
non_inference_field: "non inference test" non_inference_field: "non inference test"
refresh: true refresh: true
@ -844,11 +886,11 @@ setup:
"Query a field that uses the default ELSER 2 endpoint": "Query a field that uses the default ELSER 2 endpoint":
- requires: - requires:
reason: "default ELSER 2 inference ID is enabled via a capability" reason: "default ELSER 2 inference ID is enabled via a capability"
test_runner_features: [capabilities] test_runner_features: [ capabilities ]
capabilities: capabilities:
- method: GET - method: GET
path: /_inference path: /_inference
capabilities: [default_elser_2] capabilities: [ default_elser_2 ]
- do: - do:
indices.create: indices.create:

View file

@ -0,0 +1,284 @@
setup:
- requires:
cluster_features: "search.semantic_match_query_rewrite_interception_supported"
reason: semantic_text match support introduced in 8.18.0
- do:
inference.put:
task_type: sparse_embedding
inference_id: sparse-inference-id
body: >
{
"service": "test_service",
"service_settings": {
"model": "my_model",
"api_key": "abc64"
},
"task_settings": {
}
}
- do:
inference.put:
task_type: sparse_embedding
inference_id: sparse-inference-id-2
body: >
{
"service": "test_service",
"service_settings": {
"model": "my_model",
"api_key": "abc64"
},
"task_settings": {
}
}
- do:
inference.put:
task_type: text_embedding
inference_id: dense-inference-id
body: >
{
"service": "text_embedding_test_service",
"service_settings": {
"model": "my_model",
"dimensions": 10,
"api_key": "abc64",
"similarity": "COSINE"
},
"task_settings": {
}
}
- do:
indices.create:
index: test-sparse-index
body:
mappings:
properties:
inference_field:
type: semantic_text
inference_id: sparse-inference-id
non_inference_field:
type: text
- do:
indices.create:
index: test-dense-index
body:
mappings:
properties:
inference_field:
type: semantic_text
inference_id: dense-inference-id
non_inference_field:
type: text
- do:
indices.create:
index: test-text-only-index
body:
mappings:
properties:
inference_field:
type: text
non_inference_field:
type: text
---
"Query using a sparse embedding model":
- skip:
features: [ "headers", "close_to" ]
- do:
index:
index: test-sparse-index
id: doc_1
body:
inference_field: [ "inference test", "another inference test" ]
non_inference_field: "non inference test"
refresh: true
- do:
headers:
# Force JSON content type so that we use a parser that interprets the floating-point score as a double
Content-Type: application/json
search:
index: test-sparse-index
body:
query:
match:
inference_field:
query: "inference test"
- match: { hits.total.value: 1 }
- match: { hits.hits.0._id: "doc_1" }
---
"Query using a dense embedding model":
- skip:
features: [ "headers", "close_to" ]
- do:
index:
index: test-dense-index
id: doc_1
body:
inference_field: [ "inference test", "another inference test" ]
non_inference_field: "non inference test"
refresh: true
- do:
headers:
# Force JSON content type so that we use a parser that interprets the floating-point score as a double
Content-Type: application/json
search:
index: test-dense-index
body:
query:
match:
inference_field:
query: "inference test"
- match: { hits.total.value: 1 }
- match: { hits.hits.0._id: "doc_1" }
---
"Query an index alias":
- skip:
features: [ "headers", "close_to" ]
- do:
index:
index: test-sparse-index
id: doc_1
body:
inference_field: [ "inference test", "another inference test" ]
non_inference_field: "non inference test"
refresh: true
- do:
indices.put_alias:
index: test-sparse-index
name: my-alias
- do:
headers:
# Force JSON content type so that we use a parser that interprets the floating-point score as a double
Content-Type: application/json
search:
index: my-alias
body:
query:
match:
inference_field:
query: "inference test"
- match: { hits.total.value: 1 }
- match: { hits.hits.0._id: "doc_1" }
---
"Query indices with both semantic_text and regular text content":
- do:
index:
index: test-sparse-index
id: doc_1
body:
inference_field: [ "inference test", "another inference test" ]
non_inference_field: "non inference test"
refresh: true
- do:
index:
index: test-text-only-index
id: doc_2
body:
inference_field: [ "inference test", "not an inference field" ]
non_inference_field: "non inference test"
refresh: true
- do:
search:
index:
- test-sparse-index
- test-text-only-index
body:
query:
match:
inference_field:
query: "inference test"
- match: { hits.total.value: 2 }
- match: { hits.hits.0._id: "doc_1" }
- match: { hits.hits.1._id: "doc_2" }
# Test querying multiple indices that either use the same inference ID or combine semantic_text with lexical search
- do:
indices.create:
index: test-sparse-index-2
body:
mappings:
properties:
inference_field:
type: semantic_text
inference_id: sparse-inference-id
non_inference_field:
type: text
- do:
index:
index: test-sparse-index-2
id: doc_3
body:
inference_field: "another inference test"
refresh: true
- do:
search:
index:
- test-sparse-index*
- test-text-only-index
body:
query:
match:
inference_field:
query: "inference test"
- match: { hits.total.value: 3 }
- match: { hits.hits.0._id: "doc_1" }
- match: { hits.hits.1._id: "doc_3" }
- match: { hits.hits.2._id: "doc_2" }
---
"Query a field that has no indexed inference results":
- skip:
features: [ "headers" ]
- do:
headers:
# Force JSON content type so that we use a parser that interprets the floating-point score as a double
Content-Type: application/json
search:
index: test-sparse-index
body:
query:
match:
inference_field:
query: "inference test"
- match: { hits.total.value: 0 }
- do:
headers:
# Force JSON content type so that we use a parser that interprets the floating-point score as a double
Content-Type: application/json
search:
index: test-dense-index
body:
query:
match:
inference_field:
query: "inference test"
- match: { hits.total.value: 0 }

View file

@ -54,7 +54,9 @@ public class RRFRetrieverBuilderTests extends ESTestCase {
IllegalArgumentException iae = expectThrows( IllegalArgumentException iae = expectThrows(
IllegalArgumentException.class, IllegalArgumentException.class,
() -> ssb.parseXContent(parser, true, nf -> true) () -> ssb.parseXContent(parser, true, nf -> true)
.rewrite(new QueryRewriteContext(parserConfig(), null, null, null, new PointInTimeBuilder(new BytesArray("pitid")))) .rewrite(
new QueryRewriteContext(parserConfig(), null, null, null, new PointInTimeBuilder(new BytesArray("pitid")), null)
)
); );
assertEquals("[search_after] cannot be used in children of compound retrievers", iae.getMessage()); assertEquals("[search_after] cannot be used in children of compound retrievers", iae.getMessage());
} }
@ -70,7 +72,9 @@ public class RRFRetrieverBuilderTests extends ESTestCase {
IllegalArgumentException iae = expectThrows( IllegalArgumentException iae = expectThrows(
IllegalArgumentException.class, IllegalArgumentException.class,
() -> ssb.parseXContent(parser, true, nf -> true) () -> ssb.parseXContent(parser, true, nf -> true)
.rewrite(new QueryRewriteContext(parserConfig(), null, null, null, new PointInTimeBuilder(new BytesArray("pitid")))) .rewrite(
new QueryRewriteContext(parserConfig(), null, null, null, new PointInTimeBuilder(new BytesArray("pitid")), null)
)
); );
assertEquals("[terminate_after] cannot be used in children of compound retrievers", iae.getMessage()); assertEquals("[terminate_after] cannot be used in children of compound retrievers", iae.getMessage());
} }