Add support for sparse_vector queries against semantic_text fields (#118617)

This commit is contained in:
Kathleen DeRusso 2024-12-17 19:06:54 -05:00 committed by GitHub
parent 7c65a8e5db
commit 15bec3cefa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 889 additions and 78 deletions

View file

@ -0,0 +1,5 @@
pr: 118617
summary: Add support for `sparse_vector` queries against `semantic_text` fields
area: "Search"
type: enhancement
issues: []

View file

@ -90,7 +90,8 @@ public class SparseVectorQueryBuilder extends AbstractQueryBuilder<SparseVectorQ
: (this.shouldPruneTokens ? new TokenPruningConfig() : null));
this.weightedTokensSupplier = null;
if (queryVectors == null ^ inferenceId == null == false) {
// Preserve BWC error messaging
if (queryVectors != null && inferenceId != null) {
throw new IllegalArgumentException(
"["
+ NAME
@ -98,18 +99,24 @@ public class SparseVectorQueryBuilder extends AbstractQueryBuilder<SparseVectorQ
+ QUERY_VECTOR_FIELD.getPreferredName()
+ "] or ["
+ INFERENCE_ID_FIELD.getPreferredName()
+ "]"
+ "] for "
+ ALLOWED_FIELD_TYPE
+ " fields"
);
}
if (inferenceId != null && query == null) {
// Preserve BWC error messaging
if ((queryVectors == null) == (query == null)) {
throw new IllegalArgumentException(
"["
+ NAME
+ "] requires ["
+ QUERY_FIELD.getPreferredName()
+ "] when ["
+ "] requires one of ["
+ QUERY_VECTOR_FIELD.getPreferredName()
+ "] or ["
+ INFERENCE_ID_FIELD.getPreferredName()
+ "] is specified"
+ "] for "
+ ALLOWED_FIELD_TYPE
+ " fields"
);
}
}
@ -143,6 +150,14 @@ public class SparseVectorQueryBuilder extends AbstractQueryBuilder<SparseVectorQ
return queryVectors;
}
public String getInferenceId() {
return inferenceId;
}
public String getQuery() {
return query;
}
public boolean shouldPruneTokens() {
return shouldPruneTokens;
}
@ -176,7 +191,9 @@ public class SparseVectorQueryBuilder extends AbstractQueryBuilder<SparseVectorQ
}
builder.endObject();
} else {
builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId);
if (inferenceId != null) {
builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId);
}
builder.field(QUERY_FIELD.getPreferredName(), query);
}
builder.field(PRUNE_FIELD.getPreferredName(), shouldPruneTokens);
@ -228,6 +245,11 @@ public class SparseVectorQueryBuilder extends AbstractQueryBuilder<SparseVectorQ
shouldPruneTokens,
tokenPruningConfig
);
} else if (inferenceId == null) {
// Edge case, where inference_id was not specified in the request,
// but we did not intercept this and rewrite to a query o field with
// pre-configured inference. So we trap here and output a nicer error message.
throw new IllegalArgumentException("inference_id required to perform vector search on query string");
}
// TODO move this to xpack core and use inference APIs

View file

@ -260,16 +260,16 @@ public class SparseVectorQueryBuilderTests extends AbstractQueryTestCase<SparseV
{
IllegalArgumentException e = expectThrows(
IllegalArgumentException.class,
() -> new SparseVectorQueryBuilder("field name", null, "model id")
() -> new SparseVectorQueryBuilder("field name", null, null)
);
assertEquals("[sparse_vector] requires one of [query_vector] or [inference_id]", e.getMessage());
assertEquals("[sparse_vector] requires one of [query_vector] or [inference_id] for sparse_vector fields", e.getMessage());
}
{
IllegalArgumentException e = expectThrows(
IllegalArgumentException.class,
() -> new SparseVectorQueryBuilder("field name", "model text", null)
);
assertEquals("[sparse_vector] requires [query] when [inference_id] is specified", e.getMessage());
assertEquals("[sparse_vector] requires one of [query_vector] or [inference_id] for sparse_vector fields", e.getMessage());
}
}

View file

