From a6f685cc2ac1e40039e84a7747985ec51585bbf5 Mon Sep 17 00:00:00 2001 From: Ying Mao Date: Tue, 25 Mar 2025 12:32:18 -0400 Subject: [PATCH] Adding common rerank options to Perform Inference API (#125239) * wip * Adding rerank common options * Linting * Linting * [CI] Auto commit changes from spotless * Update docs/changelog/125239.yaml * PR feedback --------- Co-authored-by: elasticsearchmachine --- docs/changelog/125239.yaml | 6 + .../org/elasticsearch/TransportVersions.java | 2 + .../inference/InferenceService.java | 20 +- .../inference/action/InferenceAction.java | 82 +++++- .../action/InferenceActionRequestTests.java | 239 +++++++++++++++++- .../TestDenseInferenceServiceExtension.java | 2 + .../mock/TestRerankingServiceExtension.java | 2 + .../TestSparseInferenceServiceExtension.java | 2 + ...stStreamingCompletionServiceExtension.java | 3 + .../action/TransportInferenceAction.java | 2 + .../voyageai/VoyageAIActionCreator.java | 8 +- ...libabaCloudSearchRerankRequestManager.java | 2 + .../sender/CohereRerankRequestManager.java | 8 +- .../GoogleVertexAiRerankRequestManager.java | 8 +- .../sender/JinaAIRerankRequestManager.java | 8 +- .../http/sender/QueryAndDocsInputs.java | 24 +- .../AlibabaCloudSearchRerankRequest.java | 10 +- ...AlibabaCloudSearchRerankRequestEntity.java | 11 +- .../request/cohere/CohereRerankRequest.java | 16 +- .../cohere/CohereRerankRequestEntity.java | 33 ++- .../GoogleVertexAiRerankRequest.java | 23 +- .../GoogleVertexAiRerankRequestEntity.java | 17 +- .../request/jinaai/JinaAIRerankRequest.java | 16 +- .../jinaai/JinaAIRerankRequestEntity.java | 36 ++- .../voyageai/VoyageAIRerankRequest.java | 25 +- .../voyageai/VoyageAIRerankRequestEntity.java | 34 ++- .../GoogleVertexAiRerankResponseEntity.java | 4 - .../queries/SemanticQueryBuilder.java | 2 + ...ankFeaturePhaseRankCoordinatorContext.java | 2 + .../inference/services/SenderService.java | 17 +- .../inference/services/ServiceUtils.java | 2 + .../AlibabaCloudSearchService.java | 19 ++ .../ElasticsearchInternalService.java | 19 +- .../SimpleServiceIntegrationValidator.java | 2 + .../BaseTransportInferenceActionTestCase.java | 4 +- .../http/sender/InferenceInputsTests.java | 2 +- ...baCloudSearchRerankRequestEntityTests.java | 8 +- .../CohereRerankRequestEntityTests.java | 95 +++++++ ...oogleVertexAiRerankRequestEntityTests.java | 23 +- .../GoogleVertexAiRerankRequestTests.java | 77 +++++- .../JinaAIRerankRequestEntityTests.java | 165 ++++++------ .../jinaai/JinaAIRerankRequestTests.java | 29 ++- .../VoyageAIRerankRequestEntityTests.java | 118 +++++---- .../voyageai/VoyageAIRerankRequestTests.java | 29 ++- ...ogleVertexAiRerankResponseEntityTests.java | 78 +++--- .../TextSimilarityRankTests.java | 2 + .../TextSimilarityTestPlugin.java | 2 + .../inference/services/ServiceUtilsTests.java | 16 +- .../AlibabaCloudSearchServiceTests.java | 51 ++++ .../AmazonBedrockServiceTests.java | 10 + .../anthropic/AnthropicServiceTests.java | 6 + .../AzureAiStudioServiceTests.java | 10 + .../azureopenai/AzureOpenAiServiceTests.java | 8 + .../services/cohere/CohereServiceTests.java | 14 + .../deepseek/DeepSeekServiceTests.java | 4 +- .../elastic/ElasticInferenceServiceTests.java | 10 + .../GoogleAiStudioServiceTests.java | 10 + .../HuggingFaceBaseServiceTests.java | 2 + .../huggingface/HuggingFaceServiceTests.java | 6 + .../ibmwatsonx/IbmWatsonxServiceTests.java | 8 + .../services/jinaai/JinaAIServiceTests.java | 35 ++- .../services/mistral/MistralServiceTests.java | 6 + .../services/openai/OpenAiServiceTests.java | 14 + ...impleServiceIntegrationValidatorTests.java | 8 +- .../voyageai/VoyageAIServiceTests.java | 35 ++- .../TransportCoordinatedInferenceAction.java | 2 + 66 files changed, 1306 insertions(+), 287 deletions(-) create mode 100644 docs/changelog/125239.yaml create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequestEntityTests.java diff --git a/docs/changelog/125239.yaml b/docs/changelog/125239.yaml new file mode 100644 index 000000000000..60ec9bb0b717 --- /dev/null +++ b/docs/changelog/125239.yaml @@ -0,0 +1,6 @@ +pr: 125239 +summary: Adding common rerank options to Perform Inference API +area: Machine Learning +type: enhancement +issues: + - 111273 diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 2c1a32ab30e5..07c100af3543 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -155,6 +155,7 @@ public class TransportVersions { public static final TransportVersion ESQL_AGGREGATE_METRIC_DOUBLE_LITERAL_8_19 = def(8_841_0_12); public static final TransportVersion INFERENCE_MODEL_REGISTRY_METADATA_8_19 = def(8_841_0_13); public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14); + public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00); public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01); public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02); @@ -201,6 +202,7 @@ public class TransportVersions { public static final TransportVersion INDEXING_STATS_INCLUDES_RECENT_WRITE_LOAD = def(9_034_0_00); public static final TransportVersion ESQL_AGGREGATE_METRIC_DOUBLE_LITERAL = def(9_035_0_00); public static final TransportVersion INDEX_METADATA_INCLUDES_RECENT_WRITE_LOAD = def(9_036_0_00); + public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED = def(9_037_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index de1925cb641e..309db20083ec 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -91,18 +91,22 @@ public interface InferenceService extends Closeable { /** * Perform inference on the model. * - * @param model The model - * @param query Inference query, mainly for re-ranking - * @param input Inference input - * @param stream Stream inference results - * @param taskSettings Settings in the request to override the model's defaults - * @param inputType For search, ingest etc - * @param timeout The timeout for the request - * @param listener Inference result listener + * @param model The model + * @param query Inference query, mainly for re-ranking + * @param returnDocuments For re-ranking task type, whether to return documents + * @param topN For re-ranking task type, how many docs to return + * @param input Inference input + * @param stream Stream inference results + * @param taskSettings Settings in the request to override the model's defaults + * @param inputType For search, ingest etc + * @param timeout The timeout for the request + * @param listener Inference result listener */ void infer( Model model, @Nullable String query, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, List input, boolean stream, Map taskSettings, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java index b6d9689086dc..e9ccb1baeb8b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java @@ -60,6 +60,8 @@ public class InferenceAction extends ActionType { public static final ParseField INPUT_TYPE = new ParseField("input_type"); public static final ParseField TASK_SETTINGS = new ParseField("task_settings"); public static final ParseField QUERY = new ParseField("query"); + public static final ParseField RETURN_DOCUMENTS = new ParseField("return_documents"); + public static final ParseField TOP_N = new ParseField("top_n"); public static final ParseField TIMEOUT = new ParseField("timeout"); static final ObjectParser PARSER = new ObjectParser<>(NAME, Request.Builder::new); @@ -68,6 +70,8 @@ public class InferenceAction extends ActionType { PARSER.declareString(Request.Builder::setInputType, INPUT_TYPE); PARSER.declareObject(Request.Builder::setTaskSettings, (p, c) -> p.mapOrdered(), TASK_SETTINGS); PARSER.declareString(Request.Builder::setQuery, QUERY); + PARSER.declareBoolean(Request.Builder::setReturnDocuments, RETURN_DOCUMENTS); + PARSER.declareInt(Request.Builder::setTopN, TOP_N); PARSER.declareString(Builder::setInferenceTimeout, TIMEOUT); } @@ -89,6 +93,8 @@ public class InferenceAction extends ActionType { private final TaskType taskType; private final String inferenceEntityId; private final String query; + private final Boolean returnDocuments; + private final Integer topN; private final List input; private final Map taskSettings; private final InputType inputType; @@ -99,6 +105,8 @@ public class InferenceAction extends ActionType { TaskType taskType, String inferenceEntityId, String query, + Boolean returnDocuments, + Integer topN, List input, Map taskSettings, InputType inputType, @@ -109,6 +117,8 @@ public class InferenceAction extends ActionType { taskType, inferenceEntityId, query, + returnDocuments, + topN, input, taskSettings, inputType, @@ -122,6 +132,8 @@ public class InferenceAction extends ActionType { TaskType taskType, String inferenceEntityId, String query, + Boolean returnDocuments, + Integer topN, List input, Map taskSettings, InputType inputType, @@ -133,6 +145,8 @@ public class InferenceAction extends ActionType { this.taskType = taskType; this.inferenceEntityId = inferenceEntityId; this.query = query; + this.returnDocuments = returnDocuments; + this.topN = topN; this.input = input; this.taskSettings = taskSettings; this.inputType = inputType; @@ -164,6 +178,15 @@ public class InferenceAction extends ActionType { this.inferenceTimeout = DEFAULT_TIMEOUT; } + if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_COMMON_OPTIONS_ADDED) + || in.getTransportVersion().isPatchFrom(TransportVersions.RERANK_COMMON_OPTIONS_ADDED_8_19)) { + this.returnDocuments = in.readOptionalBoolean(); + this.topN = in.readOptionalInt(); + } else { + this.returnDocuments = null; + this.topN = null; + } + // streaming is not supported yet for transport traffic this.stream = false; } @@ -184,6 +207,14 @@ public class InferenceAction extends ActionType { return query; } + public Boolean getReturnDocuments() { + return returnDocuments; + } + + public Integer getTopN() { + return topN; + } + public Map getTaskSettings() { return taskSettings; } @@ -225,6 +256,17 @@ public class InferenceAction extends ActionType { e.addValidationError(format("Field [query] cannot be empty for task type [%s]", TaskType.RERANK)); return e; } + } else if (taskType.equals(TaskType.ANY) == false) { + if (returnDocuments != null) { + var e = new ActionRequestValidationException(); + e.addValidationError(format("Field [return_documents] cannot be specified for task type [%s]", taskType)); + return e; + } + if (topN != null) { + var e = new ActionRequestValidationException(); + e.addValidationError(format("Field [top_n] cannot be specified for task type [%s]", taskType)); + return e; + } } if (taskType.equals(TaskType.TEXT_EMBEDDING) == false @@ -258,6 +300,12 @@ public class InferenceAction extends ActionType { out.writeOptionalString(query); out.writeTimeValue(inferenceTimeout); } + + if (out.getTransportVersion().onOrAfter(TransportVersions.RERANK_COMMON_OPTIONS_ADDED) + || out.getTransportVersion().isPatchFrom(TransportVersions.RERANK_COMMON_OPTIONS_ADDED_8_19)) { + out.writeOptionalBoolean(returnDocuments); + out.writeOptionalInt(topN); + } } // default for easier testing @@ -283,6 +331,8 @@ public class InferenceAction extends ActionType { && taskType == request.taskType && Objects.equals(inferenceEntityId, request.inferenceEntityId) && Objects.equals(query, request.query) + && Objects.equals(returnDocuments, request.returnDocuments) + && Objects.equals(topN, request.topN) && Objects.equals(input, request.input) && Objects.equals(taskSettings, request.taskSettings) && inputType == request.inputType @@ -296,6 +346,8 @@ public class InferenceAction extends ActionType { taskType, inferenceEntityId, query, + returnDocuments, + topN, input, taskSettings, inputType, @@ -312,6 +364,8 @@ public class InferenceAction extends ActionType { private InputType inputType = InputType.UNSPECIFIED; private Map taskSettings = Map.of(); private String query; + private Boolean returnDocuments; + private Integer topN; private TimeValue timeout = DEFAULT_TIMEOUT; private boolean stream = false; private InferenceContext context; @@ -338,6 +392,16 @@ public class InferenceAction extends ActionType { return this; } + public Builder setReturnDocuments(Boolean returnDocuments) { + this.returnDocuments = returnDocuments; + return this; + } + + public Builder setTopN(Integer topN) { + this.topN = topN; + return this; + } + public Builder setInputType(InputType inputType) { this.inputType = inputType; return this; @@ -373,7 +437,19 @@ public class InferenceAction extends ActionType { } public Request build() { - return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream, context); + return new Request( + taskType, + inferenceEntityId, + query, + returnDocuments, + topN, + input, + taskSettings, + inputType, + timeout, + stream, + context + ); } } @@ -384,6 +460,10 @@ public class InferenceAction extends ActionType { + this.getInferenceEntityId() + ", query=" + this.getQuery() + + ", returnDocuments=" + + this.getReturnDocuments() + + ", topN=" + + this.getTopN() + ", input=" + this.getInput() + ", taskSettings=" diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java index 024205b365a7..2e2b9bf9b0d2 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java @@ -44,6 +44,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes randomFrom(TaskType.values()), randomAlphaOfLength(6), randomAlphaOfLengthOrNull(10), + randomBoolean(), + randomIntBetween(0, 10), randomList(1, 5, () -> randomAlphaOfLength(8)), randomMap(0, 3, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))), randomFrom(InputType.values()), @@ -85,6 +87,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes TaskType.TEXT_EMBEDDING, "model", null, + null, + null, List.of("input"), null, null, @@ -100,6 +104,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes TaskType.RERANK, "model", "query", + Boolean.TRUE, + 34, List.of("input"), null, null, @@ -119,6 +125,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes null, null, null, + null, + null, false ); ActionRequestValidationException inputNullError = inputNullRequest.validate(); @@ -131,6 +139,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes TaskType.TEXT_EMBEDDING, "model", null, + null, + null, List.of(), null, null, @@ -142,11 +152,52 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes assertThat(inputEmptyError.getMessage(), is("Validation Failed: 1: Field [input] cannot be an empty array;")); } + public void testValidation_TextEmbedding_WithReturnDocument() { + InferenceAction.Request inputRequest = new InferenceAction.Request( + TaskType.TEXT_EMBEDDING, + "model", + null, + Boolean.TRUE, + null, + List.of("input"), + null, + null, + null, + false + ); + ActionRequestValidationException inputError = inputRequest.validate(); + assertNotNull(inputError); + assertThat( + inputError.getMessage(), + is("Validation Failed: 1: Field [return_documents] cannot be specified for task type [text_embedding];") + ); + } + + public void testValidation_TextEmbedding_WithTopN() { + InferenceAction.Request inputRequest = new InferenceAction.Request( + TaskType.TEXT_EMBEDDING, + "model", + null, + null, + 12, + List.of("input"), + null, + null, + null, + false + ); + ActionRequestValidationException inputError = inputRequest.validate(); + assertNotNull(inputError); + assertThat(inputError.getMessage(), is("Validation Failed: 1: Field [top_n] cannot be specified for task type [text_embedding];")); + } + public void testValidation_Rerank_Null() { InferenceAction.Request queryNullRequest = new InferenceAction.Request( TaskType.RERANK, "model", null, + null, + null, List.of("input"), null, null, @@ -163,6 +214,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes TaskType.RERANK, "model", "", + null, + null, List.of("input"), null, null, @@ -179,6 +232,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes TaskType.RERANK, "model", "query", + null, + null, List.of("input"), null, InputType.SEARCH, @@ -195,6 +250,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes TaskType.SPARSE_EMBEDDING, "model", "", + null, + null, List.of("input"), null, InputType.SEARCH, @@ -209,11 +266,56 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes ); } + public void testValidation_SparseEmbedding_WithReturnDocument() { + InferenceAction.Request queryRequest = new InferenceAction.Request( + TaskType.SPARSE_EMBEDDING, + "model", + "", + Boolean.FALSE, + null, + List.of("input"), + null, + null, + null, + false + ); + ActionRequestValidationException queryError = queryRequest.validate(); + assertNotNull(queryError); + assertThat( + queryError.getMessage(), + is("Validation Failed: 1: Field [return_documents] cannot be specified for task type [sparse_embedding];") + ); + + } + + public void testValidation_SparseEmbedding_WithTopN() { + InferenceAction.Request queryRequest = new InferenceAction.Request( + TaskType.SPARSE_EMBEDDING, + "model", + "", + null, + 22, + List.of("input"), + null, + null, + null, + false + ); + ActionRequestValidationException queryError = queryRequest.validate(); + assertNotNull(queryError); + assertThat( + queryError.getMessage(), + is("Validation Failed: 1: Field [top_n] cannot be specified for task type [sparse_embedding];") + ); + } + public void testValidation_Completion_WithInputType() { InferenceAction.Request queryRequest = new InferenceAction.Request( TaskType.COMPLETION, "model", "", + null, + null, List.of("input"), null, InputType.SEARCH, @@ -225,11 +327,52 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes assertThat(queryError.getMessage(), is("Validation Failed: 1: Field [input_type] cannot be specified for task type [completion];")); } + public void testValidation_Completion_WithReturnDocuments() { + InferenceAction.Request queryRequest = new InferenceAction.Request( + TaskType.COMPLETION, + "model", + "", + Boolean.TRUE, + null, + List.of("input"), + null, + null, + null, + false + ); + ActionRequestValidationException queryError = queryRequest.validate(); + assertNotNull(queryError); + assertThat( + queryError.getMessage(), + is("Validation Failed: 1: Field [return_documents] cannot be specified for task type [completion];") + ); + } + + public void testValidation_Completion_WithTopN() { + InferenceAction.Request queryRequest = new InferenceAction.Request( + TaskType.COMPLETION, + "model", + "", + null, + 77, + List.of("input"), + null, + null, + null, + false + ); + ActionRequestValidationException queryError = queryRequest.validate(); + assertNotNull(queryError); + assertThat(queryError.getMessage(), is("Validation Failed: 1: Field [top_n] cannot be specified for task type [completion];")); + } + public void testValidation_ChatCompletion_WithInputType() { InferenceAction.Request queryRequest = new InferenceAction.Request( TaskType.CHAT_COMPLETION, "model", "", + null, + null, List.of("input"), null, InputType.SEARCH, @@ -244,6 +387,45 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes ); } + public void testValidation_ChatCompletion_WithReturnDocuments() { + InferenceAction.Request queryRequest = new InferenceAction.Request( + TaskType.CHAT_COMPLETION, + "model", + "", + Boolean.TRUE, + null, + List.of("input"), + null, + null, + null, + false + ); + ActionRequestValidationException queryError = queryRequest.validate(); + assertNotNull(queryError); + assertThat( + queryError.getMessage(), + is("Validation Failed: 1: Field [return_documents] cannot be specified for task type [chat_completion];") + ); + } + + public void testValidation_ChatCompletion_WithTopN() { + InferenceAction.Request queryRequest = new InferenceAction.Request( + TaskType.CHAT_COMPLETION, + "model", + "", + null, + 11, + List.of("input"), + null, + InputType.SEARCH, + null, + false + ); + ActionRequestValidationException queryError = queryRequest.validate(); + assertNotNull(queryError); + assertThat(queryError.getMessage(), is("Validation Failed: 1: Field [top_n] cannot be specified for task type [chat_completion];")); + } + public void testParseRequest_DefaultsInputTypeToIngest() throws IOException { String singleInputRequest = """ { @@ -271,6 +453,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes nextTask, instance.getInferenceEntityId(), instance.getQuery(), + instance.getReturnDocuments(), + instance.getTopN(), instance.getInput(), instance.getTaskSettings(), instance.getInputType(), @@ -283,6 +467,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes instance.getTaskType(), instance.getInferenceEntityId() + "foo", instance.getQuery(), + instance.getReturnDocuments(), + instance.getTopN(), instance.getInput(), instance.getTaskSettings(), instance.getInputType(), @@ -297,6 +483,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes instance.getTaskType(), instance.getInferenceEntityId(), instance.getQuery(), + instance.getReturnDocuments(), + instance.getTopN(), changedInputs, instance.getTaskSettings(), instance.getInputType(), @@ -317,6 +505,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes instance.getTaskType(), instance.getInferenceEntityId(), instance.getQuery(), + instance.getReturnDocuments(), + instance.getTopN(), instance.getInput(), taskSettings, instance.getInputType(), @@ -331,6 +521,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes instance.getTaskType(), instance.getInferenceEntityId(), instance.getQuery(), + instance.getReturnDocuments(), + instance.getTopN(), instance.getInput(), instance.getTaskSettings(), nextInputType, @@ -343,6 +535,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes instance.getTaskType(), instance.getInferenceEntityId(), instance.getQuery() == null ? randomAlphaOfLength(10) : instance.getQuery() + randomAlphaOfLength(1), + instance.getReturnDocuments(), + instance.getTopN(), instance.getInput(), instance.getTaskSettings(), instance.getInputType(), @@ -360,6 +554,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes instance.getTaskType(), instance.getInferenceEntityId(), instance.getQuery(), + instance.getReturnDocuments(), + instance.getTopN(), instance.getInput(), instance.getTaskSettings(), instance.getInputType(), @@ -374,6 +570,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes instance.getTaskType(), instance.getInferenceEntityId(), instance.getQuery(), + instance.getReturnDocuments(), + instance.getTopN(), instance.getInput(), instance.getTaskSettings(), instance.getInputType(), @@ -395,6 +593,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes instance.getTaskType(), instance.getInferenceEntityId(), null, + null, + null, instance.getInput().subList(0, 1), instance.getTaskSettings(), InputType.UNSPECIFIED, @@ -406,6 +606,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes instance.getTaskType(), instance.getInferenceEntityId(), null, + null, + null, instance.getInput(), instance.getTaskSettings(), InputType.UNSPECIFIED, @@ -420,6 +622,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes instance.getTaskType(), instance.getInferenceEntityId(), null, + null, + null, instance.getInput(), instance.getTaskSettings(), InputType.INGEST, @@ -432,6 +636,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes instance.getTaskType(), instance.getInferenceEntityId(), null, + null, + null, instance.getInput(), instance.getTaskSettings(), InputType.UNSPECIFIED, @@ -443,6 +649,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes instance.getTaskType(), instance.getInferenceEntityId(), null, + null, + null, instance.getInput(), instance.getTaskSettings(), instance.getInputType(), @@ -455,6 +663,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes instance.getTaskType(), instance.getInferenceEntityId(), instance.getQuery(), + null, + null, instance.getInput(), instance.getTaskSettings(), instance.getInputType(), @@ -462,9 +672,24 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes false, InferenceContext.EMPTY_INSTANCE ); - } else { - mutated = instance; - } + } else if (version.before(TransportVersions.RERANK_COMMON_OPTIONS_ADDED) + && version.isPatchFrom(TransportVersions.RERANK_COMMON_OPTIONS_ADDED_8_19) == false) { + mutated = new InferenceAction.Request( + instance.getTaskType(), + instance.getInferenceEntityId(), + instance.getQuery(), + null, + null, + instance.getInput(), + instance.getTaskSettings(), + instance.getInputType(), + instance.getInferenceTimeout(), + false, + instance.getContext() + ); + } else { + mutated = instance; + } // We always assume that a request has been rerouted, if it came from a node without adaptive rate limiting if (version.before(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)) { @@ -481,6 +706,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes TaskType.TEXT_EMBEDDING, "model", null, + null, + null, List.of(), Map.of(), InputType.UNSPECIFIED, @@ -503,6 +730,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes TaskType.TEXT_EMBEDDING, "model", null, + null, + null, List.of(), Map.of(), InputType.INGEST, @@ -525,6 +754,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes TaskType.TEXT_EMBEDDING, "model", null, + null, + null, List.of("input"), Map.of(), InputType.UNSPECIFIED, @@ -548,6 +779,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes TaskType.TEXT_EMBEDDING, "model", null, + null, + null, List.of("input"), Map.of(), InputType.UNSPECIFIED, diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 51ae8b5437b4..ad6f1b88de32 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -110,6 +110,8 @@ public class TestDenseInferenceServiceExtension implements InferenceServiceExten public void infer( Model model, @Nullable String query, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, List input, boolean stream, Map taskSettings, diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java index 8be2317a9ee6..d4e3642affdd 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java @@ -102,6 +102,8 @@ public class TestRerankingServiceExtension implements InferenceServiceExtension public void infer( Model model, @Nullable String query, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, List input, boolean stream, Map taskSettings, diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index b860bb85ebd0..6f533d83884e 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -103,6 +103,8 @@ public class TestSparseInferenceServiceExtension implements InferenceServiceExte public void infer( Model model, @Nullable String query, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, List input, boolean stream, Map taskSettings, diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index 8c876e9947bb..6bcec22bb50b 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -103,6 +104,8 @@ public class TestStreamingCompletionServiceExtension implements InferenceService public void infer( Model model, String query, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, List input, boolean stream, Map taskSettings, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index e8f52e42f570..7d24b7766baa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -77,6 +77,8 @@ public class TransportInferenceAction extends BaseTransportInferenceAction new VoyageAIRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model), + (rerankInput) -> new VoyageAIRerankRequest( + rerankInput.getQuery(), + rerankInput.getChunks(), + rerankInput.getReturnDocuments(), + rerankInput.getTopN(), + model + ), QueryAndDocsInputs.class ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchRerankRequestManager.java index 446db40aa5ae..b50b1e3fbad8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchRerankRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchRerankRequestManager.java @@ -69,6 +69,8 @@ public class AlibabaCloudSearchRerankRequestManager extends AlibabaCloudSearchRe account, rerankInput.getQuery(), rerankInput.getChunks(), + rerankInput.getReturnDocuments(), + rerankInput.getTopN(), model ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java index d27812b17399..4d379a5c8fee 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java @@ -49,7 +49,13 @@ public class CohereRerankRequestManager extends CohereRequestManager { ActionListener listener ) { var rerankInput = QueryAndDocsInputs.of(inferenceInputs); - CohereRerankRequest request = new CohereRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model); + CohereRerankRequest request = new CohereRerankRequest( + rerankInput.getQuery(), + rerankInput.getChunks(), + rerankInput.getReturnDocuments(), + rerankInput.getTopN(), + model + ); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiRerankRequestManager.java index e74f0049fffb..f499917a8c93 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiRerankRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiRerankRequestManager.java @@ -62,7 +62,13 @@ public class GoogleVertexAiRerankRequestManager extends GoogleVertexAiRequestMan ActionListener listener ) { var rerankInput = QueryAndDocsInputs.of(inferenceInputs); - GoogleVertexAiRerankRequest request = new GoogleVertexAiRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model); + GoogleVertexAiRerankRequest request = new GoogleVertexAiRerankRequest( + rerankInput.getQuery(), + rerankInput.getChunks(), + rerankInput.getReturnDocuments(), + rerankInput.getTopN(), + model + ); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIRerankRequestManager.java index 26f134873bca..4fc49eaf442e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIRerankRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIRerankRequestManager.java @@ -49,7 +49,13 @@ public class JinaAIRerankRequestManager extends JinaAIRequestManager { ActionListener listener ) { var rerankInput = QueryAndDocsInputs.of(inferenceInputs); - JinaAIRerankRequest request = new JinaAIRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model); + JinaAIRerankRequest request = new JinaAIRerankRequest( + rerankInput.getQuery(), + rerankInput.getChunks(), + rerankInput.getReturnDocuments(), + rerankInput.getTopN(), + model + ); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java index 5af5245ac5b4..d755ac982ac3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.inference.external.http.sender; +import org.elasticsearch.core.Nullable; + import java.util.List; import java.util.Objects; @@ -22,15 +24,25 @@ public class QueryAndDocsInputs extends InferenceInputs { private final String query; private final List chunks; + private final Boolean returnDocuments; + private final Integer topN; public QueryAndDocsInputs(String query, List chunks) { - this(query, chunks, false); + this(query, chunks, null, null, false); } - public QueryAndDocsInputs(String query, List chunks, boolean stream) { + public QueryAndDocsInputs( + String query, + List chunks, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + boolean stream + ) { super(stream); this.query = Objects.requireNonNull(query); this.chunks = Objects.requireNonNull(chunks); + this.returnDocuments = returnDocuments; + this.topN = topN; } public String getQuery() { @@ -41,6 +53,14 @@ public class QueryAndDocsInputs extends InferenceInputs { return chunks; } + public Boolean getReturnDocuments() { + return returnDocuments; + } + + public Integer getTopN() { + return topN; + } + public int inputSize() { return chunks.size(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequest.java index 878bcc6e6a0d..5e392725b9f4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequest.java @@ -12,6 +12,7 @@ import org.apache.http.client.methods.HttpPost; import org.apache.http.client.utils.URIBuilder; import org.apache.http.entity.ByteArrayEntity; import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount; import org.elasticsearch.xpack.inference.external.request.HttpRequest; @@ -32,6 +33,8 @@ public class AlibabaCloudSearchRerankRequest implements Request { private final AlibabaCloudSearchAccount account; private final String query; private final List input; + private final Boolean returnDocuments; + private final Integer topN; private final URI uri; private final AlibabaCloudSearchRerankTaskSettings taskSettings; private final String model; @@ -44,6 +47,8 @@ public class AlibabaCloudSearchRerankRequest implements Request { AlibabaCloudSearchAccount account, String query, List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, AlibabaCloudSearchRerankModel rerankModel ) { Objects.requireNonNull(rerankModel); @@ -51,6 +56,8 @@ public class AlibabaCloudSearchRerankRequest implements Request { this.account = Objects.requireNonNull(account); this.query = Objects.requireNonNull(query); this.input = Objects.requireNonNull(input); + this.returnDocuments = returnDocuments; + this.topN = topN; taskSettings = rerankModel.getTaskSettings(); model = rerankModel.getServiceSettings().getCommonSettings().modelId(); host = rerankModel.getServiceSettings().getCommonSettings().getHost(); @@ -67,7 +74,8 @@ public class AlibabaCloudSearchRerankRequest implements Request { HttpPost httpPost = new HttpPost(uri); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new AlibabaCloudSearchRerankRequestEntity(query, input, taskSettings)).getBytes(StandardCharsets.UTF_8) + Strings.toString(new AlibabaCloudSearchRerankRequestEntity(query, input, returnDocuments, topN, taskSettings)) + .getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntity.java index 054e373e3e52..a5731f29d93e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntity.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch; +import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankTaskSettings; @@ -15,9 +16,13 @@ import java.io.IOException; import java.util.List; import java.util.Objects; -public record AlibabaCloudSearchRerankRequestEntity(String query, List input, AlibabaCloudSearchRerankTaskSettings taskSettings) - implements - ToXContentObject { +public record AlibabaCloudSearchRerankRequestEntity( + String query, + List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + AlibabaCloudSearchRerankTaskSettings taskSettings +) implements ToXContentObject { private static final String SEARCH_QUERY = "query"; private static final String TEXTS_FIELD = "docs"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequest.java index 4ec04c018732..283ed759884c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequest.java @@ -11,6 +11,7 @@ import org.apache.http.client.methods.HttpPost; import org.apache.http.client.utils.URIBuilder; import org.apache.http.entity.ByteArrayEntity; import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; import org.elasticsearch.xpack.inference.external.cohere.CohereAccount; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; @@ -28,16 +29,26 @@ public class CohereRerankRequest extends CohereRequest { private final CohereAccount account; private final String query; private final List input; + private final Boolean returnDocuments; + private final Integer topN; private final CohereRerankTaskSettings taskSettings; private final String model; private final String inferenceEntityId; - public CohereRerankRequest(String query, List input, CohereRerankModel model) { + public CohereRerankRequest( + String query, + List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + CohereRerankModel model + ) { Objects.requireNonNull(model); this.account = CohereAccount.of(model, CohereRerankRequest::buildDefaultUri); this.input = Objects.requireNonNull(input); this.query = Objects.requireNonNull(query); + this.returnDocuments = returnDocuments; + this.topN = topN; taskSettings = model.getTaskSettings(); this.model = model.getServiceSettings().modelId(); inferenceEntityId = model.getInferenceEntityId(); @@ -48,7 +59,8 @@ public class CohereRerankRequest extends CohereRequest { HttpPost httpPost = new HttpPost(account.uri()); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new CohereRerankRequestEntity(query, input, taskSettings, model)).getBytes(StandardCharsets.UTF_8) + Strings.toString(new CohereRerankRequestEntity(query, input, returnDocuments, topN, taskSettings, model)) + .getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequestEntity.java index e7abe0990eb0..085aa0a14316 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequestEntity.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.external.request.cohere; +import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; @@ -15,9 +16,14 @@ import java.io.IOException; import java.util.List; import java.util.Objects; -public record CohereRerankRequestEntity(String model, String query, List documents, CohereRerankTaskSettings taskSettings) - implements - ToXContentObject { +public record CohereRerankRequestEntity( + String model, + String query, + List documents, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + CohereRerankTaskSettings taskSettings +) implements ToXContentObject { private static final String DOCUMENTS_FIELD = "documents"; private static final String QUERY_FIELD = "query"; @@ -29,8 +35,15 @@ public record CohereRerankRequestEntity(String model, String query, List Objects.requireNonNull(taskSettings); } - public CohereRerankRequestEntity(String query, List input, CohereRerankTaskSettings taskSettings, String model) { - this(model, query, input, taskSettings); + public CohereRerankRequestEntity( + String query, + List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + CohereRerankTaskSettings taskSettings, + String model + ) { + this(model, query, input, returnDocuments, topN, taskSettings); } @Override @@ -41,11 +54,17 @@ public record CohereRerankRequestEntity(String model, String query, List builder.field(QUERY_FIELD, query); builder.field(DOCUMENTS_FIELD, documents); - if (taskSettings.getDoesReturnDocuments() != null) { + // prefer the root level return_documents over task settings + if (returnDocuments != null) { + builder.field(CohereRerankTaskSettings.RETURN_DOCUMENTS, returnDocuments); + } else if (taskSettings.getDoesReturnDocuments() != null) { builder.field(CohereRerankTaskSettings.RETURN_DOCUMENTS, taskSettings.getDoesReturnDocuments()); } - if (taskSettings.getTopNDocumentsOnly() != null) { + // prefer the root level top_n over task settings + if (topN != null) { + builder.field(CohereRerankTaskSettings.TOP_N_DOCS_ONLY, topN); + } else if (taskSettings.getTopNDocumentsOnly() != null) { builder.field(CohereRerankTaskSettings.TOP_N_DOCS_ONLY, taskSettings.getTopNDocumentsOnly()); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequest.java index 79606c63e0ed..9004c061423c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequest.java @@ -11,6 +11,7 @@ import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; import org.apache.http.entity.ByteArrayEntity; import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; @@ -29,10 +30,22 @@ public class GoogleVertexAiRerankRequest implements GoogleVertexAiRequest { private final List input; - public GoogleVertexAiRerankRequest(String query, List input, GoogleVertexAiRerankModel model) { + private final Boolean returnDocuments; + + private final Integer topN; + + public GoogleVertexAiRerankRequest( + String query, + List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + GoogleVertexAiRerankModel model + ) { this.model = Objects.requireNonNull(model); this.query = Objects.requireNonNull(query); this.input = Objects.requireNonNull(input); + this.returnDocuments = returnDocuments; + this.topN = topN; } @Override @@ -41,7 +54,13 @@ public class GoogleVertexAiRerankRequest implements GoogleVertexAiRequest { ByteArrayEntity byteEntity = new ByteArrayEntity( Strings.toString( - new GoogleVertexAiRerankRequestEntity(query, input, model.getServiceSettings().modelId(), model.getTaskSettings().topN()) + new GoogleVertexAiRerankRequestEntity( + query, + input, + returnDocuments, + topN != null ? topN : model.getTaskSettings().topN(), + model.getServiceSettings().modelId() + ) ).getBytes(StandardCharsets.UTF_8) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestEntity.java index 2cac067f622c..13f6b1da9fc8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestEntity.java @@ -15,9 +15,13 @@ import java.io.IOException; import java.util.List; import java.util.Objects; -public record GoogleVertexAiRerankRequestEntity(String query, List inputs, @Nullable String model, @Nullable Integer topN) - implements - ToXContentObject { +public record GoogleVertexAiRerankRequestEntity( + String query, + List inputs, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + @Nullable String model +) implements ToXContentObject { private static final String MODEL_FIELD = "model"; private static final String QUERY_FIELD = "query"; @@ -26,6 +30,7 @@ public record GoogleVertexAiRerankRequestEntity(String query, List input private static final String CONTENT_FIELD = "content"; private static final String TOP_N_FIELD = "topN"; + private static final String IGNORE_RECORD_DETAILS_IN_RESPONSE_FIELD = "ignoreRecordDetailsInResponse"; public GoogleVertexAiRerankRequestEntity { Objects.requireNonNull(query); @@ -57,10 +62,16 @@ public record GoogleVertexAiRerankRequestEntity(String query, List input builder.endArray(); + // prefer the root level top_n over task settings if (topN != null) { builder.field(TOP_N_FIELD, topN); } + if (returnDocuments != null) { + // if returnDocuments = true, we do not want to ignore record details + builder.field(IGNORE_RECORD_DETAILS_IN_RESPONSE_FIELD, returnDocuments == Boolean.TRUE ? Boolean.FALSE : Boolean.TRUE); + } + builder.endObject(); return builder; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequest.java index 93d4ab830c60..8994a23f4272 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequest.java @@ -11,6 +11,7 @@ import org.apache.http.client.methods.HttpPost; import org.apache.http.client.utils.URIBuilder; import org.apache.http.entity.ByteArrayEntity; import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; import org.elasticsearch.xpack.inference.external.jinaai.JinaAIAccount; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; @@ -28,16 +29,26 @@ public class JinaAIRerankRequest extends JinaAIRequest { private final JinaAIAccount account; private final String query; private final List input; + private final Boolean returnDocuments; + private final Integer topN; private final JinaAIRerankTaskSettings taskSettings; private final String model; private final String inferenceEntityId; - public JinaAIRerankRequest(String query, List input, JinaAIRerankModel model) { + public JinaAIRerankRequest( + String query, + List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + JinaAIRerankModel model + ) { Objects.requireNonNull(model); this.account = JinaAIAccount.of(model, JinaAIRerankRequest::buildDefaultUri); this.input = Objects.requireNonNull(input); this.query = Objects.requireNonNull(query); + this.returnDocuments = returnDocuments; + this.topN = topN; taskSettings = model.getTaskSettings(); this.model = model.getServiceSettings().modelId(); inferenceEntityId = model.getInferenceEntityId(); @@ -48,7 +59,8 @@ public class JinaAIRerankRequest extends JinaAIRequest { HttpPost httpPost = new HttpPost(account.uri()); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new JinaAIRerankRequestEntity(query, input, taskSettings, model)).getBytes(StandardCharsets.UTF_8) + Strings.toString(new JinaAIRerankRequestEntity(query, input, returnDocuments, topN, taskSettings, model)) + .getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestEntity.java index 7f470d5fa91f..1a770026f9d2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestEntity.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.external.request.jinaai; +import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings; @@ -15,9 +16,14 @@ import java.io.IOException; import java.util.List; import java.util.Objects; -public record JinaAIRerankRequestEntity(String model, String query, List documents, JinaAIRerankTaskSettings taskSettings) - implements - ToXContentObject { +public record JinaAIRerankRequestEntity( + String model, + String query, + List documents, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + JinaAIRerankTaskSettings taskSettings +) implements ToXContentObject { private static final String DOCUMENTS_FIELD = "documents"; private static final String QUERY_FIELD = "query"; @@ -30,8 +36,15 @@ public record JinaAIRerankRequestEntity(String model, String query, List Objects.requireNonNull(taskSettings); } - public JinaAIRerankRequestEntity(String query, List input, JinaAIRerankTaskSettings taskSettings, String model) { - this(model, query, input, taskSettings != null ? taskSettings : JinaAIRerankTaskSettings.EMPTY_SETTINGS); + public JinaAIRerankRequestEntity( + String query, + List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + JinaAIRerankTaskSettings taskSettings, + String model + ) { + this(model, query, input, returnDocuments, topN, taskSettings != null ? taskSettings : JinaAIRerankTaskSettings.EMPTY_SETTINGS); } @Override @@ -42,13 +55,18 @@ public record JinaAIRerankRequestEntity(String model, String query, List builder.field(QUERY_FIELD, query); builder.field(DOCUMENTS_FIELD, documents); - if (taskSettings.getTopNDocumentsOnly() != null) { + // prefer the root level top_n over task settings + if (topN != null) { + builder.field(JinaAIRerankTaskSettings.TOP_N_DOCS_ONLY, topN); + } else if (taskSettings.getTopNDocumentsOnly() != null) { builder.field(JinaAIRerankTaskSettings.TOP_N_DOCS_ONLY, taskSettings.getTopNDocumentsOnly()); } - var return_documents = taskSettings.getDoesReturnDocuments(); - if (return_documents != null) { - builder.field(JinaAIRerankTaskSettings.RETURN_DOCUMENTS, return_documents); + // prefer the root level return_documents over task settings + if (returnDocuments != null) { + builder.field(JinaAIRerankTaskSettings.RETURN_DOCUMENTS, returnDocuments); + } else if (taskSettings.getDoesReturnDocuments() != null) { + builder.field(JinaAIRerankTaskSettings.RETURN_DOCUMENTS, taskSettings.getDoesReturnDocuments()); } builder.endObject(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequest.java index 9fb50720e4d5..9b0b4268fc70 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequest.java @@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference.external.request.voyageai; import org.apache.http.client.methods.HttpPost; import org.apache.http.entity.ByteArrayEntity; import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel; @@ -23,13 +24,23 @@ public class VoyageAIRerankRequest extends VoyageAIRequest { private final String query; private final List input; + private final Boolean returnDocuments; + private final Integer topN; private final VoyageAIRerankModel model; - public VoyageAIRerankRequest(String query, List input, VoyageAIRerankModel model) { + public VoyageAIRerankRequest( + String query, + List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + VoyageAIRerankModel model + ) { this.model = Objects.requireNonNull(model); this.input = Objects.requireNonNull(input); this.query = Objects.requireNonNull(query); + this.returnDocuments = returnDocuments; + this.topN = topN; } @Override @@ -37,8 +48,16 @@ public class VoyageAIRerankRequest extends VoyageAIRequest { HttpPost httpPost = new HttpPost(model.uri()); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new VoyageAIRerankRequestEntity(query, input, model.getTaskSettings(), model.getServiceSettings().modelId())) - .getBytes(StandardCharsets.UTF_8) + Strings.toString( + new VoyageAIRerankRequestEntity( + query, + input, + returnDocuments, + topN, + model.getTaskSettings(), + model.getServiceSettings().modelId() + ) + ).getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntity.java index 0f7baaa35044..a52013f5d6f0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntity.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.external.request.voyageai; +import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings; @@ -15,15 +16,19 @@ import java.io.IOException; import java.util.List; import java.util.Objects; -public record VoyageAIRerankRequestEntity(String model, String query, List documents, VoyageAIRerankTaskSettings taskSettings) - implements - ToXContentObject { +public record VoyageAIRerankRequestEntity( + String model, + String query, + List documents, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + VoyageAIRerankTaskSettings taskSettings +) implements ToXContentObject { private static final String DOCUMENTS_FIELD = "documents"; private static final String QUERY_FIELD = "query"; private static final String MODEL_FIELD = "model"; public static final String TRUNCATION_FIELD = "truncation"; - public static final String RETURN_DOCUMENTS_FIELD = "return_documents"; public VoyageAIRerankRequestEntity { Objects.requireNonNull(query); @@ -32,8 +37,15 @@ public record VoyageAIRerankRequestEntity(String model, String query, List input, VoyageAIRerankTaskSettings taskSettings, String model) { - this(model, query, input, taskSettings != null ? taskSettings : VoyageAIRerankTaskSettings.EMPTY_SETTINGS); + public VoyageAIRerankRequestEntity( + String query, + List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + VoyageAIRerankTaskSettings taskSettings, + String model + ) { + this(model, query, input, returnDocuments, topN, taskSettings != null ? taskSettings : VoyageAIRerankTaskSettings.EMPTY_SETTINGS); } @Override @@ -44,11 +56,17 @@ public record VoyageAIRerankRequestEntity(String model, String query, List { var parsedRankedDoc = RankedDoc.parse(parser); - if (parsedRankedDoc.content == null) { - throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDoc.CONTENT.getPreferredName())); - } - if (parsedRankedDoc.score == null) { throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDoc.SCORE.getPreferredName())); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java index 96dbd3948cdc..182c083ef1c2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilder.java @@ -232,6 +232,8 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder input, boolean stream, Map taskSettings, @@ -68,7 +70,7 @@ public abstract class SenderService implements InferenceService { ActionListener listener ) { init(); - var inferenceInput = createInput(this, model, input, inputType, query, stream); + var inferenceInput = createInput(this, model, input, inputType, query, returnDocuments, topN, stream); doInfer(model, inferenceInput, taskSettings, timeout, listener); } @@ -78,11 +80,20 @@ public abstract class SenderService implements InferenceService { List input, InputType inputType, @Nullable String query, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, boolean stream ) { return switch (model.getTaskType()) { case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(input, stream); - case RERANK -> new QueryAndDocsInputs(query, input, stream); + case RERANK -> { + ValidationException validationException = new ValidationException(); + service.validateRerankParameters(returnDocuments, topN, validationException); + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + yield new QueryAndDocsInputs(query, input, returnDocuments, topN, stream); + } case TEXT_EMBEDDING, SPARSE_EMBEDDING -> { ValidationException validationException = new ValidationException(); service.validateInputType(inputType, model, validationException); @@ -141,6 +152,8 @@ public abstract class SenderService implements InferenceService { protected abstract void validateInputType(InputType inputType, Model model, ValidationException validationException); + protected void validateRerankParameters(Boolean returnDocuments, Integer topN, ValidationException validationException) {} + protected abstract void doUnifiedCompletionInfer( Model model, UnifiedChatInput inputs, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 1ca63908ec5f..a0c77599b6ce 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -735,6 +735,8 @@ public final class ServiceUtils { service.infer( model, null, + null, + null, List.of(TEST_EMBEDDING_INPUT), false, Map.of(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index fe844bbe0c1a..bf1fbda2b826 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -14,6 +14,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; @@ -300,6 +301,24 @@ public class AlibabaCloudSearchService extends SenderService { ServiceUtils.validateInputTypeAgainstAllowlist(inputType, VALID_INPUT_TYPE_VALUES, SERVICE_NAME, validationException); } + @Override + protected void validateRerankParameters(Boolean returnDocuments, Integer topN, ValidationException validationException) { + if (returnDocuments != null) { + validationException.addValidationError( + Strings.format( + "Invalid return_documents [%s]. The return_documents option is not supported by this service", + returnDocuments + ) + ); + } + + if (topN != null) { + validationException.addValidationError( + Strings.format("Invalid top_n [%s]. The top_n option is not supported by this service", topN) + ); + } + } + @Override protected void doChunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 196231532556..cbf203ee4a68 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -620,6 +620,8 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi public void infer( Model model, @Nullable String query, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, List input, boolean stream, Map taskSettings, @@ -632,7 +634,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi if (TaskType.TEXT_EMBEDDING.equals(taskType)) { inferTextEmbedding(esModel, input, inputType, timeout, listener); } else if (TaskType.RERANK.equals(taskType)) { - inferRerank(esModel, query, input, inputType, timeout, taskSettings, listener); + inferRerank(esModel, query, input, returnDocuments, topN, inputType, timeout, taskSettings, listener); } else if (TaskType.SPARSE_EMBEDDING.equals(taskType)) { inferSparseEmbedding(esModel, input, inputType, timeout, listener); } else { @@ -693,6 +695,8 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi ElasticsearchInternalModel model, String query, List inputs, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, InputType inputType, TimeValue timeout, Map requestTaskSettings, @@ -701,7 +705,9 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi var request = buildInferenceRequest(model.mlNodeDeploymentId(), new TextSimilarityConfigUpdate(query), inputs, inputType, timeout); var returnDocs = Boolean.TRUE; - if (model.getTaskSettings() instanceof RerankTaskSettings modelSettings) { + if (returnDocuments != null) { + returnDocs = returnDocuments; + } else if (model.getTaskSettings() instanceof RerankTaskSettings modelSettings) { var requestSettings = RerankTaskSettings.fromMap(requestTaskSettings); returnDocs = RerankTaskSettings.of(modelSettings, requestSettings).returnDocuments(); } @@ -709,7 +715,9 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi Function inputSupplier = returnDocs == Boolean.TRUE ? inputs::get : i -> null; ActionListener mlResultsListener = listener.delegateFailureAndWrap( - (l, inferenceResult) -> l.onResponse(textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier)) + (l, inferenceResult) -> l.onResponse( + textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier, topN) + ) ); var maybeDeployListener = mlResultsListener.delegateResponse( @@ -824,7 +832,8 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi private RankedDocsResults textSimilarityResultsToRankedDocs( List results, - Function inputSupplier + Function inputSupplier, + @Nullable Integer topN ) { List rankings = new ArrayList<>(results.size()); for (int i = 0; i < results.size(); i++) { @@ -851,7 +860,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi } Collections.sort(rankings); - return new RankedDocsResults(rankings); + return new RankedDocsResults(topN != null ? rankings.subList(0, topN) : rankings); } public List defaultConfigIds() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java index 08cb9933c2b3..4c48e3018b95 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java @@ -30,6 +30,8 @@ public class SimpleServiceIntegrationValidator implements ServiceIntegrationVali service.infer( model, model.getTaskType().equals(TaskType.RERANK) ? QUERY : null, + null, + null, TEST_INPUT, false, Map.of(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java index caea68bd861d..aa0f89b4a382 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java @@ -423,9 +423,9 @@ public abstract class BaseTransportInferenceActionTestCase { - listenerAction.accept(ans.getArgument(7)); + listenerAction.accept(ans.getArgument(9)); return null; - }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any()); + }).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any()); doAnswer(ans -> { listenerAction.accept(ans.getArgument(3)); return null; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java index 12d67ae3dc96..e91b0b3451a7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java @@ -23,7 +23,7 @@ public class InferenceInputsTests extends ESTestCase { var emptyRequest = new UnifiedCompletionRequest(List.of(), null, null, null, null, null, null, null); assertThat(new UnifiedChatInput(emptyRequest, false).castTo(UnifiedChatInput.class), Matchers.instanceOf(UnifiedChatInput.class)); assertThat( - new QueryAndDocsInputs("hello", List.of(), false).castTo(QueryAndDocsInputs.class), + new QueryAndDocsInputs("hello", List.of(), Boolean.TRUE, 33, false).castTo(QueryAndDocsInputs.class), Matchers.instanceOf(QueryAndDocsInputs.class) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntityTests.java index 8f981d64d36e..0d48e7692b2e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/alibabacloudsearch/AlibabaCloudSearchRerankRequestEntityTests.java @@ -22,7 +22,13 @@ import static org.hamcrest.CoreMatchers.is; public class AlibabaCloudSearchRerankRequestEntityTests extends ESTestCase { public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { - var entity = new AlibabaCloudSearchRerankRequestEntity("query", List.of("abc"), new AlibabaCloudSearchRerankTaskSettings()); + var entity = new AlibabaCloudSearchRerankRequestEntity( + "query", + List.of("abc"), + Boolean.TRUE, + 22, + new AlibabaCloudSearchRerankTaskSettings() + ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequestEntityTests.java new file mode 100644 index 000000000000..c33d72d6bd74 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereRerankRequestEntityTests.java @@ -0,0 +1,95 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.cohere; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.List; + +import static org.hamcrest.CoreMatchers.is; + +public class CohereRerankRequestEntityTests extends ESTestCase { + public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { + var entity = new CohereRerankRequestEntity( + "query", + List.of("abc"), + Boolean.TRUE, + 22, + new CohereRerankTaskSettings(null, null, 3), + "model" + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"model":"model","query":"query","documents":["abc"],"return_documents":true,"top_n":22,"max_chunks_per_doc":3}""")); + } + + public void testXContent_WritesMinimalFields() throws IOException { + var entity = new CohereRerankRequestEntity( + "query", + List.of("abc"), + null, + null, + new CohereRerankTaskSettings(null, null, null), + "model" + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"model":"model","query":"query","documents":["abc"]}""")); + } + + public void testXContent_PrefersRootLevelReturnDocumentsAndTopN() throws IOException { + var entity = new CohereRerankRequestEntity( + "query", + List.of("abc"), + Boolean.FALSE, + 99, + new CohereRerankTaskSettings(33, Boolean.TRUE, null), + "model" + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"model":"model","query":"query","documents":["abc"],"return_documents":false,"top_n":99}""")); + } + + public void testXContent_UsesTaskSettingsIfNoRootOptionsDefined() throws IOException { + var entity = new CohereRerankRequestEntity( + "query", + List.of("abc"), + null, + null, + new CohereRerankTaskSettings(33, Boolean.TRUE, null), + "model" + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"model":"model","query":"query","documents":["abc"],"return_documents":true,"top_n":33}""")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestEntityTests.java index fd18d2573efc..764aedfc5a19 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestEntityTests.java @@ -20,8 +20,8 @@ import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhi import static org.hamcrest.MatcherAssert.assertThat; public class GoogleVertexAiRerankRequestEntityTests extends ESTestCase { - public void testXContent_SingleRequest_WritesModelAndTopNIfDefined() throws IOException { - var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc"), "model", 8); + public void testXContent_SingleRequest_WritesAllFieldsIfDefined() throws IOException { + var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc"), Boolean.TRUE, 10, "model"); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -37,13 +37,14 @@ public class GoogleVertexAiRerankRequestEntityTests extends ESTestCase { "content": "abc" } ], - "topN": 8 + "topN": 10, + "ignoreRecordDetailsInResponse": false } """)); } - public void testXContent_SingleRequest_DoesNotWriteModelAndTopNIfNull() throws IOException { - var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc"), null, null); + public void testXContent_SingleRequest_WritesMinimalFields() throws IOException { + var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc"), null, null, null); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -62,8 +63,8 @@ public class GoogleVertexAiRerankRequestEntityTests extends ESTestCase { """)); } - public void testXContent_MultipleRequests_WritesModelAndTopNIfDefined() throws IOException { - var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc", "def"), "model", 8); + public void testXContent_MultipleRequests_WritesAllFieldsIfDefined() throws IOException { + var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc", "def"), Boolean.FALSE, 12, "model"); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -83,13 +84,14 @@ public class GoogleVertexAiRerankRequestEntityTests extends ESTestCase { "content": "def" } ], - "topN": 8 + "topN": 12, + "ignoreRecordDetailsInResponse": true } """)); } - public void testXContent_MultipleRequests_DoesNotWriteModelAndTopNIfNull() throws IOException { - var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc", "def"), null, null); + public void testXContent_MultipleRequests_WritesMinimalFields() throws IOException { + var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc", "def"), null, null, null); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -111,5 +113,4 @@ public class GoogleVertexAiRerankRequestEntityTests extends ESTestCase { } """)); } - } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestTests.java index 811adb6612a4..20aa270c0808 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestTests.java @@ -29,11 +29,11 @@ public class GoogleVertexAiRerankRequestTests extends ESTestCase { private static final String AUTH_HEADER_VALUE = "foo"; - public void testCreateRequest_WithoutModelSet_And_WithoutTopNSet() throws IOException { + public void testCreateRequest_WithMinimalFieldsSet() throws IOException { var input = "input"; var query = "query"; - var request = createRequest(query, input, null, null); + var request = createRequest(query, input, null, null, null, null); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -53,8 +53,9 @@ public class GoogleVertexAiRerankRequestTests extends ESTestCase { var input = "input"; var query = "query"; var topN = 1; + var taskSettingsTopN = 3; - var request = createRequest(query, input, null, topN); + var request = createRequest(query, input, null, topN, null, taskSettingsTopN); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -71,12 +72,55 @@ public class GoogleVertexAiRerankRequestTests extends ESTestCase { assertThat(requestMap.get("topN"), is(topN)); } + public void testCreateRequest_UsesTaskSettingsTopNWhenRootLevelIsNull() throws IOException { + var input = "input"; + var query = "query"; + var topN = 1; + + var request = createRequest(query, input, null, null, null, topN); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE)); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + + assertThat(requestMap, aMapWithSize(3)); + assertThat(requestMap.get("records"), is(List.of(Map.of("id", "0", "content", input)))); + assertThat(requestMap.get("query"), is(query)); + assertThat(requestMap.get("topN"), is(topN)); + } + + public void testCreateRequest_WithReturnDocumentsSet() throws IOException { + var input = "input"; + var query = "query"; + + var request = createRequest(query, input, null, null, Boolean.TRUE, null); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE)); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + + assertThat(requestMap, aMapWithSize(3)); + assertThat(requestMap.get("records"), is(List.of(Map.of("id", "0", "content", input)))); + assertThat(requestMap.get("query"), is(query)); + assertThat(requestMap.get("ignoreRecordDetailsInResponse"), is(Boolean.FALSE)); + } + public void testCreateRequest_WithModelSet() throws IOException { var input = "input"; var query = "query"; var modelId = "model"; - var request = createRequest(query, input, modelId, null); + var request = createRequest(query, input, modelId, null, null, null); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -94,24 +138,37 @@ public class GoogleVertexAiRerankRequestTests extends ESTestCase { } public void testTruncate_DoesNotTruncate() { - var request = createRequest("query", "input", null, null); + var request = createRequest("query", "input", null, null, null, null); var truncatedRequest = request.truncate(); assertThat(truncatedRequest, sameInstance(request)); } - private static GoogleVertexAiRerankRequest createRequest(String query, String input, @Nullable String modelId, @Nullable Integer topN) { - var rerankModel = GoogleVertexAiRerankModelTests.createModel(modelId, topN); + private static GoogleVertexAiRerankRequest createRequest( + String query, + String input, + @Nullable String modelId, + @Nullable Integer topN, + @Nullable Boolean returnDocuments, + @Nullable Integer taskSettingsTopN + ) { + var rerankModel = GoogleVertexAiRerankModelTests.createModel(modelId, taskSettingsTopN); - return new GoogleVertexAiRerankWithoutAuthRequest(query, List.of(input), rerankModel); + return new GoogleVertexAiRerankWithoutAuthRequest(query, List.of(input), rerankModel, topN, returnDocuments); } /** * We use this class to fake the auth implementation to avoid static mocking of {@link GoogleVertexAiRequest} */ private static class GoogleVertexAiRerankWithoutAuthRequest extends GoogleVertexAiRerankRequest { - GoogleVertexAiRerankWithoutAuthRequest(String query, List input, GoogleVertexAiRerankModel model) { - super(query, input, model); + GoogleVertexAiRerankWithoutAuthRequest( + String query, + List input, + GoogleVertexAiRerankModel model, + @Nullable Integer topN, + @Nullable Boolean returnDocuments + ) { + super(query, input, returnDocuments, topN, model); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestEntityTests.java index 7fd738fa2a8e..11f2810e13e0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestEntityTests.java @@ -21,8 +21,15 @@ import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhi import static org.hamcrest.MatcherAssert.assertThat; public class JinaAIRerankRequestEntityTests extends ESTestCase { - public void testXContent_SingleRequest_WritesModelAndTopNIfDefined() throws IOException { - var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), new JinaAIRerankTaskSettings(8, null), "model"); + public void testXContent_SingleRequest_WritesAllFieldsIfDefined() throws IOException { + var entity = new JinaAIRerankRequestEntity( + "query", + List.of("abc"), + Boolean.TRUE, + 12, + new JinaAIRerankTaskSettings(8, Boolean.FALSE), + "model" + ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -35,33 +42,86 @@ public class JinaAIRerankRequestEntityTests extends ESTestCase { "documents": [ "abc" ], - "top_n": 8 - } - """)); - } - - public void testXContent_SingleRequest_WritesModelAndTopNIfDefined_ReturnDocumentsTrue() throws IOException { - var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), new JinaAIRerankTaskSettings(8, true), "model"); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" - { - "model": "model", - "query": "query", - "documents": [ - "abc" - ], - "top_n": 8, + "top_n": 12, "return_documents": true } """)); } - public void testXContent_SingleRequest_WritesModelAndTopNIfDefined_ReturnDocumentsFalse() throws IOException { - var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), new JinaAIRerankTaskSettings(8, false), "model"); + public void testXContent_SingleRequest_WritesMinimalFields() throws IOException { + var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), null, null, new JinaAIRerankTaskSettings(null, null), "model"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "documents": [ + "abc" + ] + } + """)); + } + + public void testXContent_MultipleRequests_WritesAllFieldsIfDefined() throws IOException { + var entity = new JinaAIRerankRequestEntity( + "query", + List.of("abc", "def"), + Boolean.FALSE, + 12, + new JinaAIRerankTaskSettings(8, Boolean.TRUE), + "model" + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "documents": [ + "abc", + "def" + ], + "top_n": 12, + "return_documents": false + } + """)); + } + + public void testXContent_MultipleRequests_WritesMinimalFields() throws IOException { + var entity = new JinaAIRerankRequestEntity("query", List.of("abc", "def"), null, null, null, "model"); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "documents": [ + "abc", + "def" + ] + } + """)); + } + + public void testXContent_SingleRequest_UsesTaskSettingsTopNIfRootIsNotDefined() throws IOException { + var entity = new JinaAIRerankRequestEntity( + "query", + List.of("abc"), + null, + null, + new JinaAIRerankTaskSettings(8, Boolean.FALSE), + "model" + ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -80,61 +140,4 @@ public class JinaAIRerankRequestEntityTests extends ESTestCase { """)); } - public void testXContent_SingleRequest_DoesNotWriteTopNIfNull() throws IOException { - var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), null, "model"); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" - { - "model": "model", - "query": "query", - "documents": [ - "abc" - ] - } - """)); - } - - public void testXContent_MultipleRequests_WritesModelAndTopNIfDefined() throws IOException { - var entity = new JinaAIRerankRequestEntity("query", List.of("abc", "def"), new JinaAIRerankTaskSettings(8, null), "model"); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" - { - "model": "model", - "query": "query", - "documents": [ - "abc", - "def" - ], - "top_n": 8 - } - """)); - } - - public void testXContent_MultipleRequests_DoesNotWriteTopNIfNull() throws IOException { - var entity = new JinaAIRerankRequestEntity("query", List.of("abc", "def"), null, "model"); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" - { - "model": "model", - "query": "query", - "documents": [ - "abc", - "def" - ] - } - """)); - } - } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestTests.java index 819362d397ba..439bcf3ae006 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/jinaai/JinaAIRerankRequestTests.java @@ -27,12 +27,12 @@ public class JinaAIRerankRequestTests extends ESTestCase { private static final String API_KEY = "foo"; - public void testCreateRequest_WithoutModelSet_And_WithoutTopNSet() throws IOException { + public void testCreateRequest_WithMinimalFieldsSet() throws IOException { var input = "input"; var query = "query"; var modelId = "model"; - var request = createRequest(query, input, modelId, null); + var request = createRequest(query, input, modelId, null, null, null); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -49,13 +49,14 @@ public class JinaAIRerankRequestTests extends ESTestCase { assertThat(requestMap.get("model"), is(modelId)); } - public void testCreateRequest_WithTopNSet() throws IOException { + public void testCreateRequest_WithAllFieldsSet() throws IOException { var input = "input"; var query = "query"; var topN = 1; + var taskSettingsTopN = 2; var modelId = "model"; - var request = createRequest(query, input, modelId, topN); + var request = createRequest(query, input, modelId, topN, Boolean.FALSE, taskSettingsTopN); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -66,10 +67,11 @@ public class JinaAIRerankRequestTests extends ESTestCase { var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(4)); + assertThat(requestMap, aMapWithSize(5)); assertThat(requestMap.get("documents"), is(List.of(input))); assertThat(requestMap.get("query"), is(query)); assertThat(requestMap.get("top_n"), is(topN)); + assertThat(requestMap.get("return_documents"), is(Boolean.FALSE)); assertThat(requestMap.get("model"), is(modelId)); } @@ -78,7 +80,7 @@ public class JinaAIRerankRequestTests extends ESTestCase { var query = "query"; var modelId = "model"; - var request = createRequest(query, input, modelId, null); + var request = createRequest(query, input, modelId, null, null, null); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -96,15 +98,22 @@ public class JinaAIRerankRequestTests extends ESTestCase { } public void testTruncate_DoesNotTruncate() { - var request = createRequest("query", "input", "null", null); + var request = createRequest("query", "input", "null", null, null, null); var truncatedRequest = request.truncate(); assertThat(truncatedRequest, sameInstance(request)); } - private static JinaAIRerankRequest createRequest(String query, String input, @Nullable String modelId, @Nullable Integer topN) { - var rerankModel = JinaAIRerankModelTests.createModel(API_KEY, modelId, topN); - return new JinaAIRerankRequest(query, List.of(input), rerankModel); + private static JinaAIRerankRequest createRequest( + String query, + String input, + @Nullable String modelId, + @Nullable Integer topN, + @Nullable Boolean returnDocuments, + @Nullable Integer taskSettingsTopN + ) { + var rerankModel = JinaAIRerankModelTests.createModel(API_KEY, modelId, taskSettingsTopN); + return new JinaAIRerankRequest(query, List.of(input), returnDocuments, topN, rerankModel); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntityTests.java index ae431b4b7bb1..f05e9052861f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestEntityTests.java @@ -20,27 +20,15 @@ import java.util.List; import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; public class VoyageAIRerankRequestEntityTests extends ESTestCase { - public void testXContent_SingleRequest_WritesModelAndTopKIfDefined() throws IOException { - var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, null, null), "model"); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" - { - "model": "model", - "query": "query", - "documents": [ - "abc" - ], - "top_k": 8 - } - """)); - } - - public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_ReturnDocumentsTrue() throws IOException { - var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, true, null), "model"); + public void testXContent_SingleRequest_WritesAllFieldsIfDefined() throws IOException { + var entity = new VoyageAIRerankRequestEntity( + "query", + List.of("abc"), + Boolean.TRUE, + 12, + new VoyageAIRerankTaskSettings(8, Boolean.FALSE, null), + "model" + ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -54,13 +42,20 @@ public class VoyageAIRerankRequestEntityTests extends ESTestCase { "abc" ], "return_documents": true, - "top_k": 8 + "top_k": 12 } """)); } - public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_ReturnDocumentsFalse() throws IOException { - var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, null), "model"); + public void testXContent_SingleRequest_WritesMinimalFields() throws IOException { + var entity = new VoyageAIRerankRequestEntity( + "query", + List.of("abc"), + null, + null, + new VoyageAIRerankTaskSettings(null, true, null), + "model" + ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -73,14 +68,20 @@ public class VoyageAIRerankRequestEntityTests extends ESTestCase { "documents": [ "abc" ], - "return_documents": false, - "top_k": 8 + "return_documents": true } """)); } public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_TruncationTrue() throws IOException { - var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, true), "model"); + var entity = new VoyageAIRerankRequestEntity( + "query", + List.of("abc"), + null, + null, + new VoyageAIRerankTaskSettings(8, false, true), + "model" + ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -101,7 +102,14 @@ public class VoyageAIRerankRequestEntityTests extends ESTestCase { } public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_TruncationFalse() throws IOException { - var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, false), "model"); + var entity = new VoyageAIRerankRequestEntity( + "query", + List.of("abc"), + null, + null, + new VoyageAIRerankTaskSettings(8, false, false), + "model" + ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -121,28 +129,12 @@ public class VoyageAIRerankRequestEntityTests extends ESTestCase { """)); } - public void testXContent_SingleRequest_DoesNotWriteTopKIfNull() throws IOException { - var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), null, "model"); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" - { - "model": "model", - "query": "query", - "documents": [ - "abc" - ] - } - """)); - } - - public void testXContent_MultipleRequests_WritesModelAndTopKIfDefined() throws IOException { + public void testXContent_MultipleRequests_WritesAllFieldsIfDefined() throws IOException { var entity = new VoyageAIRerankRequestEntity( "query", List.of("abc", "def"), + Boolean.FALSE, + 11, new VoyageAIRerankTaskSettings(8, null, null), "model" ); @@ -159,13 +151,14 @@ public class VoyageAIRerankRequestEntityTests extends ESTestCase { "abc", "def" ], - "top_k": 8 + "return_documents": false, + "top_k": 11 } """)); } public void testXContent_MultipleRequests_DoesNotWriteTopKIfNull() throws IOException { - var entity = new VoyageAIRerankRequestEntity("query", List.of("abc", "def"), null, "model"); + var entity = new VoyageAIRerankRequestEntity("query", List.of("abc", "def"), null, null, null, "model"); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); @@ -183,4 +176,31 @@ public class VoyageAIRerankRequestEntityTests extends ESTestCase { """)); } + public void testXContent_UsesTaskSettingsTopNIfRootIsNotDefined() throws IOException { + var entity = new VoyageAIRerankRequestEntity( + "query", + List.of("abc"), + null, + null, + new VoyageAIRerankTaskSettings(8, Boolean.FALSE, null), + "model" + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model": "model", + "query": "query", + "documents": [ + "abc" + ], + "return_documents": false, + "top_k": 8 + } + """)); + } + } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java index a11d259200b9..00237496304d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/voyageai/VoyageAIRerankRequestTests.java @@ -27,12 +27,12 @@ public class VoyageAIRerankRequestTests extends ESTestCase { private static final String API_KEY = "foo"; - public void testCreateRequest_WithoutModelSet_And_WithoutTopNSet() throws IOException { + public void testCreateRequest_WithMinimalFields() throws IOException { var input = "input"; var query = "query"; var modelId = "model"; - var request = createRequest(query, input, modelId, null); + var request = createRequest(query, input, modelId, null, null, null); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -49,13 +49,14 @@ public class VoyageAIRerankRequestTests extends ESTestCase { assertThat(requestMap.get("model"), is(modelId)); } - public void testCreateRequest_WithTopNSet() throws IOException { + public void testCreateRequest_WithAllFieldsDefined() throws IOException { var input = "input"; var query = "query"; var topK = 1; + var taskSettingsTopK = 2; var modelId = "model"; - var request = createRequest(query, input, modelId, topK); + var request = createRequest(query, input, modelId, topK, Boolean.FALSE, taskSettingsTopK); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -66,11 +67,12 @@ public class VoyageAIRerankRequestTests extends ESTestCase { var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(4)); + assertThat(requestMap, aMapWithSize(5)); assertThat(requestMap.get("documents"), is(List.of(input))); assertThat(requestMap.get("query"), is(query)); assertThat(requestMap.get("top_k"), is(topK)); assertThat(requestMap.get("model"), is(modelId)); + assertThat(requestMap.get("return_documents"), is(Boolean.FALSE)); } public void testCreateRequest_WithModelSet() throws IOException { @@ -78,7 +80,7 @@ public class VoyageAIRerankRequestTests extends ESTestCase { var query = "query"; var modelId = "model"; - var request = createRequest(query, input, modelId, null); + var request = createRequest(query, input, modelId, null, null, null); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -96,15 +98,22 @@ public class VoyageAIRerankRequestTests extends ESTestCase { } public void testTruncate_DoesNotTruncate() { - var request = createRequest("query", "input", "null", null); + var request = createRequest("query", "input", "null", null, null, null); var truncatedRequest = request.truncate(); assertThat(truncatedRequest, sameInstance(request)); } - private static VoyageAIRerankRequest createRequest(String query, String input, @Nullable String modelId, @Nullable Integer topK) { - var rerankModel = VoyageAIRerankModelTests.createModel(API_KEY, modelId, topK); - return new VoyageAIRerankRequest(query, List.of(input), rerankModel); + private static VoyageAIRerankRequest createRequest( + String query, + String input, + @Nullable String modelId, + @Nullable Integer topK, + @Nullable Boolean returnDocuments, + @Nullable Integer taskSettingsTopK + ) { + var rerankModel = VoyageAIRerankModelTests.createModel(API_KEY, modelId, taskSettingsTopK); + return new VoyageAIRerankRequest(query, List.of(input), returnDocuments, topK, rerankModel); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntityTests.java index 7ff79e261842..eba6887fe5c4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntityTests.java @@ -42,6 +42,26 @@ public class GoogleVertexAiRerankResponseEntityTests extends ESTestCase { assertThat(parsedResults.getRankedDocs(), is(List.of(new RankedDocsResults.RankedDoc(2, 0.97F, "content 2")))); } + public void testFromResponse_CreatesResultsForASingleItem_NoContent() throws IOException { + String responseJson = """ + { + "records": [ + { + "id": "2", + "title": "title 2", + "score": 0.97 + } + ] + } + """; + + RankedDocsResults parsedResults = GoogleVertexAiRerankResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(parsedResults.getRankedDocs(), is(List.of(new RankedDocsResults.RankedDoc(2, 0.97F, null)))); + } + public void testFromResponse_CreatesResultsForMultipleItems() throws IOException { String responseJson = """ { @@ -72,6 +92,34 @@ public class GoogleVertexAiRerankResponseEntityTests extends ESTestCase { ); } + public void testFromResponse_CreatesResultsForMultipleItems_NoContent() throws IOException { + String responseJson = """ + { + "records": [ + { + "id": "2", + "title": "title 2", + "score": 0.97 + }, + { + "id": "1", + "title": "title 1", + "score": 0.90 + } + ] + } + """; + + RankedDocsResults parsedResults = GoogleVertexAiRerankResponseEntity.fromResponse( + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + parsedResults.getRankedDocs(), + is(List.of(new RankedDocsResults.RankedDoc(2, 0.97F, null), new RankedDocsResults.RankedDoc(1, 0.90F, null))) + ); + } + public void testFromResponse_FailsWhenRecordsFieldIsNotPresent() { String responseJson = """ { @@ -102,36 +150,6 @@ public class GoogleVertexAiRerankResponseEntityTests extends ESTestCase { assertThat(thrownException.getMessage(), is("Failed to find required field [records] in Google Vertex AI rerank response")); } - public void testFromResponse_FailsWhenContentFieldIsNotPresent() { - String responseJson = """ - { - "records": [ - { - "id": "2", - "title": "title 2", - "content": "content 2", - "score": 0.97 - }, - { - "id": "1", - "title": "title 1", - "not_content": "content 1", - "score": 0.97 - } - ] - } - """; - - var thrownException = expectThrows( - IllegalStateException.class, - () -> GoogleVertexAiRerankResponseEntity.fromResponse( - new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) - ) - ); - - assertThat(thrownException.getMessage(), is("Failed to find required field [content] in Google Vertex AI rerank response")); - } - public void testFromResponse_FailsWhenScoreFieldIsNotPresent() { String responseJson = """ { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java index c4b3c0783992..ae6e5fb5a53a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java @@ -98,6 +98,8 @@ public class TextSimilarityRankTests extends ESSingleNodeTestCase { TaskType.RERANK, this.inferenceId, inferenceText, + null, + null, docFeatures, Map.of("inferenceResultCount", inferenceResultCount), InputType.INTERNAL_SEARCH, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java index 0d821f411d0b..dc0e2cc10501 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java @@ -225,6 +225,8 @@ public class TextSimilarityTestPlugin extends Plugin implements ActionPlugin { TaskType.RERANK, inferenceId, inferenceText, + null, + null, docFeatures, Map.of("throwing", true), InputType.INTERNAL_SEARCH, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java index dfc64b8fb932..190520fbc3b6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java @@ -910,11 +910,11 @@ public class ServiceUtilsTests extends ESTestCase { when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(7); + ActionListener listener = invocation.getArgument(9); listener.onResponse(new TextEmbeddingFloatResults(List.of())); return Void.TYPE; - }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any()); + }).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any()); PlainActionFuture listener = new PlainActionFuture<>(); getEmbeddingSize(model, service, listener); @@ -932,11 +932,11 @@ public class ServiceUtilsTests extends ESTestCase { when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(7); + ActionListener listener = invocation.getArgument(9); listener.onResponse(new TextEmbeddingByteResults(List.of())); return Void.TYPE; - }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any()); + }).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any()); PlainActionFuture listener = new PlainActionFuture<>(); getEmbeddingSize(model, service, listener); @@ -956,11 +956,11 @@ public class ServiceUtilsTests extends ESTestCase { var textEmbedding = TextEmbeddingFloatResultsTests.createRandomResults(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(7); + ActionListener listener = invocation.getArgument(9); listener.onResponse(textEmbedding); return Void.TYPE; - }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any()); + }).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any()); PlainActionFuture listener = new PlainActionFuture<>(); getEmbeddingSize(model, service, listener); @@ -979,11 +979,11 @@ public class ServiceUtilsTests extends ESTestCase { var textEmbedding = TextEmbeddingByteResultsTests.createRandomResults(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(7); + ActionListener listener = invocation.getArgument(9); listener.onResponse(textEmbedding); return Void.TYPE; - }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any()); + }).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any()); PlainActionFuture listener = new PlainActionFuture<>(); getEmbeddingSize(model, service, listener); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index 2e0f64ed4ef9..bcf2fb85ae9d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -389,6 +389,8 @@ public class AlibabaCloudSearchServiceTests extends ESTestCase { () -> service.infer( model, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -431,6 +433,8 @@ public class AlibabaCloudSearchServiceTests extends ESTestCase { () -> service.infer( model, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -446,6 +450,53 @@ public class AlibabaCloudSearchServiceTests extends ESTestCase { } } + public void testInfer_ThrowsValidationErrorForInvalidRerankParams() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + Map serviceSettingsMap = new HashMap<>(); + serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.SERVICE_ID, "service_id"); + serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.HOST, "host"); + serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, "default"); + serviceSettingsMap.put(ServiceFields.DIMENSIONS, 1536); + + Map taskSettingsMap = new HashMap<>(); + + Map secretSettingsMap = new HashMap<>(); + secretSettingsMap.put("api_key", "secret"); + + var model = AlibabaCloudSearchEmbeddingsModelTests.createModel( + "service", + TaskType.RERANK, + serviceSettingsMap, + taskSettingsMap, + secretSettingsMap + ); + try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) { + PlainActionFuture listener = new PlainActionFuture<>(); + var thrownException = expectThrows( + ValidationException.class, + () -> service.infer( + model, + "hi", + Boolean.TRUE, + 10, + List.of("a"), + false, + new HashMap<>(), + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ) + ); + assertThat( + thrownException.getMessage(), + is( + "Validation Failed: 1: Invalid return_documents [true]. The return_documents option is not supported by this " + + "service;2: Invalid top_n [10]. The top_n option is not supported by this service;" + ) + ); + } + } + public void testChunkedInfer_TextEmbeddingChunkingSettingsSet() throws IOException { testChunkedInfer(TaskType.TEXT_EMBEDDING, ChunkingSettingsTests.createRandomChunkingSettings()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index 1b9bba3fa1b0..0ec3799cb7dc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -932,6 +932,8 @@ public class AmazonBedrockServiceTests extends ESTestCase { service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -979,6 +981,8 @@ public class AmazonBedrockServiceTests extends ESTestCase { () -> service.infer( model, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -1029,6 +1033,8 @@ public class AmazonBedrockServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1071,6 +1077,8 @@ public class AmazonBedrockServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1414,6 +1422,8 @@ public class AmazonBedrockServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java index 405aba35e8d3..71e35aa211d3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java @@ -458,6 +458,8 @@ public class AnthropicServiceTests extends ESTestCase { service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -513,6 +515,8 @@ public class AnthropicServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("input"), false, new HashMap<>(), @@ -571,6 +575,8 @@ public class AnthropicServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), true, new HashMap<>(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index e20fc54598aa..00cfa2a53f8b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -1096,6 +1096,8 @@ public class AzureAiStudioServiceTests extends ESTestCase { service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -1134,6 +1136,8 @@ public class AzureAiStudioServiceTests extends ESTestCase { () -> service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -1296,6 +1300,8 @@ public class AzureAiStudioServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1347,6 +1353,8 @@ public class AzureAiStudioServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1403,6 +1411,8 @@ public class AzureAiStudioServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), true, new HashMap<>(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index 2e0c0d04fa9c..ffda34f0e8fd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -766,6 +766,8 @@ public class AzureOpenAiServiceTests extends ESTestCase { service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -822,6 +824,8 @@ public class AzureOpenAiServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1286,6 +1290,8 @@ public class AzureOpenAiServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1453,6 +1459,8 @@ public class AzureOpenAiServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), true, new HashMap<>(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index dec1052589c9..b17a8b29bce2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -788,6 +788,8 @@ public class CohereServiceTests extends ESTestCase { service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -856,6 +858,8 @@ public class CohereServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1147,6 +1151,8 @@ public class CohereServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1207,6 +1213,8 @@ public class CohereServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1281,6 +1289,8 @@ public class CohereServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), false, CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH, null), @@ -1353,6 +1363,8 @@ public class CohereServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1629,6 +1641,8 @@ public class CohereServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), true, new HashMap<>(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java index aa1313793274..88a2fc76aadc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java @@ -232,7 +232,7 @@ public class DeepSeekServiceTests extends ESTestCase { try (var service = createService()) { var model = createModel(service, TaskType.COMPLETION); PlainActionFuture listener = new PlainActionFuture<>(); - service.infer(model, null, List.of("hello"), false, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener); + service.infer(model, null, null, null, List.of("hello"), false, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); assertThat(result, isA(ChatCompletionResults.class)); var completionResults = (ChatCompletionResults) result; @@ -255,7 +255,7 @@ public class DeepSeekServiceTests extends ESTestCase { try (var service = createService()) { var model = createModel(service, TaskType.COMPLETION); PlainActionFuture listener = new PlainActionFuture<>(); - service.infer(model, null, List.of("hello"), true, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener); + service.infer(model, null, null, null, List.of("hello"), true, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener); InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream().hasNoErrors().hasEvent(""" {"completion":[{"delta":"hello, world"}]}"""); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index f9fb6521e979..4f61269fcc6c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -368,6 +368,8 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -404,6 +406,8 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -443,6 +447,8 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -494,6 +500,8 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { service.infer( model, null, + null, + null, List.of("input text"), false, new HashMap<>(), @@ -551,6 +559,8 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { service.infer( model, null, + null, + null, List.of("input text"), false, new HashMap<>(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index 3a50e716ab16..a1430d36a0f5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -662,6 +662,8 @@ public class GoogleAiStudioServiceTests extends ESTestCase { service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -700,6 +702,8 @@ public class GoogleAiStudioServiceTests extends ESTestCase { () -> service.infer( model, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -775,6 +779,8 @@ public class GoogleAiStudioServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("input"), false, new HashMap<>(), @@ -832,6 +838,8 @@ public class GoogleAiStudioServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of(input), false, new HashMap<>(), @@ -1005,6 +1013,8 @@ public class GoogleAiStudioServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java index fe8172cf5db0..3be4b72c1237 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java @@ -65,6 +65,8 @@ public class HuggingFaceBaseServiceTests extends ESTestCase { service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index 7c4b0de656c3..957532149492 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -556,6 +556,8 @@ public class HuggingFaceServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -593,6 +595,8 @@ public class HuggingFaceServiceTests extends ESTestCase { () -> service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -627,6 +631,8 @@ public class HuggingFaceServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index 7a78dbce6310..3f508c6cbca5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -602,6 +602,8 @@ public class IbmWatsonxServiceTests extends ESTestCase { service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -641,6 +643,8 @@ public class IbmWatsonxServiceTests extends ESTestCase { () -> service.infer( model, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -697,6 +701,8 @@ public class IbmWatsonxServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of(input), false, new HashMap<>(), @@ -840,6 +846,8 @@ public class IbmWatsonxServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index fabcca09d3e3..e1446c36b893 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -782,6 +782,8 @@ public class JinaAIServiceTests extends ESTestCase { service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -1044,6 +1046,8 @@ public class JinaAIServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1076,6 +1080,8 @@ public class JinaAIServiceTests extends ESTestCase { service.infer( model, "query", + null, + null, List.of("candidate1", "candidate2"), false, new HashMap<>(), @@ -1132,6 +1138,8 @@ public class JinaAIServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1201,6 +1209,8 @@ public class JinaAIServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1254,6 +1264,8 @@ public class JinaAIServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1320,7 +1332,18 @@ public class JinaAIServiceTests extends ESTestCase { JinaAIEmbeddingType.FLOAT ); PlainActionFuture listener = new PlainActionFuture<>(); - service.infer(model, null, List.of("abc"), false, new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, listener); + service.infer( + model, + null, + null, + null, + List.of("abc"), + false, + new HashMap<>(), + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var result = listener.actionGet(TIMEOUT); @@ -1371,6 +1394,8 @@ public class JinaAIServiceTests extends ESTestCase { service.infer( model, "query", + null, + null, List.of("candidate1", "candidate2", "candidate3"), false, new HashMap<>(), @@ -1454,6 +1479,8 @@ public class JinaAIServiceTests extends ESTestCase { service.infer( model, "query", + null, + null, List.of("candidate1", "candidate2", "candidate3", "candidate4"), false, new HashMap<>(), @@ -1549,6 +1576,8 @@ public class JinaAIServiceTests extends ESTestCase { service.infer( model, "query", + null, + null, List.of("candidate1", "candidate2", "candidate3"), false, new HashMap<>(), @@ -1630,6 +1659,8 @@ public class JinaAIServiceTests extends ESTestCase { service.infer( model, "query", + null, + null, List.of("candidate1", "candidate2", "candidate3", "candidate4"), false, new HashMap<>(), @@ -1724,6 +1755,8 @@ public class JinaAIServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index a2ee005d719f..db771f13cc0a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -586,6 +586,8 @@ public class MistralServiceTests extends ESTestCase { service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -625,6 +627,8 @@ public class MistralServiceTests extends ESTestCase { () -> service.infer( model, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -781,6 +785,8 @@ public class MistralServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 219ff210cfa9..70452db7c171 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -852,6 +852,8 @@ public class OpenAiServiceTests extends ESTestCase { service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -890,6 +892,8 @@ public class OpenAiServiceTests extends ESTestCase { () -> service.infer( model, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -925,6 +929,8 @@ public class OpenAiServiceTests extends ESTestCase { service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -964,6 +970,8 @@ public class OpenAiServiceTests extends ESTestCase { service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -1024,6 +1032,8 @@ public class OpenAiServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1263,6 +1273,8 @@ public class OpenAiServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), true, new HashMap<>(), @@ -1794,6 +1806,8 @@ public class OpenAiServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java index 44de4a3d9ccd..9ee2201b4f02 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java @@ -63,6 +63,8 @@ public class SimpleServiceIntegrationValidatorTests extends ESTestCase { .infer( eq(mockModel), eq(null), + eq(null), + eq(null), eq(TEST_INPUT), eq(false), eq(Map.of()), @@ -97,13 +99,15 @@ public class SimpleServiceIntegrationValidatorTests extends ESTestCase { private void mockSuccessfulCallToService(String query, InferenceServiceResults result) { doAnswer(ans -> { - ActionListener responseListener = ans.getArgument(7); + ActionListener responseListener = ans.getArgument(9); responseListener.onResponse(result); return null; }).when(mockInferenceService) .infer( eq(mockModel), eq(query), + eq(null), + eq(null), eq(TEST_INPUT), eq(false), eq(Map.of()), @@ -120,6 +124,8 @@ public class SimpleServiceIntegrationValidatorTests extends ESTestCase { verify(mockInferenceService).infer( eq(mockModel), eq(withQuery ? TEST_QUERY : null), + eq(null), + eq(null), eq(TEST_INPUT), eq(false), eq(Map.of()), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index bddcc27194c4..521d042bb861 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -722,6 +722,8 @@ public class VoyageAIServiceTests extends ESTestCase { service.infer( mockModel, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -768,6 +770,8 @@ public class VoyageAIServiceTests extends ESTestCase { () -> service.infer( model, null, + null, + null, List.of(""), false, new HashMap<>(), @@ -1017,6 +1021,8 @@ public class VoyageAIServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1049,6 +1055,8 @@ public class VoyageAIServiceTests extends ESTestCase { service.infer( model, "query", + null, + null, List.of("candidate1", "candidate2"), false, new HashMap<>(), @@ -1103,6 +1111,8 @@ public class VoyageAIServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1183,6 +1193,8 @@ public class VoyageAIServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), @@ -1260,7 +1272,18 @@ public class VoyageAIServiceTests extends ESTestCase { (SimilarityMeasure) null ); PlainActionFuture listener = new PlainActionFuture<>(); - service.infer(model, null, List.of("abc"), false, new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, listener); + service.infer( + model, + null, + null, + null, + List.of("abc"), + false, + new HashMap<>(), + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var result = listener.actionGet(TIMEOUT); @@ -1315,6 +1338,8 @@ public class VoyageAIServiceTests extends ESTestCase { service.infer( model, "query", + null, + null, List.of("candidate1", "candidate2", "candidate3"), false, new HashMap<>(), @@ -1401,6 +1426,8 @@ public class VoyageAIServiceTests extends ESTestCase { service.infer( model, "query", + null, + null, List.of("candidate1", "candidate2", "candidate3", "candidate4"), false, new HashMap<>(), @@ -1493,6 +1520,8 @@ public class VoyageAIServiceTests extends ESTestCase { service.infer( model, "query", + null, + null, List.of("candidate1", "candidate2", "candidate3"), false, new HashMap<>(), @@ -1569,6 +1598,8 @@ public class VoyageAIServiceTests extends ESTestCase { service.infer( model, "query", + null, + null, List.of("candidate1", "candidate2", "candidate3", "candidate4"), false, new HashMap<>(), @@ -1663,6 +1694,8 @@ public class VoyageAIServiceTests extends ESTestCase { service.infer( model, null, + null, + null, List.of("abc"), false, new HashMap<>(), diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java index 65b5d3b3110f..a3ad59644772 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java @@ -123,6 +123,8 @@ public class TransportCoordinatedInferenceAction extends HandledTransportAction< TaskType.ANY, request.getModelId(), null, + null, + null, request.getInputs(), request.getTaskSettings(), inputType,