Adding common rerank options to Perform Inference API (#125239)

* wip

* Adding rerank common options

* Linting

* Linting

* [CI] Auto commit changes from spotless

* Update docs/changelog/125239.yaml

* PR feedback

---------

Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
This commit is contained in:
Ying Mao 2025-03-25 12:32:18 -04:00 committed by GitHub
parent af1f1452d7
commit a6f685cc2a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
66 changed files with 1306 additions and 287 deletions

View file

@ -0,0 +1,6 @@
pr: 125239
summary: Adding common rerank options to Perform Inference API
area: Machine Learning
type: enhancement
issues:
- 111273

View file

@ -155,6 +155,7 @@ public class TransportVersions {
public static final TransportVersion ESQL_AGGREGATE_METRIC_DOUBLE_LITERAL_8_19 = def(8_841_0_12);
public static final TransportVersion INFERENCE_MODEL_REGISTRY_METADATA_8_19 = def(8_841_0_13);
public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14);
public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
@ -201,6 +202,7 @@ public class TransportVersions {
public static final TransportVersion INDEXING_STATS_INCLUDES_RECENT_WRITE_LOAD = def(9_034_0_00);
public static final TransportVersion ESQL_AGGREGATE_METRIC_DOUBLE_LITERAL = def(9_035_0_00);
public static final TransportVersion INDEX_METADATA_INCLUDES_RECENT_WRITE_LOAD = def(9_036_0_00);
public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED = def(9_037_0_00);
/*
* STOP! READ THIS FIRST! No, really,

View file

@ -91,18 +91,22 @@ public interface InferenceService extends Closeable {
/**
* Perform inference on the model.
*
* @param model The model
* @param query Inference query, mainly for re-ranking
* @param input Inference input
* @param stream Stream inference results
* @param taskSettings Settings in the request to override the model's defaults
* @param inputType For search, ingest etc
* @param timeout The timeout for the request
* @param listener Inference result listener
* @param model The model
* @param query Inference query, mainly for re-ranking
* @param returnDocuments For re-ranking task type, whether to return documents
* @param topN For re-ranking task type, how many docs to return
* @param input Inference input
* @param stream Stream inference results
* @param taskSettings Settings in the request to override the model's defaults
* @param inputType For search, ingest etc
* @param timeout The timeout for the request
* @param listener Inference result listener
*/
void infer(
Model model,
@Nullable String query,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
List<String> input,
boolean stream,
Map<String, Object> taskSettings,

View file

@ -60,6 +60,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
public static final ParseField INPUT_TYPE = new ParseField("input_type");
public static final ParseField TASK_SETTINGS = new ParseField("task_settings");
public static final ParseField QUERY = new ParseField("query");
public static final ParseField RETURN_DOCUMENTS = new ParseField("return_documents");
public static final ParseField TOP_N = new ParseField("top_n");
public static final ParseField TIMEOUT = new ParseField("timeout");
static final ObjectParser<Request.Builder, Void> PARSER = new ObjectParser<>(NAME, Request.Builder::new);
@ -68,6 +70,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
PARSER.declareString(Request.Builder::setInputType, INPUT_TYPE);
PARSER.declareObject(Request.Builder::setTaskSettings, (p, c) -> p.mapOrdered(), TASK_SETTINGS);
PARSER.declareString(Request.Builder::setQuery, QUERY);
PARSER.declareBoolean(Request.Builder::setReturnDocuments, RETURN_DOCUMENTS);
PARSER.declareInt(Request.Builder::setTopN, TOP_N);
PARSER.declareString(Builder::setInferenceTimeout, TIMEOUT);
}
@ -89,6 +93,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
private final TaskType taskType;
private final String inferenceEntityId;
private final String query;
private final Boolean returnDocuments;
private final Integer topN;
private final List<String> input;
private final Map<String, Object> taskSettings;
private final InputType inputType;
@ -99,6 +105,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
TaskType taskType,
String inferenceEntityId,
String query,
Boolean returnDocuments,
Integer topN,
List<String> input,
Map<String, Object> taskSettings,
InputType inputType,
@ -109,6 +117,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
taskType,
inferenceEntityId,
query,
returnDocuments,
topN,
input,
taskSettings,
inputType,
@ -122,6 +132,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
TaskType taskType,
String inferenceEntityId,
String query,
Boolean returnDocuments,
Integer topN,
List<String> input,
Map<String, Object> taskSettings,
InputType inputType,
@ -133,6 +145,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
this.taskType = taskType;
this.inferenceEntityId = inferenceEntityId;
this.query = query;
this.returnDocuments = returnDocuments;
this.topN = topN;
this.input = input;
this.taskSettings = taskSettings;
this.inputType = inputType;
@ -164,6 +178,15 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
this.inferenceTimeout = DEFAULT_TIMEOUT;
}
if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_COMMON_OPTIONS_ADDED)
|| in.getTransportVersion().isPatchFrom(TransportVersions.RERANK_COMMON_OPTIONS_ADDED_8_19)) {
this.returnDocuments = in.readOptionalBoolean();
this.topN = in.readOptionalInt();
} else {
this.returnDocuments = null;
this.topN = null;
}
// streaming is not supported yet for transport traffic
this.stream = false;
}
@ -184,6 +207,14 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
return query;
}
public Boolean getReturnDocuments() {
return returnDocuments;
}
public Integer getTopN() {
return topN;
}
public Map<String, Object> getTaskSettings() {
return taskSettings;
}
@ -225,6 +256,17 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
e.addValidationError(format("Field [query] cannot be empty for task type [%s]", TaskType.RERANK));
return e;
}
} else if (taskType.equals(TaskType.ANY) == false) {
if (returnDocuments != null) {
var e = new ActionRequestValidationException();
e.addValidationError(format("Field [return_documents] cannot be specified for task type [%s]", taskType));
return e;
}
if (topN != null) {
var e = new ActionRequestValidationException();
e.addValidationError(format("Field [top_n] cannot be specified for task type [%s]", taskType));
return e;
}
}
if (taskType.equals(TaskType.TEXT_EMBEDDING) == false
@ -258,6 +300,12 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
out.writeOptionalString(query);
out.writeTimeValue(inferenceTimeout);
}
if (out.getTransportVersion().onOrAfter(TransportVersions.RERANK_COMMON_OPTIONS_ADDED)
|| out.getTransportVersion().isPatchFrom(TransportVersions.RERANK_COMMON_OPTIONS_ADDED_8_19)) {
out.writeOptionalBoolean(returnDocuments);
out.writeOptionalInt(topN);
}
}
// default for easier testing
@ -283,6 +331,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
&& taskType == request.taskType
&& Objects.equals(inferenceEntityId, request.inferenceEntityId)
&& Objects.equals(query, request.query)
&& Objects.equals(returnDocuments, request.returnDocuments)
&& Objects.equals(topN, request.topN)
&& Objects.equals(input, request.input)
&& Objects.equals(taskSettings, request.taskSettings)
&& inputType == request.inputType
@ -296,6 +346,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
taskType,
inferenceEntityId,
query,
returnDocuments,
topN,
input,
taskSettings,
inputType,
@ -312,6 +364,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
private InputType inputType = InputType.UNSPECIFIED;
private Map<String, Object> taskSettings = Map.of();
private String query;
private Boolean returnDocuments;
private Integer topN;
private TimeValue timeout = DEFAULT_TIMEOUT;
private boolean stream = false;
private InferenceContext context;
@ -338,6 +392,16 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
return this;
}
public Builder setReturnDocuments(Boolean returnDocuments) {
this.returnDocuments = returnDocuments;
return this;
}
public Builder setTopN(Integer topN) {
this.topN = topN;
return this;
}
public Builder setInputType(InputType inputType) {
this.inputType = inputType;
return this;
@ -373,7 +437,19 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
}
public Request build() {
return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream, context);
return new Request(
taskType,
inferenceEntityId,
query,
returnDocuments,
topN,
input,
taskSettings,
inputType,
timeout,
stream,
context
);
}
}
@ -384,6 +460,10 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
+ this.getInferenceEntityId()
+ ", query="
+ this.getQuery()
+ ", returnDocuments="
+ this.getReturnDocuments()
+ ", topN="
+ this.getTopN()
+ ", input="
+ this.getInput()
+ ", taskSettings="

View file