@ -10,13 +10,15 @@ package org.elasticsearch.xpack.inference;
import org.elasticsearch.features.FeatureSpecification;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;
import java.util.Set;
import static org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor.SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
import static org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor.SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
/**
* Provides inference features.
*/
@ -45,7 +47,8 @@ public class InferenceFeatures implements FeatureSpecification {
SemanticTextFieldMapper.SEMANTIC_TEXT_ZERO_SIZE_FIX,
SemanticTextFieldMapper.SEMANTIC_TEXT_ALWAYS_EMIT_INFERENCE_ID_FIX,
SEMANTIC_TEXT_HIGHLIGHTER,
SemanticMatchQueryRewriteInterceptor.SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED
SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED,
SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED
);
}
}

View file

@ -80,6 +80,7 @@ import org.elasticsearch.xpack.inference.mapper.OffsetSourceFieldMapper;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
import org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder;
import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder;
@ -440,7 +441,7 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
@Override
public List<QueryRewriteInterceptor> getQueryRewriteInterceptors() {
return List.of(new SemanticMatchQueryRewriteInterceptor());
return List.of(new SemanticMatchQueryRewriteInterceptor(), new SemanticSparseVectorQueryRewriteInterceptor());
}
@Override

View file

@ -7,24 +7,12 @@
package org.elasticsearch.xpack.inference.queries;
import org.elasticsearch.action.ResolvedIndices;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.mapper.IndexFieldMapper;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.MatchQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.TermQueryBuilder;
import org.elasticsearch.index.query.TermsQueryBuilder;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
public class SemanticMatchQueryRewriteInterceptor implements QueryRewriteInterceptor {
public class SemanticMatchQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor {
public static final NodeFeature SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature(
"search.semantic_match_query_rewrite_interception_supported"
@ -33,63 +21,45 @@ public class SemanticMatchQueryRewriteInterceptor implements QueryRewriteInterce
public SemanticMatchQueryRewriteInterceptor() {}
@Override
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
protected String getFieldName(QueryBuilder queryBuilder) {
assert (queryBuilder instanceof MatchQueryBuilder);
MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
QueryBuilder rewritten = queryBuilder;
ResolvedIndices resolvedIndices = context.getResolvedIndices();
if (resolvedIndices != null) {
Collection<IndexMetadata> indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values();
List<String> inferenceIndices = new ArrayList<>();
List<String> nonInferenceIndices = new ArrayList<>();
for (IndexMetadata indexMetadata : indexMetadataCollection) {
String indexName = indexMetadata.getIndex().getName();
InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(matchQueryBuilder.fieldName());
if (inferenceFieldMetadata != null) {
inferenceIndices.add(indexName);
} else {
nonInferenceIndices.add(indexName);
}
}
return matchQueryBuilder.fieldName();
}
if (inferenceIndices.isEmpty()) {
return rewritten;
} else if (nonInferenceIndices.isEmpty() == false) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
for (String inferenceIndexName : inferenceIndices) {
// Add a separate clause for each semantic query, because they may be using different inference endpoints
// TODO - consolidate this to a single clause once the semantic query supports multiple inference endpoints
boolQueryBuilder.should(
createSemanticSubQuery(inferenceIndexName, matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value())
);
}
boolQueryBuilder.should(createMatchSubQuery(nonInferenceIndices, matchQueryBuilder));
rewritten = boolQueryBuilder;
} else {
rewritten = new SemanticQueryBuilder(matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value(), false);
}
}
@Override
protected String getQuery(QueryBuilder queryBuilder) {
assert (queryBuilder instanceof MatchQueryBuilder);
MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
return (String) matchQueryBuilder.value();
}
return rewritten;
@Override
protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
return new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false);
}
@Override
protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
QueryBuilder queryBuilder,
InferenceIndexInformationForField indexInformation
) {
assert (queryBuilder instanceof MatchQueryBuilder);
MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.should(
createSemanticSubQuery(
indexInformation.getInferenceIndices(),
matchQueryBuilder.fieldName(),
(String) matchQueryBuilder.value()
)
);
boolQueryBuilder.should(createSubQueryForIndices(indexInformation.nonInferenceIndices(), matchQueryBuilder));
return boolQueryBuilder;
}
@Override
public String getQueryName() {
return MatchQueryBuilder.NAME;
}
private QueryBuilder createSemanticSubQuery(String indexName, String fieldName, String value) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.must(new SemanticQueryBuilder(fieldName, value, true));
boolQueryBuilder.filter(new TermQueryBuilder(IndexFieldMapper.NAME, indexName));
return boolQueryBuilder;
}
private QueryBuilder createMatchSubQuery(List<String> indices, MatchQueryBuilder matchQueryBuilder) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.must(matchQueryBuilder);
boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices));
return boolQueryBuilder;
}
}

