mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 17:34:17 -04:00
Add support for sparse_vector queries against semantic_text fields (#118617)
This commit is contained in:
parent
7c65a8e5db
commit
15bec3cefa
13 changed files with 889 additions and 78 deletions
5
docs/changelog/118617.yaml
Normal file
5
docs/changelog/118617.yaml
Normal file
|
@ -0,0 +1,5 @@
|
|||
pr: 118617
|
||||
summary: Add support for `sparse_vector` queries against `semantic_text` fields
|
||||
area: "Search"
|
||||
type: enhancement
|
||||
issues: []
|
|
@ -90,7 +90,8 @@ public class SparseVectorQueryBuilder extends AbstractQueryBuilder<SparseVectorQ
|
|||
: (this.shouldPruneTokens ? new TokenPruningConfig() : null));
|
||||
this.weightedTokensSupplier = null;
|
||||
|
||||
if (queryVectors == null ^ inferenceId == null == false) {
|
||||
// Preserve BWC error messaging
|
||||
if (queryVectors != null && inferenceId != null) {
|
||||
throw new IllegalArgumentException(
|
||||
"["
|
||||
+ NAME
|
||||
|
@ -98,18 +99,24 @@ public class SparseVectorQueryBuilder extends AbstractQueryBuilder<SparseVectorQ
|
|||
+ QUERY_VECTOR_FIELD.getPreferredName()
|
||||
+ "] or ["
|
||||
+ INFERENCE_ID_FIELD.getPreferredName()
|
||||
+ "]"
|
||||
+ "] for "
|
||||
+ ALLOWED_FIELD_TYPE
|
||||
+ " fields"
|
||||
);
|
||||
}
|
||||
if (inferenceId != null && query == null) {
|
||||
|
||||
// Preserve BWC error messaging
|
||||
if ((queryVectors == null) == (query == null)) {
|
||||
throw new IllegalArgumentException(
|
||||
"["
|
||||
+ NAME
|
||||
+ "] requires ["
|
||||
+ QUERY_FIELD.getPreferredName()
|
||||
+ "] when ["
|
||||
+ "] requires one of ["
|
||||
+ QUERY_VECTOR_FIELD.getPreferredName()
|
||||
+ "] or ["
|
||||
+ INFERENCE_ID_FIELD.getPreferredName()
|
||||
+ "] is specified"
|
||||
+ "] for "
|
||||
+ ALLOWED_FIELD_TYPE
|
||||
+ " fields"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -143,6 +150,14 @@ public class SparseVectorQueryBuilder extends AbstractQueryBuilder<SparseVectorQ
|
|||
return queryVectors;
|
||||
}
|
||||
|
||||
public String getInferenceId() {
|
||||
return inferenceId;
|
||||
}
|
||||
|
||||
public String getQuery() {
|
||||
return query;
|
||||
}
|
||||
|
||||
public boolean shouldPruneTokens() {
|
||||
return shouldPruneTokens;
|
||||
}
|
||||
|
@ -176,7 +191,9 @@ public class SparseVectorQueryBuilder extends AbstractQueryBuilder<SparseVectorQ
|
|||
}
|
||||
builder.endObject();
|
||||
} else {
|
||||
builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId);
|
||||
if (inferenceId != null) {
|
||||
builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId);
|
||||
}
|
||||
builder.field(QUERY_FIELD.getPreferredName(), query);
|
||||
}
|
||||
builder.field(PRUNE_FIELD.getPreferredName(), shouldPruneTokens);
|
||||
|
@ -228,6 +245,11 @@ public class SparseVectorQueryBuilder extends AbstractQueryBuilder<SparseVectorQ
|
|||
shouldPruneTokens,
|
||||
tokenPruningConfig
|
||||
);
|
||||
} else if (inferenceId == null) {
|
||||
// Edge case, where inference_id was not specified in the request,
|
||||
// but we did not intercept this and rewrite to a query o field with
|
||||
// pre-configured inference. So we trap here and output a nicer error message.
|
||||
throw new IllegalArgumentException("inference_id required to perform vector search on query string");
|
||||
}
|
||||
|
||||
// TODO move this to xpack core and use inference APIs
|
||||
|
|
|
@ -260,16 +260,16 @@ public class SparseVectorQueryBuilderTests extends AbstractQueryTestCase<SparseV
|
|||
{
|
||||
IllegalArgumentException e = expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> new SparseVectorQueryBuilder("field name", null, "model id")
|
||||
() -> new SparseVectorQueryBuilder("field name", null, null)
|
||||
);
|
||||
assertEquals("[sparse_vector] requires one of [query_vector] or [inference_id]", e.getMessage());
|
||||
assertEquals("[sparse_vector] requires one of [query_vector] or [inference_id] for sparse_vector fields", e.getMessage());
|
||||
}
|
||||
{
|
||||
IllegalArgumentException e = expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> new SparseVectorQueryBuilder("field name", "model text", null)
|
||||
);
|
||||
assertEquals("[sparse_vector] requires [query] when [inference_id] is specified", e.getMessage());
|
||||
assertEquals("[sparse_vector] requires one of [query_vector] or [inference_id] for sparse_vector fields", e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -10,13 +10,15 @@ package org.elasticsearch.xpack.inference;
|
|||
import org.elasticsearch.features.FeatureSpecification;
|
||||
import org.elasticsearch.features.NodeFeature;
|
||||
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
|
||||
import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor;
|
||||
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
|
||||
import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
|
||||
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor.SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
|
||||
import static org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor.SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
|
||||
|
||||
/**
|
||||
* Provides inference features.
|
||||
*/
|
||||
|
@ -45,7 +47,8 @@ public class InferenceFeatures implements FeatureSpecification {
|
|||
SemanticTextFieldMapper.SEMANTIC_TEXT_ZERO_SIZE_FIX,
|
||||
SemanticTextFieldMapper.SEMANTIC_TEXT_ALWAYS_EMIT_INFERENCE_ID_FIX,
|
||||
SEMANTIC_TEXT_HIGHLIGHTER,
|
||||
SemanticMatchQueryRewriteInterceptor.SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED
|
||||
SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED,
|
||||
SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -80,6 +80,7 @@ import org.elasticsearch.xpack.inference.mapper.OffsetSourceFieldMapper;
|
|||
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
|
||||
import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor;
|
||||
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
|
||||
import org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor;
|
||||
import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder;
|
||||
import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
|
||||
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder;
|
||||
|
@ -440,7 +441,7 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
|
|||
|
||||
@Override
|
||||
public List<QueryRewriteInterceptor> getQueryRewriteInterceptors() {
|
||||
return List.of(new SemanticMatchQueryRewriteInterceptor());
|
||||
return List.of(new SemanticMatchQueryRewriteInterceptor(), new SemanticSparseVectorQueryRewriteInterceptor());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -7,24 +7,12 @@
|
|||
|
||||
package org.elasticsearch.xpack.inference.queries;
|
||||
|
||||
import org.elasticsearch.action.ResolvedIndices;
|
||||
import org.elasticsearch.cluster.metadata.IndexMetadata;
|
||||
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
|
||||
import org.elasticsearch.features.NodeFeature;
|
||||
import org.elasticsearch.index.mapper.IndexFieldMapper;
|
||||
import org.elasticsearch.index.query.BoolQueryBuilder;
|
||||
import org.elasticsearch.index.query.MatchQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryRewriteContext;
|
||||
import org.elasticsearch.index.query.TermQueryBuilder;
|
||||
import org.elasticsearch.index.query.TermsQueryBuilder;
|
||||
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
|
||||
public class SemanticMatchQueryRewriteInterceptor implements QueryRewriteInterceptor {
|
||||
public class SemanticMatchQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor {
|
||||
|
||||
public static final NodeFeature SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature(
|
||||
"search.semantic_match_query_rewrite_interception_supported"
|
||||
|
@ -33,63 +21,45 @@ public class SemanticMatchQueryRewriteInterceptor implements QueryRewriteInterce
|
|||
public SemanticMatchQueryRewriteInterceptor() {}
|
||||
|
||||
@Override
|
||||
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
|
||||
protected String getFieldName(QueryBuilder queryBuilder) {
|
||||
assert (queryBuilder instanceof MatchQueryBuilder);
|
||||
MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
|
||||
QueryBuilder rewritten = queryBuilder;
|
||||
ResolvedIndices resolvedIndices = context.getResolvedIndices();
|
||||
if (resolvedIndices != null) {
|
||||
Collection<IndexMetadata> indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values();
|
||||
List<String> inferenceIndices = new ArrayList<>();
|
||||
List<String> nonInferenceIndices = new ArrayList<>();
|
||||
for (IndexMetadata indexMetadata : indexMetadataCollection) {
|
||||
String indexName = indexMetadata.getIndex().getName();
|
||||
InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(matchQueryBuilder.fieldName());
|
||||
if (inferenceFieldMetadata != null) {
|
||||
inferenceIndices.add(indexName);
|
||||
} else {
|
||||
nonInferenceIndices.add(indexName);
|
||||
}
|
||||
}
|
||||
return matchQueryBuilder.fieldName();
|
||||
}
|
||||
|
||||
if (inferenceIndices.isEmpty()) {
|
||||
return rewritten;
|
||||
} else if (nonInferenceIndices.isEmpty() == false) {
|
||||
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
|
||||
for (String inferenceIndexName : inferenceIndices) {
|
||||
// Add a separate clause for each semantic query, because they may be using different inference endpoints
|
||||
// TODO - consolidate this to a single clause once the semantic query supports multiple inference endpoints
|
||||
boolQueryBuilder.should(
|
||||
createSemanticSubQuery(inferenceIndexName, matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value())
|
||||
);
|
||||
}
|
||||
boolQueryBuilder.should(createMatchSubQuery(nonInferenceIndices, matchQueryBuilder));
|
||||
rewritten = boolQueryBuilder;
|
||||
} else {
|
||||
rewritten = new SemanticQueryBuilder(matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value(), false);
|
||||
}
|
||||
}
|
||||
@Override
|
||||
protected String getQuery(QueryBuilder queryBuilder) {
|
||||
assert (queryBuilder instanceof MatchQueryBuilder);
|
||||
MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
|
||||
return (String) matchQueryBuilder.value();
|
||||
}
|
||||
|
||||
return rewritten;
|
||||
@Override
|
||||
protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
|
||||
return new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
|
||||
QueryBuilder queryBuilder,
|
||||
InferenceIndexInformationForField indexInformation
|
||||
) {
|
||||
assert (queryBuilder instanceof MatchQueryBuilder);
|
||||
MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
|
||||
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
|
||||
boolQueryBuilder.should(
|
||||
createSemanticSubQuery(
|
||||
indexInformation.getInferenceIndices(),
|
||||
matchQueryBuilder.fieldName(),
|
||||
(String) matchQueryBuilder.value()
|
||||
)
|
||||
);
|
||||
boolQueryBuilder.should(createSubQueryForIndices(indexInformation.nonInferenceIndices(), matchQueryBuilder));
|
||||
return boolQueryBuilder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getQueryName() {
|
||||
return MatchQueryBuilder.NAME;
|
||||
}
|
||||
|
||||
private QueryBuilder createSemanticSubQuery(String indexName, String fieldName, String value) {
|
||||
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
|
||||
boolQueryBuilder.must(new SemanticQueryBuilder(fieldName, value, true));
|
||||
boolQueryBuilder.filter(new TermQueryBuilder(IndexFieldMapper.NAME, indexName));
|
||||
return boolQueryBuilder;
|
||||
}
|
||||
|
||||
private QueryBuilder createMatchSubQuery(List<String> indices, MatchQueryBuilder matchQueryBuilder) {
|
||||
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
|
||||
boolQueryBuilder.must(matchQueryBuilder);
|
||||
boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices));
|
||||
return boolQueryBuilder;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -148,6 +148,14 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
|
|||
return NAME;
|
||||
}
|
||||
|
||||
public String getFieldName() {
|
||||
return fieldName;
|
||||
}
|
||||
|
||||
public String getQuery() {
|
||||
return query;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
return TransportVersions.V_8_15_0;
|
||||
|
|
|
@ -0,0 +1,148 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.queries;
|
||||
|
||||
import org.elasticsearch.action.ResolvedIndices;
|
||||
import org.elasticsearch.cluster.metadata.IndexMetadata;
|
||||
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
|
||||
import org.elasticsearch.index.mapper.IndexFieldMapper;
|
||||
import org.elasticsearch.index.query.BoolQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryRewriteContext;
|
||||
import org.elasticsearch.index.query.TermsQueryBuilder;
|
||||
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Intercepts and adapts a query to be rewritten to work seamlessly on a semantic_text field.
|
||||
*/
|
||||
public abstract class SemanticQueryRewriteInterceptor implements QueryRewriteInterceptor {
|
||||
|
||||
public SemanticQueryRewriteInterceptor() {}
|
||||
|
||||
@Override
|
||||
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
|
||||
String fieldName = getFieldName(queryBuilder);
|
||||
ResolvedIndices resolvedIndices = context.getResolvedIndices();
|
||||
|
||||
if (resolvedIndices == null) {
|
||||
// No resolved indices, so return the original query.
|
||||
return queryBuilder;
|
||||
}
|
||||
|
||||
InferenceIndexInformationForField indexInformation = resolveIndicesForField(fieldName, resolvedIndices);
|
||||
if (indexInformation.getInferenceIndices().isEmpty()) {
|
||||
// No inference fields were identified, so return the original query.
|
||||
return queryBuilder;
|
||||
} else if (indexInformation.nonInferenceIndices().isEmpty() == false) {
|
||||
// Combined case where the field name requested by this query contains both
|
||||
// semantic_text and non-inference fields, so we have to combine queries per index
|
||||
// containing each field type.
|
||||
return buildCombinedInferenceAndNonInferenceQuery(queryBuilder, indexInformation);
|
||||
} else {
|
||||
// The only fields we've identified are inference fields (e.g. semantic_text),
|
||||
// so rewrite the entire query to work on a semantic_text field.
|
||||
return buildInferenceQuery(queryBuilder, indexInformation);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @param queryBuilder {@link QueryBuilder}
|
||||
* @return The singular field name requested by the provided query builder.
|
||||
*/
|
||||
protected abstract String getFieldName(QueryBuilder queryBuilder);
|
||||
|
||||
/**
|
||||
* @param queryBuilder {@link QueryBuilder}
|
||||
* @return The text/query string requested by the provided query builder.
|
||||
*/
|
||||
protected abstract String getQuery(QueryBuilder queryBuilder);
|
||||
|
||||
/**
|
||||
* Builds the inference query
|
||||
*
|
||||
* @param queryBuilder {@link QueryBuilder}
|
||||
* @param indexInformation {@link InferenceIndexInformationForField}
|
||||
* @return {@link QueryBuilder}
|
||||
*/
|
||||
protected abstract QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation);
|
||||
|
||||
/**
|
||||
* Builds a combined inference and non-inference query,
|
||||
* which separates the different queries into appropriate indices based on field type.
|
||||
* @param queryBuilder {@link QueryBuilder}
|
||||
* @param indexInformation {@link InferenceIndexInformationForField}
|
||||
* @return {@link QueryBuilder}
|
||||
*/
|
||||
protected abstract QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
|
||||
QueryBuilder queryBuilder,
|
||||
InferenceIndexInformationForField indexInformation
|
||||
);
|
||||
|
||||
private InferenceIndexInformationForField resolveIndicesForField(String fieldName, ResolvedIndices resolvedIndices) {
|
||||
Collection<IndexMetadata> indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values();
|
||||
Map<String, InferenceFieldMetadata> inferenceIndicesMetadata = new HashMap<>();
|
||||
List<String> nonInferenceIndices = new ArrayList<>();
|
||||
for (IndexMetadata indexMetadata : indexMetadataCollection) {
|
||||
String indexName = indexMetadata.getIndex().getName();
|
||||
InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(fieldName);
|
||||
if (inferenceFieldMetadata != null) {
|
||||
inferenceIndicesMetadata.put(indexName, inferenceFieldMetadata);
|
||||
} else {
|
||||
nonInferenceIndices.add(indexName);
|
||||
}
|
||||
}
|
||||
|
||||
return new InferenceIndexInformationForField(fieldName, inferenceIndicesMetadata, nonInferenceIndices);
|
||||
}
|
||||
|
||||
protected QueryBuilder createSubQueryForIndices(Collection<String> indices, QueryBuilder queryBuilder) {
|
||||
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
|
||||
boolQueryBuilder.must(queryBuilder);
|
||||
boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices));
|
||||
return boolQueryBuilder;
|
||||
}
|
||||
|
||||
protected QueryBuilder createSemanticSubQuery(Collection<String> indices, String fieldName, String value) {
|
||||
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
|
||||
boolQueryBuilder.must(new SemanticQueryBuilder(fieldName, value, true));
|
||||
boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices));
|
||||
return boolQueryBuilder;
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents the indices and associated inference information for a field.
|
||||
*/
|
||||
public record InferenceIndexInformationForField(
|
||||
String fieldName,
|
||||
Map<String, InferenceFieldMetadata> inferenceIndicesMetadata,
|
||||
List<String> nonInferenceIndices
|
||||
) {
|
||||
|
||||
public Collection<String> getInferenceIndices() {
|
||||
return inferenceIndicesMetadata.keySet();
|
||||
}
|
||||
|
||||
public Map<String, List<String>> getInferenceIdsIndices() {
|
||||
return inferenceIndicesMetadata.entrySet()
|
||||
.stream()
|
||||
.collect(
|
||||
Collectors.groupingBy(
|
||||
entry -> entry.getValue().getSearchInferenceId(),
|
||||
Collectors.mapping(Map.Entry::getKey, Collectors.toList())
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,124 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.queries;
|
||||
|
||||
import org.apache.lucene.search.join.ScoreMode;
|
||||
import org.elasticsearch.features.NodeFeature;
|
||||
import org.elasticsearch.index.query.BoolQueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
|
||||
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class SemanticSparseVectorQueryRewriteInterceptor extends SemanticQueryRewriteInterceptor {
|
||||
|
||||
public static final NodeFeature SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature(
|
||||
"search.semantic_sparse_vector_query_rewrite_interception_supported"
|
||||
);
|
||||
|
||||
public SemanticSparseVectorQueryRewriteInterceptor() {}
|
||||
|
||||
@Override
|
||||
protected String getFieldName(QueryBuilder queryBuilder) {
|
||||
assert (queryBuilder instanceof SparseVectorQueryBuilder);
|
||||
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder;
|
||||
return sparseVectorQueryBuilder.getFieldName();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getQuery(QueryBuilder queryBuilder) {
|
||||
assert (queryBuilder instanceof SparseVectorQueryBuilder);
|
||||
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder;
|
||||
return sparseVectorQueryBuilder.getQuery();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
|
||||
Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();
|
||||
if (inferenceIdsIndices.size() == 1) {
|
||||
// Simple case, everything uses the same inference ID
|
||||
String searchInferenceId = inferenceIdsIndices.keySet().iterator().next();
|
||||
return buildNestedQueryFromSparseVectorQuery(queryBuilder, searchInferenceId);
|
||||
} else {
|
||||
// Multiple inference IDs, construct a boolean query
|
||||
return buildInferenceQueryWithMultipleInferenceIds(queryBuilder, inferenceIdsIndices);
|
||||
}
|
||||
}
|
||||
|
||||
private QueryBuilder buildInferenceQueryWithMultipleInferenceIds(
|
||||
QueryBuilder queryBuilder,
|
||||
Map<String, List<String>> inferenceIdsIndices
|
||||
) {
|
||||
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
|
||||
for (String inferenceId : inferenceIdsIndices.keySet()) {
|
||||
boolQueryBuilder.should(
|
||||
createSubQueryForIndices(
|
||||
inferenceIdsIndices.get(inferenceId),
|
||||
buildNestedQueryFromSparseVectorQuery(queryBuilder, inferenceId)
|
||||
)
|
||||
);
|
||||
}
|
||||
return boolQueryBuilder;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
|
||||
QueryBuilder queryBuilder,
|
||||
InferenceIndexInformationForField indexInformation
|
||||
) {
|
||||
assert (queryBuilder instanceof SparseVectorQueryBuilder);
|
||||
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder;
|
||||
Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();
|
||||
|
||||
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
|
||||
boolQueryBuilder.should(
|
||||
createSubQueryForIndices(
|
||||
indexInformation.nonInferenceIndices(),
|
||||
createSubQueryForIndices(indexInformation.nonInferenceIndices(), sparseVectorQueryBuilder)
|
||||
)
|
||||
);
|
||||
// We always perform nested subqueries on semantic_text fields, to support
|
||||
// sparse_vector queries using query vectors.
|
||||
for (String inferenceId : inferenceIdsIndices.keySet()) {
|
||||
boolQueryBuilder.should(
|
||||
createSubQueryForIndices(
|
||||
inferenceIdsIndices.get(inferenceId),
|
||||
buildNestedQueryFromSparseVectorQuery(sparseVectorQueryBuilder, inferenceId)
|
||||
)
|
||||
);
|
||||
}
|
||||
return boolQueryBuilder;
|
||||
}
|
||||
|
||||
private QueryBuilder buildNestedQueryFromSparseVectorQuery(QueryBuilder queryBuilder, String searchInferenceId) {
|
||||
assert (queryBuilder instanceof SparseVectorQueryBuilder);
|
||||
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder;
|
||||
return QueryBuilders.nestedQuery(
|
||||
SemanticTextField.getChunksFieldName(sparseVectorQueryBuilder.getFieldName()),
|
||||
new SparseVectorQueryBuilder(
|
||||
SemanticTextField.getEmbeddingsFieldName(sparseVectorQueryBuilder.getFieldName()),
|
||||
sparseVectorQueryBuilder.getQueryVectors(),
|
||||
(sparseVectorQueryBuilder.getInferenceId() == null && sparseVectorQueryBuilder.getQuery() != null)
|
||||
? searchInferenceId
|
||||
: sparseVectorQueryBuilder.getInferenceId(),
|
||||
sparseVectorQueryBuilder.getQuery(),
|
||||
sparseVectorQueryBuilder.shouldPruneTokens(),
|
||||
sparseVectorQueryBuilder.getTokenPruningConfig()
|
||||
),
|
||||
ScoreMode.Max
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getQueryName() {
|
||||
return SparseVectorQueryBuilder.NAME;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,110 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.query;
|
||||
|
||||
import org.elasticsearch.action.MockResolvedIndices;
|
||||
import org.elasticsearch.action.OriginalIndices;
|
||||
import org.elasticsearch.action.ResolvedIndices;
|
||||
import org.elasticsearch.action.support.IndicesOptions;
|
||||
import org.elasticsearch.cluster.metadata.IndexMetadata;
|
||||
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.index.Index;
|
||||
import org.elasticsearch.index.IndexVersion;
|
||||
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.test.client.NoOpClient;
|
||||
import org.elasticsearch.threadpool.TestThreadPool;
|
||||
import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor;
|
||||
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
|
||||
public class SemanticMatchQueryRewriteInterceptorTests extends ESTestCase {
|
||||
|
||||
private TestThreadPool threadPool;
|
||||
private NoOpClient client;
|
||||
private Index index;
|
||||
|
||||
private static final String FIELD_NAME = "fieldName";
|
||||
private static final String VALUE = "value";
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
threadPool = createThreadPool();
|
||||
client = new NoOpClient(threadPool);
|
||||
index = new Index(randomAlphaOfLength(10), randomAlphaOfLength(10));
|
||||
}
|
||||
|
||||
@After
|
||||
public void cleanup() {
|
||||
threadPool.close();
|
||||
}
|
||||
|
||||
public void testMatchQueryOnInferenceFieldIsInterceptedAndRewrittenToSemanticQuery() throws IOException {
|
||||
Map<String, InferenceFieldMetadata> inferenceFields = Map.of(
|
||||
FIELD_NAME,
|
||||
new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME })
|
||||
);
|
||||
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
|
||||
QueryBuilder original = createTestQueryBuilder();
|
||||
QueryBuilder rewritten = original.rewrite(context);
|
||||
assertTrue(
|
||||
"Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
|
||||
rewritten instanceof InterceptedQueryBuilderWrapper
|
||||
);
|
||||
InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
|
||||
assertTrue(intercepted.queryBuilder instanceof SemanticQueryBuilder);
|
||||
SemanticQueryBuilder semanticQueryBuilder = (SemanticQueryBuilder) intercepted.queryBuilder;
|
||||
assertEquals(FIELD_NAME, semanticQueryBuilder.getFieldName());
|
||||
assertEquals(VALUE, semanticQueryBuilder.getQuery());
|
||||
}
|
||||
|
||||
public void testMatchQueryOnNonInferenceFieldRemainsMatchQuery() throws IOException {
|
||||
QueryRewriteContext context = createQueryRewriteContext(Map.of()); // No inference fields
|
||||
QueryBuilder original = createTestQueryBuilder();
|
||||
QueryBuilder rewritten = original.rewrite(context);
|
||||
assertTrue(
|
||||
"Expected query to remain match but was [" + rewritten.getClass().getName() + "]",
|
||||
rewritten instanceof MatchQueryBuilder
|
||||
);
|
||||
assertEquals(original, rewritten);
|
||||
}
|
||||
|
||||
private MatchQueryBuilder createTestQueryBuilder() {
|
||||
return new MatchQueryBuilder(FIELD_NAME, VALUE);
|
||||
}
|
||||
|
||||
private QueryRewriteContext createQueryRewriteContext(Map<String, InferenceFieldMetadata> inferenceFields) {
|
||||
IndexMetadata indexMetadata = IndexMetadata.builder(index.getName())
|
||||
.settings(
|
||||
Settings.builder()
|
||||
.put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())
|
||||
.put(IndexMetadata.SETTING_INDEX_UUID, index.getUUID())
|
||||
)
|
||||
.numberOfShards(1)
|
||||
.numberOfReplicas(0)
|
||||
.putInferenceFields(inferenceFields)
|
||||
.build();
|
||||
|
||||
ResolvedIndices resolvedIndices = new MockResolvedIndices(
|
||||
Map.of(),
|
||||
new OriginalIndices(new String[] { index.getName() }, IndicesOptions.DEFAULT),
|
||||
Map.of(index, indexMetadata)
|
||||
);
|
||||
|
||||
return new QueryRewriteContext(null, client, null, resolvedIndices, null, createRewriteInterceptor());
|
||||
}
|
||||
|
||||
private QueryRewriteInterceptor createRewriteInterceptor() {
|
||||
return new SemanticMatchQueryRewriteInterceptor();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,137 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.query;
|
||||
|
||||
import org.elasticsearch.action.MockResolvedIndices;
|
||||
import org.elasticsearch.action.OriginalIndices;
|
||||
import org.elasticsearch.action.ResolvedIndices;
|
||||
import org.elasticsearch.action.support.IndicesOptions;
|
||||
import org.elasticsearch.cluster.metadata.IndexMetadata;
|
||||
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.index.Index;
|
||||
import org.elasticsearch.index.IndexVersion;
|
||||
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.test.client.NoOpClient;
|
||||
import org.elasticsearch.threadpool.TestThreadPool;
|
||||
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
|
||||
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
|
||||
import org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
|
||||
public class SemanticSparseVectorQueryRewriteInterceptorTests extends ESTestCase {
|
||||
|
||||
private TestThreadPool threadPool;
|
||||
private NoOpClient client;
|
||||
private Index index;
|
||||
|
||||
private static final String FIELD_NAME = "fieldName";
|
||||
private static final String INFERENCE_ID = "inferenceId";
|
||||
private static final String QUERY = "query";
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
threadPool = createThreadPool();
|
||||
client = new NoOpClient(threadPool);
|
||||
index = new Index(randomAlphaOfLength(10), randomAlphaOfLength(10));
|
||||
}
|
||||
|
||||
@After
|
||||
public void cleanup() {
|
||||
threadPool.close();
|
||||
}
|
||||
|
||||
public void testSparseVectorQueryOnInferenceFieldIsInterceptedAndRewritten() throws IOException {
|
||||
Map<String, InferenceFieldMetadata> inferenceFields = Map.of(
|
||||
FIELD_NAME,
|
||||
new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME })
|
||||
);
|
||||
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
|
||||
QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY);
|
||||
QueryBuilder rewritten = original.rewrite(context);
|
||||
assertTrue(
|
||||
"Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
|
||||
rewritten instanceof InterceptedQueryBuilderWrapper
|
||||
);
|
||||
InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
|
||||
assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder);
|
||||
NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder;
|
||||
assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path());
|
||||
QueryBuilder innerQuery = nestedQueryBuilder.query();
|
||||
assertTrue(innerQuery instanceof SparseVectorQueryBuilder);
|
||||
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) innerQuery;
|
||||
assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), sparseVectorQueryBuilder.getFieldName());
|
||||
assertEquals(INFERENCE_ID, sparseVectorQueryBuilder.getInferenceId());
|
||||
assertEquals(QUERY, sparseVectorQueryBuilder.getQuery());
|
||||
}
|
||||
|
||||
public void testSparseVectorQueryOnInferenceFieldWithoutInferenceIdIsInterceptedAndRewritten() throws IOException {
|
||||
Map<String, InferenceFieldMetadata> inferenceFields = Map.of(
|
||||
FIELD_NAME,
|
||||
new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME })
|
||||
);
|
||||
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
|
||||
QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, null, QUERY);
|
||||
QueryBuilder rewritten = original.rewrite(context);
|
||||
assertTrue(
|
||||
"Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
|
||||
rewritten instanceof InterceptedQueryBuilderWrapper
|
||||
);
|
||||
InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
|
||||
assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder);
|
||||
NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder;
|
||||
assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path());
|
||||
QueryBuilder innerQuery = nestedQueryBuilder.query();
|
||||
assertTrue(innerQuery instanceof SparseVectorQueryBuilder);
|
||||
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) innerQuery;
|
||||
assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), sparseVectorQueryBuilder.getFieldName());
|
||||
assertEquals(INFERENCE_ID, sparseVectorQueryBuilder.getInferenceId());
|
||||
assertEquals(QUERY, sparseVectorQueryBuilder.getQuery());
|
||||
}
|
||||
|
||||
public void testSparseVectorQueryOnNonInferenceFieldRemainsUnchanged() throws IOException {
|
||||
QueryRewriteContext context = createQueryRewriteContext(Map.of()); // No inference fields
|
||||
QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY);
|
||||
QueryBuilder rewritten = original.rewrite(context);
|
||||
assertTrue(
|
||||
"Expected query to remain sparse_vector but was [" + rewritten.getClass().getName() + "]",
|
||||
rewritten instanceof SparseVectorQueryBuilder
|
||||
);
|
||||
assertEquals(original, rewritten);
|
||||
}
|
||||
|
||||
private QueryRewriteContext createQueryRewriteContext(Map<String, InferenceFieldMetadata> inferenceFields) {
|
||||
IndexMetadata indexMetadata = IndexMetadata.builder(index.getName())
|
||||
.settings(
|
||||
Settings.builder()
|
||||
.put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())
|
||||
.put(IndexMetadata.SETTING_INDEX_UUID, index.getUUID())
|
||||
)
|
||||
.numberOfShards(1)
|
||||
.numberOfReplicas(0)
|
||||
.putInferenceFields(inferenceFields)
|
||||
.build();
|
||||
|
||||
ResolvedIndices resolvedIndices = new MockResolvedIndices(
|
||||
Map.of(),
|
||||
new OriginalIndices(new String[] { index.getName() }, IndicesOptions.DEFAULT),
|
||||
Map.of(index, indexMetadata)
|
||||
);
|
||||
|
||||
return new QueryRewriteContext(null, client, null, resolvedIndices, null, createRewriteInterceptor());
|
||||
}
|
||||
|
||||
private QueryRewriteInterceptor createRewriteInterceptor() {
|
||||
return new SemanticSparseVectorQueryRewriteInterceptor();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,249 @@
|
|||
setup:
|
||||
- requires:
|
||||
cluster_features: "search.semantic_sparse_vector_query_rewrite_interception_supported"
|
||||
reason: semantic_text sparse_vector support introduced in 8.18.0
|
||||
|
||||
- do:
|
||||
inference.put:
|
||||
task_type: sparse_embedding
|
||||
inference_id: sparse-inference-id
|
||||
body: >
|
||||
{
|
||||
"service": "test_service",
|
||||
"service_settings": {
|
||||
"model": "my_model",
|
||||
"api_key": "abc64"
|
||||
},
|
||||
"task_settings": {
|
||||
}
|
||||
}
|
||||
|
||||
- do:
|
||||
inference.put:
|
||||
task_type: sparse_embedding
|
||||
inference_id: sparse-inference-id-2
|
||||
body: >
|
||||
{
|
||||
"service": "test_service",
|
||||
"service_settings": {
|
||||
"model": "my_model",
|
||||
"api_key": "abc64"
|
||||
},
|
||||
"task_settings": {
|
||||
}
|
||||
}
|
||||
|
||||
- do:
|
||||
indices.create:
|
||||
index: test-semantic-text-index
|
||||
body:
|
||||
mappings:
|
||||
properties:
|
||||
inference_field:
|
||||
type: semantic_text
|
||||
inference_id: sparse-inference-id
|
||||
|
||||
- do:
|
||||
indices.create:
|
||||
index: test-semantic-text-index-2
|
||||
body:
|
||||
mappings:
|
||||
properties:
|
||||
inference_field:
|
||||
type: semantic_text
|
||||
inference_id: sparse-inference-id-2
|
||||
|
||||
- do:
|
||||
indices.create:
|
||||
index: test-sparse-vector-index
|
||||
body:
|
||||
mappings:
|
||||
properties:
|
||||
inference_field:
|
||||
type: sparse_vector
|
||||
|
||||
- do:
|
||||
index:
|
||||
index: test-semantic-text-index
|
||||
id: doc_1
|
||||
body:
|
||||
inference_field: [ "inference test", "another inference test" ]
|
||||
refresh: true
|
||||
|
||||
- do:
|
||||
index:
|
||||
index: test-semantic-text-index-2
|
||||
id: doc_3
|
||||
body:
|
||||
inference_field: [ "inference test", "another inference test" ]
|
||||
refresh: true
|
||||
|
||||
- do:
|
||||
index:
|
||||
index: test-sparse-vector-index
|
||||
id: doc_2
|
||||
body:
|
||||
inference_field: { "feature_0": 1, "feature_1": 2, "feature_2": 3, "feature_3": 4, "feature_4": 5 }
|
||||
refresh: true
|
||||
|
||||
---
|
||||
"Nested sparse_vector queries using the old format on semantic_text embeddings and inference still work":
|
||||
- skip:
|
||||
features: [ "headers" ]
|
||||
|
||||
- do:
|
||||
headers:
|
||||
# Force JSON content type so that we use a parser that interprets the floating-point score as a double
|
||||
Content-Type: application/json
|
||||
search:
|
||||
index: test-semantic-text-index
|
||||
body:
|
||||
query:
|
||||
nested:
|
||||
path: inference_field.inference.chunks
|
||||
query:
|
||||
sparse_vector:
|
||||
field: inference_field.inference.chunks.embeddings
|
||||
inference_id: sparse-inference-id
|
||||
query: test
|
||||
|
||||
- match: { hits.total.value: 1 }
|
||||
- match: { hits.hits.0._id: "doc_1" }
|
||||
|
||||
---
|
||||
"Nested sparse_vector queries using the old format on semantic_text embeddings and query vectors still work":
|
||||
- skip:
|
||||
features: [ "headers" ]
|
||||
|
||||
- do:
|
||||
headers:
|
||||
# Force JSON content type so that we use a parser that interprets the floating-point score as a double
|
||||
Content-Type: application/json
|
||||
search:
|
||||
index: test-semantic-text-index
|
||||
body:
|
||||
query:
|
||||
nested:
|
||||
path: inference_field.inference.chunks
|
||||
query:
|
||||
sparse_vector:
|
||||
field: inference_field.inference.chunks.embeddings
|
||||
query_vector: { "feature_0": 1, "feature_1": 2, "feature_2": 3, "feature_3": 4, "feature_4": 5 }
|
||||
|
||||
- match: { hits.total.value: 1 }
|
||||
- match: { hits.hits.0._id: "doc_1" }
|
||||
|
||||
---
|
||||
"sparse_vector query against semantic_text field using a specified inference ID":
|
||||
|
||||
- do:
|
||||
search:
|
||||
index: test-semantic-text-index
|
||||
body:
|
||||
query:
|
||||
sparse_vector:
|
||||
field: inference_field
|
||||
inference_id: sparse-inference-id
|
||||
query: "inference test"
|
||||
|
||||
- match: { hits.total.value: 1 }
|
||||
- match: { hits.hits.0._id: "doc_1" }
|
||||
|
||||
---
|
||||
"sparse_vector query against semantic_text field using inference ID configured in semantic_text field":
|
||||
|
||||
- do:
|
||||
search:
|
||||
index: test-semantic-text-index
|
||||
body:
|
||||
query:
|
||||
sparse_vector:
|
||||
field: inference_field
|
||||
query: "inference test"
|
||||
|
||||
- match: { hits.total.value: 1 }
|
||||
- match: { hits.hits.0._id: "doc_1" }
|
||||
|
||||
---
|
||||
"sparse_vector query against semantic_text field using query vectors":
|
||||
|
||||
- do:
|
||||
search:
|
||||
index: test-semantic-text-index
|
||||
body:
|
||||
query:
|
||||
sparse_vector:
|
||||
field: inference_field
|
||||
query_vector: { "feature_0": 1, "feature_1": 2, "feature_2": 3, "feature_3": 4, "feature_4": 5 }
|
||||
|
||||
- match: { hits.total.value: 1 }
|
||||
- match: { hits.hits.0._id: "doc_1" }
|
||||
|
||||
---
|
||||
"sparse_vector query against combined sparse_vector and semantic_text fields using inference":
|
||||
|
||||
- do:
|
||||
search:
|
||||
index:
|
||||
- test-semantic-text-index
|
||||
- test-sparse-vector-index
|
||||
body:
|
||||
query:
|
||||
sparse_vector:
|
||||
field: inference_field
|
||||
inference_id: sparse-inference-id
|
||||
query: "inference test"
|
||||
|
||||
- match: { hits.total.value: 2 }
|
||||
|
||||
---
|
||||
"sparse_vector query against combined sparse_vector and semantic_text fields still requires inference ID":
|
||||
|
||||
- do:
|
||||
catch: bad_request
|
||||
search:
|
||||
index:
|
||||
- test-semantic-text-index
|
||||
- test-sparse-vector-index
|
||||
body:
|
||||
query:
|
||||
sparse_vector:
|
||||
field: inference_field
|
||||
query: "inference test"
|
||||
|
||||
- match: { error.type: "illegal_argument_exception" }
|
||||
- match: { error.reason: "inference_id required to perform vector search on query string" }
|
||||
|
||||
---
|
||||
"sparse_vector query against combined sparse_vector and semantic_text fields using query vectors":
|
||||
|
||||
- do:
|
||||
search:
|
||||
index:
|
||||
- test-semantic-text-index
|
||||
- test-sparse-vector-index
|
||||
body:
|
||||
query:
|
||||
sparse_vector:
|
||||
field: inference_field
|
||||
query_vector: { "feature_0": 1, "feature_1": 2, "feature_2": 3, "feature_3": 4, "feature_4": 5 }
|
||||
|
||||
- match: { hits.total.value: 2 }
|
||||
|
||||
|
||||
---
|
||||
"sparse_vector query against multiple semantic_text fields with multiple inference IDs specified in semantic_text fields":
|
||||
|
||||
- do:
|
||||
search:
|
||||
index:
|
||||
- test-semantic-text-index
|
||||
- test-semantic-text-index-2
|
||||
body:
|
||||
query:
|
||||
sparse_vector:
|
||||
field: inference_field
|
||||
query: "inference test"
|
||||
|
||||
- match: { hits.total.value: 2 }
|
||||
|
|
@ -268,7 +268,7 @@ setup:
|
|||
- match: { hits.hits.0._score: 0.25 }
|
||||
|
||||
---
|
||||
"Test sparse_vector requires one of inference_id or query_vector":
|
||||
"Test sparse_vector requires one of query or query_vector":
|
||||
- do:
|
||||
catch: /\[sparse_vector\] requires one of \[query_vector\] or \[inference_id\]/
|
||||
search:
|
||||
|
@ -281,7 +281,41 @@ setup:
|
|||
- match: { status: 400 }
|
||||
|
||||
---
|
||||
"Test sparse_vector only allows one of inference_id or query_vector":
|
||||
"Test sparse_vector returns an error if inference ID not specified with query":
|
||||
- do:
|
||||
catch: bad_request # This is for BWC, the actual error message is tested in a subsequent test
|
||||
search:
|
||||
index: index-with-sparse-vector
|
||||
body:
|
||||
query:
|
||||
sparse_vector:
|
||||
field: text
|
||||
query: "octopus comforter smells"
|
||||
|
||||
- match: { status: 400 }
|
||||
|
||||
---
|
||||
"Test sparse_vector requires an inference ID to be specified on sparse_vector fields":
|
||||
- requires:
|
||||
cluster_features: [ "search.semantic_sparse_vector_query_rewrite_interception_supported" ]
|
||||
reason: "Error message changed in 8.18"
|
||||
- do:
|
||||
catch: /inference_id required to perform vector search on query string/
|
||||
search:
|
||||
index: index-with-sparse-vector
|
||||
body:
|
||||
query:
|
||||
sparse_vector:
|
||||
field: text
|
||||
query: "octopus comforter smells"
|
||||
|
||||
- match: { status: 400 }
|
||||
|
||||
---
|
||||
"Test sparse_vector only allows one of query or query_vector (note the error message is misleading)":
|
||||
- requires:
|
||||
cluster_features: [ "search.semantic_sparse_vector_query_rewrite_interception_supported" ]
|
||||
reason: "sparse vector inference checks updated in 8.18 to support sparse_vector on semantic_text fields"
|
||||
- do:
|
||||
catch: /\[sparse_vector\] requires one of \[query_vector\] or \[inference_id\]/
|
||||
search:
|
||||
|
@ -290,7 +324,7 @@ setup:
|
|||
query:
|
||||
sparse_vector:
|
||||
field: text
|
||||
inference_id: text_expansion_model
|
||||
query: "octopus comforter smells"
|
||||
query_vector:
|
||||
the: 1.0
|
||||
comforter: 1.0
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue