[ES|QL] RERANK command - Updating the syntax and behavior (#129488)

This commit is contained in:
Aurélien FOUCRET 2025-06-19 15:46:33 +02:00 committed by GitHub
parent a0109bb0fe
commit 34ccaba56d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 1844 additions and 1154 deletions

View file

@ -3,21 +3,62 @@
// This makes the output more predictable which is helpful here.
reranker using a single field
reranker using a single field, overwrite existing _score column
required_capability: rerank
required_capability: match_operator_colon
FROM books METADATA _score
| WHERE title:"war and peace" AND author:"Tolstoy"
| RERANK "war and peace" ON title WITH test_reranker
| KEEP book_no, title, author
| SORT _score DESC, book_no ASC
| RERANK "war and peace" ON title WITH inferenceId=test_reranker
| EVAL _score=ROUND(_score, 2)
| KEEP book_no, title, author, _score
;
book_no:keyword | title:text | author:text
5327 | War and Peace | Leo Tolstoy
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy]
9032 | War and Peace: A Novel (6 Volumes) | Tolstoy Leo
2776 | The Devil and Other Stories (Oxford World's Classics) | Leo Tolstoy
book_no:keyword | title:text | author:text | _score:double
5327 | War and Peace | Leo Tolstoy | 0.08
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy] | 0.03
9032 | War and Peace: A Novel (6 Volumes) | Tolstoy Leo | 0.03
2776 | The Devil and Other Stories (Oxford World's Classics) | Leo Tolstoy | 0.02
;
reranker using a single field, create a mew column
required_capability: rerank
required_capability: match_operator_colon
FROM books METADATA _score
| WHERE title:"war and peace" AND author:"Tolstoy"
| SORT _score DESC, book_no ASC
| RERANK "war and peace" ON title WITH inferenceId=test_reranker, scoreColumn=rerank_score
| EVAL _score=ROUND(_score, 2), rerank_score=ROUND(rerank_score, 2)
| KEEP book_no, title, author, rerank_score
;
book_no:keyword | title:text | author:text | rerank_score:double
5327 | War and Peace | Leo Tolstoy | 0.08
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy] | 0.03
9032 | War and Peace: A Novel (6 Volumes) | Tolstoy Leo | 0.03
2776 | The Devil and Other Stories (Oxford World's Classics) | Leo Tolstoy | 0.02
;
reranker using a single field, create a mew column, sort by rerank_score
required_capability: rerank
required_capability: match_operator_colon
FROM books METADATA _score
| WHERE title:"war and peace" AND author:"Tolstoy"
| SORT _score DESC
| RERANK "war and peace" ON title WITH inferenceId=test_reranker, scoreColumn=rerank_score
| EVAL _score=ROUND(_score, 2), rerank_score=ROUND(rerank_score, 2)
| SORT rerank_score, _score ASC, book_no ASC
| KEEP book_no, title, author, rerank_score
;
book_no:keyword | title:text | author:text | rerank_score:double
2776 | The Devil and Other Stories (Oxford World's Classics) | Leo Tolstoy | 0.02
9032 | War and Peace: A Novel (6 Volumes) | Tolstoy Leo | 0.03
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy] | 0.03
5327 | War and Peace | Leo Tolstoy | 0.08
;
@ -27,15 +68,17 @@ required_capability: match_operator_colon
FROM books METADATA _score
| WHERE title:"war and peace" AND author:"Tolstoy"
| RERANK "war and peace" ON title, author WITH test_reranker
| KEEP book_no, title, author
| RERANK "war and peace" ON title, author WITH inferenceId=test_reranker
| EVAL _score=ROUND(_score, 2)
| SORT _score DESC, book_no ASC
| KEEP book_no, title, author, _score
;
book_no:keyword | title:text | author:text
5327 | War and Peace | Leo Tolstoy
9032 | War and Peace: A Novel (6 Volumes) | Tolstoy Leo
2776 | The Devil and Other Stories (Oxford World's Classics) | Leo Tolstoy
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy]
book_no:keyword | title:text | author:text | _score:double
5327 | War and Peace | Leo Tolstoy | 0.02
2776 | The Devil and Other Stories (Oxford World's Classics) | Leo Tolstoy | 0.01
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy] | 0.01
9032 | War and Peace: A Novel (6 Volumes) | Tolstoy Leo | 0.01
;
@ -45,16 +88,18 @@ required_capability: match_operator_colon
FROM books METADATA _score
| WHERE title:"war and peace" AND author:"Tolstoy"
| SORT _score DESC
| SORT _score DESC, book_no ASC
| LIMIT 3
| RERANK "war and peace" ON title WITH test_reranker
| KEEP book_no, title, author
| RERANK "war and peace" ON title WITH inferenceId=test_reranker
| EVAL _score=ROUND(_score, 2)
| SORT _score DESC, book_no ASC
| KEEP book_no, title, author, _score
;
book_no:keyword | title:text | author:text
5327 | War and Peace | Leo Tolstoy
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy]
9032 | War and Peace: A Novel (6 Volumes) | Tolstoy Leo
book_no:keyword | title:text | author:text | _score:double
5327 | War and Peace | Leo Tolstoy | 0.08
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy] | 0.03
9032 | War and Peace: A Novel (6 Volumes) | Tolstoy Leo | 0.03
;
@ -64,15 +109,17 @@ required_capability: match_operator_colon
FROM books METADATA _score
| WHERE title:"war and peace" AND author:"Tolstoy"
| RERANK "war and peace" ON title WITH test_reranker
| KEEP book_no, title, author
| RERANK "war and peace" ON title WITH inferenceId=test_reranker
| EVAL _score=ROUND(_score, 2)
| SORT _score DESC, book_no ASC
| KEEP book_no, title, author, _score
| LIMIT 3
;
book_no:keyword | title:text | author:text
5327 | War and Peace | Leo Tolstoy
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy]
9032 | War and Peace: A Novel (6 Volumes) | Tolstoy Leo
book_no:keyword | title:text | author:text | _score:double
5327 | War and Peace | Leo Tolstoy | 0.08
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy] | 0.03
9032 | War and Peace: A Novel (6 Volumes) | Tolstoy Leo | 0.03
;
@ -82,16 +129,17 @@ required_capability: match_operator_colon
FROM books
| WHERE title:"war and peace" AND author:"Tolstoy"
| RERANK "war and peace" ON title WITH test_reranker
| KEEP book_no, title, author
| RERANK "war and peace" ON title WITH inferenceId=test_reranker
| EVAL _score=ROUND(_score, 2)
| KEEP book_no, title, author, _score
| SORT author, title
| LIMIT 3
;
book_no:keyword | title:text | author:text
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy]
2776 | The Devil and Other Stories (Oxford World's Classics) | Leo Tolstoy
5327 | War and Peace | Leo Tolstoy
book_no:keyword | title:text | author:text | _score:double
4536 | War and Peace (Signet Classics) | [John Hockenberry, Leo Tolstoy, Pat Conroy] | 0.03
2776 | The Devil and Other Stories (Oxford World's Classics) | Leo Tolstoy | 0.02
5327 | War and Peace | Leo Tolstoy | 0.08
;
@ -105,12 +153,14 @@ FROM books METADATA _id, _index, _score
| FORK ( WHERE title:"Tolkien" | SORT _score, _id DESC | LIMIT 3 )
( WHERE author:"Tolkien" | SORT _score, _id DESC | LIMIT 3 )
| RRF
| RERANK "Tolkien" ON title WITH test_reranker
| RERANK "Tolkien" ON title WITH inferenceId=test_reranker
| EVAL _score=ROUND(_score, 2)
| SORT _score DESC, book_no ASC
| LIMIT 2
| KEEP book_no, title, author
| KEEP book_no, title, author, _score
;
book_no:keyword | title:keyword | author:keyword
5335 | Letters of J R R Tolkien | J.R.R. Tolkien
2130 | The J. R. R. Tolkien Audio Collection | [Christopher Tolkien, John Ronald Reuel Tolkien]
book_no:keyword | title:keyword | author:keyword | _score:double
5335 | Letters of J R R Tolkien | J.R.R. Tolkien | 0.04
2130 | The J. R. R. Tolkien Audio Collection | [Christopher Tolkien, John Ronald Reuel Tolkien] | 0.03
;

View file

@ -306,8 +306,21 @@ rrfCommand
: DEV_RRF
;
inferenceCommandOptions
: inferenceCommandOption (COMMA inferenceCommandOption)*
;
inferenceCommandOption
: identifier ASSIGN inferenceCommandOptionValue
;
inferenceCommandOptionValue
: constant
| identifier
;
rerankCommand
: DEV_RERANK queryText=constant ON rerankFields (WITH inferenceId=identifierOrParameter)?
: DEV_RERANK queryText=constant ON rerankFields (WITH inferenceCommandOptions)?
;
completionCommand

View file

@ -841,7 +841,11 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
if (rerank.scoreAttribute() instanceof UnresolvedAttribute ua) {
Attribute resolved = resolveAttribute(ua, childrenOutput);
if (resolved.resolved() == false || resolved.dataType() != DOUBLE) {
if (ua.name().equals(MetadataAttribute.SCORE)) {
resolved = MetadataAttribute.create(Source.EMPTY, MetadataAttribute.SCORE);
} else {
resolved = new ReferenceAttribute(resolved.source(), resolved.name(), DOUBLE);
}
}
rerank = rerank.withScoreAttribute(resolved);
}

View file

@ -38,9 +38,9 @@ import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownAndCombineFi
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownAndCombineLimits;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownAndCombineOrderBy;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownAndCombineSample;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownCompletion;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownEnrich;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownEval;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownInferencePlan;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownRegexExtract;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.RemoveStatsOverride;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceAggregateAggExpressionWithEval;
@ -192,7 +192,7 @@ public class LogicalPlanOptimizer extends ParameterizedRuleExecutor<LogicalPlan,
new PushDownAndCombineLimits(),
new PushDownAndCombineFilters(),
new PushDownAndCombineSample(),
new PushDownCompletion(),
new PushDownInferencePlan(),
new PushDownEval(),
new PushDownRegexExtract(),
new PushDownEnrich(),

View file

@ -24,7 +24,7 @@ import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.RegexExtract;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin;
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
import org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes;
@ -72,10 +72,10 @@ public final class PushDownAndCombineFilters extends OptimizerRules.OptimizerRul
// Push down filters that do not rely on attributes created by RegexExtract
var attributes = AttributeSet.of(Expressions.asAttributes(re.extractedFields()));
plan = maybePushDownPastUnary(filter, re, attributes::contains, NO_OP);
} else if (child instanceof Completion completion) {
} else if (child instanceof InferencePlan<?> inferencePlan) {
// Push down filters that do not rely on attributes created by Cpmpletion
var attributes = AttributeSet.of(completion.generatedAttributes());
plan = maybePushDownPastUnary(filter, completion, attributes::contains, NO_OP);
var attributes = AttributeSet.of(inferencePlan.generatedAttributes());
plan = maybePushDownPastUnary(filter, inferencePlan, attributes::contains, NO_OP);
} else if (child instanceof Enrich enrich) {
// Push down filters that do not rely on attributes created by Enrich
var attributes = AttributeSet.of(Expressions.asAttributes(enrich.enrichFields()));

View file

@ -17,7 +17,7 @@ import org.elasticsearch.xpack.esql.plan.logical.MvExpand;
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.RegexExtract;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
import org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes;
@ -43,7 +43,7 @@ public final class PushDownAndCombineLimits extends OptimizerRules.Parameterized
|| unary instanceof Project
|| unary instanceof RegexExtract
|| unary instanceof Enrich
|| unary instanceof Completion) {
|| unary instanceof InferencePlan<?>) {
return unary.replaceChild(limit.replaceChild(unary.child()));
} else if (unary instanceof MvExpand) {
// MV_EXPAND can increase the number of rows, so we cannot just push the limit down

View file

@ -8,11 +8,11 @@
package org.elasticsearch.xpack.esql.optimizer.rules.logical;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
public final class PushDownCompletion extends OptimizerRules.OptimizerRule<Completion> {
public final class PushDownInferencePlan extends OptimizerRules.OptimizerRule<InferencePlan<?>> {
@Override
protected LogicalPlan rule(Completion p) {
protected LogicalPlan rule(InferencePlan<?> p) {
return PushDownUtils.pushGeneratingPlanPastProjectAndOrderBy(p);
}
}

File diff suppressed because one or more lines are too long

View file

@ -788,6 +788,42 @@ public class EsqlBaseParserBaseListener implements EsqlBaseParserListener {
* <p>The default implementation does nothing.</p>
*/
@Override public void exitRrfCommand(EsqlBaseParser.RrfCommandContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterInferenceCommandOptions(EsqlBaseParser.InferenceCommandOptionsContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitInferenceCommandOptions(EsqlBaseParser.InferenceCommandOptionsContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterInferenceCommandOption(EsqlBaseParser.InferenceCommandOptionContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitInferenceCommandOption(EsqlBaseParser.InferenceCommandOptionContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void enterInferenceCommandOptionValue(EsqlBaseParser.InferenceCommandOptionValueContext ctx) { }
/**
* {@inheritDoc}
*
* <p>The default implementation does nothing.</p>
*/
@Override public void exitInferenceCommandOptionValue(EsqlBaseParser.InferenceCommandOptionValueContext ctx) { }
/**
* {@inheritDoc}
*

View file

@ -468,6 +468,27 @@ public class EsqlBaseParserBaseVisitor<T> extends AbstractParseTreeVisitor<T> im
* {@link #visitChildren} on {@code ctx}.</p>
*/
@Override public T visitRrfCommand(EsqlBaseParser.RrfCommandContext ctx) { return visitChildren(ctx); }
/**
* {@inheritDoc}
*
* <p>The default implementation returns the result of calling
* {@link #visitChildren} on {@code ctx}.</p>
*/
@Override public T visitInferenceCommandOptions(EsqlBaseParser.InferenceCommandOptionsContext ctx) { return visitChildren(ctx); }
/**
* {@inheritDoc}
*
* <p>The default implementation returns the result of calling
* {@link #visitChildren} on {@code ctx}.</p>
*/
@Override public T visitInferenceCommandOption(EsqlBaseParser.InferenceCommandOptionContext ctx) { return visitChildren(ctx); }
/**
* {@inheritDoc}
*
* <p>The default implementation returns the result of calling
* {@link #visitChildren} on {@code ctx}.</p>
*/
@Override public T visitInferenceCommandOptionValue(EsqlBaseParser.InferenceCommandOptionValueContext ctx) { return visitChildren(ctx); }
/**
* {@inheritDoc}
*

View file

@ -675,6 +675,36 @@ public interface EsqlBaseParserListener extends ParseTreeListener {
* @param ctx the parse tree
*/
void exitRrfCommand(EsqlBaseParser.RrfCommandContext ctx);
/**
* Enter a parse tree produced by {@link EsqlBaseParser#inferenceCommandOptions}.
* @param ctx the parse tree
*/
void enterInferenceCommandOptions(EsqlBaseParser.InferenceCommandOptionsContext ctx);
/**
* Exit a parse tree produced by {@link EsqlBaseParser#inferenceCommandOptions}.
* @param ctx the parse tree
*/
void exitInferenceCommandOptions(EsqlBaseParser.InferenceCommandOptionsContext ctx);
/**
* Enter a parse tree produced by {@link EsqlBaseParser#inferenceCommandOption}.
* @param ctx the parse tree
*/
void enterInferenceCommandOption(EsqlBaseParser.InferenceCommandOptionContext ctx);
/**
* Exit a parse tree produced by {@link EsqlBaseParser#inferenceCommandOption}.
* @param ctx the parse tree
*/
void exitInferenceCommandOption(EsqlBaseParser.InferenceCommandOptionContext ctx);
/**
* Enter a parse tree produced by {@link EsqlBaseParser#inferenceCommandOptionValue}.
* @param ctx the parse tree
*/
void enterInferenceCommandOptionValue(EsqlBaseParser.InferenceCommandOptionValueContext ctx);
/**
* Exit a parse tree produced by {@link EsqlBaseParser#inferenceCommandOptionValue}.
* @param ctx the parse tree
*/
void exitInferenceCommandOptionValue(EsqlBaseParser.InferenceCommandOptionValueContext ctx);
/**
* Enter a parse tree produced by {@link EsqlBaseParser#rerankCommand}.
* @param ctx the parse tree

View file

@ -412,6 +412,24 @@ public interface EsqlBaseParserVisitor<T> extends ParseTreeVisitor<T> {
* @return the visitor result
*/
T visitRrfCommand(EsqlBaseParser.RrfCommandContext ctx);
/**
* Visit a parse tree produced by {@link EsqlBaseParser#inferenceCommandOptions}.
* @param ctx the parse tree
* @return the visitor result
*/
T visitInferenceCommandOptions(EsqlBaseParser.InferenceCommandOptionsContext ctx);
/**
* Visit a parse tree produced by {@link EsqlBaseParser#inferenceCommandOption}.
* @param ctx the parse tree
* @return the visitor result
*/
T visitInferenceCommandOption(EsqlBaseParser.InferenceCommandOptionContext ctx);
/**
* Visit a parse tree produced by {@link EsqlBaseParser#inferenceCommandOptionValue}.
* @param ctx the parse tree
* @return the visitor result
*/
T visitInferenceCommandOptionValue(EsqlBaseParser.InferenceCommandOptionValueContext ctx);
/**
* Visit a parse tree produced by {@link EsqlBaseParser#rerankCommand}.
* @param ctx the parse tree

View file

@ -734,7 +734,9 @@ public class LogicalPlanBuilder extends ExpressionBuilder {
@Override
public PlanFactory visitRerankCommand(EsqlBaseParser.RerankCommandContext ctx) {
Source source = source(ctx);
List<Alias> rerankFields = visitRerankFields(ctx.rerankFields());
Expression queryText = expression(ctx.queryText);
if (queryText instanceof Literal queryTextLiteral && DataType.isString(queryText.dataType())) {
if (queryTextLiteral.value() == null) {
throw new ParsingException(
@ -751,19 +753,72 @@ public class LogicalPlanBuilder extends ExpressionBuilder {
);
}
Literal inferenceId = ctx.inferenceId != null ? inferenceId(ctx.inferenceId) : Literal.keyword(source, Rerank.DEFAULT_INFERENCE_ID);
return p -> {
checkForRemoteClusters(p, source, "RERANK");
return new Rerank(source, p, inferenceId, queryText, visitRerankFields(ctx.rerankFields()));
return visitRerankOptions(new Rerank(source, p, queryText, rerankFields), ctx.inferenceCommandOptions());
};
}
private Rerank visitRerankOptions(Rerank rerank, EsqlBaseParser.InferenceCommandOptionsContext ctx) {
if (ctx == null) {
return rerank;
}
Rerank.Builder rerankBuilder = new Rerank.Builder(rerank);
for (var option : ctx.inferenceCommandOption()) {
String optionName = visitIdentifier(option.identifier());
EsqlBaseParser.InferenceCommandOptionValueContext optionValue = option.inferenceCommandOptionValue();
if (optionName.equals(Rerank.INFERENCE_ID_OPTION_NAME)) {
rerankBuilder.withInferenceId(visitInferenceId(optionValue));
} else if (optionName.equals(Rerank.SCORE_COLUMN_OPTION_NAME)) {
rerankBuilder.withScoreAttribute(visitRerankScoreAttribute(optionName, optionValue));
} else {
throw new ParsingException(
source(option.identifier()),
"Unknowm parameter [{}] in RERANK command",
option.identifier().getText()
);
}
}
return rerankBuilder.build();
}
private UnresolvedAttribute visitRerankScoreAttribute(String optionName, EsqlBaseParser.InferenceCommandOptionValueContext ctx) {
if (ctx.constant() == null && ctx.identifier() == null) {
throw new ParsingException(source(ctx), "Parameter [{}] is null or undefined", optionName);
}
Expression optionValue = ctx.identifier() != null
? Literal.keyword(source(ctx.identifier()), visitIdentifier(ctx.identifier()))
: expression(ctx.constant());
if (optionValue instanceof UnresolvedAttribute scoreAttribute) {
return scoreAttribute;
} else if (optionValue instanceof Literal literal) {
if (literal.value() == null) {
throw new ParsingException(optionValue.source(), "Parameter [{}] is null or undefined", optionName);
}
if (literal.value() instanceof BytesRef attributeName) {
return new UnresolvedAttribute(literal.source(), BytesRefs.toString(attributeName));
}
}
throw new ParsingException(
source(ctx),
"Option [{}] expects a valid attribute in RERANK command. [{}] provided.",
optionName,
ctx.constant().getText()
);
}
@Override
public PlanFactory visitCompletionCommand(EsqlBaseParser.CompletionCommandContext ctx) {
Source source = source(ctx);
Expression prompt = expression(ctx.prompt);
Literal inferenceId = inferenceId(ctx.inferenceId);
Literal inferenceId = visitInferenceId(ctx.inferenceId);
Attribute targetField = ctx.targetField == null
? new UnresolvedAttribute(source, Completion.DEFAULT_OUTPUT_FIELD_NAME)
: visitQualifiedName(ctx.targetField);
@ -774,27 +829,43 @@ public class LogicalPlanBuilder extends ExpressionBuilder {
};
}
public Literal inferenceId(EsqlBaseParser.IdentifierOrParameterContext ctx) {
private Literal visitInferenceId(EsqlBaseParser.IdentifierOrParameterContext ctx) {
if (ctx.identifier() != null) {
return Literal.keyword(source(ctx), visitIdentifier(ctx.identifier()));
}
if (expression(ctx.parameter()) instanceof Literal literalParam) {
if (literalParam.value() != null) {
return literalParam;
return visitInferenceId(expression(ctx.parameter()));
}
private Literal visitInferenceId(EsqlBaseParser.InferenceCommandOptionValueContext ctx) {
if (ctx.identifier() != null) {
return Literal.keyword(source(ctx), visitIdentifier(ctx.identifier()));
}
return visitInferenceId(expression(ctx.constant()));
}
private Literal visitInferenceId(Expression expression) {
if (expression instanceof Literal literal) {
if (literal.value() == null) {
throw new ParsingException(
source(ctx.parameter()),
"Query parameter [{}] is null or undefined and cannot be used as inference id",
ctx.parameter().getText()
expression.source(),
"Parameter [{}] is null or undefined and cannot be used as inference id",
expression.source().text()
);
}
return literal;
} else if (expression instanceof UnresolvedAttribute attribute) {
// Support for unquoted inference id
return new Literal(expression.source(), attribute.name(), KEYWORD);
}
throw new ParsingException(
source(ctx.parameter()),
"Query parameter [{}] is not a string and cannot be used as inference id",
ctx.parameter().getText()
expression.source(),
"Query parameter [{}] is not a string and cannot be used as inference id [{}]",
expression.source().text(),
expression.getClass()
);
}

View file

@ -22,9 +22,7 @@ import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.plan.GeneratingPlan;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.SortAgnostic;
import java.io.IOException;
import java.util.List;
@ -34,12 +32,7 @@ import static org.elasticsearch.xpack.esql.common.Failure.fail;
import static org.elasticsearch.xpack.esql.core.type.DataType.TEXT;
import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;
public class Completion extends InferencePlan<Completion>
implements
GeneratingPlan<Completion>,
SortAgnostic,
TelemetryAware,
PostAnalysisVerificationAware {
public class Completion extends InferencePlan<Completion> implements TelemetryAware, PostAnalysisVerificationAware {
public static final String DEFAULT_OUTPUT_FIELD_NAME = "completion";

View file

@ -12,13 +12,18 @@ import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.plan.GeneratingPlan;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.SortAgnostic;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
import java.io.IOException;
import java.util.Objects;
public abstract class InferencePlan<PlanType extends InferencePlan<PlanType>> extends UnaryPlan {
public abstract class InferencePlan<PlanType extends InferencePlan<PlanType>> extends UnaryPlan
implements
SortAgnostic,
GeneratingPlan<InferencePlan<PlanType>> {
private final Expression inferenceId;

View file

@ -18,17 +18,14 @@ import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
import org.elasticsearch.xpack.esql.core.expression.NameId;
import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.expression.Order;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.plan.QueryPlan;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
import org.elasticsearch.xpack.esql.plan.logical.SortAgnostic;
import org.elasticsearch.xpack.esql.plan.logical.SurrogateLogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
import java.io.IOException;
@ -37,21 +34,29 @@ import java.util.Objects;
import static org.elasticsearch.xpack.esql.core.expression.Expressions.asAttributes;
import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;
import static org.elasticsearch.xpack.esql.parser.ParserUtils.source;
public class Rerank extends InferencePlan<Rerank> implements SortAgnostic, SurrogateLogicalPlan, TelemetryAware {
public class Rerank extends InferencePlan<Rerank> implements TelemetryAware {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(LogicalPlan.class, "Rerank", Rerank::new);
public static final String DEFAULT_INFERENCE_ID = ".rerank-v1-elasticsearch";
public static final String INFERENCE_ID_OPTION_NAME = "inferenceId";
public static final String SCORE_COLUMN_OPTION_NAME = "scoreColumn";
private final Attribute scoreAttribute;
private final Expression queryText;
private final List<Alias> rerankFields;
private List<Attribute> lazyOutput;
public Rerank(Source source, LogicalPlan child, Expression inferenceId, Expression queryText, List<Alias> rerankFields) {
super(source, child, inferenceId);
this.queryText = queryText;
this.rerankFields = rerankFields;
this.scoreAttribute = new UnresolvedAttribute(source, MetadataAttribute.SCORE);
public Rerank(Source source, LogicalPlan child, Expression queryText, List<Alias> rerankFields) {
this(
source,
child,
Literal.keyword(Source.EMPTY, DEFAULT_INFERENCE_ID),
queryText,
rerankFields,
new UnresolvedAttribute(Source.EMPTY, MetadataAttribute.SCORE)
);
}
public Rerank(
@ -129,13 +134,25 @@ public class Rerank extends InferencePlan<Rerank> implements SortAgnostic, Surro
@Override
protected AttributeSet computeReferences() {
AttributeSet.Builder refs = computeReferences(rerankFields).asBuilder();
if (planHasAttribute(child(), scoreAttribute)) {
refs.add(scoreAttribute);
return computeReferences(rerankFields);
}
return refs.build();
public List<Attribute> generatedAttributes() {
return List.of(scoreAttribute);
}
@Override
public Rerank withGeneratedNames(List<String> newNames) {
checkNumberOfNewNames(newNames);
return new Rerank(source(), child(), inferenceId(), queryText, rerankFields, this.renameScoreAttribute(newNames.get(0)));
}
private Attribute renameScoreAttribute(String newName) {
if (newName.equals(scoreAttribute.name())) {
return scoreAttribute;
}
return scoreAttribute.withName(newName).withId(new NameId());
}
public static AttributeSet computeReferences(List<Alias> fields) {
@ -169,24 +186,33 @@ public class Rerank extends InferencePlan<Rerank> implements SortAgnostic, Surro
return Objects.hash(super.hashCode(), queryText, rerankFields, scoreAttribute);
}
@Override
public LogicalPlan surrogate() {
Order sortOrder = new Order(source(), scoreAttribute, Order.OrderDirection.DESC, Order.NullsPosition.ANY);
return new OrderBy(source(), this, List.of(sortOrder));
}
@Override
public List<Attribute> output() {
if (lazyOutput == null) {
lazyOutput = planHasAttribute(child(), scoreAttribute)
? child().output()
: mergeOutputAttributes(List.of(scoreAttribute), child().output());
lazyOutput = mergeOutputAttributes(List.of(scoreAttribute), child().output());
}
return lazyOutput;
}
public static boolean planHasAttribute(QueryPlan<?> plan, Attribute attribute) {
return plan.outputSet().stream().anyMatch(attr -> attr.equals(attribute));
public static class Builder {
private Rerank rerank;
public Builder(Rerank rerank) {
this.rerank = rerank;
}
public Rerank build() {
return rerank;
}
public Builder withInferenceId(Expression inferenceId) {
this.rerank = this.rerank.withInferenceId(inferenceId);
return this;
}
public Builder withScoreAttribute(Attribute scoreAttribute) {
this.rerank = this.rerank.withScoreAttribute(scoreAttribute);
return this;
}
}
}

View file

@ -26,7 +26,6 @@ import java.util.List;
import java.util.Objects;
import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;
import static org.elasticsearch.xpack.esql.plan.logical.inference.Rerank.planHasAttribute;
public class RerankExec extends InferenceExec {
@ -39,6 +38,7 @@ public class RerankExec extends InferenceExec {
private final Expression queryText;
private final List<Alias> rerankFields;
private final Attribute scoreAttribute;
private List<Attribute> lazyOutput;
public RerankExec(
Source source,
@ -102,22 +102,15 @@ public class RerankExec extends InferenceExec {
@Override
public List<Attribute> output() {
if (planHasAttribute(child(), scoreAttribute)) {
return child().output();
if (lazyOutput == null) {
lazyOutput = mergeOutputAttributes(List.of(scoreAttribute), child().output());
}
return mergeOutputAttributes(List.of(scoreAttribute), child().output());
return lazyOutput;
}
@Override
protected AttributeSet computeReferences() {
AttributeSet.Builder refs = Rerank.computeReferences(rerankFields).asBuilder();
if (planHasAttribute(child(), scoreAttribute)) {
refs.add(scoreAttribute);
}
return refs.build();
return Rerank.computeReferences(rerankFields);
}
@Override

View file

@ -618,6 +618,15 @@ public class LocalExecutionPlanner {
private PhysicalOperation planRerank(RerankExec rerank, LocalExecutionPlannerContext context) {
PhysicalOperation source = plan(rerank.child(), context);
EvalOperator.ExpressionEvaluator.Factory rowEncoderFactory;
if (rerank.rerankFields().size() > 1) {
// If there is more than one field used for reranking we are encoded the input in a YAML doc, using field names as key.
// The input value will looks like
// text_field: foo bar
// multivalue_text_field:
// - value 1
// - value 2
// integer_field: 132
Map<ColumnInfoImpl, EvalOperator.ExpressionEvaluator.Factory> rerankFieldsEvaluatorSuppliers = Maps
.newLinkedHashMapWithExpectedSize(rerank.rerankFields().size());
@ -627,8 +636,10 @@ public class LocalExecutionPlanner {
EvalMapper.toEvaluator(context.foldCtx(), rerankField.child(), source.layout)
);
}
XContentRowEncoder.Factory rowEncoderFactory = XContentRowEncoder.yamlRowEncoderFactory(rerankFieldsEvaluatorSuppliers);
rowEncoderFactory = XContentRowEncoder.yamlRowEncoderFactory(rerankFieldsEvaluatorSuppliers);
} else {
rowEncoderFactory = EvalMapper.toEvaluator(context.foldCtx(), rerank.rerankFields().get(0).child(), source.layout);
}
String inferenceId = BytesRefs.toString(rerank.inferenceId().fold(context.foldCtx));
String queryText = BytesRefs.toString(rerank.queryText().fold(context.foldCtx));

View file

@ -124,6 +124,7 @@ import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
import static org.elasticsearch.xpack.esql.core.type.DataType.DATETIME;
import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_NANOS;
import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_PERIOD;
import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE;
import static org.elasticsearch.xpack.esql.core.type.DataType.LONG;
import static org.elasticsearch.xpack.esql.core.type.DataType.UNSUPPORTED;
import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.dateTimeToString;
@ -3546,7 +3547,7 @@ public class AnalyzerTests extends ESTestCase {
{
LogicalPlan plan = analyze(
"FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH `reranking-inference-id`",
"FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH inferenceId=`reranking-inference-id`",
"mapping-books.json"
);
Rerank rerank = as(as(plan, Limit.class).child(), Rerank.class);
@ -3557,7 +3558,7 @@ public class AnalyzerTests extends ESTestCase {
VerificationException ve = expectThrows(
VerificationException.class,
() -> analyze(
"FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH `completion-inference-id`",
"FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH inferenceId=`completion-inference-id`",
"mapping-books.json"
)
@ -3575,7 +3576,7 @@ public class AnalyzerTests extends ESTestCase {
VerificationException ve = expectThrows(
VerificationException.class,
() -> analyze(
"FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH `error-inference-id`",
"FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH inferenceId=`error-inference-id`",
"mapping-books.json"
)
@ -3587,7 +3588,7 @@ public class AnalyzerTests extends ESTestCase {
VerificationException ve = expectThrows(
VerificationException.class,
() -> analyze(
"FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH `unknown-inference-id`",
"FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH inferenceId=`unknown-inference-id`",
"mapping-books.json"
)
@ -3606,7 +3607,7 @@ public class AnalyzerTests extends ESTestCase {
| WHERE title:"italian food recipe" OR description:"italian food recipe"
| KEEP description, title, year, _score
| DROP description
| RERANK "italian food recipe" ON title WITH `reranking-inference-id`
| RERANK "italian food recipe" ON title WITH inferenceId=`reranking-inference-id`
""", "mapping-books.json");
Limit limit = as(plan, Limit.class); // Implicit limit added by AddImplicitLimit rule.
@ -3630,7 +3631,8 @@ public class AnalyzerTests extends ESTestCase {
LogicalPlan plan = analyze("""
FROM books METADATA _score
| WHERE title:"food"
| RERANK "food" ON title, description=SUBSTRING(description, 0, 100), yearRenamed=year WITH `reranking-inference-id`
| RERANK "food" ON title, description=SUBSTRING(description, 0, 100), yearRenamed=year
WITH inferenceId=`reranking-inference-id`
""", "mapping-books.json");
Limit limit = as(plan, Limit.class); // Implicit limit added by AddImplicitLimit rule.
@ -3668,7 +3670,7 @@ public class AnalyzerTests extends ESTestCase {
LogicalPlan plan = analyze("""
FROM books METADATA _score
| WHERE title:"food"
| RERANK "food" ON title, SUBSTRING(description, 0, 100), yearRenamed=year WITH `reranking-inference-id`
| RERANK "food" ON title, SUBSTRING(description, 0, 100), yearRenamed=year WITH inferenceId=`reranking-inference-id`
""", "mapping-books.json");
} catch (ParsingException ex) {
assertThat(
@ -3682,7 +3684,7 @@ public class AnalyzerTests extends ESTestCase {
VerificationException ve = expectThrows(
VerificationException.class,
() -> analyze(
"FROM books METADATA _score | RERANK \"italian food recipe\" ON missingField WITH `reranking-inference-id`",
"FROM books METADATA _score | RERANK \"italian food recipe\" ON missingField WITH inferenceId=`reranking-inference-id`",
"mapping-books.json"
)
@ -3699,7 +3701,7 @@ public class AnalyzerTests extends ESTestCase {
LogicalPlan plan = analyze("""
FROM books METADATA _score
| WHERE title:"italian food recipe" OR description:"italian food recipe"
| RERANK "italian food recipe" ON title WITH `reranking-inference-id`
| RERANK "italian food recipe" ON title WITH inferenceId=`reranking-inference-id`
""", "mapping-books.json");
Limit limit = as(plan, Limit.class); // Implicit limit added by AddImplicitLimit rule.
@ -3717,7 +3719,7 @@ public class AnalyzerTests extends ESTestCase {
LogicalPlan plan = analyze("""
FROM books
| WHERE title:"italian food recipe" OR description:"italian food recipe"
| RERANK "italian food recipe" ON title WITH `reranking-inference-id`
| RERANK "italian food recipe" ON title WITH inferenceId=`reranking-inference-id`
""", "mapping-books.json");
Limit limit = as(plan, Limit.class); // Implicit limit added by AddImplicitLimit rule.
@ -3729,6 +3731,42 @@ public class AnalyzerTests extends ESTestCase {
assertThat(rerank.scoreAttribute(), equalTo(MetadataAttribute.create(EMPTY, MetadataAttribute.SCORE)));
assertThat(rerank.output(), hasItem(rerank.scoreAttribute()));
}
{
// When using a custom fields that does not exist
LogicalPlan plan = analyze("""
FROM books METADATA _score
| WHERE title:"italian food recipe" OR description:"italian food recipe"
| RERANK "italian food recipe" ON title WITH inferenceId=`reranking-inference-id`, scoreColumn=rerank_score
""", "mapping-books.json");
Limit limit = as(plan, Limit.class); // Implicit limit added by AddImplicitLimit rule.
Rerank rerank = as(limit.child(), Rerank.class);
Attribute scoreAttribute = rerank.scoreAttribute();
assertThat(scoreAttribute.name(), equalTo("rerank_score"));
assertThat(scoreAttribute.dataType(), equalTo(DOUBLE));
assertThat(rerank.output(), hasItem(scoreAttribute));
}
{
// When using a custom fields that already exists
LogicalPlan plan = analyze("""
FROM books METADATA _score
| WHERE title:"italian food recipe" OR description:"italian food recipe"
| EVAL rerank_score = _score
| RERANK "italian food recipe" ON title WITH inferenceId=`reranking-inference-id`, scoreColumn=rerank_score
""", "mapping-books.json");
Limit limit = as(plan, Limit.class); // Implicit limit added by AddImplicitLimit rule.
Rerank rerank = as(limit.child(), Rerank.class);
Attribute scoreAttribute = rerank.scoreAttribute();
assertThat(scoreAttribute.name(), equalTo("rerank_score"));
assertThat(scoreAttribute.dataType(), equalTo(DOUBLE));
assertThat(rerank.output(), hasItem(scoreAttribute));
assertThat(rerank.child().output().stream().anyMatch(scoreAttribute::equals), is(true));
}
}
public void testResolveCompletionInferenceId() {

View file

@ -100,9 +100,9 @@ import org.elasticsearch.xpack.esql.optimizer.rules.logical.LiteralsOnTheRight;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PruneRedundantOrderBy;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownAndCombineLimits;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownCompletion;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownEnrich;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownEval;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownInferencePlan;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.PushDownRegexExtract;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.SplitInWithFoldableValue;
import org.elasticsearch.xpack.esql.parser.EsqlParser;
@ -127,6 +127,7 @@ import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate;
import org.elasticsearch.xpack.esql.plan.logical.TopN;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin;
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
import org.elasticsearch.xpack.esql.plan.logical.join.JoinConfig;
@ -5655,8 +5656,20 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
new Concat(EMPTY, randomLiteral(TEXT), List.of(attr)),
new ReferenceAttribute(EMPTY, "y", KEYWORD)
),
new PushDownCompletion()
) };
new PushDownInferencePlan()
),
// | RERANK "some text" ON x WITH inferenceID=inferenceID, scoreColumn=y
new PushdownShadowingGeneratingPlanTestCase(
(plan, attr) -> new Rerank(
EMPTY,
plan,
randomLiteral(TEXT),
randomLiteral(TEXT),
List.of(new Alias(EMPTY, attr.name(), attr)),
new ReferenceAttribute(EMPTY, "y", KEYWORD)
),
new PushDownInferencePlan()
), };
/**
* Consider

View file

@ -14,6 +14,7 @@ import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
import org.elasticsearch.xpack.esql.expression.function.fulltext.Match;
@ -32,6 +33,7 @@ import org.elasticsearch.xpack.esql.plan.logical.Filter;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.Project;
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject;
import java.util.ArrayList;
@ -245,8 +247,8 @@ public class PushDownAndCombineFiltersTests extends ESTestCase {
assertEquals(expected, new PushDownAndCombineFilters().apply(fb));
}
// from ... | where a > 1 | COMPLETION completion="some prompt" WITH reranker | where b < 2 and match(completion, some text)
// => ... | where a > 1 AND b < 2| COMPLETION completion="some prompt" WITH reranker | match(completion, some text)
// from ... | where a > 1 | COMPLETION completion="some prompt" WITH inferenceId | where b < 2 and match(completion, some text)
// => ... | where a > 1 AND b < 2| COMPLETION completion="some prompt" WITH inferenceId | where match(completion, some text)
public void testPushDownFilterPastCompletion() {
FieldAttribute a = getFieldAttribute("a");
FieldAttribute b = getFieldAttribute("b");
@ -282,13 +284,57 @@ public class PushDownAndCombineFiltersTests extends ESTestCase {
assertEquals(expectedOptimizedPlan, new PushDownAndCombineFilters().apply(filterB));
}
// from ... | where a > 1 | RERANK "query" ON title WITH inferenceId | where b < 2 and _score > 1
// => ... | where a > 1 AND b < 2| RERANK "query" ON title WITH inferenceId | where _score > 1
public void testPushDownFilterPastRerank() {
FieldAttribute a = getFieldAttribute("a");
FieldAttribute b = getFieldAttribute("b");
EsRelation relation = relation(List.of(a, b));
GreaterThan conditionA = greaterThanOf(getFieldAttribute("a"), ONE);
Filter filterA = new Filter(EMPTY, relation, conditionA);
Rerank rerank = rerank(filterA);
LessThan conditionB = lessThanOf(getFieldAttribute("b"), TWO);
GreaterThan scoreCondition = greaterThanOf(rerank.scoreAttribute(), ONE);
Filter filterB = new Filter(EMPTY, rerank, new And(EMPTY, conditionB, scoreCondition));
LogicalPlan expectedOptimizedPlan = new Filter(
EMPTY,
new Rerank(
EMPTY,
new Filter(EMPTY, relation, new And(EMPTY, conditionA, conditionB)),
rerank.inferenceId(),
rerank.queryText(),
rerank.rerankFields(),
rerank.scoreAttribute()
),
scoreCondition
);
assertEquals(expectedOptimizedPlan, new PushDownAndCombineFilters().apply(filterB));
}
private static Completion completion(LogicalPlan child) {
return new Completion(
EMPTY,
child,
randomLiteral(DataType.TEXT),
randomLiteral(DataType.TEXT),
referenceAttribute(randomIdentifier(), DataType.TEXT)
randomLiteral(DataType.KEYWORD),
randomLiteral(randomBoolean() ? DataType.TEXT : DataType.KEYWORD),
referenceAttribute(randomIdentifier(), DataType.KEYWORD)
);
}
private static Rerank rerank(LogicalPlan child) {
return new Rerank(
EMPTY,
child,
randomLiteral(DataType.KEYWORD),
randomLiteral(randomBoolean() ? DataType.TEXT : DataType.KEYWORD),
randomList(1, 10, () -> new Alias(EMPTY, randomIdentifier(), randomLiteral(DataType.KEYWORD))),
referenceAttribute(randomBoolean() ? MetadataAttribute.SCORE : randomIdentifier(), DataType.DOUBLE)
);
}

View file

@ -24,6 +24,7 @@ import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.OrderBy;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
import java.util.List;
import java.util.function.BiConsumer;
@ -35,6 +36,7 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.randomLiteral;
import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext;
import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER;
import static org.elasticsearch.xpack.esql.core.type.DataType.KEYWORD;
import static org.elasticsearch.xpack.esql.core.type.DataType.TEXT;
import static org.elasticsearch.xpack.esql.optimizer.LocalLogicalPlanOptimizerTests.relation;
@ -75,13 +77,31 @@ public class PushDownAndCombineLimitsTests extends ESTestCase {
),
new PushDownLimitTestCase<>(
Completion.class,
(plan, attr) -> new Completion(EMPTY, plan, randomLiteral(TEXT), randomLiteral(TEXT), attr),
(plan, attr) -> new Completion(EMPTY, plan, randomLiteral(KEYWORD), randomLiteral(KEYWORD), attr),
(basePlan, optimizedPlan) -> {
assertEquals(basePlan.source(), optimizedPlan.source());
assertEquals(basePlan.inferenceId(), optimizedPlan.inferenceId());
assertEquals(basePlan.prompt(), optimizedPlan.prompt());
assertEquals(basePlan.targetField(), optimizedPlan.targetField());
}
),
new PushDownLimitTestCase<>(
Rerank.class,
(plan, attr) -> new Rerank(
EMPTY,
plan,
randomLiteral(KEYWORD),
randomLiteral(KEYWORD),
randomList(1, 10, () -> new Alias(EMPTY, randomIdentifier(), randomLiteral(KEYWORD))),
attr
),
(basePlan, optimizedPlan) -> {
assertEquals(basePlan.source(), optimizedPlan.source());
assertEquals(basePlan.inferenceId(), optimizedPlan.inferenceId());
assertEquals(basePlan.queryText(), optimizedPlan.queryText());
assertEquals(basePlan.rerankFields(), optimizedPlan.rerankFields());
assertEquals(basePlan.scoreAttribute(), optimizedPlan.scoreAttribute());
}
)
);

View file

@ -3481,13 +3481,74 @@ public class StatementParserTests extends AbstractStatementParserTests {
expectError("explain [row x = 1", "line 1:19: missing ']' at '<EOF>'");
}
public void testRerankDefaultInferenceId() {
public void testRerankDefaultInferenceIdAndScoreAttribute() {
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
var plan = processingCommand("RERANK \"query text\" ON title");
var rerank = as(plan, Rerank.class);
assertThat(rerank.inferenceId(), equalTo(literalString(".rerank-v1-elasticsearch")));
assertThat(rerank.scoreAttribute(), equalTo(attribute("_score")));
assertThat(rerank.queryText(), equalTo(literalString("query text")));
assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", attribute("title")))));
}
public void testRerankInferenceId() {
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
var plan = processingCommand("RERANK \"query text\" ON title WITH inferenceId=inferenceId");
var rerank = as(plan, Rerank.class);
assertThat(rerank.inferenceId(), equalTo(literalString("inferenceId")));
assertThat(rerank.queryText(), equalTo(literalString("query text")));
assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", attribute("title")))));
assertThat(rerank.scoreAttribute(), equalTo(attribute("_score")));
}
public void testRerankQuotedInferenceId() {
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
var plan = processingCommand("RERANK \"query text\" ON title WITH inferenceId=\"inferenceId\"");
var rerank = as(plan, Rerank.class);
assertThat(rerank.inferenceId(), equalTo(literalString("inferenceId")));
assertThat(rerank.queryText(), equalTo(literalString("query text")));
assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", attribute("title")))));
assertThat(rerank.scoreAttribute(), equalTo(attribute("_score")));
}
public void testRerankScoreAttribute() {
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
var plan = processingCommand("RERANK \"query text\" ON title WITH scoreColumn=rerank_score");
var rerank = as(plan, Rerank.class);
assertThat(rerank.inferenceId(), equalTo(literalString(".rerank-v1-elasticsearch")));
assertThat(rerank.scoreAttribute(), equalTo(attribute("rerank_score")));
assertThat(rerank.queryText(), equalTo(literalString("query text")));
assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", attribute("title")))));
}
public void testRerankQuotedScoreAttribute() {
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
var plan = processingCommand("RERANK \"query text\" ON title WITH scoreColumn=\"rerank_score\"");
var rerank = as(plan, Rerank.class);
assertThat(rerank.inferenceId(), equalTo(literalString(".rerank-v1-elasticsearch")));
assertThat(rerank.scoreAttribute(), equalTo(attribute("rerank_score")));
assertThat(rerank.queryText(), equalTo(literalString("query text")));
assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", attribute("title")))));
}
public void testRerankInferenceIdAnddScoreAttribute() {
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
var plan = processingCommand("RERANK \"query text\" ON title WITH inferenceId=inferenceId, scoreColumn=rerank_score");
var rerank = as(plan, Rerank.class);
assertThat(rerank.inferenceId(), equalTo(literalString("inferenceId")));
assertThat(rerank.scoreAttribute(), equalTo(attribute("rerank_score")));
assertThat(rerank.queryText(), equalTo(literalString("query text")));
assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", attribute("title")))));
}
@ -3495,18 +3556,19 @@ public class StatementParserTests extends AbstractStatementParserTests {
public void testRerankSingleField() {
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
var plan = processingCommand("RERANK \"query text\" ON title WITH inferenceID");
var plan = processingCommand("RERANK \"query text\" ON title WITH inferenceId=inferenceID");
var rerank = as(plan, Rerank.class);
assertThat(rerank.queryText(), equalTo(literalString("query text")));
assertThat(rerank.inferenceId(), equalTo(literalString("inferenceID")));
assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", attribute("title")))));
assertThat(rerank.scoreAttribute(), equalTo(attribute("_score")));
}
public void testRerankMultipleFields() {
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
var plan = processingCommand("RERANK \"query text\" ON title, description, authors_renamed=authors WITH inferenceID");
var plan = processingCommand("RERANK \"query text\" ON title, description, authors_renamed=authors WITH inferenceId=inferenceID");
var rerank = as(plan, Rerank.class);
assertThat(rerank.queryText(), equalTo(literalString("query text")));
@ -3521,12 +3583,15 @@ public class StatementParserTests extends AbstractStatementParserTests {
)
)
);
assertThat(rerank.scoreAttribute(), equalTo(attribute("_score")));
}
public void testRerankComputedFields() {
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
var plan = processingCommand("RERANK \"query text\" ON title, short_description = SUBSTRING(description, 0, 100) WITH inferenceID");
var plan = processingCommand(
"RERANK \"query text\" ON title, short_description = SUBSTRING(description, 0, 100) WITH inferenceId=inferenceID"
);
var rerank = as(plan, Rerank.class);
assertThat(rerank.queryText(), equalTo(literalString("query text")));
@ -3540,24 +3605,43 @@ public class StatementParserTests extends AbstractStatementParserTests {
)
)
);
assertThat(rerank.scoreAttribute(), equalTo(attribute("_score")));
}
public void testRerankWithPositionalParameters() {
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
var queryParams = new QueryParams(List.of(paramAsConstant(null, "query text"), paramAsConstant(null, "reranker")));
var rerank = as(parser.createStatement("row a = 1 | RERANK ? ON title WITH ?", queryParams), Rerank.class);
var queryParams = new QueryParams(
List.of(paramAsConstant(null, "query text"), paramAsConstant(null, "reranker"), paramAsConstant(null, "rerank_score"))
);
var rerank = as(
parser.createStatement("row a = 1 | RERANK ? ON title WITH inferenceId=?, scoreColumn=? ", queryParams),
Rerank.class
);
assertThat(rerank.queryText(), equalTo(literalString("query text")));
assertThat(rerank.inferenceId(), equalTo(literalString("reranker")));
assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", attribute("title")))));
assertThat(rerank.scoreAttribute(), equalTo(attribute("rerank_score")));
}
public void testRerankWithNamedParameters() {
assumeTrue("RERANK requires corresponding capability", EsqlCapabilities.Cap.RERANK.isEnabled());
var queryParams = new QueryParams(List.of(paramAsConstant("queryText", "query text"), paramAsConstant("inferenceId", "reranker")));
var rerank = as(parser.createStatement("row a = 1 | RERANK ?queryText ON title WITH ?inferenceId", queryParams), Rerank.class);
var queryParams = new QueryParams(
List.of(
paramAsConstant("queryText", "query text"),
paramAsConstant("inferenceId", "reranker"),
paramAsConstant("scoreColumnName", "rerank_score")
)
);
var rerank = as(
parser.createStatement(
"row a = 1 | RERANK ?queryText ON title WITH inferenceId=?inferenceId, scoreColumn=?scoreColumnName",
queryParams
),
Rerank.class
);
assertThat(rerank.queryText(), equalTo(literalString("query text")));
assertThat(rerank.inferenceId(), equalTo(literalString("reranker")));
@ -3571,7 +3655,7 @@ public class StatementParserTests extends AbstractStatementParserTests {
var fromPatterns = randomIndexPatterns(CROSS_CLUSTER);
expectError(
"FROM " + fromPatterns + " | RERANK \"query text\" ON title WITH inferenceId",
"FROM " + fromPatterns + " | RERANK \"query text\" ON title WITH inferenceId=inferenceId",
"invalid index pattern [" + unquoteIndexPattern(fromPatterns) + "], remote clusters are not supported with RERANK"
);
}