View file

@ -148,6 +148,14 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
return NAME;
}
public String getFieldName() {
return fieldName;
}
public String getQuery() {
return query;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.V_8_15_0;

View file

@ -0,0 +1,148 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.queries;
import org.elasticsearch.action.ResolvedIndices;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.index.mapper.IndexFieldMapper;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.TermsQueryBuilder;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
* Intercepts and adapts a query to be rewritten to work seamlessly on a semantic_text field.
*/
public abstract class SemanticQueryRewriteInterceptor implements QueryRewriteInterceptor {
public SemanticQueryRewriteInterceptor() {}
@Override
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
String fieldName = getFieldName(queryBuilder);
ResolvedIndices resolvedIndices = context.getResolvedIndices();
if (resolvedIndices == null) {
// No resolved indices, so return the original query.
return queryBuilder;
}
InferenceIndexInformationForField indexInformation = resolveIndicesForField(fieldName, resolvedIndices);
if (indexInformation.getInferenceIndices().isEmpty()) {
// No inference fields were identified, so return the original query.
return queryBuilder;
} else if (indexInformation.nonInferenceIndices().isEmpty() == false) {
// Combined case where the field name requested by this query contains both
// semantic_text and non-inference fields, so we have to combine queries per index
// containing each field type.
return buildCombinedInferenceAndNonInferenceQuery(queryBuilder, indexInformation);
} else {
// The only fields we've identified are inference fields (e.g. semantic_text),
// so rewrite the entire query to work on a semantic_text field.
return buildInferenceQuery(queryBuilder, indexInformation);
}
}
/**
* @param queryBuilder {@link QueryBuilder}
* @return The singular field name requested by the provided query builder.
*/
protected abstract String getFieldName(QueryBuilder queryBuilder);
/**
* @param queryBuilder {@link QueryBuilder}
* @return The text/query string requested by the provided query builder.
*/
protected abstract String getQuery(QueryBuilder queryBuilder);
/**
* Builds the inference query
*
* @param queryBuilder {@link QueryBuilder}
* @param indexInformation {@link InferenceIndexInformationForField}
* @return {@link QueryBuilder}
*/
protected abstract QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation);
/**
* Builds a combined inference and non-inference query,
* which separates the different queries into appropriate indices based on field type.
* @param queryBuilder {@link QueryBuilder}
* @param indexInformation {@link InferenceIndexInformationForField}
* @return {@link QueryBuilder}
*/
protected abstract QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
QueryBuilder queryBuilder,
InferenceIndexInformationForField indexInformation
);
private InferenceIndexInformationForField resolveIndicesForField(String fieldName, ResolvedIndices resolvedIndices) {
Collection<IndexMetadata> indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values();
Map<String, InferenceFieldMetadata> inferenceIndicesMetadata = new HashMap<>();
List<String> nonInferenceIndices = new ArrayList<>();
for (IndexMetadata indexMetadata : indexMetadataCollection) {
String indexName = indexMetadata.getIndex().getName();
InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(fieldName);
if (inferenceFieldMetadata != null) {
inferenceIndicesMetadata.put(indexName, inferenceFieldMetadata);
} else {
nonInferenceIndices.add(indexName);
}
}
return new InferenceIndexInformationForField(fieldName, inferenceIndicesMetadata, nonInferenceIndices);
}
protected QueryBuilder createSubQueryForIndices(Collection<String> indices, QueryBuilder queryBuilder) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.must(queryBuilder);
boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices));
return boolQueryBuilder;
}
protected QueryBuilder createSemanticSubQuery(Collection<String> indices, String fieldName, String value) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.must(new SemanticQueryBuilder(fieldName, value, true));
boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices));
return boolQueryBuilder;
}
/**
* Represents the indices and associated inference information for a field.
*/
public record InferenceIndexInformationForField(
String fieldName,
Map<String, InferenceFieldMetadata> inferenceIndicesMetadata,
List<String> nonInferenceIndices
) {
public Collection<String> getInferenceIndices() {
return inferenceIndicesMetadata.keySet();
}
public Map<String, List<String>> getInferenceIdsIndices() {
return inferenceIndicesMetadata.entrySet()
.stream()
.collect(
Collectors.groupingBy(
entry -> entry.getValue().getSearchInferenceId(),
Collectors.mapping(Map.Entry::getKey, Collectors.toList())
)
);
}
}
}