@ -44,6 +44,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
randomFrom(TaskType.values()),
randomAlphaOfLength(6),
randomAlphaOfLengthOrNull(10),
randomBoolean(),
randomIntBetween(0, 10),
randomList(1, 5, () -> randomAlphaOfLength(8)),
randomMap(0, 3, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))),
randomFrom(InputType.values()),
@ -85,6 +87,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
TaskType.TEXT_EMBEDDING,
"model",
null,
null,
null,
List.of("input"),
null,
null,
@ -100,6 +104,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
TaskType.RERANK,
"model",
"query",
Boolean.TRUE,
34,
List.of("input"),
null,
null,
@ -119,6 +125,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
null,
null,
null,
null,
null,
false
);
ActionRequestValidationException inputNullError = inputNullRequest.validate();
@ -131,6 +139,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
TaskType.TEXT_EMBEDDING,
"model",
null,
null,
null,
List.of(),
null,
null,
@ -142,11 +152,52 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
assertThat(inputEmptyError.getMessage(), is("Validation Failed: 1: Field [input] cannot be an empty array;"));
}
public void testValidation_TextEmbedding_WithReturnDocument() {
InferenceAction.Request inputRequest = new InferenceAction.Request(
TaskType.TEXT_EMBEDDING,
"model",
null,
Boolean.TRUE,
null,
List.of("input"),
null,
null,
null,
false
);
ActionRequestValidationException inputError = inputRequest.validate();
assertNotNull(inputError);
assertThat(
inputError.getMessage(),
is("Validation Failed: 1: Field [return_documents] cannot be specified for task type [text_embedding];")
);
}
public void testValidation_TextEmbedding_WithTopN() {
InferenceAction.Request inputRequest = new InferenceAction.Request(
TaskType.TEXT_EMBEDDING,
"model",
null,
null,
12,
List.of("input"),
null,
null,
null,
false
);
ActionRequestValidationException inputError = inputRequest.validate();
assertNotNull(inputError);
assertThat(inputError.getMessage(), is("Validation Failed: 1: Field [top_n] cannot be specified for task type [text_embedding];"));
}
public void testValidation_Rerank_Null() {
InferenceAction.Request queryNullRequest = new InferenceAction.Request(
TaskType.RERANK,
"model",
null,
null,
null,
List.of("input"),
null,
null,
@ -163,6 +214,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
TaskType.RERANK,
"model",
"",
null,
null,
List.of("input"),
null,
null,
@ -179,6 +232,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
TaskType.RERANK,
"model",
"query",
null,
null,
List.of("input"),
null,
InputType.SEARCH,
@ -195,6 +250,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
TaskType.SPARSE_EMBEDDING,
"model",
"",
null,
null,
List.of("input"),
null,
InputType.SEARCH,
@ -209,11 +266,56 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
);
}
public void testValidation_SparseEmbedding_WithReturnDocument() {
InferenceAction.Request queryRequest = new InferenceAction.Request(
TaskType.SPARSE_EMBEDDING,
"model",
"",
Boolean.FALSE,
null,
List.of("input"),
null,
null,
null,
false
);
ActionRequestValidationException queryError = queryRequest.validate();
assertNotNull(queryError);
assertThat(
queryError.getMessage(),
is("Validation Failed: 1: Field [return_documents] cannot be specified for task type [sparse_embedding];")
);
}
public void testValidation_SparseEmbedding_WithTopN() {
InferenceAction.Request queryRequest = new InferenceAction.Request(
TaskType.SPARSE_EMBEDDING,
"model",
"",
null,
22,
List.of("input"),
null,
null,
null,
false
);
ActionRequestValidationException queryError = queryRequest.validate();
assertNotNull(queryError);
assertThat(
queryError.getMessage(),
is("Validation Failed: 1: Field [top_n] cannot be specified for task type [sparse_embedding];")
);
}
public void testValidation_Completion_WithInputType() {
InferenceAction.Request queryRequest = new InferenceAction.Request(
TaskType.COMPLETION,
"model",
"",
null,
null,
List.of("input"),
null,
InputType.SEARCH,
@ -225,11 +327,52 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
assertThat(queryError.getMessage(), is("Validation Failed: 1: Field [input_type] cannot be specified for task type [completion];"));
}
public void testValidation_Completion_WithReturnDocuments() {
InferenceAction.Request queryRequest = new InferenceAction.Request(
TaskType.COMPLETION,
"model",
"",
Boolean.TRUE,
null,
List.of("input"),
null,
null,
null,
false
);
ActionRequestValidationException queryError = queryRequest.validate();
assertNotNull(queryError);
assertThat(
queryError.getMessage(),
is("Validation Failed: 1: Field [return_documents] cannot be specified for task type [completion];")
);
}
public void testValidation_Completion_WithTopN() {
InferenceAction.Request queryRequest = new InferenceAction.Request(
TaskType.COMPLETION,
"model",
"",
null,
77,
List.of("input"),
null,
null,
null,
false
);
ActionRequestValidationException queryError = queryRequest.validate();
assertNotNull(queryError);
assertThat(queryError.getMessage(), is("Validation Failed: 1: Field [top_n] cannot be specified for task type [completion];"));
}
public void testValidation_ChatCompletion_WithInputType() {
InferenceAction.Request queryRequest = new InferenceAction.Request(
TaskType.CHAT_COMPLETION,
"model",
"",
null,
null,
List.of("input"),
null,
InputType.SEARCH,
@ -244,6 +387,45 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
);
}
public void testValidation_ChatCompletion_WithReturnDocuments() {
InferenceAction.Request queryRequest = new InferenceAction.Request(
TaskType.CHAT_COMPLETION,
"model",
"",
Boolean.TRUE,
null,
List.of("input"),
null,
null,
null,
false
);
ActionRequestValidationException queryError = queryRequest.validate();
assertNotNull(queryError);
assertThat(
queryError.getMessage(),
is("Validation Failed: 1: Field [return_documents] cannot be specified for task type [chat_completion];")
);
}
public void testValidation_ChatCompletion_WithTopN() {
InferenceAction.Request queryRequest = new InferenceAction.Request(
TaskType.CHAT_COMPLETION,
"model",
"",
null,
11,
List.of("input"),
null,
InputType.SEARCH,
null,
false
);
ActionRequestValidationException queryError = queryRequest.validate();
assertNotNull(queryError);
assertThat(queryError.getMessage(), is("Validation Failed: 1: Field [top_n] cannot be specified for task type [chat_completion];"));
}
public void testParseRequest_DefaultsInputTypeToIngest() throws IOException {
String singleInputRequest = """
{
@ -271,6 +453,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
nextTask,
instance.getInferenceEntityId(),
instance.getQuery(),
instance.getReturnDocuments(),
instance.getTopN(),
instance.getInput(),
instance.getTaskSettings(),
instance.getInputType(),
@ -283,6 +467,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
instance.getTaskType(),
instance.getInferenceEntityId() + "foo",
instance.getQuery(),
instance.getReturnDocuments(),
instance.getTopN(),
instance.getInput(),
instance.getTaskSettings(),
instance.getInputType(),
@ -297,6 +483,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
instance.getTaskType(),
instance.getInferenceEntityId(),
instance.getQuery(),
instance.getReturnDocuments(),
instance.getTopN(),
changedInputs,
instance.getTaskSettings(),
instance.getInputType(),
@ -317,6 +505,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
instance.getTaskType(),
instance.getInferenceEntityId(),
instance.getQuery(),
instance.getReturnDocuments(),
instance.getTopN(),
instance.getInput(),
taskSettings,
instance.getInputType(),
@ -331,6 +521,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
instance.getTaskType(),
instance.getInferenceEntityId(),
instance.getQuery(),
instance.getReturnDocuments(),
instance.getTopN(),
instance.getInput(),
instance.getTaskSettings(),
nextInputType,
@ -343,6 +535,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
instance.getTaskType(),
instance.getInferenceEntityId(),
instance.getQuery() == null ? randomAlphaOfLength(10) : instance.getQuery() + randomAlphaOfLength(1),
instance.getReturnDocuments(),
instance.getTopN(),
instance.getInput(),
instance.getTaskSettings(),
instance.getInputType(),
@ -360,6 +554,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
instance.getTaskType(),
instance.getInferenceEntityId(),
instance.getQuery(),
instance.getReturnDocuments(),
instance.getTopN(),
instance.getInput(),
instance.getTaskSettings(),
instance.getInputType(),
@ -374,6 +570,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
instance.getTaskType(),
instance.getInferenceEntityId(),
instance.getQuery(),
instance.getReturnDocuments(),
instance.getTopN(),
instance.getInput(),
instance.getTaskSettings(),
instance.getInputType(),
@ -395,6 +593,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
instance.getTaskType(),
instance.getInferenceEntityId(),
null,
null,
null,
instance.getInput().subList(0, 1),
instance.getTaskSettings(),
InputType.UNSPECIFIED,
@ -406,6 +606,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
instance.getTaskType(),
instance.getInferenceEntityId(),
null,
null,
null,
instance.getInput(),
instance.getTaskSettings(),
InputType.UNSPECIFIED,
@ -420,6 +622,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
instance.getTaskType(),
instance.getInferenceEntityId(),
null,
null,
null,
instance.getInput(),
instance.getTaskSettings(),
InputType.INGEST,
@ -432,6 +636,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
instance.getTaskType(),
instance.getInferenceEntityId(),
null,
null,
null,
instance.getInput(),
instance.getTaskSettings(),
InputType.UNSPECIFIED,
@ -443,6 +649,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
instance.getTaskType(),
instance.getInferenceEntityId(),
null,
null,
null,
instance.getInput(),
instance.getTaskSettings(),
instance.getInputType(),
@ -455,6 +663,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
instance.getTaskType(),
instance.getInferenceEntityId(),
instance.getQuery(),
null,
null,
instance.getInput(),
instance.getTaskSettings(),
instance.getInputType(),
@ -462,9 +672,24 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
false,
InferenceContext.EMPTY_INSTANCE
);
} else {
mutated = instance;
}
} else if (version.before(TransportVersions.RERANK_COMMON_OPTIONS_ADDED)
&& version.isPatchFrom(TransportVersions.RERANK_COMMON_OPTIONS_ADDED_8_19) == false) {
mutated = new InferenceAction.Request(
instance.getTaskType(),
instance.getInferenceEntityId(),
instance.getQuery(),
null,
null,
instance.getInput(),
instance.getTaskSettings(),
instance.getInputType(),
instance.getInferenceTimeout(),
false,
instance.getContext()
);
} else {
mutated = instance;
}
// We always assume that a request has been rerouted, if it came from a node without adaptive rate limiting
if (version.before(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)) {
@ -481,6 +706,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
TaskType.TEXT_EMBEDDING,
"model",
null,
null,
null,
List.of(),
Map.of(),
InputType.UNSPECIFIED,
@ -503,6 +730,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
TaskType.TEXT_EMBEDDING,
"model",
null,
null,
null,
List.of(),
Map.of(),
InputType.INGEST,
@ -525,6 +754,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
TaskType.TEXT_EMBEDDING,
"model",
null,
null,
null,
List.of("input"),
Map.of(),
InputType.UNSPECIFIED,
@ -548,6 +779,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
TaskType.TEXT_EMBEDDING,
"model",
null,
null,
null,
List.of("input"),
Map.of(),
InputType.UNSPECIFIED,

View file

@ -110,6 +110,8 @@ public class TestDenseInferenceServiceExtension implements InferenceServiceExten
public void infer(
Model model,
@Nullable String query,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
List<String> input,
boolean stream,
Map<String, Object> taskSettings,

View file

@ -102,6 +102,8 @@ public class TestRerankingServiceExtension implements InferenceServiceExtension
public void infer(
Model model,
@Nullable String query,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
List<String> input,
boolean stream,
Map<String, Object> taskSettings,

View file

@ -103,6 +103,8 @@ public class TestSparseInferenceServiceExtension implements InferenceServiceExte
public void infer(
Model model,
@Nullable String query,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
List<String> input,
boolean stream,
Map<String, Object> taskSettings,

View file

@ -15,6 +15,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceServiceConfiguration;
@ -103,6 +104,8 @@ public class TestStreamingCompletionServiceExtension implements InferenceService
public void infer(
Model model,
String query,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
List<String> input,
boolean stream,
Map<String, Object> taskSettings,

View file

@ -77,6 +77,8 @@ public class TransportInferenceAction extends BaseTransportInferenceAction<Infer
service.infer(
model,
request.getQuery(),
request.getReturnDocuments(),
request.getTopN(),
request.getInput(),
request.isStreaming(),
request.getTaskSettings(),

View file

@ -75,7 +75,13 @@ public class VoyageAIActionCreator implements VoyageAIActionVisitor {
serviceComponents.threadPool(),
overriddenModel,
RERANK_HANDLER,
(rerankInput) -> new VoyageAIRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model),
(rerankInput) -> new VoyageAIRerankRequest(
rerankInput.getQuery(),
rerankInput.getChunks(),
rerankInput.getReturnDocuments(),
rerankInput.getTopN(),
model
),
QueryAndDocsInputs.class
);

View file

@ -69,6 +69,8 @@ public class AlibabaCloudSearchRerankRequestManager extends AlibabaCloudSearchRe
account,
rerankInput.getQuery(),
rerankInput.getChunks(),
rerankInput.getReturnDocuments(),
rerankInput.getTopN(),
model
);

View file

@ -49,7 +49,13 @@ public class CohereRerankRequestManager extends CohereRequestManager {
ActionListener<InferenceServiceResults> listener
) {
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
CohereRerankRequest request = new CohereRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model);
CohereRerankRequest request = new CohereRerankRequest(
rerankInput.getQuery(),
rerankInput.getChunks(),
rerankInput.getReturnDocuments(),
rerankInput.getTopN(),
model
);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}

View file

@ -62,7 +62,13 @@ public class GoogleVertexAiRerankRequestManager extends GoogleVertexAiRequestMan
ActionListener<InferenceServiceResults> listener
) {
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
GoogleVertexAiRerankRequest request = new GoogleVertexAiRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model);
GoogleVertexAiRerankRequest request = new GoogleVertexAiRerankRequest(
rerankInput.getQuery(),
rerankInput.getChunks(),
rerankInput.getReturnDocuments(),
rerankInput.getTopN(),
model
);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}

View file

@ -49,7 +49,13 @@ public class JinaAIRerankRequestManager extends JinaAIRequestManager {
ActionListener<InferenceServiceResults> listener
) {
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
JinaAIRerankRequest request = new JinaAIRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model);
JinaAIRerankRequest request = new JinaAIRerankRequest(
rerankInput.getQuery(),
rerankInput.getChunks(),
rerankInput.getReturnDocuments(),
rerankInput.getTopN(),
model
);
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}

View file

@ -7,6 +7,8 @@
package org.elasticsearch.xpack.inference.external.http.sender;
import org.elasticsearch.core.Nullable;
import java.util.List;
import java.util.Objects;
@ -22,15 +24,25 @@ public class QueryAndDocsInputs extends InferenceInputs {
private final String query;
private final List<String> chunks;
private final Boolean returnDocuments;
private final Integer topN;
public QueryAndDocsInputs(String query, List<String> chunks) {
this(query, chunks, false);
this(query, chunks, null, null, false);
}
public QueryAndDocsInputs(String query, List<String> chunks, boolean stream) {
public QueryAndDocsInputs(
String query,
List<String> chunks,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
boolean stream
) {
super(stream);
this.query = Objects.requireNonNull(query);
this.chunks = Objects.requireNonNull(chunks);
this.returnDocuments = returnDocuments;
this.topN = topN;
}
public String getQuery() {
@ -41,6 +53,14 @@ public class QueryAndDocsInputs extends InferenceInputs {
return chunks;
}
public Boolean getReturnDocuments() {
return returnDocuments;
}
public Integer getTopN() {
return topN;
}
public int inputSize() {
return chunks.size();
}

View file

@ -12,6 +12,7 @@ import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.common.Strings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
@ -32,6 +33,8 @@ public class AlibabaCloudSearchRerankRequest implements Request {
private final AlibabaCloudSearchAccount account;
private final String query;
private final List<String> input;
private final Boolean returnDocuments;
private final Integer topN;
private final URI uri;
private final AlibabaCloudSearchRerankTaskSettings taskSettings;
private final String model;
@ -44,6 +47,8 @@ public class AlibabaCloudSearchRerankRequest implements Request {
AlibabaCloudSearchAccount account,
String query,
List<String> input,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
AlibabaCloudSearchRerankModel rerankModel
) {
Objects.requireNonNull(rerankModel);
@ -51,6 +56,8 @@ public class AlibabaCloudSearchRerankRequest implements Request {
this.account = Objects.requireNonNull(account);
this.query = Objects.requireNonNull(query);
this.input = Objects.requireNonNull(input);
this.returnDocuments = returnDocuments;
this.topN = topN;
taskSettings = rerankModel.getTaskSettings();
model = rerankModel.getServiceSettings().getCommonSettings().modelId();
host = rerankModel.getServiceSettings().getCommonSettings().getHost();
@ -67,7 +74,8 @@ public class AlibabaCloudSearchRerankRequest implements Request {
HttpPost httpPost = new HttpPost(uri);
ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(new AlibabaCloudSearchRerankRequestEntity(query, input, taskSettings)).getBytes(StandardCharsets.UTF_8)
Strings.toString(new AlibabaCloudSearchRerankRequestEntity(query, input, returnDocuments, topN, taskSettings))
.getBytes(StandardCharsets.UTF_8)
);
httpPost.setEntity(byteEntity);

View file

@ -7,6 +7,7 @@
package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankTaskSettings;
@ -15,9 +16,13 @@ import java.io.IOException;
import java.util.List;
import java.util.Objects;
public record AlibabaCloudSearchRerankRequestEntity(String query, List<String> input, AlibabaCloudSearchRerankTaskSettings taskSettings)
implements
ToXContentObject {
public record AlibabaCloudSearchRerankRequestEntity(
String query,
List<String> input,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
AlibabaCloudSearchRerankTaskSettings taskSettings
) implements ToXContentObject {
private static final String SEARCH_QUERY = "query";
private static final String TEXTS_FIELD = "docs";

View file

@ -11,6 +11,7 @@ import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.common.Strings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xpack.inference.external.cohere.CohereAccount;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
@ -28,16 +29,26 @@ public class CohereRerankRequest extends CohereRequest {
private final CohereAccount account;
private final String query;
private final List<String> input;
private final Boolean returnDocuments;
private final Integer topN;
private final CohereRerankTaskSettings taskSettings;
private final String model;
private final String inferenceEntityId;
public CohereRerankRequest(String query, List<String> input, CohereRerankModel model) {
public CohereRerankRequest(
String query,
List<String> input,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
CohereRerankModel model
) {
Objects.requireNonNull(model);
this.account = CohereAccount.of(model, CohereRerankRequest::buildDefaultUri);
this.input = Objects.requireNonNull(input);
this.query = Objects.requireNonNull(query);
this.returnDocuments = returnDocuments;
this.topN = topN;
taskSettings = model.getTaskSettings();
this.model = model.getServiceSettings().modelId();
inferenceEntityId = model.getInferenceEntityId();
@ -48,7 +59,8 @@ public class CohereRerankRequest extends CohereRequest {
HttpPost httpPost = new HttpPost(account.uri());
ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(new CohereRerankRequestEntity(query, input, taskSettings, model)).getBytes(StandardCharsets.UTF_8)
Strings.toString(new CohereRerankRequestEntity(query, input, returnDocuments, topN, taskSettings, model))
.getBytes(StandardCharsets.UTF_8)
);
httpPost.setEntity(byteEntity);

View file

@ -7,6 +7,7 @@
package org.elasticsearch.xpack.inference.external.request.cohere;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
@ -15,9 +16,14 @@ import java.io.IOException;
import java.util.List;
import java.util.Objects;
public record CohereRerankRequestEntity(String model, String query, List<String> documents, CohereRerankTaskSettings taskSettings)
implements
ToXContentObject {
public record CohereRerankRequestEntity(
String model,
String query,
List<String> documents,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
CohereRerankTaskSettings taskSettings
) implements ToXContentObject {
private static final String DOCUMENTS_FIELD = "documents";
private static final String QUERY_FIELD = "query";
@ -29,8 +35,15 @@ public record CohereRerankRequestEntity(String model, String query, List<String>
Objects.requireNonNull(taskSettings);
}
public CohereRerankRequestEntity(String query, List<String> input, CohereRerankTaskSettings taskSettings, String model) {
this(model, query, input, taskSettings);
public CohereRerankRequestEntity(
String query,
List<String> input,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
CohereRerankTaskSettings taskSettings,
String model
) {
this(model, query, input, returnDocuments, topN, taskSettings);
}
@Override
@ -41,11 +54,17 @@ public record CohereRerankRequestEntity(String model, String query, List<String>
builder.field(QUERY_FIELD, query);
builder.field(DOCUMENTS_FIELD, documents);
if (taskSettings.getDoesReturnDocuments() != null) {
// prefer the root level return_documents over task settings
if (returnDocuments != null) {
builder.field(CohereRerankTaskSettings.RETURN_DOCUMENTS, returnDocuments);
} else if (taskSettings.getDoesReturnDocuments() != null) {
builder.field(CohereRerankTaskSettings.RETURN_DOCUMENTS, taskSettings.getDoesReturnDocuments());
}
if (taskSettings.getTopNDocumentsOnly() != null) {
// prefer the root level top_n over task settings
if (topN != null) {
builder.field(CohereRerankTaskSettings.TOP_N_DOCS_ONLY, topN);
} else if (taskSettings.getTopNDocumentsOnly() != null) {
builder.field(CohereRerankTaskSettings.TOP_N_DOCS_ONLY, taskSettings.getTopNDocumentsOnly());
}

View file

@ -11,6 +11,7 @@ import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.common.Strings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
@ -29,10 +30,22 @@ public class GoogleVertexAiRerankRequest implements GoogleVertexAiRequest {
private final List<String> input;
public GoogleVertexAiRerankRequest(String query, List<String> input, GoogleVertexAiRerankModel model) {
private final Boolean returnDocuments;
private final Integer topN;
public GoogleVertexAiRerankRequest(
String query,
List<String> input,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
GoogleVertexAiRerankModel model
) {
this.model = Objects.requireNonNull(model);
this.query = Objects.requireNonNull(query);
this.input = Objects.requireNonNull(input);
this.returnDocuments = returnDocuments;
this.topN = topN;
}
@Override
@ -41,7 +54,13 @@ public class GoogleVertexAiRerankRequest implements GoogleVertexAiRequest {
ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(
new GoogleVertexAiRerankRequestEntity(query, input, model.getServiceSettings().modelId(), model.getTaskSettings().topN())
new GoogleVertexAiRerankRequestEntity(
query,
input,
returnDocuments,
topN != null ? topN : model.getTaskSettings().topN(),
model.getServiceSettings().modelId()
)
).getBytes(StandardCharsets.UTF_8)
);

View file

@ -15,9 +15,13 @@ import java.io.IOException;
import java.util.List;
import java.util.Objects;
public record GoogleVertexAiRerankRequestEntity(String query, List<String> inputs, @Nullable String model, @Nullable Integer topN)
implements
ToXContentObject {
public record GoogleVertexAiRerankRequestEntity(
String query,
List<String> inputs,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
@Nullable String model
) implements ToXContentObject {
private static final String MODEL_FIELD = "model";
private static final String QUERY_FIELD = "query";
@ -26,6 +30,7 @@ public record GoogleVertexAiRerankRequestEntity(String query, List<String> input
private static final String CONTENT_FIELD = "content";
private static final String TOP_N_FIELD = "topN";
private static final String IGNORE_RECORD_DETAILS_IN_RESPONSE_FIELD = "ignoreRecordDetailsInResponse";
public GoogleVertexAiRerankRequestEntity {
Objects.requireNonNull(query);
@ -57,10 +62,16 @@ public record GoogleVertexAiRerankRequestEntity(String query, List<String> input
builder.endArray();
// prefer the root level top_n over task settings
if (topN != null) {
builder.field(TOP_N_FIELD, topN);
}
if (returnDocuments != null) {
// if returnDocuments = true, we do not want to ignore record details
builder.field(IGNORE_RECORD_DETAILS_IN_RESPONSE_FIELD, returnDocuments == Boolean.TRUE ? Boolean.FALSE : Boolean.TRUE);
}
builder.endObject();
return builder;

View file

@ -11,6 +11,7 @@ import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.common.Strings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xpack.inference.external.jinaai.JinaAIAccount;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
@ -28,16 +29,26 @@ public class JinaAIRerankRequest extends JinaAIRequest {
private final JinaAIAccount account;
private final String query;
private final List<String> input;
private final Boolean returnDocuments;
private final Integer topN;
private final JinaAIRerankTaskSettings taskSettings;
private final String model;
private final String inferenceEntityId;
public JinaAIRerankRequest(String query, List<String> input, JinaAIRerankModel model) {
public JinaAIRerankRequest(
String query,
List<String> input,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
JinaAIRerankModel model
) {
Objects.requireNonNull(model);
this.account = JinaAIAccount.of(model, JinaAIRerankRequest::buildDefaultUri);
this.input = Objects.requireNonNull(input);
this.query = Objects.requireNonNull(query);
this.returnDocuments = returnDocuments;
this.topN = topN;
taskSettings = model.getTaskSettings();
this.model = model.getServiceSettings().modelId();
inferenceEntityId = model.getInferenceEntityId();
@ -48,7 +59,8 @@ public class JinaAIRerankRequest extends JinaAIRequest {
HttpPost httpPost = new HttpPost(account.uri());
ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(new JinaAIRerankRequestEntity(query, input, taskSettings, model)).getBytes(StandardCharsets.UTF_8)
Strings.toString(new JinaAIRerankRequestEntity(query, input, returnDocuments, topN, taskSettings, model))
.getBytes(StandardCharsets.UTF_8)
);
httpPost.setEntity(byteEntity);

View file

@ -7,6 +7,7 @@
package org.elasticsearch.xpack.inference.external.request.jinaai;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings;
@ -15,9 +16,14 @@ import java.io.IOException;
import java.util.List;
import java.util.Objects;
public record JinaAIRerankRequestEntity(String model, String query, List<String> documents, JinaAIRerankTaskSettings taskSettings)
implements
ToXContentObject {
public record JinaAIRerankRequestEntity(
String model,
String query,
List<String> documents,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
JinaAIRerankTaskSettings taskSettings
) implements ToXContentObject {
private static final String DOCUMENTS_FIELD = "documents";
private static final String QUERY_FIELD = "query";
@ -30,8 +36,15 @@ public record JinaAIRerankRequestEntity(String model, String query, List<String>
Objects.requireNonNull(taskSettings);
}
public JinaAIRerankRequestEntity(String query, List<String> input, JinaAIRerankTaskSettings taskSettings, String model) {
this(model, query, input, taskSettings != null ? taskSettings : JinaAIRerankTaskSettings.EMPTY_SETTINGS);
public JinaAIRerankRequestEntity(
String query,
List<String> input,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
JinaAIRerankTaskSettings taskSettings,
String model
) {
this(model, query, input, returnDocuments, topN, taskSettings != null ? taskSettings : JinaAIRerankTaskSettings.EMPTY_SETTINGS);
}
@Override
@ -42,13 +55,18 @@ public record JinaAIRerankRequestEntity(String model, String query, List<String>
builder.field(QUERY_FIELD, query);
builder.field(DOCUMENTS_FIELD, documents);
if (taskSettings.getTopNDocumentsOnly() != null) {
// prefer the root level top_n over task settings
if (topN != null) {
builder.field(JinaAIRerankTaskSettings.TOP_N_DOCS_ONLY, topN);
} else if (taskSettings.getTopNDocumentsOnly() != null) {
builder.field(JinaAIRerankTaskSettings.TOP_N_DOCS_ONLY, taskSettings.getTopNDocumentsOnly());
}
var return_documents = taskSettings.getDoesReturnDocuments();
if (return_documents != null) {
builder.field(JinaAIRerankTaskSettings.RETURN_DOCUMENTS, return_documents);
// prefer the root level return_documents over task settings
if (returnDocuments != null) {
builder.field(JinaAIRerankTaskSettings.RETURN_DOCUMENTS, returnDocuments);
} else if (taskSettings.getDoesReturnDocuments() != null) {
builder.field(JinaAIRerankTaskSettings.RETURN_DOCUMENTS, taskSettings.getDoesReturnDocuments());
}
builder.endObject();

View file

@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference.external.request.voyageai;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.common.Strings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;
@ -23,13 +24,23 @@ public class VoyageAIRerankRequest extends VoyageAIRequest {
private final String query;
private final List<String> input;
private final Boolean returnDocuments;
private final Integer topN;
private final VoyageAIRerankModel model;
public VoyageAIRerankRequest(String query, List<String> input, VoyageAIRerankModel model) {
public VoyageAIRerankRequest(
String query,
List<String> input,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
VoyageAIRerankModel model
) {
this.model = Objects.requireNonNull(model);
this.input = Objects.requireNonNull(input);
this.query = Objects.requireNonNull(query);
this.returnDocuments = returnDocuments;
this.topN = topN;
}
@Override
@ -37,8 +48,16 @@ public class VoyageAIRerankRequest extends VoyageAIRequest {
HttpPost httpPost = new HttpPost(model.uri());
ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(new VoyageAIRerankRequestEntity(query, input, model.getTaskSettings(), model.getServiceSettings().modelId()))
.getBytes(StandardCharsets.UTF_8)
Strings.toString(
new VoyageAIRerankRequestEntity(
query,
input,
returnDocuments,
topN,
model.getTaskSettings(),
model.getServiceSettings().modelId()
)
).getBytes(StandardCharsets.UTF_8)
);
httpPost.setEntity(byteEntity);

View file

@ -7,6 +7,7 @@
package org.elasticsearch.xpack.inference.external.request.voyageai;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings;
@ -15,15 +16,19 @@ import java.io.IOException;
import java.util.List;
import java.util.Objects;
public record VoyageAIRerankRequestEntity(String model, String query, List<String> documents, VoyageAIRerankTaskSettings taskSettings)
implements
ToXContentObject {
public record VoyageAIRerankRequestEntity(
String model,
String query,
List<String> documents,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
VoyageAIRerankTaskSettings taskSettings
) implements ToXContentObject {
private static final String DOCUMENTS_FIELD = "documents";
private static final String QUERY_FIELD = "query";
private static final String MODEL_FIELD = "model";
public static final String TRUNCATION_FIELD = "truncation";
public static final String RETURN_DOCUMENTS_FIELD = "return_documents";
public VoyageAIRerankRequestEntity {
Objects.requireNonNull(query);
@ -32,8 +37,15 @@ public record VoyageAIRerankRequestEntity(String model, String query, List<Strin
Objects.requireNonNull(taskSettings);
}
public VoyageAIRerankRequestEntity(String query, List<String> input, VoyageAIRerankTaskSettings taskSettings, String model) {
this(model, query, input, taskSettings != null ? taskSettings : VoyageAIRerankTaskSettings.EMPTY_SETTINGS);
public VoyageAIRerankRequestEntity(
String query,
List<String> input,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
VoyageAIRerankTaskSettings taskSettings,
String model
) {
this(model, query, input, returnDocuments, topN, taskSettings != null ? taskSettings : VoyageAIRerankTaskSettings.EMPTY_SETTINGS);
}
@Override
@ -44,11 +56,17 @@ public record VoyageAIRerankRequestEntity(String model, String query, List<Strin
builder.field(QUERY_FIELD, query);
builder.field(DOCUMENTS_FIELD, documents);
if (taskSettings.getDoesReturnDocuments() != null) {
// prefer the root level return_documents over task settings
if (returnDocuments != null) {
builder.field(VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, returnDocuments);
} else if (taskSettings.getDoesReturnDocuments() != null) {
builder.field(VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, taskSettings.getDoesReturnDocuments());
}
if (taskSettings.getTopKDocumentsOnly() != null) {
// prefer the root level top_n over task settings
if (topN != null) {
builder.field(VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, topN);
} else if (taskSettings.getTopKDocumentsOnly() != null) {
builder.field(VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, taskSettings.getTopKDocumentsOnly());
}

View file

@ -103,10 +103,6 @@ public class GoogleVertexAiRerankResponseEntity {
return parseList(parser, (listParser, index) -> {
var parsedRankedDoc = RankedDoc.parse(parser);
if (parsedRankedDoc.content == null) {
throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDoc.CONTENT.getPreferredName()));
}
if (parsedRankedDoc.score == null) {
throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDoc.SCORE.getPreferredName()));
}

View file

@ -232,6 +232,8 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
TaskType.ANY,
inferenceId,
null,
null,
null,
List.of(query),
Map.of(),
InputType.INTERNAL_SEARCH,

View file

@ -153,6 +153,8 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContext extends RankFe
TaskType.RERANK,
inferenceId,
inferenceText,
null,
null,
docFeatures,
Map.of(),
InputType.INTERNAL_SEARCH,

View file

@ -60,6 +60,8 @@ public abstract class SenderService implements InferenceService {
public void infer(
Model model,
@Nullable String query,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
List<String> input,
boolean stream,
Map<String, Object> taskSettings,
@ -68,7 +70,7 @@ public abstract class SenderService implements InferenceService {
ActionListener<InferenceServiceResults> listener
) {
init();
var inferenceInput = createInput(this, model, input, inputType, query, stream);
var inferenceInput = createInput(this, model, input, inputType, query, returnDocuments, topN, stream);
doInfer(model, inferenceInput, taskSettings, timeout, listener);
}
@ -78,11 +80,20 @@ public abstract class SenderService implements InferenceService {
List<String> input,
InputType inputType,
@Nullable String query,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
boolean stream
) {
return switch (model.getTaskType()) {
case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(input, stream);
case RERANK -> new QueryAndDocsInputs(query, input, stream);
case RERANK -> {
ValidationException validationException = new ValidationException();
service.validateRerankParameters(returnDocuments, topN, validationException);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
yield new QueryAndDocsInputs(query, input, returnDocuments, topN, stream);
}
case TEXT_EMBEDDING, SPARSE_EMBEDDING -> {
ValidationException validationException = new ValidationException();
service.validateInputType(inputType, model, validationException);
@ -141,6 +152,8 @@ public abstract class SenderService implements InferenceService {
protected abstract void validateInputType(InputType inputType, Model model, ValidationException validationException);
protected void validateRerankParameters(Boolean returnDocuments, Integer topN, ValidationException validationException) {}
protected abstract void doUnifiedCompletionInfer(
Model model,
UnifiedChatInput inputs,

View file

@ -735,6 +735,8 @@ public final class ServiceUtils {
service.infer(
model,
null,
null,
null,
List.of(TEST_EMBEDDING_INPUT),
false,
Map.of(),

View file

@ -14,6 +14,7 @@ import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
@ -300,6 +301,24 @@ public class AlibabaCloudSearchService extends SenderService {
ServiceUtils.validateInputTypeAgainstAllowlist(inputType, VALID_INPUT_TYPE_VALUES, SERVICE_NAME, validationException);
}
@Override
protected void validateRerankParameters(Boolean returnDocuments, Integer topN, ValidationException validationException) {
if (returnDocuments != null) {
validationException.addValidationError(
Strings.format(
"Invalid return_documents [%s]. The return_documents option is not supported by this service",
returnDocuments
)
);
}
if (topN != null) {
validationException.addValidationError(
Strings.format("Invalid top_n [%s]. The top_n option is not supported by this service", topN)
);
}
}
@Override
protected void doChunkedInfer(
Model model,

View file

@ -620,6 +620,8 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
public void infer(
Model model,
@Nullable String query,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
List<String> input,
boolean stream,
Map<String, Object> taskSettings,
@ -632,7 +634,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
inferTextEmbedding(esModel, input, inputType, timeout, listener);
} else if (TaskType.RERANK.equals(taskType)) {
inferRerank(esModel, query, input, inputType, timeout, taskSettings, listener);
inferRerank(esModel, query, input, returnDocuments, topN, inputType, timeout, taskSettings, listener);
} else if (TaskType.SPARSE_EMBEDDING.equals(taskType)) {
inferSparseEmbedding(esModel, input, inputType, timeout, listener);
} else {
@ -693,6 +695,8 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
ElasticsearchInternalModel model,
String query,
List<String> inputs,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
InputType inputType,
TimeValue timeout,
Map<String, Object> requestTaskSettings,
@ -701,7 +705,9 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
var request = buildInferenceRequest(model.mlNodeDeploymentId(), new TextSimilarityConfigUpdate(query), inputs, inputType, timeout);
var returnDocs = Boolean.TRUE;
if (model.getTaskSettings() instanceof RerankTaskSettings modelSettings) {
if (returnDocuments != null) {
returnDocs = returnDocuments;
} else if (model.getTaskSettings() instanceof RerankTaskSettings modelSettings) {
var requestSettings = RerankTaskSettings.fromMap(requestTaskSettings);
returnDocs = RerankTaskSettings.of(modelSettings, requestSettings).returnDocuments();
}
@ -709,7 +715,9 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
Function<Integer, String> inputSupplier = returnDocs == Boolean.TRUE ? inputs::get : i -> null;
ActionListener<InferModelAction.Response> mlResultsListener = listener.delegateFailureAndWrap(
(l, inferenceResult) -> l.onResponse(textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier))
(l, inferenceResult) -> l.onResponse(
textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier, topN)
)
);
var maybeDeployListener = mlResultsListener.delegateResponse(
@ -824,7 +832,8 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
private RankedDocsResults textSimilarityResultsToRankedDocs(
List<? extends InferenceResults> results,
Function<Integer, String> inputSupplier
Function<Integer, String> inputSupplier,
@Nullable Integer topN
) {
List<RankedDocsResults.RankedDoc> rankings = new ArrayList<>(results.size());
for (int i = 0; i < results.size(); i++) {
@ -851,7 +860,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
}
Collections.sort(rankings);
return new RankedDocsResults(rankings);
return new RankedDocsResults(topN != null ? rankings.subList(0, topN) : rankings);
}
public List<DefaultConfigId> defaultConfigIds() {

View file

@ -30,6 +30,8 @@ public class SimpleServiceIntegrationValidator implements ServiceIntegrationVali
service.infer(
model,
model.getTaskType().equals(TaskType.RERANK) ? QUERY : null,
null,
null,
TEST_INPUT,
false,
Map.of(),

View file

@ -423,9 +423,9 @@ public abstract class BaseTransportInferenceActionTestCase<Request extends BaseI
when(service.canStream(any())).thenReturn(stream);
when(service.supportedStreamingTasks()).thenReturn(supportedStreamingTasks);
doAnswer(ans -> {
listenerAction.accept(ans.getArgument(7));
listenerAction.accept(ans.getArgument(9));
return null;
}).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());
}).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any());
doAnswer(ans -> {
listenerAction.accept(ans.getArgument(3));
return null;

View file

@ -23,7 +23,7 @@ public class InferenceInputsTests extends ESTestCase {
var emptyRequest = new UnifiedCompletionRequest(List.of(), null, null, null, null, null, null, null);
assertThat(new UnifiedChatInput(emptyRequest, false).castTo(UnifiedChatInput.class), Matchers.instanceOf(UnifiedChatInput.class));
assertThat(
new QueryAndDocsInputs("hello", List.of(), false).castTo(QueryAndDocsInputs.class),
new QueryAndDocsInputs("hello", List.of(), Boolean.TRUE, 33, false).castTo(QueryAndDocsInputs.class),
Matchers.instanceOf(QueryAndDocsInputs.class)
);
}

View file

@ -22,7 +22,13 @@ import static org.hamcrest.CoreMatchers.is;
public class AlibabaCloudSearchRerankRequestEntityTests extends ESTestCase {
public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException {
var entity = new AlibabaCloudSearchRerankRequestEntity("query", List.of("abc"), new AlibabaCloudSearchRerankTaskSettings());
var entity = new AlibabaCloudSearchRerankRequestEntity(
"query",
List.of("abc"),
Boolean.TRUE,
22,
new AlibabaCloudSearchRerankTaskSettings()
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);

View file

@ -0,0 +1,95 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.request.cohere;
import org.elasticsearch.common.Strings;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
import org.hamcrest.MatcherAssert;
import java.io.IOException;
import java.util.List;
import static org.hamcrest.CoreMatchers.is;
public class CohereRerankRequestEntityTests extends ESTestCase {
public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException {
var entity = new CohereRerankRequestEntity(
"query",
List.of("abc"),
Boolean.TRUE,
22,
new CohereRerankTaskSettings(null, null, 3),
"model"
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
MatcherAssert.assertThat(xContentResult, is("""
{"model":"model","query":"query","documents":["abc"],"return_documents":true,"top_n":22,"max_chunks_per_doc":3}"""));
}
public void testXContent_WritesMinimalFields() throws IOException {
var entity = new CohereRerankRequestEntity(
"query",
List.of("abc"),
null,
null,
new CohereRerankTaskSettings(null, null, null),
"model"
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
MatcherAssert.assertThat(xContentResult, is("""
{"model":"model","query":"query","documents":["abc"]}"""));
}
public void testXContent_PrefersRootLevelReturnDocumentsAndTopN() throws IOException {
var entity = new CohereRerankRequestEntity(
"query",
List.of("abc"),
Boolean.FALSE,
99,
new CohereRerankTaskSettings(33, Boolean.TRUE, null),
"model"
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
MatcherAssert.assertThat(xContentResult, is("""
{"model":"model","query":"query","documents":["abc"],"return_documents":false,"top_n":99}"""));
}
public void testXContent_UsesTaskSettingsIfNoRootOptionsDefined() throws IOException {
var entity = new CohereRerankRequestEntity(
"query",
List.of("abc"),
null,
null,
new CohereRerankTaskSettings(33, Boolean.TRUE, null),
"model"
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
MatcherAssert.assertThat(xContentResult, is("""
{"model":"model","query":"query","documents":["abc"],"return_documents":true,"top_n":33}"""));
}
}

View file

@ -20,8 +20,8 @@ import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhi
import static org.hamcrest.MatcherAssert.assertThat;
public class GoogleVertexAiRerankRequestEntityTests extends ESTestCase {
public void testXContent_SingleRequest_WritesModelAndTopNIfDefined() throws IOException {
var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc"), "model", 8);
public void testXContent_SingleRequest_WritesAllFieldsIfDefined() throws IOException {
var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc"), Boolean.TRUE, 10, "model");
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
@ -37,13 +37,14 @@ public class GoogleVertexAiRerankRequestEntityTests extends ESTestCase {
"content": "abc"
}
],
"topN": 8
"topN": 10,
"ignoreRecordDetailsInResponse": false
}
"""));
}
public void testXContent_SingleRequest_DoesNotWriteModelAndTopNIfNull() throws IOException {
var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc"), null, null);
public void testXContent_SingleRequest_WritesMinimalFields() throws IOException {
var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc"), null, null, null);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
@ -62,8 +63,8 @@ public class GoogleVertexAiRerankRequestEntityTests extends ESTestCase {
"""));
}
public void testXContent_MultipleRequests_WritesModelAndTopNIfDefined() throws IOException {
var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc", "def"), "model", 8);
public void testXContent_MultipleRequests_WritesAllFieldsIfDefined() throws IOException {
var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc", "def"), Boolean.FALSE, 12, "model");
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
@ -83,13 +84,14 @@ public class GoogleVertexAiRerankRequestEntityTests extends ESTestCase {
"content": "def"
}
],
"topN": 8
"topN": 12,
"ignoreRecordDetailsInResponse": true
}
"""));
}
public void testXContent_MultipleRequests_DoesNotWriteModelAndTopNIfNull() throws IOException {
var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc", "def"), null, null);
public void testXContent_MultipleRequests_WritesMinimalFields() throws IOException {
var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc", "def"), null, null, null);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
@ -111,5 +113,4 @@ public class GoogleVertexAiRerankRequestEntityTests extends ESTestCase {
}
"""));
}
}

View file

@ -29,11 +29,11 @@ public class GoogleVertexAiRerankRequestTests extends ESTestCase {
private static final String AUTH_HEADER_VALUE = "foo";
public void testCreateRequest_WithoutModelSet_And_WithoutTopNSet() throws IOException {
public void testCreateRequest_WithMinimalFieldsSet() throws IOException {
var input = "input";
var query = "query";
var request = createRequest(query, input, null, null);
var request = createRequest(query, input, null, null, null, null);
var httpRequest = request.createHttpRequest();
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
@ -53,8 +53,9 @@ public class GoogleVertexAiRerankRequestTests extends ESTestCase {
var input = "input";
var query = "query";
var topN = 1;
var taskSettingsTopN = 3;
var request = createRequest(query, input, null, topN);
var request = createRequest(query, input, null, topN, null, taskSettingsTopN);
var httpRequest = request.createHttpRequest();
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
@ -71,12 +72,55 @@ public class GoogleVertexAiRerankRequestTests extends ESTestCase {
assertThat(requestMap.get("topN"), is(topN));
}
public void testCreateRequest_UsesTaskSettingsTopNWhenRootLevelIsNull() throws IOException {
var input = "input";
var query = "query";
var topN = 1;
var request = createRequest(query, input, null, null, null, topN);
var httpRequest = request.createHttpRequest();
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
var httpPost = (HttpPost) httpRequest.httpRequestBase();
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE));
var requestMap = entityAsMap(httpPost.getEntity().getContent());
assertThat(requestMap, aMapWithSize(3));
assertThat(requestMap.get("records"), is(List.of(Map.of("id", "0", "content", input))));
assertThat(requestMap.get("query"), is(query));
assertThat(requestMap.get("topN"), is(topN));
}
public void testCreateRequest_WithReturnDocumentsSet() throws IOException {
var input = "input";
var query = "query";
var request = createRequest(query, input, null, null, Boolean.TRUE, null);
var httpRequest = request.createHttpRequest();
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
var httpPost = (HttpPost) httpRequest.httpRequestBase();
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE));
var requestMap = entityAsMap(httpPost.getEntity().getContent());
assertThat(requestMap, aMapWithSize(3));
assertThat(requestMap.get("records"), is(List.of(Map.of("id", "0", "content", input))));
assertThat(requestMap.get("query"), is(query));
assertThat(requestMap.get("ignoreRecordDetailsInResponse"), is(Boolean.FALSE));
}
public void testCreateRequest_WithModelSet() throws IOException {
var input = "input";
var query = "query";
var modelId = "model";
var request = createRequest(query, input, modelId, null);
var request = createRequest(query, input, modelId, null, null, null);
var httpRequest = request.createHttpRequest();
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
@ -94,24 +138,37 @@ public class GoogleVertexAiRerankRequestTests extends ESTestCase {
}
public void testTruncate_DoesNotTruncate() {
var request = createRequest("query", "input", null, null);
var request = createRequest("query", "input", null, null, null, null);
var truncatedRequest = request.truncate();
assertThat(truncatedRequest, sameInstance(request));
}
private static GoogleVertexAiRerankRequest createRequest(String query, String input, @Nullable String modelId, @Nullable Integer topN) {
var rerankModel = GoogleVertexAiRerankModelTests.createModel(modelId, topN);
private static GoogleVertexAiRerankRequest createRequest(
String query,
String input,
@Nullable String modelId,
@Nullable Integer topN,
@Nullable Boolean returnDocuments,
@Nullable Integer taskSettingsTopN
) {
var rerankModel = GoogleVertexAiRerankModelTests.createModel(modelId, taskSettingsTopN);
return new GoogleVertexAiRerankWithoutAuthRequest(query, List.of(input), rerankModel);
return new GoogleVertexAiRerankWithoutAuthRequest(query, List.of(input), rerankModel, topN, returnDocuments);
}
/**
* We use this class to fake the auth implementation to avoid static mocking of {@link GoogleVertexAiRequest}
*/
private static class GoogleVertexAiRerankWithoutAuthRequest extends GoogleVertexAiRerankRequest {
GoogleVertexAiRerankWithoutAuthRequest(String query, List<String> input, GoogleVertexAiRerankModel model) {
super(query, input, model);
GoogleVertexAiRerankWithoutAuthRequest(
String query,
List<String> input,
GoogleVertexAiRerankModel model,
@Nullable Integer topN,
@Nullable Boolean returnDocuments
) {
super(query, input, returnDocuments, topN, model);
}
@Override

View file

@ -21,8 +21,15 @@ import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhi
import static org.hamcrest.MatcherAssert.assertThat;
public class JinaAIRerankRequestEntityTests extends ESTestCase {
public void testXContent_SingleRequest_WritesModelAndTopNIfDefined() throws IOException {
var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), new JinaAIRerankTaskSettings(8, null), "model");
public void testXContent_SingleRequest_WritesAllFieldsIfDefined() throws IOException {
var entity = new JinaAIRerankRequestEntity(
"query",
List.of("abc"),
Boolean.TRUE,
12,
new JinaAIRerankTaskSettings(8, Boolean.FALSE),
"model"
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
@ -35,33 +42,86 @@ public class JinaAIRerankRequestEntityTests extends ESTestCase {
"documents": [
"abc"
],
"top_n": 8
}
"""));
}
public void testXContent_SingleRequest_WritesModelAndTopNIfDefined_ReturnDocumentsTrue() throws IOException {
var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), new JinaAIRerankTaskSettings(8, true), "model");
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
{
"model": "model",
"query": "query",
"documents": [
"abc"
],
"top_n": 8,
"top_n": 12,
"return_documents": true
}
"""));
}
public void testXContent_SingleRequest_WritesModelAndTopNIfDefined_ReturnDocumentsFalse() throws IOException {
var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), new JinaAIRerankTaskSettings(8, false), "model");
public void testXContent_SingleRequest_WritesMinimalFields() throws IOException {
var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), null, null, new JinaAIRerankTaskSettings(null, null), "model");
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
{
"model": "model",
"query": "query",
"documents": [
"abc"
]
}
"""));
}
public void testXContent_MultipleRequests_WritesAllFieldsIfDefined() throws IOException {
var entity = new JinaAIRerankRequestEntity(
"query",
List.of("abc", "def"),
Boolean.FALSE,
12,
new JinaAIRerankTaskSettings(8, Boolean.TRUE),
"model"
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
{
"model": "model",
"query": "query",
"documents": [
"abc",
"def"
],
"top_n": 12,
"return_documents": false
}
"""));
}
public void testXContent_MultipleRequests_WritesMinimalFields() throws IOException {
var entity = new JinaAIRerankRequestEntity("query", List.of("abc", "def"), null, null, null, "model");
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
{
"model": "model",
"query": "query",
"documents": [
"abc",
"def"
]
}
"""));
}
public void testXContent_SingleRequest_UsesTaskSettingsTopNIfRootIsNotDefined() throws IOException {
var entity = new JinaAIRerankRequestEntity(
"query",
List.of("abc"),
null,
null,
new JinaAIRerankTaskSettings(8, Boolean.FALSE),
"model"
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
@ -80,61 +140,4 @@ public class JinaAIRerankRequestEntityTests extends ESTestCase {
"""));
}
public void testXContent_SingleRequest_DoesNotWriteTopNIfNull() throws IOException {
var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), null, "model");
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
{
"model": "model",
"query": "query",
"documents": [
"abc"
]
}
"""));
}
public void testXContent_MultipleRequests_WritesModelAndTopNIfDefined() throws IOException {
var entity = new JinaAIRerankRequestEntity("query", List.of("abc", "def"), new JinaAIRerankTaskSettings(8, null), "model");
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
{
"model": "model",
"query": "query",
"documents": [
"abc",
"def"
],
"top_n": 8
}
"""));
}
public void testXContent_MultipleRequests_DoesNotWriteTopNIfNull() throws IOException {
var entity = new JinaAIRerankRequestEntity("query", List.of("abc", "def"), null, "model");
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
{
"model": "model",
"query": "query",
"documents": [
"abc",
"def"
]
}
"""));
}
}