View file

@ -0,0 +1,124 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.queries;
import org.apache.lucene.search.join.ScoreMode;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
import java.util.List;
import java.util.Map;
public class SemanticSparseVectorQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor {
public static final NodeFeature SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature(
"search.semantic_sparse_vector_query_rewrite_interception_supported"
);
public SemanticSparseVectorQueryRewriteInterceptor() {}
@Override
protected String getFieldName(QueryBuilder queryBuilder) {
assert (queryBuilder instanceof SparseVectorQueryBuilder);
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder;
return sparseVectorQueryBuilder.getFieldName();
}
@Override
protected String getQuery(QueryBuilder queryBuilder) {
assert (queryBuilder instanceof SparseVectorQueryBuilder);
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder;
return sparseVectorQueryBuilder.getQuery();
}
@Override
protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();
if (inferenceIdsIndices.size() == 1) {
// Simple case, everything uses the same inference ID
String searchInferenceId = inferenceIdsIndices.keySet().iterator().next();
return buildNestedQueryFromSparseVectorQuery(queryBuilder, searchInferenceId);
} else {
// Multiple inference IDs, construct a boolean query
return buildInferenceQueryWithMultipleInferenceIds(queryBuilder, inferenceIdsIndices);
}
}
private QueryBuilder buildInferenceQueryWithMultipleInferenceIds(
QueryBuilder queryBuilder,
Map<String, List<String>> inferenceIdsIndices
) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
for (String inferenceId : inferenceIdsIndices.keySet()) {
boolQueryBuilder.should(
createSubQueryForIndices(
inferenceIdsIndices.get(inferenceId),
buildNestedQueryFromSparseVectorQuery(queryBuilder, inferenceId)
)
);
}
return boolQueryBuilder;
}
@Override
protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
QueryBuilder queryBuilder,
InferenceIndexInformationForField indexInformation
) {
assert (queryBuilder instanceof SparseVectorQueryBuilder);
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder;
Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.should(
createSubQueryForIndices(
indexInformation.nonInferenceIndices(),
createSubQueryForIndices(indexInformation.nonInferenceIndices(), sparseVectorQueryBuilder)
)
);
// We always perform nested subqueries on semantic_text fields, to support
// sparse_vector queries using query vectors.
for (String inferenceId : inferenceIdsIndices.keySet()) {
boolQueryBuilder.should(
createSubQueryForIndices(
inferenceIdsIndices.get(inferenceId),
buildNestedQueryFromSparseVectorQuery(sparseVectorQueryBuilder, inferenceId)
)
);
}
return boolQueryBuilder;
}
private QueryBuilder buildNestedQueryFromSparseVectorQuery(QueryBuilder queryBuilder, String searchInferenceId) {
assert (queryBuilder instanceof SparseVectorQueryBuilder);
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder;
return QueryBuilders.nestedQuery(
SemanticTextField.getChunksFieldName(sparseVectorQueryBuilder.getFieldName()),
new SparseVectorQueryBuilder(
SemanticTextField.getEmbeddingsFieldName(sparseVectorQueryBuilder.getFieldName()),
sparseVectorQueryBuilder.getQueryVectors(),
(sparseVectorQueryBuilder.getInferenceId() == null && sparseVectorQueryBuilder.getQuery() != null)
? searchInferenceId
: sparseVectorQueryBuilder.getInferenceId(),
sparseVectorQueryBuilder.getQuery(),
sparseVectorQueryBuilder.shouldPruneTokens(),
sparseVectorQueryBuilder.getTokenPruningConfig()
),
ScoreMode.Max
);
}
@Override
public String getQueryName() {
return SparseVectorQueryBuilder.NAME;
}
}

View file

@ -0,0 +1,110 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.index.query;
import org.elasticsearch.action.MockResolvedIndices;
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.action.ResolvedIndices;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.client.NoOpClient;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.util.Map;
public class SemanticMatchQueryRewriteInterceptorTests extends ESTestCase {
private TestThreadPool threadPool;
private NoOpClient client;
private Index index;
private static final String FIELD_NAME = "fieldName";
private static final String VALUE = "value";
@Before
public void setup() {
threadPool = createThreadPool();
client = new NoOpClient(threadPool);
index = new Index(randomAlphaOfLength(10), randomAlphaOfLength(10));
}
@After
public void cleanup() {
threadPool.close();
}
public void testMatchQueryOnInferenceFieldIsInterceptedAndRewrittenToSemanticQuery() throws IOException {
Map<String, InferenceFieldMetadata> inferenceFields = Map.of(
FIELD_NAME,
new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME })
);
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
QueryBuilder original = createTestQueryBuilder();
QueryBuilder rewritten = original.rewrite(context);
assertTrue(
"Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
rewritten instanceof InterceptedQueryBuilderWrapper
);
InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
assertTrue(intercepted.queryBuilder instanceof SemanticQueryBuilder);
SemanticQueryBuilder semanticQueryBuilder = (SemanticQueryBuilder) intercepted.queryBuilder;
assertEquals(FIELD_NAME, semanticQueryBuilder.getFieldName());
assertEquals(VALUE, semanticQueryBuilder.getQuery());
}
public void testMatchQueryOnNonInferenceFieldRemainsMatchQuery() throws IOException {
QueryRewriteContext context = createQueryRewriteContext(Map.of()); // No inference fields
QueryBuilder original = createTestQueryBuilder();
QueryBuilder rewritten = original.rewrite(context);
assertTrue(
"Expected query to remain match but was [" + rewritten.getClass().getName() + "]",
rewritten instanceof MatchQueryBuilder
);
assertEquals(original, rewritten);
}
private MatchQueryBuilder createTestQueryBuilder() {
return new MatchQueryBuilder(FIELD_NAME, VALUE);
}
private QueryRewriteContext createQueryRewriteContext(Map<String, InferenceFieldMetadata> inferenceFields) {
IndexMetadata indexMetadata = IndexMetadata.builder(index.getName())
.settings(
Settings.builder()
.put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())
.put(IndexMetadata.SETTING_INDEX_UUID, index.getUUID())
)
.numberOfShards(1)
.numberOfReplicas(0)
.putInferenceFields(inferenceFields)
.build();
ResolvedIndices resolvedIndices = new MockResolvedIndices(
Map.of(),
new OriginalIndices(new String[] { index.getName() }, IndicesOptions.DEFAULT),
Map.of(index, indexMetadata)
);
return new QueryRewriteContext(null, client, null, resolvedIndices, null, createRewriteInterceptor());
}
private QueryRewriteInterceptor createRewriteInterceptor() {
return new SemanticMatchQueryRewriteInterceptor();
}
}

View file

@ -0,0 +1,137 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.index.query;
import org.elasticsearch.action.MockResolvedIndices;
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.action.ResolvedIndices;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.client.NoOpClient;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
import org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.util.Map;
public class SemanticSparseVectorQueryRewriteInterceptorTests extends ESTestCase {
private TestThreadPool threadPool;
private NoOpClient client;
private Index index;
private static final String FIELD_NAME = "fieldName";
private static final String INFERENCE_ID = "inferenceId";
private static final String QUERY = "query";
@Before
public void setup() {
threadPool = createThreadPool();
client = new NoOpClient(threadPool);
index = new Index(randomAlphaOfLength(10), randomAlphaOfLength(10));
}
@After
public void cleanup() {
threadPool.close();
}
public void testSparseVectorQueryOnInferenceFieldIsInterceptedAndRewritten() throws IOException {
Map<String, InferenceFieldMetadata> inferenceFields = Map.of(
FIELD_NAME,
new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME })
);
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY);
QueryBuilder rewritten = original.rewrite(context);
assertTrue(
"Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
rewritten instanceof InterceptedQueryBuilderWrapper
);
InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder);
NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder;
assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path());
QueryBuilder innerQuery = nestedQueryBuilder.query();
assertTrue(innerQuery instanceof SparseVectorQueryBuilder);
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) innerQuery;
assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), sparseVectorQueryBuilder.getFieldName());
assertEquals(INFERENCE_ID, sparseVectorQueryBuilder.getInferenceId());
assertEquals(QUERY, sparseVectorQueryBuilder.getQuery());
}
public void testSparseVectorQueryOnInferenceFieldWithoutInferenceIdIsInterceptedAndRewritten() throws IOException {
Map<String, InferenceFieldMetadata> inferenceFields = Map.of(
FIELD_NAME,
new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME })
);
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, null, QUERY);
QueryBuilder rewritten = original.rewrite(context);
assertTrue(
"Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
rewritten instanceof InterceptedQueryBuilderWrapper
);
InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder);
NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder;
assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path());
QueryBuilder innerQuery = nestedQueryBuilder.query();
assertTrue(innerQuery instanceof SparseVectorQueryBuilder);
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) innerQuery;
assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), sparseVectorQueryBuilder.getFieldName());
assertEquals(INFERENCE_ID, sparseVectorQueryBuilder.getInferenceId());
assertEquals(QUERY, sparseVectorQueryBuilder.getQuery());
}
public void testSparseVectorQueryOnNonInferenceFieldRemainsUnchanged() throws IOException {
QueryRewriteContext context = createQueryRewriteContext(Map.of()); // No inference fields
QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY);
QueryBuilder rewritten = original.rewrite(context);
assertTrue(
"Expected query to remain sparse_vector but was [" + rewritten.getClass().getName() + "]",
rewritten instanceof SparseVectorQueryBuilder
);
assertEquals(original, rewritten);
}
private QueryRewriteContext createQueryRewriteContext(Map<String, InferenceFieldMetadata> inferenceFields) {
IndexMetadata indexMetadata = IndexMetadata.builder(index.getName())
.settings(
Settings.builder()
.put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())
.put(IndexMetadata.SETTING_INDEX_UUID, index.getUUID())
)
.numberOfShards(1)
.numberOfReplicas(0)
.putInferenceFields(inferenceFields)
.build();
ResolvedIndices resolvedIndices = new MockResolvedIndices(
Map.of(),
new OriginalIndices(new String[] { index.getName() }, IndicesOptions.DEFAULT),
Map.of(index, indexMetadata)
);
return new QueryRewriteContext(null, client, null, resolvedIndices, null, createRewriteInterceptor());
}
private QueryRewriteInterceptor createRewriteInterceptor() {
return new SemanticSparseVectorQueryRewriteInterceptor();
}
}