View file

@ -27,12 +27,12 @@ public class JinaAIRerankRequestTests extends ESTestCase {
private static final String API_KEY = "foo";
public void testCreateRequest_WithoutModelSet_And_WithoutTopNSet() throws IOException {
public void testCreateRequest_WithMinimalFieldsSet() throws IOException {
var input = "input";
var query = "query";
var modelId = "model";
var request = createRequest(query, input, modelId, null);
var request = createRequest(query, input, modelId, null, null, null);
var httpRequest = request.createHttpRequest();
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
@ -49,13 +49,14 @@ public class JinaAIRerankRequestTests extends ESTestCase {
assertThat(requestMap.get("model"), is(modelId));
}
public void testCreateRequest_WithTopNSet() throws IOException {
public void testCreateRequest_WithAllFieldsSet() throws IOException {
var input = "input";
var query = "query";
var topN = 1;
var taskSettingsTopN = 2;
var modelId = "model";
var request = createRequest(query, input, modelId, topN);
var request = createRequest(query, input, modelId, topN, Boolean.FALSE, taskSettingsTopN);
var httpRequest = request.createHttpRequest();
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
@ -66,10 +67,11 @@ public class JinaAIRerankRequestTests extends ESTestCase {
var requestMap = entityAsMap(httpPost.getEntity().getContent());
assertThat(requestMap, aMapWithSize(4));
assertThat(requestMap, aMapWithSize(5));
assertThat(requestMap.get("documents"), is(List.of(input)));
assertThat(requestMap.get("query"), is(query));
assertThat(requestMap.get("top_n"), is(topN));
assertThat(requestMap.get("return_documents"), is(Boolean.FALSE));
assertThat(requestMap.get("model"), is(modelId));
}
@ -78,7 +80,7 @@ public class JinaAIRerankRequestTests extends ESTestCase {
var query = "query";
var modelId = "model";
var request = createRequest(query, input, modelId, null);
var request = createRequest(query, input, modelId, null, null, null);
var httpRequest = request.createHttpRequest();
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
@ -96,15 +98,22 @@ public class JinaAIRerankRequestTests extends ESTestCase {
}
public void testTruncate_DoesNotTruncate() {
var request = createRequest("query", "input", "null", null);
var request = createRequest("query", "input", "null", null, null, null);
var truncatedRequest = request.truncate();
assertThat(truncatedRequest, sameInstance(request));
}
private static JinaAIRerankRequest createRequest(String query, String input, @Nullable String modelId, @Nullable Integer topN) {
var rerankModel = JinaAIRerankModelTests.createModel(API_KEY, modelId, topN);
return new JinaAIRerankRequest(query, List.of(input), rerankModel);
private static JinaAIRerankRequest createRequest(
String query,
String input,
@Nullable String modelId,
@Nullable Integer topN,
@Nullable Boolean returnDocuments,
@Nullable Integer taskSettingsTopN
) {
var rerankModel = JinaAIRerankModelTests.createModel(API_KEY, modelId, taskSettingsTopN);
return new JinaAIRerankRequest(query, List.of(input), returnDocuments, topN, rerankModel);
}
}

View file

@ -20,27 +20,15 @@ import java.util.List;
import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString;
public class VoyageAIRerankRequestEntityTests extends ESTestCase {
public void testXContent_SingleRequest_WritesModelAndTopKIfDefined() throws IOException {
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, null, null), "model");
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
{
"model": "model",
"query": "query",
"documents": [
"abc"
],
"top_k": 8
}
"""));
}
public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_ReturnDocumentsTrue() throws IOException {
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, true, null), "model");
public void testXContent_SingleRequest_WritesAllFieldsIfDefined() throws IOException {
var entity = new VoyageAIRerankRequestEntity(
"query",
List.of("abc"),
Boolean.TRUE,
12,
new VoyageAIRerankTaskSettings(8, Boolean.FALSE, null),
"model"
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
@ -54,13 +42,20 @@ public class VoyageAIRerankRequestEntityTests extends ESTestCase {
"abc"
],
"return_documents": true,
"top_k": 8
"top_k": 12
}
"""));
}
public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_ReturnDocumentsFalse() throws IOException {
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, null), "model");
public void testXContent_SingleRequest_WritesMinimalFields() throws IOException {
var entity = new VoyageAIRerankRequestEntity(
"query",
List.of("abc"),
null,
null,
new VoyageAIRerankTaskSettings(null, true, null),
"model"
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
@ -73,14 +68,20 @@ public class VoyageAIRerankRequestEntityTests extends ESTestCase {
"documents": [
"abc"
],
"return_documents": false,
"top_k": 8
"return_documents": true
}
"""));
}
public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_TruncationTrue() throws IOException {
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, true), "model");
var entity = new VoyageAIRerankRequestEntity(
"query",
List.of("abc"),
null,
null,
new VoyageAIRerankTaskSettings(8, false, true),
"model"
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
@ -101,7 +102,14 @@ public class VoyageAIRerankRequestEntityTests extends ESTestCase {
}
public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_TruncationFalse() throws IOException {
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, false), "model");
var entity = new VoyageAIRerankRequestEntity(
"query",
List.of("abc"),
null,
null,
new VoyageAIRerankTaskSettings(8, false, false),
"model"
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
@ -121,28 +129,12 @@ public class VoyageAIRerankRequestEntityTests extends ESTestCase {
"""));
}
public void testXContent_SingleRequest_DoesNotWriteTopKIfNull() throws IOException {
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), null, "model");
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
{
"model": "model",
"query": "query",
"documents": [
"abc"
]
}
"""));
}
public void testXContent_MultipleRequests_WritesModelAndTopKIfDefined() throws IOException {
public void testXContent_MultipleRequests_WritesAllFieldsIfDefined() throws IOException {
var entity = new VoyageAIRerankRequestEntity(
"query",
List.of("abc", "def"),
Boolean.FALSE,
11,
new VoyageAIRerankTaskSettings(8, null, null),
"model"
);
@ -159,13 +151,14 @@ public class VoyageAIRerankRequestEntityTests extends ESTestCase {
"abc",
"def"
],
"top_k": 8
"return_documents": false,
"top_k": 11
}
"""));
}
public void testXContent_MultipleRequests_DoesNotWriteTopKIfNull() throws IOException {
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc", "def"), null, "model");
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc", "def"), null, null, null, "model");
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
@ -183,4 +176,31 @@ public class VoyageAIRerankRequestEntityTests extends ESTestCase {
"""));
}
public void testXContent_UsesTaskSettingsTopNIfRootIsNotDefined() throws IOException {
var entity = new VoyageAIRerankRequestEntity(
"query",
List.of("abc"),
null,
null,
new VoyageAIRerankTaskSettings(8, Boolean.FALSE, null),
"model"
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
{
"model": "model",
"query": "query",
"documents": [
"abc"
],
"return_documents": false,
"top_k": 8
}
"""));
}
}