View file

@ -0,0 +1,249 @@
setup:
- requires:
cluster_features: "search.semantic_sparse_vector_query_rewrite_interception_supported"
reason: semantic_text sparse_vector support introduced in 8.18.0
- do:
inference.put:
task_type: sparse_embedding
inference_id: sparse-inference-id
body: >
{
"service": "test_service",
"service_settings": {
"model": "my_model",
"api_key": "abc64"
},
"task_settings": {
}
}
- do:
inference.put:
task_type: sparse_embedding
inference_id: sparse-inference-id-2
body: >
{
"service": "test_service",
"service_settings": {
"model": "my_model",
"api_key": "abc64"
},
"task_settings": {
}
}
- do:
indices.create:
index: test-semantic-text-index
body:
mappings:
properties:
inference_field:
type: semantic_text
inference_id: sparse-inference-id
- do:
indices.create:
index: test-semantic-text-index-2
body:
mappings:
properties:
inference_field:
type: semantic_text
inference_id: sparse-inference-id-2
- do:
indices.create:
index: test-sparse-vector-index
body:
mappings:
properties:
inference_field:
type: sparse_vector
- do:
index:
index: test-semantic-text-index
id: doc_1
body:
inference_field: [ "inference test", "another inference test" ]
refresh: true
- do:
index:
index: test-semantic-text-index-2
id: doc_3
body:
inference_field: [ "inference test", "another inference test" ]
refresh: true
- do:
index:
index: test-sparse-vector-index
id: doc_2
body:
inference_field: { "feature_0": 1, "feature_1": 2, "feature_2": 3, "feature_3": 4, "feature_4": 5 }
refresh: true
---
"Nested sparse_vector queries using the old format on semantic_text embeddings and inference still work":
- skip:
features: [ "headers" ]
- do:
headers:
# Force JSON content type so that we use a parser that interprets the floating-point score as a double
Content-Type: application/json
search:
index: test-semantic-text-index
body:
query:
nested:
path: inference_field.inference.chunks
query:
sparse_vector:
field: inference_field.inference.chunks.embeddings
inference_id: sparse-inference-id
query: test
- match: { hits.total.value: 1 }
- match: { hits.hits.0._id: "doc_1" }
---
"Nested sparse_vector queries using the old format on semantic_text embeddings and query vectors still work":
- skip:
features: [ "headers" ]
- do:
headers:
# Force JSON content type so that we use a parser that interprets the floating-point score as a double
Content-Type: application/json
search:
index: test-semantic-text-index
body:
query:
nested:
path: inference_field.inference.chunks
query:
sparse_vector:
field: inference_field.inference.chunks.embeddings
query_vector: { "feature_0": 1, "feature_1": 2, "feature_2": 3, "feature_3": 4, "feature_4": 5 }
- match: { hits.total.value: 1 }
- match: { hits.hits.0._id: "doc_1" }
---
"sparse_vector query against semantic_text field using a specified inference ID":
- do:
search:
index: test-semantic-text-index
body:
query:
sparse_vector:
field: inference_field
inference_id: sparse-inference-id
query: "inference test"
- match: { hits.total.value: 1 }
- match: { hits.hits.0._id: "doc_1" }
---
"sparse_vector query against semantic_text field using inference ID configured in semantic_text field":
- do:
search:
index: test-semantic-text-index
body:
query:
sparse_vector:
field: inference_field
query: "inference test"
- match: { hits.total.value: 1 }
- match: { hits.hits.0._id: "doc_1" }
---
"sparse_vector query against semantic_text field using query vectors":
- do:
search:
index: test-semantic-text-index
body:
query:
sparse_vector:
field: inference_field
query_vector: { "feature_0": 1, "feature_1": 2, "feature_2": 3, "feature_3": 4, "feature_4": 5 }
- match: { hits.total.value: 1 }
- match: { hits.hits.0._id: "doc_1" }
---
"sparse_vector query against combined sparse_vector and semantic_text fields using inference":
- do:
search:
index:
- test-semantic-text-index
- test-sparse-vector-index
body:
query:
sparse_vector:
field: inference_field
inference_id: sparse-inference-id
query: "inference test"
- match: { hits.total.value: 2 }
---
"sparse_vector query against combined sparse_vector and semantic_text fields still requires inference ID":
- do:
catch: bad_request
search:
index:
- test-semantic-text-index
- test-sparse-vector-index
body:
query:
sparse_vector:
field: inference_field
query: "inference test"
- match: { error.type: "illegal_argument_exception" }
- match: { error.reason: "inference_id required to perform vector search on query string" }
---
"sparse_vector query against combined sparse_vector and semantic_text fields using query vectors":
- do:
search:
index:
- test-semantic-text-index
- test-sparse-vector-index
body:
query:
sparse_vector:
field: inference_field
query_vector: { "feature_0": 1, "feature_1": 2, "feature_2": 3, "feature_3": 4, "feature_4": 5 }
- match: { hits.total.value: 2 }
---
"sparse_vector query against multiple semantic_text fields with multiple inference IDs specified in semantic_text fields":
- do:
search:
index:
- test-semantic-text-index
- test-semantic-text-index-2
body:
query:
sparse_vector:
field: inference_field
query: "inference test"
- match: { hits.total.value: 2 }

View file

@ -268,7 +268,7 @@ setup:
- match: { hits.hits.0._score: 0.25 }
---
"Test sparse_vector requires one of inference_id or query_vector":
"Test sparse_vector requires one of query or query_vector":
- do:
catch: /\[sparse_vector\] requires one of \[query_vector\] or \[inference_id\]/
search:
@ -281,7 +281,41 @@ setup:
- match: { status: 400 }
---
"Test sparse_vector only allows one of inference_id or query_vector":
"Test sparse_vector returns an error if inference ID not specified with query":
- do:
catch: bad_request # This is for BWC, the actual error message is tested in a subsequent test
search:
index: index-with-sparse-vector
body:
query:
sparse_vector:
field: text
query: "octopus comforter smells"
- match: { status: 400 }
---
"Test sparse_vector requires an inference ID to be specified on sparse_vector fields":
- requires:
cluster_features: [ "search.semantic_sparse_vector_query_rewrite_interception_supported" ]
reason: "Error message changed in 8.18"
- do:
catch: /inference_id required to perform vector search on query string/
search:
index: index-with-sparse-vector
body:
query:
sparse_vector:
field: text
query: "octopus comforter smells"
- match: { status: 400 }
---
"Test sparse_vector only allows one of query or query_vector (note the error message is misleading)":
- requires:
cluster_features: [ "search.semantic_sparse_vector_query_rewrite_interception_supported" ]
reason: "sparse vector inference checks updated in 8.18 to support sparse_vector on semantic_text fields"
- do:
catch: /\[sparse_vector\] requires one of \[query_vector\] or \[inference_id\]/
search:
@ -290,7 +324,7 @@ setup:
query:
sparse_vector:
field: text
inference_id: text_expansion_model
query: "octopus comforter smells"
query_vector:
the: 1.0
comforter: 1.0