View file

@ -27,12 +27,12 @@ public class VoyageAIRerankRequestTests extends ESTestCase {
private static final String API_KEY = "foo";
public void testCreateRequest_WithoutModelSet_And_WithoutTopNSet() throws IOException {
public void testCreateRequest_WithMinimalFields() throws IOException {
var input = "input";
var query = "query";
var modelId = "model";
var request = createRequest(query, input, modelId, null);
var request = createRequest(query, input, modelId, null, null, null);
var httpRequest = request.createHttpRequest();
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
@ -49,13 +49,14 @@ public class VoyageAIRerankRequestTests extends ESTestCase {
assertThat(requestMap.get("model"), is(modelId));
}
public void testCreateRequest_WithTopNSet() throws IOException {
public void testCreateRequest_WithAllFieldsDefined() throws IOException {
var input = "input";
var query = "query";
var topK = 1;
var taskSettingsTopK = 2;
var modelId = "model";
var request = createRequest(query, input, modelId, topK);
var request = createRequest(query, input, modelId, topK, Boolean.FALSE, taskSettingsTopK);
var httpRequest = request.createHttpRequest();
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
@ -66,11 +67,12 @@ public class VoyageAIRerankRequestTests extends ESTestCase {
var requestMap = entityAsMap(httpPost.getEntity().getContent());
assertThat(requestMap, aMapWithSize(4));
assertThat(requestMap, aMapWithSize(5));
assertThat(requestMap.get("documents"), is(List.of(input)));
assertThat(requestMap.get("query"), is(query));
assertThat(requestMap.get("top_k"), is(topK));
assertThat(requestMap.get("model"), is(modelId));
assertThat(requestMap.get("return_documents"), is(Boolean.FALSE));
}
public void testCreateRequest_WithModelSet() throws IOException {
@ -78,7 +80,7 @@ public class VoyageAIRerankRequestTests extends ESTestCase {
var query = "query";
var modelId = "model";
var request = createRequest(query, input, modelId, null);
var request = createRequest(query, input, modelId, null, null, null);
var httpRequest = request.createHttpRequest();
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
@ -96,15 +98,22 @@ public class VoyageAIRerankRequestTests extends ESTestCase {
}
public void testTruncate_DoesNotTruncate() {
var request = createRequest("query", "input", "null", null);
var request = createRequest("query", "input", "null", null, null, null);
var truncatedRequest = request.truncate();
assertThat(truncatedRequest, sameInstance(request));
}
private static VoyageAIRerankRequest createRequest(String query, String input, @Nullable String modelId, @Nullable Integer topK) {
var rerankModel = VoyageAIRerankModelTests.createModel(API_KEY, modelId, topK);
return new VoyageAIRerankRequest(query, List.of(input), rerankModel);
private static VoyageAIRerankRequest createRequest(
String query,
String input,
@Nullable String modelId,
@Nullable Integer topK,
@Nullable Boolean returnDocuments,
@Nullable Integer taskSettingsTopK
) {
var rerankModel = VoyageAIRerankModelTests.createModel(API_KEY, modelId, taskSettingsTopK);
return new VoyageAIRerankRequest(query, List.of(input), returnDocuments, topK, rerankModel);
}
}

View file

@ -42,6 +42,26 @@ public class GoogleVertexAiRerankResponseEntityTests extends ESTestCase {
assertThat(parsedResults.getRankedDocs(), is(List.of(new RankedDocsResults.RankedDoc(2, 0.97F, "content 2"))));
}
public void testFromResponse_CreatesResultsForASingleItem_NoContent() throws IOException {
String responseJson = """
{
"records": [
{
"id": "2",
"title": "title 2",
"score": 0.97
}
]
}
""";
RankedDocsResults parsedResults = GoogleVertexAiRerankResponseEntity.fromResponse(
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
assertThat(parsedResults.getRankedDocs(), is(List.of(new RankedDocsResults.RankedDoc(2, 0.97F, null))));
}
public void testFromResponse_CreatesResultsForMultipleItems() throws IOException {
String responseJson = """
{
@ -72,6 +92,34 @@ public class GoogleVertexAiRerankResponseEntityTests extends ESTestCase {
);
}
public void testFromResponse_CreatesResultsForMultipleItems_NoContent() throws IOException {
String responseJson = """
{
"records": [
{
"id": "2",
"title": "title 2",
"score": 0.97
},
{
"id": "1",
"title": "title 1",
"score": 0.90
}
]
}
""";
RankedDocsResults parsedResults = GoogleVertexAiRerankResponseEntity.fromResponse(
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
assertThat(
parsedResults.getRankedDocs(),
is(List.of(new RankedDocsResults.RankedDoc(2, 0.97F, null), new RankedDocsResults.RankedDoc(1, 0.90F, null)))
);
}
public void testFromResponse_FailsWhenRecordsFieldIsNotPresent() {
String responseJson = """
{
@ -102,36 +150,6 @@ public class GoogleVertexAiRerankResponseEntityTests extends ESTestCase {
assertThat(thrownException.getMessage(), is("Failed to find required field [records] in Google Vertex AI rerank response"));
}
public void testFromResponse_FailsWhenContentFieldIsNotPresent() {
String responseJson = """
{
"records": [
{
"id": "2",
"title": "title 2",
"content": "content 2",
"score": 0.97
},
{
"id": "1",
"title": "title 1",
"not_content": "content 1",
"score": 0.97
}
]
}
""";
var thrownException = expectThrows(
IllegalStateException.class,
() -> GoogleVertexAiRerankResponseEntity.fromResponse(
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
)
);
assertThat(thrownException.getMessage(), is("Failed to find required field [content] in Google Vertex AI rerank response"));
}
public void testFromResponse_FailsWhenScoreFieldIsNotPresent() {
String responseJson = """
{

View file

@ -98,6 +98,8 @@ public class TextSimilarityRankTests extends ESSingleNodeTestCase {
TaskType.RERANK,
this.inferenceId,
inferenceText,
null,
null,
docFeatures,
Map.of("inferenceResultCount", inferenceResultCount),
InputType.INTERNAL_SEARCH,

View file

@ -225,6 +225,8 @@ public class TextSimilarityTestPlugin extends Plugin implements ActionPlugin {
TaskType.RERANK,
inferenceId,
inferenceText,
null,
null,
docFeatures,
Map.of("throwing", true),
InputType.INTERNAL_SEARCH,

View file

@ -910,11 +910,11 @@ public class ServiceUtilsTests extends ESTestCase {
when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
doAnswer(invocation -> {
ActionListener<InferenceServiceResults> listener = invocation.getArgument(7);
ActionListener<InferenceServiceResults> listener = invocation.getArgument(9);
listener.onResponse(new TextEmbeddingFloatResults(List.of()));
return Void.TYPE;
}).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());
}).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any());
PlainActionFuture<Integer> listener = new PlainActionFuture<>();
getEmbeddingSize(model, service, listener);
@ -932,11 +932,11 @@ public class ServiceUtilsTests extends ESTestCase {
when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
doAnswer(invocation -> {
ActionListener<InferenceServiceResults> listener = invocation.getArgument(7);
ActionListener<InferenceServiceResults> listener = invocation.getArgument(9);
listener.onResponse(new TextEmbeddingByteResults(List.of()));
return Void.TYPE;
}).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());
}).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any());
PlainActionFuture<Integer> listener = new PlainActionFuture<>();
getEmbeddingSize(model, service, listener);
@ -956,11 +956,11 @@ public class ServiceUtilsTests extends ESTestCase {
var textEmbedding = TextEmbeddingFloatResultsTests.createRandomResults();
doAnswer(invocation -> {
ActionListener<InferenceServiceResults> listener = invocation.getArgument(7);
ActionListener<InferenceServiceResults> listener = invocation.getArgument(9);
listener.onResponse(textEmbedding);
return Void.TYPE;
}).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());
}).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any());
PlainActionFuture<Integer> listener = new PlainActionFuture<>();
getEmbeddingSize(model, service, listener);
@ -979,11 +979,11 @@ public class ServiceUtilsTests extends ESTestCase {
var textEmbedding = TextEmbeddingByteResultsTests.createRandomResults();
doAnswer(invocation -> {
ActionListener<InferenceServiceResults> listener = invocation.getArgument(7);
ActionListener<InferenceServiceResults> listener = invocation.getArgument(9);
listener.onResponse(textEmbedding);
return Void.TYPE;
}).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());
}).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any());
PlainActionFuture<Integer> listener = new PlainActionFuture<>();
getEmbeddingSize(model, service, listener);

View file

@ -389,6 +389,8 @@ public class AlibabaCloudSearchServiceTests extends ESTestCase {
() -> service.infer(
model,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),
@ -431,6 +433,8 @@ public class AlibabaCloudSearchServiceTests extends ESTestCase {
() -> service.infer(
model,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),
@ -446,6 +450,53 @@ public class AlibabaCloudSearchServiceTests extends ESTestCase {
}
}
public void testInfer_ThrowsValidationErrorForInvalidRerankParams() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
Map<String, Object> serviceSettingsMap = new HashMap<>();
serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.SERVICE_ID, "service_id");
serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.HOST, "host");
serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, "default");
serviceSettingsMap.put(ServiceFields.DIMENSIONS, 1536);
Map<String, Object> taskSettingsMap = new HashMap<>();
Map<String, Object> secretSettingsMap = new HashMap<>();
secretSettingsMap.put("api_key", "secret");
var model = AlibabaCloudSearchEmbeddingsModelTests.createModel(
"service",
TaskType.RERANK,
serviceSettingsMap,
taskSettingsMap,
secretSettingsMap
);
try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) {
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
var thrownException = expectThrows(
ValidationException.class,
() -> service.infer(
model,
"hi",
Boolean.TRUE,
10,
List.of("a"),
false,
new HashMap<>(),
null,
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
)
);
assertThat(
thrownException.getMessage(),
is(
"Validation Failed: 1: Invalid return_documents [true]. The return_documents option is not supported by this "
+ "service;2: Invalid top_n [10]. The top_n option is not supported by this service;"
)
);
}
}
public void testChunkedInfer_TextEmbeddingChunkingSettingsSet() throws IOException {
testChunkedInfer(TaskType.TEXT_EMBEDDING, ChunkingSettingsTests.createRandomChunkingSettings());
}

View file

@ -932,6 +932,8 @@ public class AmazonBedrockServiceTests extends ESTestCase {
service.infer(
mockModel,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),
@ -979,6 +981,8 @@ public class AmazonBedrockServiceTests extends ESTestCase {
() -> service.infer(
model,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),
@ -1029,6 +1033,8 @@ public class AmazonBedrockServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),
@ -1071,6 +1077,8 @@ public class AmazonBedrockServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),
@ -1414,6 +1422,8 @@ public class AmazonBedrockServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),

View file

@ -458,6 +458,8 @@ public class AnthropicServiceTests extends ESTestCase {
service.infer(
mockModel,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),
@ -513,6 +515,8 @@ public class AnthropicServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("input"),
false,
new HashMap<>(),
@ -571,6 +575,8 @@ public class AnthropicServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
true,
new HashMap<>(),

View file

@ -1096,6 +1096,8 @@ public class AzureAiStudioServiceTests extends ESTestCase {
service.infer(
mockModel,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),
@ -1134,6 +1136,8 @@ public class AzureAiStudioServiceTests extends ESTestCase {
() -> service.infer(
mockModel,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),
@ -1296,6 +1300,8 @@ public class AzureAiStudioServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),
@ -1347,6 +1353,8 @@ public class AzureAiStudioServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),
@ -1403,6 +1411,8 @@ public class AzureAiStudioServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
true,
new HashMap<>(),

View file

@ -766,6 +766,8 @@ public class AzureOpenAiServiceTests extends ESTestCase {
service.infer(
mockModel,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),
@ -822,6 +824,8 @@ public class AzureOpenAiServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),
@ -1286,6 +1290,8 @@ public class AzureOpenAiServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),
@ -1453,6 +1459,8 @@ public class AzureOpenAiServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
true,
new HashMap<>(),

View file

@ -788,6 +788,8 @@ public class CohereServiceTests extends ESTestCase {
service.infer(
mockModel,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),
@ -856,6 +858,8 @@ public class CohereServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),
@ -1147,6 +1151,8 @@ public class CohereServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),
@ -1207,6 +1213,8 @@ public class CohereServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),
@ -1281,6 +1289,8 @@ public class CohereServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH, null),
@ -1353,6 +1363,8 @@ public class CohereServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),
@ -1629,6 +1641,8 @@ public class CohereServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
true,
new HashMap<>(),

View file

@ -232,7 +232,7 @@ public class DeepSeekServiceTests extends ESTestCase {
try (var service = createService()) {
var model = createModel(service, TaskType.COMPLETION);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.infer(model, null, List.of("hello"), false, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener);
service.infer(model, null, null, null, List.of("hello"), false, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener);
var result = listener.actionGet(TIMEOUT);
assertThat(result, isA(ChatCompletionResults.class));
var completionResults = (ChatCompletionResults) result;
@ -255,7 +255,7 @@ public class DeepSeekServiceTests extends ESTestCase {
try (var service = createService()) {
var model = createModel(service, TaskType.COMPLETION);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.infer(model, null, List.of("hello"), true, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener);
service.infer(model, null, null, null, List.of("hello"), true, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener);
InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream().hasNoErrors().hasEvent("""
{"completion":[{"delta":"hello, world"}]}""");
}

View file

@ -368,6 +368,8 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
service.infer(
mockModel,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),
@ -404,6 +406,8 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
service.infer(
mockModel,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),
@ -443,6 +447,8 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
service.infer(
mockModel,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),
@ -494,6 +500,8 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
service.infer(
model,
null,
null,
null,
List.of("input text"),
false,
new HashMap<>(),
@ -551,6 +559,8 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
service.infer(
model,
null,
null,
null,
List.of("input text"),
false,
new HashMap<>(),

View file

@ -662,6 +662,8 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
service.infer(
mockModel,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),
@ -700,6 +702,8 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
() -> service.infer(
model,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),
@ -775,6 +779,8 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("input"),
false,
new HashMap<>(),
@ -832,6 +838,8 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of(input),
false,
new HashMap<>(),
@ -1005,6 +1013,8 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),

View file

@ -65,6 +65,8 @@ public class HuggingFaceBaseServiceTests extends ESTestCase {
service.infer(
mockModel,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),

View file

@ -556,6 +556,8 @@ public class HuggingFaceServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),
@ -593,6 +595,8 @@ public class HuggingFaceServiceTests extends ESTestCase {
() -> service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),
@ -627,6 +631,8 @@ public class HuggingFaceServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),

View file

@ -602,6 +602,8 @@ public class IbmWatsonxServiceTests extends ESTestCase {
service.infer(
mockModel,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),
@ -641,6 +643,8 @@ public class IbmWatsonxServiceTests extends ESTestCase {
() -> service.infer(
model,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),
@ -697,6 +701,8 @@ public class IbmWatsonxServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of(input),
false,
new HashMap<>(),
@ -840,6 +846,8 @@ public class IbmWatsonxServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),

View file

@ -782,6 +782,8 @@ public class JinaAIServiceTests extends ESTestCase {
service.infer(
mockModel,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),
@ -1044,6 +1046,8 @@ public class JinaAIServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),
@ -1076,6 +1080,8 @@ public class JinaAIServiceTests extends ESTestCase {
service.infer(
model,
"query",
null,
null,
List.of("candidate1", "candidate2"),
false,
new HashMap<>(),
@ -1132,6 +1138,8 @@ public class JinaAIServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),
@ -1201,6 +1209,8 @@ public class JinaAIServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),
@ -1254,6 +1264,8 @@ public class JinaAIServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),
@ -1320,7 +1332,18 @@ public class JinaAIServiceTests extends ESTestCase {
JinaAIEmbeddingType.FLOAT
);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.infer(model, null, List.of("abc"), false, new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, listener);
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),
null,
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);
var result = listener.actionGet(TIMEOUT);
@ -1371,6 +1394,8 @@ public class JinaAIServiceTests extends ESTestCase {
service.infer(
model,
"query",
null,
null,
List.of("candidate1", "candidate2", "candidate3"),
false,
new HashMap<>(),
@ -1454,6 +1479,8 @@ public class JinaAIServiceTests extends ESTestCase {
service.infer(
model,
"query",
null,
null,
List.of("candidate1", "candidate2", "candidate3", "candidate4"),
false,
new HashMap<>(),
@ -1549,6 +1576,8 @@ public class JinaAIServiceTests extends ESTestCase {
service.infer(
model,
"query",
null,
null,
List.of("candidate1", "candidate2", "candidate3"),
false,
new HashMap<>(),
@ -1630,6 +1659,8 @@ public class JinaAIServiceTests extends ESTestCase {
service.infer(
model,
"query",
null,
null,
List.of("candidate1", "candidate2", "candidate3", "candidate4"),
false,
new HashMap<>(),
@ -1724,6 +1755,8 @@ public class JinaAIServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),

View file

@ -586,6 +586,8 @@ public class MistralServiceTests extends ESTestCase {
service.infer(
mockModel,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),
@ -625,6 +627,8 @@ public class MistralServiceTests extends ESTestCase {
() -> service.infer(
model,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),
@ -781,6 +785,8 @@ public class MistralServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),

View file

@ -852,6 +852,8 @@ public class OpenAiServiceTests extends ESTestCase {
service.infer(
mockModel,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),
@ -890,6 +892,8 @@ public class OpenAiServiceTests extends ESTestCase {
() -> service.infer(
model,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),
@ -925,6 +929,8 @@ public class OpenAiServiceTests extends ESTestCase {
service.infer(
mockModel,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),
@ -964,6 +970,8 @@ public class OpenAiServiceTests extends ESTestCase {
service.infer(
mockModel,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),
@ -1024,6 +1032,8 @@ public class OpenAiServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),
@ -1263,6 +1273,8 @@ public class OpenAiServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
true,
new HashMap<>(),
@ -1794,6 +1806,8 @@ public class OpenAiServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),

View file

@ -63,6 +63,8 @@ public class SimpleServiceIntegrationValidatorTests extends ESTestCase {
.infer(
eq(mockModel),
eq(null),
eq(null),
eq(null),
eq(TEST_INPUT),
eq(false),
eq(Map.of()),
@ -97,13 +99,15 @@ public class SimpleServiceIntegrationValidatorTests extends ESTestCase {
private void mockSuccessfulCallToService(String query, InferenceServiceResults result) {
doAnswer(ans -> {
ActionListener<InferenceServiceResults> responseListener = ans.getArgument(7);
ActionListener<InferenceServiceResults> responseListener = ans.getArgument(9);
responseListener.onResponse(result);
return null;
}).when(mockInferenceService)
.infer(
eq(mockModel),
eq(query),
eq(null),
eq(null),
eq(TEST_INPUT),
eq(false),
eq(Map.of()),
@ -120,6 +124,8 @@ public class SimpleServiceIntegrationValidatorTests extends ESTestCase {
verify(mockInferenceService).infer(
eq(mockModel),
eq(withQuery ? TEST_QUERY : null),
eq(null),
eq(null),
eq(TEST_INPUT),
eq(false),
eq(Map.of()),

View file

@ -722,6 +722,8 @@ public class VoyageAIServiceTests extends ESTestCase {
service.infer(
mockModel,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),
@ -768,6 +770,8 @@ public class VoyageAIServiceTests extends ESTestCase {
() -> service.infer(
model,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),
@ -1017,6 +1021,8 @@ public class VoyageAIServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),
@ -1049,6 +1055,8 @@ public class VoyageAIServiceTests extends ESTestCase {
service.infer(
model,
"query",
null,
null,
List.of("candidate1", "candidate2"),
false,
new HashMap<>(),
@ -1103,6 +1111,8 @@ public class VoyageAIServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),
@ -1183,6 +1193,8 @@ public class VoyageAIServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),
@ -1260,7 +1272,18 @@ public class VoyageAIServiceTests extends ESTestCase {
(SimilarityMeasure) null
);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.infer(model, null, List.of("abc"), false, new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, listener);
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),
null,
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);
var result = listener.actionGet(TIMEOUT);
@ -1315,6 +1338,8 @@ public class VoyageAIServiceTests extends ESTestCase {
service.infer(
model,
"query",
null,
null,
List.of("candidate1", "candidate2", "candidate3"),
false,
new HashMap<>(),
@ -1401,6 +1426,8 @@ public class VoyageAIServiceTests extends ESTestCase {
service.infer(
model,
"query",
null,
null,
List.of("candidate1", "candidate2", "candidate3", "candidate4"),
false,
new HashMap<>(),
@ -1493,6 +1520,8 @@ public class VoyageAIServiceTests extends ESTestCase {
service.infer(
model,
"query",
null,
null,
List.of("candidate1", "candidate2", "candidate3"),
false,
new HashMap<>(),
@ -1569,6 +1598,8 @@ public class VoyageAIServiceTests extends ESTestCase {
service.infer(
model,
"query",
null,
null,
List.of("candidate1", "candidate2", "candidate3", "candidate4"),
false,
new HashMap<>(),
@ -1663,6 +1694,8 @@ public class VoyageAIServiceTests extends ESTestCase {
service.infer(
model,
null,
null,
null,
List.of("abc"),
false,
new HashMap<>(),

View file

@ -123,6 +123,8 @@ public class TransportCoordinatedInferenceAction extends HandledTransportAction<
TaskType.ANY,
request.getModelId(),
null,
null,
null,
request.getInputs(),
request.getTaskSettings(),
inputType,