mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 09:28:55 -04:00
Adding common rerank options to Perform Inference API (#125239)
* wip * Adding rerank common options * Linting * Linting * [CI] Auto commit changes from spotless * Update docs/changelog/125239.yaml * PR feedback --------- Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
This commit is contained in:
parent
af1f1452d7
commit
a6f685cc2a
66 changed files with 1306 additions and 287 deletions
6
docs/changelog/125239.yaml
Normal file
6
docs/changelog/125239.yaml
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
pr: 125239
|
||||||
|
summary: Adding common rerank options to Perform Inference API
|
||||||
|
area: Machine Learning
|
||||||
|
type: enhancement
|
||||||
|
issues:
|
||||||
|
- 111273
|
|
@ -155,6 +155,7 @@ public class TransportVersions {
|
||||||
public static final TransportVersion ESQL_AGGREGATE_METRIC_DOUBLE_LITERAL_8_19 = def(8_841_0_12);
|
public static final TransportVersion ESQL_AGGREGATE_METRIC_DOUBLE_LITERAL_8_19 = def(8_841_0_12);
|
||||||
public static final TransportVersion INFERENCE_MODEL_REGISTRY_METADATA_8_19 = def(8_841_0_13);
|
public static final TransportVersion INFERENCE_MODEL_REGISTRY_METADATA_8_19 = def(8_841_0_13);
|
||||||
public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14);
|
public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14);
|
||||||
|
public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15);
|
||||||
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
|
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
|
||||||
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
|
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
|
||||||
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
|
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
|
||||||
|
@ -201,6 +202,7 @@ public class TransportVersions {
|
||||||
public static final TransportVersion INDEXING_STATS_INCLUDES_RECENT_WRITE_LOAD = def(9_034_0_00);
|
public static final TransportVersion INDEXING_STATS_INCLUDES_RECENT_WRITE_LOAD = def(9_034_0_00);
|
||||||
public static final TransportVersion ESQL_AGGREGATE_METRIC_DOUBLE_LITERAL = def(9_035_0_00);
|
public static final TransportVersion ESQL_AGGREGATE_METRIC_DOUBLE_LITERAL = def(9_035_0_00);
|
||||||
public static final TransportVersion INDEX_METADATA_INCLUDES_RECENT_WRITE_LOAD = def(9_036_0_00);
|
public static final TransportVersion INDEX_METADATA_INCLUDES_RECENT_WRITE_LOAD = def(9_036_0_00);
|
||||||
|
public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED = def(9_037_0_00);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* STOP! READ THIS FIRST! No, really,
|
* STOP! READ THIS FIRST! No, really,
|
||||||
|
|
|
@ -93,6 +93,8 @@ public interface InferenceService extends Closeable {
|
||||||
*
|
*
|
||||||
* @param model The model
|
* @param model The model
|
||||||
* @param query Inference query, mainly for re-ranking
|
* @param query Inference query, mainly for re-ranking
|
||||||
|
* @param returnDocuments For re-ranking task type, whether to return documents
|
||||||
|
* @param topN For re-ranking task type, how many docs to return
|
||||||
* @param input Inference input
|
* @param input Inference input
|
||||||
* @param stream Stream inference results
|
* @param stream Stream inference results
|
||||||
* @param taskSettings Settings in the request to override the model's defaults
|
* @param taskSettings Settings in the request to override the model's defaults
|
||||||
|
@ -103,6 +105,8 @@ public interface InferenceService extends Closeable {
|
||||||
void infer(
|
void infer(
|
||||||
Model model,
|
Model model,
|
||||||
@Nullable String query,
|
@Nullable String query,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer topN,
|
||||||
List<String> input,
|
List<String> input,
|
||||||
boolean stream,
|
boolean stream,
|
||||||
Map<String, Object> taskSettings,
|
Map<String, Object> taskSettings,
|
||||||
|
|
|
@ -60,6 +60,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
|
||||||
public static final ParseField INPUT_TYPE = new ParseField("input_type");
|
public static final ParseField INPUT_TYPE = new ParseField("input_type");
|
||||||
public static final ParseField TASK_SETTINGS = new ParseField("task_settings");
|
public static final ParseField TASK_SETTINGS = new ParseField("task_settings");
|
||||||
public static final ParseField QUERY = new ParseField("query");
|
public static final ParseField QUERY = new ParseField("query");
|
||||||
|
public static final ParseField RETURN_DOCUMENTS = new ParseField("return_documents");
|
||||||
|
public static final ParseField TOP_N = new ParseField("top_n");
|
||||||
public static final ParseField TIMEOUT = new ParseField("timeout");
|
public static final ParseField TIMEOUT = new ParseField("timeout");
|
||||||
|
|
||||||
static final ObjectParser<Request.Builder, Void> PARSER = new ObjectParser<>(NAME, Request.Builder::new);
|
static final ObjectParser<Request.Builder, Void> PARSER = new ObjectParser<>(NAME, Request.Builder::new);
|
||||||
|
@ -68,6 +70,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
|
||||||
PARSER.declareString(Request.Builder::setInputType, INPUT_TYPE);
|
PARSER.declareString(Request.Builder::setInputType, INPUT_TYPE);
|
||||||
PARSER.declareObject(Request.Builder::setTaskSettings, (p, c) -> p.mapOrdered(), TASK_SETTINGS);
|
PARSER.declareObject(Request.Builder::setTaskSettings, (p, c) -> p.mapOrdered(), TASK_SETTINGS);
|
||||||
PARSER.declareString(Request.Builder::setQuery, QUERY);
|
PARSER.declareString(Request.Builder::setQuery, QUERY);
|
||||||
|
PARSER.declareBoolean(Request.Builder::setReturnDocuments, RETURN_DOCUMENTS);
|
||||||
|
PARSER.declareInt(Request.Builder::setTopN, TOP_N);
|
||||||
PARSER.declareString(Builder::setInferenceTimeout, TIMEOUT);
|
PARSER.declareString(Builder::setInferenceTimeout, TIMEOUT);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,6 +93,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
|
||||||
private final TaskType taskType;
|
private final TaskType taskType;
|
||||||
private final String inferenceEntityId;
|
private final String inferenceEntityId;
|
||||||
private final String query;
|
private final String query;
|
||||||
|
private final Boolean returnDocuments;
|
||||||
|
private final Integer topN;
|
||||||
private final List<String> input;
|
private final List<String> input;
|
||||||
private final Map<String, Object> taskSettings;
|
private final Map<String, Object> taskSettings;
|
||||||
private final InputType inputType;
|
private final InputType inputType;
|
||||||
|
@ -99,6 +105,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
|
||||||
TaskType taskType,
|
TaskType taskType,
|
||||||
String inferenceEntityId,
|
String inferenceEntityId,
|
||||||
String query,
|
String query,
|
||||||
|
Boolean returnDocuments,
|
||||||
|
Integer topN,
|
||||||
List<String> input,
|
List<String> input,
|
||||||
Map<String, Object> taskSettings,
|
Map<String, Object> taskSettings,
|
||||||
InputType inputType,
|
InputType inputType,
|
||||||
|
@ -109,6 +117,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
|
||||||
taskType,
|
taskType,
|
||||||
inferenceEntityId,
|
inferenceEntityId,
|
||||||
query,
|
query,
|
||||||
|
returnDocuments,
|
||||||
|
topN,
|
||||||
input,
|
input,
|
||||||
taskSettings,
|
taskSettings,
|
||||||
inputType,
|
inputType,
|
||||||
|
@ -122,6 +132,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
|
||||||
TaskType taskType,
|
TaskType taskType,
|
||||||
String inferenceEntityId,
|
String inferenceEntityId,
|
||||||
String query,
|
String query,
|
||||||
|
Boolean returnDocuments,
|
||||||
|
Integer topN,
|
||||||
List<String> input,
|
List<String> input,
|
||||||
Map<String, Object> taskSettings,
|
Map<String, Object> taskSettings,
|
||||||
InputType inputType,
|
InputType inputType,
|
||||||
|
@ -133,6 +145,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
|
||||||
this.taskType = taskType;
|
this.taskType = taskType;
|
||||||
this.inferenceEntityId = inferenceEntityId;
|
this.inferenceEntityId = inferenceEntityId;
|
||||||
this.query = query;
|
this.query = query;
|
||||||
|
this.returnDocuments = returnDocuments;
|
||||||
|
this.topN = topN;
|
||||||
this.input = input;
|
this.input = input;
|
||||||
this.taskSettings = taskSettings;
|
this.taskSettings = taskSettings;
|
||||||
this.inputType = inputType;
|
this.inputType = inputType;
|
||||||
|
@ -164,6 +178,15 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
|
||||||
this.inferenceTimeout = DEFAULT_TIMEOUT;
|
this.inferenceTimeout = DEFAULT_TIMEOUT;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_COMMON_OPTIONS_ADDED)
|
||||||
|
|| in.getTransportVersion().isPatchFrom(TransportVersions.RERANK_COMMON_OPTIONS_ADDED_8_19)) {
|
||||||
|
this.returnDocuments = in.readOptionalBoolean();
|
||||||
|
this.topN = in.readOptionalInt();
|
||||||
|
} else {
|
||||||
|
this.returnDocuments = null;
|
||||||
|
this.topN = null;
|
||||||
|
}
|
||||||
|
|
||||||
// streaming is not supported yet for transport traffic
|
// streaming is not supported yet for transport traffic
|
||||||
this.stream = false;
|
this.stream = false;
|
||||||
}
|
}
|
||||||
|
@ -184,6 +207,14 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
|
||||||
return query;
|
return query;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Boolean getReturnDocuments() {
|
||||||
|
return returnDocuments;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getTopN() {
|
||||||
|
return topN;
|
||||||
|
}
|
||||||
|
|
||||||
public Map<String, Object> getTaskSettings() {
|
public Map<String, Object> getTaskSettings() {
|
||||||
return taskSettings;
|
return taskSettings;
|
||||||
}
|
}
|
||||||
|
@ -225,6 +256,17 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
|
||||||
e.addValidationError(format("Field [query] cannot be empty for task type [%s]", TaskType.RERANK));
|
e.addValidationError(format("Field [query] cannot be empty for task type [%s]", TaskType.RERANK));
|
||||||
return e;
|
return e;
|
||||||
}
|
}
|
||||||
|
} else if (taskType.equals(TaskType.ANY) == false) {
|
||||||
|
if (returnDocuments != null) {
|
||||||
|
var e = new ActionRequestValidationException();
|
||||||
|
e.addValidationError(format("Field [return_documents] cannot be specified for task type [%s]", taskType));
|
||||||
|
return e;
|
||||||
|
}
|
||||||
|
if (topN != null) {
|
||||||
|
var e = new ActionRequestValidationException();
|
||||||
|
e.addValidationError(format("Field [top_n] cannot be specified for task type [%s]", taskType));
|
||||||
|
return e;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (taskType.equals(TaskType.TEXT_EMBEDDING) == false
|
if (taskType.equals(TaskType.TEXT_EMBEDDING) == false
|
||||||
|
@ -258,6 +300,12 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
|
||||||
out.writeOptionalString(query);
|
out.writeOptionalString(query);
|
||||||
out.writeTimeValue(inferenceTimeout);
|
out.writeTimeValue(inferenceTimeout);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (out.getTransportVersion().onOrAfter(TransportVersions.RERANK_COMMON_OPTIONS_ADDED)
|
||||||
|
|| out.getTransportVersion().isPatchFrom(TransportVersions.RERANK_COMMON_OPTIONS_ADDED_8_19)) {
|
||||||
|
out.writeOptionalBoolean(returnDocuments);
|
||||||
|
out.writeOptionalInt(topN);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// default for easier testing
|
// default for easier testing
|
||||||
|
@ -283,6 +331,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
|
||||||
&& taskType == request.taskType
|
&& taskType == request.taskType
|
||||||
&& Objects.equals(inferenceEntityId, request.inferenceEntityId)
|
&& Objects.equals(inferenceEntityId, request.inferenceEntityId)
|
||||||
&& Objects.equals(query, request.query)
|
&& Objects.equals(query, request.query)
|
||||||
|
&& Objects.equals(returnDocuments, request.returnDocuments)
|
||||||
|
&& Objects.equals(topN, request.topN)
|
||||||
&& Objects.equals(input, request.input)
|
&& Objects.equals(input, request.input)
|
||||||
&& Objects.equals(taskSettings, request.taskSettings)
|
&& Objects.equals(taskSettings, request.taskSettings)
|
||||||
&& inputType == request.inputType
|
&& inputType == request.inputType
|
||||||
|
@ -296,6 +346,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
|
||||||
taskType,
|
taskType,
|
||||||
inferenceEntityId,
|
inferenceEntityId,
|
||||||
query,
|
query,
|
||||||
|
returnDocuments,
|
||||||
|
topN,
|
||||||
input,
|
input,
|
||||||
taskSettings,
|
taskSettings,
|
||||||
inputType,
|
inputType,
|
||||||
|
@ -312,6 +364,8 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
|
||||||
private InputType inputType = InputType.UNSPECIFIED;
|
private InputType inputType = InputType.UNSPECIFIED;
|
||||||
private Map<String, Object> taskSettings = Map.of();
|
private Map<String, Object> taskSettings = Map.of();
|
||||||
private String query;
|
private String query;
|
||||||
|
private Boolean returnDocuments;
|
||||||
|
private Integer topN;
|
||||||
private TimeValue timeout = DEFAULT_TIMEOUT;
|
private TimeValue timeout = DEFAULT_TIMEOUT;
|
||||||
private boolean stream = false;
|
private boolean stream = false;
|
||||||
private InferenceContext context;
|
private InferenceContext context;
|
||||||
|
@ -338,6 +392,16 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Builder setReturnDocuments(Boolean returnDocuments) {
|
||||||
|
this.returnDocuments = returnDocuments;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder setTopN(Integer topN) {
|
||||||
|
this.topN = topN;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
public Builder setInputType(InputType inputType) {
|
public Builder setInputType(InputType inputType) {
|
||||||
this.inputType = inputType;
|
this.inputType = inputType;
|
||||||
return this;
|
return this;
|
||||||
|
@ -373,7 +437,19 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
|
||||||
}
|
}
|
||||||
|
|
||||||
public Request build() {
|
public Request build() {
|
||||||
return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream, context);
|
return new Request(
|
||||||
|
taskType,
|
||||||
|
inferenceEntityId,
|
||||||
|
query,
|
||||||
|
returnDocuments,
|
||||||
|
topN,
|
||||||
|
input,
|
||||||
|
taskSettings,
|
||||||
|
inputType,
|
||||||
|
timeout,
|
||||||
|
stream,
|
||||||
|
context
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -384,6 +460,10 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
|
||||||
+ this.getInferenceEntityId()
|
+ this.getInferenceEntityId()
|
||||||
+ ", query="
|
+ ", query="
|
||||||
+ this.getQuery()
|
+ this.getQuery()
|
||||||
|
+ ", returnDocuments="
|
||||||
|
+ this.getReturnDocuments()
|
||||||
|
+ ", topN="
|
||||||
|
+ this.getTopN()
|
||||||
+ ", input="
|
+ ", input="
|
||||||
+ this.getInput()
|
+ this.getInput()
|
||||||
+ ", taskSettings="
|
+ ", taskSettings="
|
||||||
|
|
|
@ -44,6 +44,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
randomFrom(TaskType.values()),
|
randomFrom(TaskType.values()),
|
||||||
randomAlphaOfLength(6),
|
randomAlphaOfLength(6),
|
||||||
randomAlphaOfLengthOrNull(10),
|
randomAlphaOfLengthOrNull(10),
|
||||||
|
randomBoolean(),
|
||||||
|
randomIntBetween(0, 10),
|
||||||
randomList(1, 5, () -> randomAlphaOfLength(8)),
|
randomList(1, 5, () -> randomAlphaOfLength(8)),
|
||||||
randomMap(0, 3, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))),
|
randomMap(0, 3, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))),
|
||||||
randomFrom(InputType.values()),
|
randomFrom(InputType.values()),
|
||||||
|
@ -85,6 +87,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
TaskType.TEXT_EMBEDDING,
|
TaskType.TEXT_EMBEDDING,
|
||||||
"model",
|
"model",
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("input"),
|
List.of("input"),
|
||||||
null,
|
null,
|
||||||
null,
|
null,
|
||||||
|
@ -100,6 +104,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
TaskType.RERANK,
|
TaskType.RERANK,
|
||||||
"model",
|
"model",
|
||||||
"query",
|
"query",
|
||||||
|
Boolean.TRUE,
|
||||||
|
34,
|
||||||
List.of("input"),
|
List.of("input"),
|
||||||
null,
|
null,
|
||||||
null,
|
null,
|
||||||
|
@ -119,6 +125,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
null,
|
null,
|
||||||
null,
|
null,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
false
|
false
|
||||||
);
|
);
|
||||||
ActionRequestValidationException inputNullError = inputNullRequest.validate();
|
ActionRequestValidationException inputNullError = inputNullRequest.validate();
|
||||||
|
@ -131,6 +139,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
TaskType.TEXT_EMBEDDING,
|
TaskType.TEXT_EMBEDDING,
|
||||||
"model",
|
"model",
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(),
|
List.of(),
|
||||||
null,
|
null,
|
||||||
null,
|
null,
|
||||||
|
@ -142,11 +152,52 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
assertThat(inputEmptyError.getMessage(), is("Validation Failed: 1: Field [input] cannot be an empty array;"));
|
assertThat(inputEmptyError.getMessage(), is("Validation Failed: 1: Field [input] cannot be an empty array;"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testValidation_TextEmbedding_WithReturnDocument() {
|
||||||
|
InferenceAction.Request inputRequest = new InferenceAction.Request(
|
||||||
|
TaskType.TEXT_EMBEDDING,
|
||||||
|
"model",
|
||||||
|
null,
|
||||||
|
Boolean.TRUE,
|
||||||
|
null,
|
||||||
|
List.of("input"),
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
false
|
||||||
|
);
|
||||||
|
ActionRequestValidationException inputError = inputRequest.validate();
|
||||||
|
assertNotNull(inputError);
|
||||||
|
assertThat(
|
||||||
|
inputError.getMessage(),
|
||||||
|
is("Validation Failed: 1: Field [return_documents] cannot be specified for task type [text_embedding];")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testValidation_TextEmbedding_WithTopN() {
|
||||||
|
InferenceAction.Request inputRequest = new InferenceAction.Request(
|
||||||
|
TaskType.TEXT_EMBEDDING,
|
||||||
|
"model",
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
12,
|
||||||
|
List.of("input"),
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
false
|
||||||
|
);
|
||||||
|
ActionRequestValidationException inputError = inputRequest.validate();
|
||||||
|
assertNotNull(inputError);
|
||||||
|
assertThat(inputError.getMessage(), is("Validation Failed: 1: Field [top_n] cannot be specified for task type [text_embedding];"));
|
||||||
|
}
|
||||||
|
|
||||||
public void testValidation_Rerank_Null() {
|
public void testValidation_Rerank_Null() {
|
||||||
InferenceAction.Request queryNullRequest = new InferenceAction.Request(
|
InferenceAction.Request queryNullRequest = new InferenceAction.Request(
|
||||||
TaskType.RERANK,
|
TaskType.RERANK,
|
||||||
"model",
|
"model",
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("input"),
|
List.of("input"),
|
||||||
null,
|
null,
|
||||||
null,
|
null,
|
||||||
|
@ -163,6 +214,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
TaskType.RERANK,
|
TaskType.RERANK,
|
||||||
"model",
|
"model",
|
||||||
"",
|
"",
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("input"),
|
List.of("input"),
|
||||||
null,
|
null,
|
||||||
null,
|
null,
|
||||||
|
@ -179,6 +232,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
TaskType.RERANK,
|
TaskType.RERANK,
|
||||||
"model",
|
"model",
|
||||||
"query",
|
"query",
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("input"),
|
List.of("input"),
|
||||||
null,
|
null,
|
||||||
InputType.SEARCH,
|
InputType.SEARCH,
|
||||||
|
@ -195,6 +250,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
TaskType.SPARSE_EMBEDDING,
|
TaskType.SPARSE_EMBEDDING,
|
||||||
"model",
|
"model",
|
||||||
"",
|
"",
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("input"),
|
List.of("input"),
|
||||||
null,
|
null,
|
||||||
InputType.SEARCH,
|
InputType.SEARCH,
|
||||||
|
@ -209,11 +266,56 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testValidation_SparseEmbedding_WithReturnDocument() {
|
||||||
|
InferenceAction.Request queryRequest = new InferenceAction.Request(
|
||||||
|
TaskType.SPARSE_EMBEDDING,
|
||||||
|
"model",
|
||||||
|
"",
|
||||||
|
Boolean.FALSE,
|
||||||
|
null,
|
||||||
|
List.of("input"),
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
false
|
||||||
|
);
|
||||||
|
ActionRequestValidationException queryError = queryRequest.validate();
|
||||||
|
assertNotNull(queryError);
|
||||||
|
assertThat(
|
||||||
|
queryError.getMessage(),
|
||||||
|
is("Validation Failed: 1: Field [return_documents] cannot be specified for task type [sparse_embedding];")
|
||||||
|
);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testValidation_SparseEmbedding_WithTopN() {
|
||||||
|
InferenceAction.Request queryRequest = new InferenceAction.Request(
|
||||||
|
TaskType.SPARSE_EMBEDDING,
|
||||||
|
"model",
|
||||||
|
"",
|
||||||
|
null,
|
||||||
|
22,
|
||||||
|
List.of("input"),
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
false
|
||||||
|
);
|
||||||
|
ActionRequestValidationException queryError = queryRequest.validate();
|
||||||
|
assertNotNull(queryError);
|
||||||
|
assertThat(
|
||||||
|
queryError.getMessage(),
|
||||||
|
is("Validation Failed: 1: Field [top_n] cannot be specified for task type [sparse_embedding];")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
public void testValidation_Completion_WithInputType() {
|
public void testValidation_Completion_WithInputType() {
|
||||||
InferenceAction.Request queryRequest = new InferenceAction.Request(
|
InferenceAction.Request queryRequest = new InferenceAction.Request(
|
||||||
TaskType.COMPLETION,
|
TaskType.COMPLETION,
|
||||||
"model",
|
"model",
|
||||||
"",
|
"",
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("input"),
|
List.of("input"),
|
||||||
null,
|
null,
|
||||||
InputType.SEARCH,
|
InputType.SEARCH,
|
||||||
|
@ -225,11 +327,52 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
assertThat(queryError.getMessage(), is("Validation Failed: 1: Field [input_type] cannot be specified for task type [completion];"));
|
assertThat(queryError.getMessage(), is("Validation Failed: 1: Field [input_type] cannot be specified for task type [completion];"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testValidation_Completion_WithReturnDocuments() {
|
||||||
|
InferenceAction.Request queryRequest = new InferenceAction.Request(
|
||||||
|
TaskType.COMPLETION,
|
||||||
|
"model",
|
||||||
|
"",
|
||||||
|
Boolean.TRUE,
|
||||||
|
null,
|
||||||
|
List.of("input"),
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
false
|
||||||
|
);
|
||||||
|
ActionRequestValidationException queryError = queryRequest.validate();
|
||||||
|
assertNotNull(queryError);
|
||||||
|
assertThat(
|
||||||
|
queryError.getMessage(),
|
||||||
|
is("Validation Failed: 1: Field [return_documents] cannot be specified for task type [completion];")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testValidation_Completion_WithTopN() {
|
||||||
|
InferenceAction.Request queryRequest = new InferenceAction.Request(
|
||||||
|
TaskType.COMPLETION,
|
||||||
|
"model",
|
||||||
|
"",
|
||||||
|
null,
|
||||||
|
77,
|
||||||
|
List.of("input"),
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
false
|
||||||
|
);
|
||||||
|
ActionRequestValidationException queryError = queryRequest.validate();
|
||||||
|
assertNotNull(queryError);
|
||||||
|
assertThat(queryError.getMessage(), is("Validation Failed: 1: Field [top_n] cannot be specified for task type [completion];"));
|
||||||
|
}
|
||||||
|
|
||||||
public void testValidation_ChatCompletion_WithInputType() {
|
public void testValidation_ChatCompletion_WithInputType() {
|
||||||
InferenceAction.Request queryRequest = new InferenceAction.Request(
|
InferenceAction.Request queryRequest = new InferenceAction.Request(
|
||||||
TaskType.CHAT_COMPLETION,
|
TaskType.CHAT_COMPLETION,
|
||||||
"model",
|
"model",
|
||||||
"",
|
"",
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("input"),
|
List.of("input"),
|
||||||
null,
|
null,
|
||||||
InputType.SEARCH,
|
InputType.SEARCH,
|
||||||
|
@ -244,6 +387,45 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testValidation_ChatCompletion_WithReturnDocuments() {
|
||||||
|
InferenceAction.Request queryRequest = new InferenceAction.Request(
|
||||||
|
TaskType.CHAT_COMPLETION,
|
||||||
|
"model",
|
||||||
|
"",
|
||||||
|
Boolean.TRUE,
|
||||||
|
null,
|
||||||
|
List.of("input"),
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
false
|
||||||
|
);
|
||||||
|
ActionRequestValidationException queryError = queryRequest.validate();
|
||||||
|
assertNotNull(queryError);
|
||||||
|
assertThat(
|
||||||
|
queryError.getMessage(),
|
||||||
|
is("Validation Failed: 1: Field [return_documents] cannot be specified for task type [chat_completion];")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testValidation_ChatCompletion_WithTopN() {
|
||||||
|
InferenceAction.Request queryRequest = new InferenceAction.Request(
|
||||||
|
TaskType.CHAT_COMPLETION,
|
||||||
|
"model",
|
||||||
|
"",
|
||||||
|
null,
|
||||||
|
11,
|
||||||
|
List.of("input"),
|
||||||
|
null,
|
||||||
|
InputType.SEARCH,
|
||||||
|
null,
|
||||||
|
false
|
||||||
|
);
|
||||||
|
ActionRequestValidationException queryError = queryRequest.validate();
|
||||||
|
assertNotNull(queryError);
|
||||||
|
assertThat(queryError.getMessage(), is("Validation Failed: 1: Field [top_n] cannot be specified for task type [chat_completion];"));
|
||||||
|
}
|
||||||
|
|
||||||
public void testParseRequest_DefaultsInputTypeToIngest() throws IOException {
|
public void testParseRequest_DefaultsInputTypeToIngest() throws IOException {
|
||||||
String singleInputRequest = """
|
String singleInputRequest = """
|
||||||
{
|
{
|
||||||
|
@ -271,6 +453,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
nextTask,
|
nextTask,
|
||||||
instance.getInferenceEntityId(),
|
instance.getInferenceEntityId(),
|
||||||
instance.getQuery(),
|
instance.getQuery(),
|
||||||
|
instance.getReturnDocuments(),
|
||||||
|
instance.getTopN(),
|
||||||
instance.getInput(),
|
instance.getInput(),
|
||||||
instance.getTaskSettings(),
|
instance.getTaskSettings(),
|
||||||
instance.getInputType(),
|
instance.getInputType(),
|
||||||
|
@ -283,6 +467,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
instance.getTaskType(),
|
instance.getTaskType(),
|
||||||
instance.getInferenceEntityId() + "foo",
|
instance.getInferenceEntityId() + "foo",
|
||||||
instance.getQuery(),
|
instance.getQuery(),
|
||||||
|
instance.getReturnDocuments(),
|
||||||
|
instance.getTopN(),
|
||||||
instance.getInput(),
|
instance.getInput(),
|
||||||
instance.getTaskSettings(),
|
instance.getTaskSettings(),
|
||||||
instance.getInputType(),
|
instance.getInputType(),
|
||||||
|
@ -297,6 +483,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
instance.getTaskType(),
|
instance.getTaskType(),
|
||||||
instance.getInferenceEntityId(),
|
instance.getInferenceEntityId(),
|
||||||
instance.getQuery(),
|
instance.getQuery(),
|
||||||
|
instance.getReturnDocuments(),
|
||||||
|
instance.getTopN(),
|
||||||
changedInputs,
|
changedInputs,
|
||||||
instance.getTaskSettings(),
|
instance.getTaskSettings(),
|
||||||
instance.getInputType(),
|
instance.getInputType(),
|
||||||
|
@ -317,6 +505,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
instance.getTaskType(),
|
instance.getTaskType(),
|
||||||
instance.getInferenceEntityId(),
|
instance.getInferenceEntityId(),
|
||||||
instance.getQuery(),
|
instance.getQuery(),
|
||||||
|
instance.getReturnDocuments(),
|
||||||
|
instance.getTopN(),
|
||||||
instance.getInput(),
|
instance.getInput(),
|
||||||
taskSettings,
|
taskSettings,
|
||||||
instance.getInputType(),
|
instance.getInputType(),
|
||||||
|
@ -331,6 +521,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
instance.getTaskType(),
|
instance.getTaskType(),
|
||||||
instance.getInferenceEntityId(),
|
instance.getInferenceEntityId(),
|
||||||
instance.getQuery(),
|
instance.getQuery(),
|
||||||
|
instance.getReturnDocuments(),
|
||||||
|
instance.getTopN(),
|
||||||
instance.getInput(),
|
instance.getInput(),
|
||||||
instance.getTaskSettings(),
|
instance.getTaskSettings(),
|
||||||
nextInputType,
|
nextInputType,
|
||||||
|
@ -343,6 +535,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
instance.getTaskType(),
|
instance.getTaskType(),
|
||||||
instance.getInferenceEntityId(),
|
instance.getInferenceEntityId(),
|
||||||
instance.getQuery() == null ? randomAlphaOfLength(10) : instance.getQuery() + randomAlphaOfLength(1),
|
instance.getQuery() == null ? randomAlphaOfLength(10) : instance.getQuery() + randomAlphaOfLength(1),
|
||||||
|
instance.getReturnDocuments(),
|
||||||
|
instance.getTopN(),
|
||||||
instance.getInput(),
|
instance.getInput(),
|
||||||
instance.getTaskSettings(),
|
instance.getTaskSettings(),
|
||||||
instance.getInputType(),
|
instance.getInputType(),
|
||||||
|
@ -360,6 +554,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
instance.getTaskType(),
|
instance.getTaskType(),
|
||||||
instance.getInferenceEntityId(),
|
instance.getInferenceEntityId(),
|
||||||
instance.getQuery(),
|
instance.getQuery(),
|
||||||
|
instance.getReturnDocuments(),
|
||||||
|
instance.getTopN(),
|
||||||
instance.getInput(),
|
instance.getInput(),
|
||||||
instance.getTaskSettings(),
|
instance.getTaskSettings(),
|
||||||
instance.getInputType(),
|
instance.getInputType(),
|
||||||
|
@ -374,6 +570,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
instance.getTaskType(),
|
instance.getTaskType(),
|
||||||
instance.getInferenceEntityId(),
|
instance.getInferenceEntityId(),
|
||||||
instance.getQuery(),
|
instance.getQuery(),
|
||||||
|
instance.getReturnDocuments(),
|
||||||
|
instance.getTopN(),
|
||||||
instance.getInput(),
|
instance.getInput(),
|
||||||
instance.getTaskSettings(),
|
instance.getTaskSettings(),
|
||||||
instance.getInputType(),
|
instance.getInputType(),
|
||||||
|
@ -395,6 +593,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
instance.getTaskType(),
|
instance.getTaskType(),
|
||||||
instance.getInferenceEntityId(),
|
instance.getInferenceEntityId(),
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
instance.getInput().subList(0, 1),
|
instance.getInput().subList(0, 1),
|
||||||
instance.getTaskSettings(),
|
instance.getTaskSettings(),
|
||||||
InputType.UNSPECIFIED,
|
InputType.UNSPECIFIED,
|
||||||
|
@ -406,6 +606,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
instance.getTaskType(),
|
instance.getTaskType(),
|
||||||
instance.getInferenceEntityId(),
|
instance.getInferenceEntityId(),
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
instance.getInput(),
|
instance.getInput(),
|
||||||
instance.getTaskSettings(),
|
instance.getTaskSettings(),
|
||||||
InputType.UNSPECIFIED,
|
InputType.UNSPECIFIED,
|
||||||
|
@ -420,6 +622,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
instance.getTaskType(),
|
instance.getTaskType(),
|
||||||
instance.getInferenceEntityId(),
|
instance.getInferenceEntityId(),
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
instance.getInput(),
|
instance.getInput(),
|
||||||
instance.getTaskSettings(),
|
instance.getTaskSettings(),
|
||||||
InputType.INGEST,
|
InputType.INGEST,
|
||||||
|
@ -432,6 +636,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
instance.getTaskType(),
|
instance.getTaskType(),
|
||||||
instance.getInferenceEntityId(),
|
instance.getInferenceEntityId(),
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
instance.getInput(),
|
instance.getInput(),
|
||||||
instance.getTaskSettings(),
|
instance.getTaskSettings(),
|
||||||
InputType.UNSPECIFIED,
|
InputType.UNSPECIFIED,
|
||||||
|
@ -443,6 +649,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
instance.getTaskType(),
|
instance.getTaskType(),
|
||||||
instance.getInferenceEntityId(),
|
instance.getInferenceEntityId(),
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
instance.getInput(),
|
instance.getInput(),
|
||||||
instance.getTaskSettings(),
|
instance.getTaskSettings(),
|
||||||
instance.getInputType(),
|
instance.getInputType(),
|
||||||
|
@ -455,6 +663,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
instance.getTaskType(),
|
instance.getTaskType(),
|
||||||
instance.getInferenceEntityId(),
|
instance.getInferenceEntityId(),
|
||||||
instance.getQuery(),
|
instance.getQuery(),
|
||||||
|
null,
|
||||||
|
null,
|
||||||
instance.getInput(),
|
instance.getInput(),
|
||||||
instance.getTaskSettings(),
|
instance.getTaskSettings(),
|
||||||
instance.getInputType(),
|
instance.getInputType(),
|
||||||
|
@ -462,6 +672,21 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
false,
|
false,
|
||||||
InferenceContext.EMPTY_INSTANCE
|
InferenceContext.EMPTY_INSTANCE
|
||||||
);
|
);
|
||||||
|
} else if (version.before(TransportVersions.RERANK_COMMON_OPTIONS_ADDED)
|
||||||
|
&& version.isPatchFrom(TransportVersions.RERANK_COMMON_OPTIONS_ADDED_8_19) == false) {
|
||||||
|
mutated = new InferenceAction.Request(
|
||||||
|
instance.getTaskType(),
|
||||||
|
instance.getInferenceEntityId(),
|
||||||
|
instance.getQuery(),
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
instance.getInput(),
|
||||||
|
instance.getTaskSettings(),
|
||||||
|
instance.getInputType(),
|
||||||
|
instance.getInferenceTimeout(),
|
||||||
|
false,
|
||||||
|
instance.getContext()
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
mutated = instance;
|
mutated = instance;
|
||||||
}
|
}
|
||||||
|
@ -481,6 +706,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
TaskType.TEXT_EMBEDDING,
|
TaskType.TEXT_EMBEDDING,
|
||||||
"model",
|
"model",
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(),
|
List.of(),
|
||||||
Map.of(),
|
Map.of(),
|
||||||
InputType.UNSPECIFIED,
|
InputType.UNSPECIFIED,
|
||||||
|
@ -503,6 +730,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
TaskType.TEXT_EMBEDDING,
|
TaskType.TEXT_EMBEDDING,
|
||||||
"model",
|
"model",
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(),
|
List.of(),
|
||||||
Map.of(),
|
Map.of(),
|
||||||
InputType.INGEST,
|
InputType.INGEST,
|
||||||
|
@ -525,6 +754,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
TaskType.TEXT_EMBEDDING,
|
TaskType.TEXT_EMBEDDING,
|
||||||
"model",
|
"model",
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("input"),
|
List.of("input"),
|
||||||
Map.of(),
|
Map.of(),
|
||||||
InputType.UNSPECIFIED,
|
InputType.UNSPECIFIED,
|
||||||
|
@ -548,6 +779,8 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
|
||||||
TaskType.TEXT_EMBEDDING,
|
TaskType.TEXT_EMBEDDING,
|
||||||
"model",
|
"model",
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("input"),
|
List.of("input"),
|
||||||
Map.of(),
|
Map.of(),
|
||||||
InputType.UNSPECIFIED,
|
InputType.UNSPECIFIED,
|
||||||
|
|
|
@ -110,6 +110,8 @@ public class TestDenseInferenceServiceExtension implements InferenceServiceExten
|
||||||
public void infer(
|
public void infer(
|
||||||
Model model,
|
Model model,
|
||||||
@Nullable String query,
|
@Nullable String query,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer topN,
|
||||||
List<String> input,
|
List<String> input,
|
||||||
boolean stream,
|
boolean stream,
|
||||||
Map<String, Object> taskSettings,
|
Map<String, Object> taskSettings,
|
||||||
|
|
|
@ -102,6 +102,8 @@ public class TestRerankingServiceExtension implements InferenceServiceExtension
|
||||||
public void infer(
|
public void infer(
|
||||||
Model model,
|
Model model,
|
||||||
@Nullable String query,
|
@Nullable String query,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer topN,
|
||||||
List<String> input,
|
List<String> input,
|
||||||
boolean stream,
|
boolean stream,
|
||||||
Map<String, Object> taskSettings,
|
Map<String, Object> taskSettings,
|
||||||
|
|
|
@ -103,6 +103,8 @@ public class TestSparseInferenceServiceExtension implements InferenceServiceExte
|
||||||
public void infer(
|
public void infer(
|
||||||
Model model,
|
Model model,
|
||||||
@Nullable String query,
|
@Nullable String query,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer topN,
|
||||||
List<String> input,
|
List<String> input,
|
||||||
boolean stream,
|
boolean stream,
|
||||||
Map<String, Object> taskSettings,
|
Map<String, Object> taskSettings,
|
||||||
|
|
|
@ -15,6 +15,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
|
||||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
import org.elasticsearch.common.util.LazyInitializable;
|
import org.elasticsearch.common.util.LazyInitializable;
|
||||||
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
|
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
import org.elasticsearch.core.TimeValue;
|
import org.elasticsearch.core.TimeValue;
|
||||||
import org.elasticsearch.inference.ChunkedInference;
|
import org.elasticsearch.inference.ChunkedInference;
|
||||||
import org.elasticsearch.inference.InferenceServiceConfiguration;
|
import org.elasticsearch.inference.InferenceServiceConfiguration;
|
||||||
|
@ -103,6 +104,8 @@ public class TestStreamingCompletionServiceExtension implements InferenceService
|
||||||
public void infer(
|
public void infer(
|
||||||
Model model,
|
Model model,
|
||||||
String query,
|
String query,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer topN,
|
||||||
List<String> input,
|
List<String> input,
|
||||||
boolean stream,
|
boolean stream,
|
||||||
Map<String, Object> taskSettings,
|
Map<String, Object> taskSettings,
|
||||||
|
|
|
@ -77,6 +77,8 @@ public class TransportInferenceAction extends BaseTransportInferenceAction<Infer
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
request.getQuery(),
|
request.getQuery(),
|
||||||
|
request.getReturnDocuments(),
|
||||||
|
request.getTopN(),
|
||||||
request.getInput(),
|
request.getInput(),
|
||||||
request.isStreaming(),
|
request.isStreaming(),
|
||||||
request.getTaskSettings(),
|
request.getTaskSettings(),
|
||||||
|
|
|
@ -75,7 +75,13 @@ public class VoyageAIActionCreator implements VoyageAIActionVisitor {
|
||||||
serviceComponents.threadPool(),
|
serviceComponents.threadPool(),
|
||||||
overriddenModel,
|
overriddenModel,
|
||||||
RERANK_HANDLER,
|
RERANK_HANDLER,
|
||||||
(rerankInput) -> new VoyageAIRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model),
|
(rerankInput) -> new VoyageAIRerankRequest(
|
||||||
|
rerankInput.getQuery(),
|
||||||
|
rerankInput.getChunks(),
|
||||||
|
rerankInput.getReturnDocuments(),
|
||||||
|
rerankInput.getTopN(),
|
||||||
|
model
|
||||||
|
),
|
||||||
QueryAndDocsInputs.class
|
QueryAndDocsInputs.class
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
@ -69,6 +69,8 @@ public class AlibabaCloudSearchRerankRequestManager extends AlibabaCloudSearchRe
|
||||||
account,
|
account,
|
||||||
rerankInput.getQuery(),
|
rerankInput.getQuery(),
|
||||||
rerankInput.getChunks(),
|
rerankInput.getChunks(),
|
||||||
|
rerankInput.getReturnDocuments(),
|
||||||
|
rerankInput.getTopN(),
|
||||||
model
|
model
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,13 @@ public class CohereRerankRequestManager extends CohereRequestManager {
|
||||||
ActionListener<InferenceServiceResults> listener
|
ActionListener<InferenceServiceResults> listener
|
||||||
) {
|
) {
|
||||||
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
|
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
|
||||||
CohereRerankRequest request = new CohereRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model);
|
CohereRerankRequest request = new CohereRerankRequest(
|
||||||
|
rerankInput.getQuery(),
|
||||||
|
rerankInput.getChunks(),
|
||||||
|
rerankInput.getReturnDocuments(),
|
||||||
|
rerankInput.getTopN(),
|
||||||
|
model
|
||||||
|
);
|
||||||
|
|
||||||
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||||
}
|
}
|
||||||
|
|
|
@ -62,7 +62,13 @@ public class GoogleVertexAiRerankRequestManager extends GoogleVertexAiRequestMan
|
||||||
ActionListener<InferenceServiceResults> listener
|
ActionListener<InferenceServiceResults> listener
|
||||||
) {
|
) {
|
||||||
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
|
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
|
||||||
GoogleVertexAiRerankRequest request = new GoogleVertexAiRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model);
|
GoogleVertexAiRerankRequest request = new GoogleVertexAiRerankRequest(
|
||||||
|
rerankInput.getQuery(),
|
||||||
|
rerankInput.getChunks(),
|
||||||
|
rerankInput.getReturnDocuments(),
|
||||||
|
rerankInput.getTopN(),
|
||||||
|
model
|
||||||
|
);
|
||||||
|
|
||||||
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,7 +49,13 @@ public class JinaAIRerankRequestManager extends JinaAIRequestManager {
|
||||||
ActionListener<InferenceServiceResults> listener
|
ActionListener<InferenceServiceResults> listener
|
||||||
) {
|
) {
|
||||||
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
|
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
|
||||||
JinaAIRerankRequest request = new JinaAIRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model);
|
JinaAIRerankRequest request = new JinaAIRerankRequest(
|
||||||
|
rerankInput.getQuery(),
|
||||||
|
rerankInput.getChunks(),
|
||||||
|
rerankInput.getReturnDocuments(),
|
||||||
|
rerankInput.getTopN(),
|
||||||
|
model
|
||||||
|
);
|
||||||
|
|
||||||
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,8 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.external.http.sender;
|
package org.elasticsearch.xpack.inference.external.http.sender;
|
||||||
|
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
|
@ -22,15 +24,25 @@ public class QueryAndDocsInputs extends InferenceInputs {
|
||||||
|
|
||||||
private final String query;
|
private final String query;
|
||||||
private final List<String> chunks;
|
private final List<String> chunks;
|
||||||
|
private final Boolean returnDocuments;
|
||||||
|
private final Integer topN;
|
||||||
|
|
||||||
public QueryAndDocsInputs(String query, List<String> chunks) {
|
public QueryAndDocsInputs(String query, List<String> chunks) {
|
||||||
this(query, chunks, false);
|
this(query, chunks, null, null, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
public QueryAndDocsInputs(String query, List<String> chunks, boolean stream) {
|
public QueryAndDocsInputs(
|
||||||
|
String query,
|
||||||
|
List<String> chunks,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer topN,
|
||||||
|
boolean stream
|
||||||
|
) {
|
||||||
super(stream);
|
super(stream);
|
||||||
this.query = Objects.requireNonNull(query);
|
this.query = Objects.requireNonNull(query);
|
||||||
this.chunks = Objects.requireNonNull(chunks);
|
this.chunks = Objects.requireNonNull(chunks);
|
||||||
|
this.returnDocuments = returnDocuments;
|
||||||
|
this.topN = topN;
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getQuery() {
|
public String getQuery() {
|
||||||
|
@ -41,6 +53,14 @@ public class QueryAndDocsInputs extends InferenceInputs {
|
||||||
return chunks;
|
return chunks;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Boolean getReturnDocuments() {
|
||||||
|
return returnDocuments;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getTopN() {
|
||||||
|
return topN;
|
||||||
|
}
|
||||||
|
|
||||||
public int inputSize() {
|
public int inputSize() {
|
||||||
return chunks.size();
|
return chunks.size();
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,7 @@ import org.apache.http.client.methods.HttpPost;
|
||||||
import org.apache.http.client.utils.URIBuilder;
|
import org.apache.http.client.utils.URIBuilder;
|
||||||
import org.apache.http.entity.ByteArrayEntity;
|
import org.apache.http.entity.ByteArrayEntity;
|
||||||
import org.elasticsearch.common.Strings;
|
import org.elasticsearch.common.Strings;
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
import org.elasticsearch.xcontent.XContentType;
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
|
import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
|
||||||
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
|
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
|
||||||
|
@ -32,6 +33,8 @@ public class AlibabaCloudSearchRerankRequest implements Request {
|
||||||
private final AlibabaCloudSearchAccount account;
|
private final AlibabaCloudSearchAccount account;
|
||||||
private final String query;
|
private final String query;
|
||||||
private final List<String> input;
|
private final List<String> input;
|
||||||
|
private final Boolean returnDocuments;
|
||||||
|
private final Integer topN;
|
||||||
private final URI uri;
|
private final URI uri;
|
||||||
private final AlibabaCloudSearchRerankTaskSettings taskSettings;
|
private final AlibabaCloudSearchRerankTaskSettings taskSettings;
|
||||||
private final String model;
|
private final String model;
|
||||||
|
@ -44,6 +47,8 @@ public class AlibabaCloudSearchRerankRequest implements Request {
|
||||||
AlibabaCloudSearchAccount account,
|
AlibabaCloudSearchAccount account,
|
||||||
String query,
|
String query,
|
||||||
List<String> input,
|
List<String> input,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer topN,
|
||||||
AlibabaCloudSearchRerankModel rerankModel
|
AlibabaCloudSearchRerankModel rerankModel
|
||||||
) {
|
) {
|
||||||
Objects.requireNonNull(rerankModel);
|
Objects.requireNonNull(rerankModel);
|
||||||
|
@ -51,6 +56,8 @@ public class AlibabaCloudSearchRerankRequest implements Request {
|
||||||
this.account = Objects.requireNonNull(account);
|
this.account = Objects.requireNonNull(account);
|
||||||
this.query = Objects.requireNonNull(query);
|
this.query = Objects.requireNonNull(query);
|
||||||
this.input = Objects.requireNonNull(input);
|
this.input = Objects.requireNonNull(input);
|
||||||
|
this.returnDocuments = returnDocuments;
|
||||||
|
this.topN = topN;
|
||||||
taskSettings = rerankModel.getTaskSettings();
|
taskSettings = rerankModel.getTaskSettings();
|
||||||
model = rerankModel.getServiceSettings().getCommonSettings().modelId();
|
model = rerankModel.getServiceSettings().getCommonSettings().modelId();
|
||||||
host = rerankModel.getServiceSettings().getCommonSettings().getHost();
|
host = rerankModel.getServiceSettings().getCommonSettings().getHost();
|
||||||
|
@ -67,7 +74,8 @@ public class AlibabaCloudSearchRerankRequest implements Request {
|
||||||
HttpPost httpPost = new HttpPost(uri);
|
HttpPost httpPost = new HttpPost(uri);
|
||||||
|
|
||||||
ByteArrayEntity byteEntity = new ByteArrayEntity(
|
ByteArrayEntity byteEntity = new ByteArrayEntity(
|
||||||
Strings.toString(new AlibabaCloudSearchRerankRequestEntity(query, input, taskSettings)).getBytes(StandardCharsets.UTF_8)
|
Strings.toString(new AlibabaCloudSearchRerankRequestEntity(query, input, returnDocuments, topN, taskSettings))
|
||||||
|
.getBytes(StandardCharsets.UTF_8)
|
||||||
);
|
);
|
||||||
httpPost.setEntity(byteEntity);
|
httpPost.setEntity(byteEntity);
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch;
|
package org.elasticsearch.xpack.inference.external.request.alibabacloudsearch;
|
||||||
|
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
import org.elasticsearch.xcontent.ToXContentObject;
|
import org.elasticsearch.xcontent.ToXContentObject;
|
||||||
import org.elasticsearch.xcontent.XContentBuilder;
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankTaskSettings;
|
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankTaskSettings;
|
||||||
|
@ -15,9 +16,13 @@ import java.io.IOException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
public record AlibabaCloudSearchRerankRequestEntity(String query, List<String> input, AlibabaCloudSearchRerankTaskSettings taskSettings)
|
public record AlibabaCloudSearchRerankRequestEntity(
|
||||||
implements
|
String query,
|
||||||
ToXContentObject {
|
List<String> input,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer topN,
|
||||||
|
AlibabaCloudSearchRerankTaskSettings taskSettings
|
||||||
|
) implements ToXContentObject {
|
||||||
|
|
||||||
private static final String SEARCH_QUERY = "query";
|
private static final String SEARCH_QUERY = "query";
|
||||||
private static final String TEXTS_FIELD = "docs";
|
private static final String TEXTS_FIELD = "docs";
|
||||||
|
|
|
@ -11,6 +11,7 @@ import org.apache.http.client.methods.HttpPost;
|
||||||
import org.apache.http.client.utils.URIBuilder;
|
import org.apache.http.client.utils.URIBuilder;
|
||||||
import org.apache.http.entity.ByteArrayEntity;
|
import org.apache.http.entity.ByteArrayEntity;
|
||||||
import org.elasticsearch.common.Strings;
|
import org.elasticsearch.common.Strings;
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
import org.elasticsearch.xpack.inference.external.cohere.CohereAccount;
|
import org.elasticsearch.xpack.inference.external.cohere.CohereAccount;
|
||||||
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
|
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
|
||||||
import org.elasticsearch.xpack.inference.external.request.Request;
|
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||||
|
@ -28,16 +29,26 @@ public class CohereRerankRequest extends CohereRequest {
|
||||||
private final CohereAccount account;
|
private final CohereAccount account;
|
||||||
private final String query;
|
private final String query;
|
||||||
private final List<String> input;
|
private final List<String> input;
|
||||||
|
private final Boolean returnDocuments;
|
||||||
|
private final Integer topN;
|
||||||
private final CohereRerankTaskSettings taskSettings;
|
private final CohereRerankTaskSettings taskSettings;
|
||||||
private final String model;
|
private final String model;
|
||||||
private final String inferenceEntityId;
|
private final String inferenceEntityId;
|
||||||
|
|
||||||
public CohereRerankRequest(String query, List<String> input, CohereRerankModel model) {
|
public CohereRerankRequest(
|
||||||
|
String query,
|
||||||
|
List<String> input,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer topN,
|
||||||
|
CohereRerankModel model
|
||||||
|
) {
|
||||||
Objects.requireNonNull(model);
|
Objects.requireNonNull(model);
|
||||||
|
|
||||||
this.account = CohereAccount.of(model, CohereRerankRequest::buildDefaultUri);
|
this.account = CohereAccount.of(model, CohereRerankRequest::buildDefaultUri);
|
||||||
this.input = Objects.requireNonNull(input);
|
this.input = Objects.requireNonNull(input);
|
||||||
this.query = Objects.requireNonNull(query);
|
this.query = Objects.requireNonNull(query);
|
||||||
|
this.returnDocuments = returnDocuments;
|
||||||
|
this.topN = topN;
|
||||||
taskSettings = model.getTaskSettings();
|
taskSettings = model.getTaskSettings();
|
||||||
this.model = model.getServiceSettings().modelId();
|
this.model = model.getServiceSettings().modelId();
|
||||||
inferenceEntityId = model.getInferenceEntityId();
|
inferenceEntityId = model.getInferenceEntityId();
|
||||||
|
@ -48,7 +59,8 @@ public class CohereRerankRequest extends CohereRequest {
|
||||||
HttpPost httpPost = new HttpPost(account.uri());
|
HttpPost httpPost = new HttpPost(account.uri());
|
||||||
|
|
||||||
ByteArrayEntity byteEntity = new ByteArrayEntity(
|
ByteArrayEntity byteEntity = new ByteArrayEntity(
|
||||||
Strings.toString(new CohereRerankRequestEntity(query, input, taskSettings, model)).getBytes(StandardCharsets.UTF_8)
|
Strings.toString(new CohereRerankRequestEntity(query, input, returnDocuments, topN, taskSettings, model))
|
||||||
|
.getBytes(StandardCharsets.UTF_8)
|
||||||
);
|
);
|
||||||
httpPost.setEntity(byteEntity);
|
httpPost.setEntity(byteEntity);
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.external.request.cohere;
|
package org.elasticsearch.xpack.inference.external.request.cohere;
|
||||||
|
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
import org.elasticsearch.xcontent.ToXContentObject;
|
import org.elasticsearch.xcontent.ToXContentObject;
|
||||||
import org.elasticsearch.xcontent.XContentBuilder;
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
|
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
|
||||||
|
@ -15,9 +16,14 @@ import java.io.IOException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
public record CohereRerankRequestEntity(String model, String query, List<String> documents, CohereRerankTaskSettings taskSettings)
|
public record CohereRerankRequestEntity(
|
||||||
implements
|
String model,
|
||||||
ToXContentObject {
|
String query,
|
||||||
|
List<String> documents,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer topN,
|
||||||
|
CohereRerankTaskSettings taskSettings
|
||||||
|
) implements ToXContentObject {
|
||||||
|
|
||||||
private static final String DOCUMENTS_FIELD = "documents";
|
private static final String DOCUMENTS_FIELD = "documents";
|
||||||
private static final String QUERY_FIELD = "query";
|
private static final String QUERY_FIELD = "query";
|
||||||
|
@ -29,8 +35,15 @@ public record CohereRerankRequestEntity(String model, String query, List<String>
|
||||||
Objects.requireNonNull(taskSettings);
|
Objects.requireNonNull(taskSettings);
|
||||||
}
|
}
|
||||||
|
|
||||||
public CohereRerankRequestEntity(String query, List<String> input, CohereRerankTaskSettings taskSettings, String model) {
|
public CohereRerankRequestEntity(
|
||||||
this(model, query, input, taskSettings);
|
String query,
|
||||||
|
List<String> input,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer topN,
|
||||||
|
CohereRerankTaskSettings taskSettings,
|
||||||
|
String model
|
||||||
|
) {
|
||||||
|
this(model, query, input, returnDocuments, topN, taskSettings);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -41,11 +54,17 @@ public record CohereRerankRequestEntity(String model, String query, List<String>
|
||||||
builder.field(QUERY_FIELD, query);
|
builder.field(QUERY_FIELD, query);
|
||||||
builder.field(DOCUMENTS_FIELD, documents);
|
builder.field(DOCUMENTS_FIELD, documents);
|
||||||
|
|
||||||
if (taskSettings.getDoesReturnDocuments() != null) {
|
// prefer the root level return_documents over task settings
|
||||||
|
if (returnDocuments != null) {
|
||||||
|
builder.field(CohereRerankTaskSettings.RETURN_DOCUMENTS, returnDocuments);
|
||||||
|
} else if (taskSettings.getDoesReturnDocuments() != null) {
|
||||||
builder.field(CohereRerankTaskSettings.RETURN_DOCUMENTS, taskSettings.getDoesReturnDocuments());
|
builder.field(CohereRerankTaskSettings.RETURN_DOCUMENTS, taskSettings.getDoesReturnDocuments());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (taskSettings.getTopNDocumentsOnly() != null) {
|
// prefer the root level top_n over task settings
|
||||||
|
if (topN != null) {
|
||||||
|
builder.field(CohereRerankTaskSettings.TOP_N_DOCS_ONLY, topN);
|
||||||
|
} else if (taskSettings.getTopNDocumentsOnly() != null) {
|
||||||
builder.field(CohereRerankTaskSettings.TOP_N_DOCS_ONLY, taskSettings.getTopNDocumentsOnly());
|
builder.field(CohereRerankTaskSettings.TOP_N_DOCS_ONLY, taskSettings.getTopNDocumentsOnly());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ import org.apache.http.HttpHeaders;
|
||||||
import org.apache.http.client.methods.HttpPost;
|
import org.apache.http.client.methods.HttpPost;
|
||||||
import org.apache.http.entity.ByteArrayEntity;
|
import org.apache.http.entity.ByteArrayEntity;
|
||||||
import org.elasticsearch.common.Strings;
|
import org.elasticsearch.common.Strings;
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
import org.elasticsearch.xcontent.XContentType;
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
|
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
|
||||||
import org.elasticsearch.xpack.inference.external.request.Request;
|
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||||
|
@ -29,10 +30,22 @@ public class GoogleVertexAiRerankRequest implements GoogleVertexAiRequest {
|
||||||
|
|
||||||
private final List<String> input;
|
private final List<String> input;
|
||||||
|
|
||||||
public GoogleVertexAiRerankRequest(String query, List<String> input, GoogleVertexAiRerankModel model) {
|
private final Boolean returnDocuments;
|
||||||
|
|
||||||
|
private final Integer topN;
|
||||||
|
|
||||||
|
public GoogleVertexAiRerankRequest(
|
||||||
|
String query,
|
||||||
|
List<String> input,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer topN,
|
||||||
|
GoogleVertexAiRerankModel model
|
||||||
|
) {
|
||||||
this.model = Objects.requireNonNull(model);
|
this.model = Objects.requireNonNull(model);
|
||||||
this.query = Objects.requireNonNull(query);
|
this.query = Objects.requireNonNull(query);
|
||||||
this.input = Objects.requireNonNull(input);
|
this.input = Objects.requireNonNull(input);
|
||||||
|
this.returnDocuments = returnDocuments;
|
||||||
|
this.topN = topN;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -41,7 +54,13 @@ public class GoogleVertexAiRerankRequest implements GoogleVertexAiRequest {
|
||||||
|
|
||||||
ByteArrayEntity byteEntity = new ByteArrayEntity(
|
ByteArrayEntity byteEntity = new ByteArrayEntity(
|
||||||
Strings.toString(
|
Strings.toString(
|
||||||
new GoogleVertexAiRerankRequestEntity(query, input, model.getServiceSettings().modelId(), model.getTaskSettings().topN())
|
new GoogleVertexAiRerankRequestEntity(
|
||||||
|
query,
|
||||||
|
input,
|
||||||
|
returnDocuments,
|
||||||
|
topN != null ? topN : model.getTaskSettings().topN(),
|
||||||
|
model.getServiceSettings().modelId()
|
||||||
|
)
|
||||||
).getBytes(StandardCharsets.UTF_8)
|
).getBytes(StandardCharsets.UTF_8)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
@ -15,9 +15,13 @@ import java.io.IOException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
public record GoogleVertexAiRerankRequestEntity(String query, List<String> inputs, @Nullable String model, @Nullable Integer topN)
|
public record GoogleVertexAiRerankRequestEntity(
|
||||||
implements
|
String query,
|
||||||
ToXContentObject {
|
List<String> inputs,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer topN,
|
||||||
|
@Nullable String model
|
||||||
|
) implements ToXContentObject {
|
||||||
|
|
||||||
private static final String MODEL_FIELD = "model";
|
private static final String MODEL_FIELD = "model";
|
||||||
private static final String QUERY_FIELD = "query";
|
private static final String QUERY_FIELD = "query";
|
||||||
|
@ -26,6 +30,7 @@ public record GoogleVertexAiRerankRequestEntity(String query, List<String> input
|
||||||
|
|
||||||
private static final String CONTENT_FIELD = "content";
|
private static final String CONTENT_FIELD = "content";
|
||||||
private static final String TOP_N_FIELD = "topN";
|
private static final String TOP_N_FIELD = "topN";
|
||||||
|
private static final String IGNORE_RECORD_DETAILS_IN_RESPONSE_FIELD = "ignoreRecordDetailsInResponse";
|
||||||
|
|
||||||
public GoogleVertexAiRerankRequestEntity {
|
public GoogleVertexAiRerankRequestEntity {
|
||||||
Objects.requireNonNull(query);
|
Objects.requireNonNull(query);
|
||||||
|
@ -57,10 +62,16 @@ public record GoogleVertexAiRerankRequestEntity(String query, List<String> input
|
||||||
|
|
||||||
builder.endArray();
|
builder.endArray();
|
||||||
|
|
||||||
|
// prefer the root level top_n over task settings
|
||||||
if (topN != null) {
|
if (topN != null) {
|
||||||
builder.field(TOP_N_FIELD, topN);
|
builder.field(TOP_N_FIELD, topN);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (returnDocuments != null) {
|
||||||
|
// if returnDocuments = true, we do not want to ignore record details
|
||||||
|
builder.field(IGNORE_RECORD_DETAILS_IN_RESPONSE_FIELD, returnDocuments == Boolean.TRUE ? Boolean.FALSE : Boolean.TRUE);
|
||||||
|
}
|
||||||
|
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
|
|
||||||
return builder;
|
return builder;
|
||||||
|
|
|
@ -11,6 +11,7 @@ import org.apache.http.client.methods.HttpPost;
|
||||||
import org.apache.http.client.utils.URIBuilder;
|
import org.apache.http.client.utils.URIBuilder;
|
||||||
import org.apache.http.entity.ByteArrayEntity;
|
import org.apache.http.entity.ByteArrayEntity;
|
||||||
import org.elasticsearch.common.Strings;
|
import org.elasticsearch.common.Strings;
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
import org.elasticsearch.xpack.inference.external.jinaai.JinaAIAccount;
|
import org.elasticsearch.xpack.inference.external.jinaai.JinaAIAccount;
|
||||||
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
|
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
|
||||||
import org.elasticsearch.xpack.inference.external.request.Request;
|
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||||
|
@ -28,16 +29,26 @@ public class JinaAIRerankRequest extends JinaAIRequest {
|
||||||
private final JinaAIAccount account;
|
private final JinaAIAccount account;
|
||||||
private final String query;
|
private final String query;
|
||||||
private final List<String> input;
|
private final List<String> input;
|
||||||
|
private final Boolean returnDocuments;
|
||||||
|
private final Integer topN;
|
||||||
private final JinaAIRerankTaskSettings taskSettings;
|
private final JinaAIRerankTaskSettings taskSettings;
|
||||||
private final String model;
|
private final String model;
|
||||||
private final String inferenceEntityId;
|
private final String inferenceEntityId;
|
||||||
|
|
||||||
public JinaAIRerankRequest(String query, List<String> input, JinaAIRerankModel model) {
|
public JinaAIRerankRequest(
|
||||||
|
String query,
|
||||||
|
List<String> input,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer topN,
|
||||||
|
JinaAIRerankModel model
|
||||||
|
) {
|
||||||
Objects.requireNonNull(model);
|
Objects.requireNonNull(model);
|
||||||
|
|
||||||
this.account = JinaAIAccount.of(model, JinaAIRerankRequest::buildDefaultUri);
|
this.account = JinaAIAccount.of(model, JinaAIRerankRequest::buildDefaultUri);
|
||||||
this.input = Objects.requireNonNull(input);
|
this.input = Objects.requireNonNull(input);
|
||||||
this.query = Objects.requireNonNull(query);
|
this.query = Objects.requireNonNull(query);
|
||||||
|
this.returnDocuments = returnDocuments;
|
||||||
|
this.topN = topN;
|
||||||
taskSettings = model.getTaskSettings();
|
taskSettings = model.getTaskSettings();
|
||||||
this.model = model.getServiceSettings().modelId();
|
this.model = model.getServiceSettings().modelId();
|
||||||
inferenceEntityId = model.getInferenceEntityId();
|
inferenceEntityId = model.getInferenceEntityId();
|
||||||
|
@ -48,7 +59,8 @@ public class JinaAIRerankRequest extends JinaAIRequest {
|
||||||
HttpPost httpPost = new HttpPost(account.uri());
|
HttpPost httpPost = new HttpPost(account.uri());
|
||||||
|
|
||||||
ByteArrayEntity byteEntity = new ByteArrayEntity(
|
ByteArrayEntity byteEntity = new ByteArrayEntity(
|
||||||
Strings.toString(new JinaAIRerankRequestEntity(query, input, taskSettings, model)).getBytes(StandardCharsets.UTF_8)
|
Strings.toString(new JinaAIRerankRequestEntity(query, input, returnDocuments, topN, taskSettings, model))
|
||||||
|
.getBytes(StandardCharsets.UTF_8)
|
||||||
);
|
);
|
||||||
httpPost.setEntity(byteEntity);
|
httpPost.setEntity(byteEntity);
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.external.request.jinaai;
|
package org.elasticsearch.xpack.inference.external.request.jinaai;
|
||||||
|
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
import org.elasticsearch.xcontent.ToXContentObject;
|
import org.elasticsearch.xcontent.ToXContentObject;
|
||||||
import org.elasticsearch.xcontent.XContentBuilder;
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings;
|
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings;
|
||||||
|
@ -15,9 +16,14 @@ import java.io.IOException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
public record JinaAIRerankRequestEntity(String model, String query, List<String> documents, JinaAIRerankTaskSettings taskSettings)
|
public record JinaAIRerankRequestEntity(
|
||||||
implements
|
String model,
|
||||||
ToXContentObject {
|
String query,
|
||||||
|
List<String> documents,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer topN,
|
||||||
|
JinaAIRerankTaskSettings taskSettings
|
||||||
|
) implements ToXContentObject {
|
||||||
|
|
||||||
private static final String DOCUMENTS_FIELD = "documents";
|
private static final String DOCUMENTS_FIELD = "documents";
|
||||||
private static final String QUERY_FIELD = "query";
|
private static final String QUERY_FIELD = "query";
|
||||||
|
@ -30,8 +36,15 @@ public record JinaAIRerankRequestEntity(String model, String query, List<String>
|
||||||
Objects.requireNonNull(taskSettings);
|
Objects.requireNonNull(taskSettings);
|
||||||
}
|
}
|
||||||
|
|
||||||
public JinaAIRerankRequestEntity(String query, List<String> input, JinaAIRerankTaskSettings taskSettings, String model) {
|
public JinaAIRerankRequestEntity(
|
||||||
this(model, query, input, taskSettings != null ? taskSettings : JinaAIRerankTaskSettings.EMPTY_SETTINGS);
|
String query,
|
||||||
|
List<String> input,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer topN,
|
||||||
|
JinaAIRerankTaskSettings taskSettings,
|
||||||
|
String model
|
||||||
|
) {
|
||||||
|
this(model, query, input, returnDocuments, topN, taskSettings != null ? taskSettings : JinaAIRerankTaskSettings.EMPTY_SETTINGS);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -42,13 +55,18 @@ public record JinaAIRerankRequestEntity(String model, String query, List<String>
|
||||||
builder.field(QUERY_FIELD, query);
|
builder.field(QUERY_FIELD, query);
|
||||||
builder.field(DOCUMENTS_FIELD, documents);
|
builder.field(DOCUMENTS_FIELD, documents);
|
||||||
|
|
||||||
if (taskSettings.getTopNDocumentsOnly() != null) {
|
// prefer the root level top_n over task settings
|
||||||
|
if (topN != null) {
|
||||||
|
builder.field(JinaAIRerankTaskSettings.TOP_N_DOCS_ONLY, topN);
|
||||||
|
} else if (taskSettings.getTopNDocumentsOnly() != null) {
|
||||||
builder.field(JinaAIRerankTaskSettings.TOP_N_DOCS_ONLY, taskSettings.getTopNDocumentsOnly());
|
builder.field(JinaAIRerankTaskSettings.TOP_N_DOCS_ONLY, taskSettings.getTopNDocumentsOnly());
|
||||||
}
|
}
|
||||||
|
|
||||||
var return_documents = taskSettings.getDoesReturnDocuments();
|
// prefer the root level return_documents over task settings
|
||||||
if (return_documents != null) {
|
if (returnDocuments != null) {
|
||||||
builder.field(JinaAIRerankTaskSettings.RETURN_DOCUMENTS, return_documents);
|
builder.field(JinaAIRerankTaskSettings.RETURN_DOCUMENTS, returnDocuments);
|
||||||
|
} else if (taskSettings.getDoesReturnDocuments() != null) {
|
||||||
|
builder.field(JinaAIRerankTaskSettings.RETURN_DOCUMENTS, taskSettings.getDoesReturnDocuments());
|
||||||
}
|
}
|
||||||
|
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
|
|
|
@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference.external.request.voyageai;
|
||||||
import org.apache.http.client.methods.HttpPost;
|
import org.apache.http.client.methods.HttpPost;
|
||||||
import org.apache.http.entity.ByteArrayEntity;
|
import org.apache.http.entity.ByteArrayEntity;
|
||||||
import org.elasticsearch.common.Strings;
|
import org.elasticsearch.common.Strings;
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
|
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
|
||||||
import org.elasticsearch.xpack.inference.external.request.Request;
|
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||||
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;
|
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel;
|
||||||
|
@ -23,13 +24,23 @@ public class VoyageAIRerankRequest extends VoyageAIRequest {
|
||||||
|
|
||||||
private final String query;
|
private final String query;
|
||||||
private final List<String> input;
|
private final List<String> input;
|
||||||
|
private final Boolean returnDocuments;
|
||||||
|
private final Integer topN;
|
||||||
private final VoyageAIRerankModel model;
|
private final VoyageAIRerankModel model;
|
||||||
|
|
||||||
public VoyageAIRerankRequest(String query, List<String> input, VoyageAIRerankModel model) {
|
public VoyageAIRerankRequest(
|
||||||
|
String query,
|
||||||
|
List<String> input,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer topN,
|
||||||
|
VoyageAIRerankModel model
|
||||||
|
) {
|
||||||
this.model = Objects.requireNonNull(model);
|
this.model = Objects.requireNonNull(model);
|
||||||
|
|
||||||
this.input = Objects.requireNonNull(input);
|
this.input = Objects.requireNonNull(input);
|
||||||
this.query = Objects.requireNonNull(query);
|
this.query = Objects.requireNonNull(query);
|
||||||
|
this.returnDocuments = returnDocuments;
|
||||||
|
this.topN = topN;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -37,8 +48,16 @@ public class VoyageAIRerankRequest extends VoyageAIRequest {
|
||||||
HttpPost httpPost = new HttpPost(model.uri());
|
HttpPost httpPost = new HttpPost(model.uri());
|
||||||
|
|
||||||
ByteArrayEntity byteEntity = new ByteArrayEntity(
|
ByteArrayEntity byteEntity = new ByteArrayEntity(
|
||||||
Strings.toString(new VoyageAIRerankRequestEntity(query, input, model.getTaskSettings(), model.getServiceSettings().modelId()))
|
Strings.toString(
|
||||||
.getBytes(StandardCharsets.UTF_8)
|
new VoyageAIRerankRequestEntity(
|
||||||
|
query,
|
||||||
|
input,
|
||||||
|
returnDocuments,
|
||||||
|
topN,
|
||||||
|
model.getTaskSettings(),
|
||||||
|
model.getServiceSettings().modelId()
|
||||||
|
)
|
||||||
|
).getBytes(StandardCharsets.UTF_8)
|
||||||
);
|
);
|
||||||
httpPost.setEntity(byteEntity);
|
httpPost.setEntity(byteEntity);
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.external.request.voyageai;
|
package org.elasticsearch.xpack.inference.external.request.voyageai;
|
||||||
|
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
import org.elasticsearch.xcontent.ToXContentObject;
|
import org.elasticsearch.xcontent.ToXContentObject;
|
||||||
import org.elasticsearch.xcontent.XContentBuilder;
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings;
|
import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankTaskSettings;
|
||||||
|
@ -15,15 +16,19 @@ import java.io.IOException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
public record VoyageAIRerankRequestEntity(String model, String query, List<String> documents, VoyageAIRerankTaskSettings taskSettings)
|
public record VoyageAIRerankRequestEntity(
|
||||||
implements
|
String model,
|
||||||
ToXContentObject {
|
String query,
|
||||||
|
List<String> documents,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer topN,
|
||||||
|
VoyageAIRerankTaskSettings taskSettings
|
||||||
|
) implements ToXContentObject {
|
||||||
|
|
||||||
private static final String DOCUMENTS_FIELD = "documents";
|
private static final String DOCUMENTS_FIELD = "documents";
|
||||||
private static final String QUERY_FIELD = "query";
|
private static final String QUERY_FIELD = "query";
|
||||||
private static final String MODEL_FIELD = "model";
|
private static final String MODEL_FIELD = "model";
|
||||||
public static final String TRUNCATION_FIELD = "truncation";
|
public static final String TRUNCATION_FIELD = "truncation";
|
||||||
public static final String RETURN_DOCUMENTS_FIELD = "return_documents";
|
|
||||||
|
|
||||||
public VoyageAIRerankRequestEntity {
|
public VoyageAIRerankRequestEntity {
|
||||||
Objects.requireNonNull(query);
|
Objects.requireNonNull(query);
|
||||||
|
@ -32,8 +37,15 @@ public record VoyageAIRerankRequestEntity(String model, String query, List<Strin
|
||||||
Objects.requireNonNull(taskSettings);
|
Objects.requireNonNull(taskSettings);
|
||||||
}
|
}
|
||||||
|
|
||||||
public VoyageAIRerankRequestEntity(String query, List<String> input, VoyageAIRerankTaskSettings taskSettings, String model) {
|
public VoyageAIRerankRequestEntity(
|
||||||
this(model, query, input, taskSettings != null ? taskSettings : VoyageAIRerankTaskSettings.EMPTY_SETTINGS);
|
String query,
|
||||||
|
List<String> input,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer topN,
|
||||||
|
VoyageAIRerankTaskSettings taskSettings,
|
||||||
|
String model
|
||||||
|
) {
|
||||||
|
this(model, query, input, returnDocuments, topN, taskSettings != null ? taskSettings : VoyageAIRerankTaskSettings.EMPTY_SETTINGS);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -44,11 +56,17 @@ public record VoyageAIRerankRequestEntity(String model, String query, List<Strin
|
||||||
builder.field(QUERY_FIELD, query);
|
builder.field(QUERY_FIELD, query);
|
||||||
builder.field(DOCUMENTS_FIELD, documents);
|
builder.field(DOCUMENTS_FIELD, documents);
|
||||||
|
|
||||||
if (taskSettings.getDoesReturnDocuments() != null) {
|
// prefer the root level return_documents over task settings
|
||||||
|
if (returnDocuments != null) {
|
||||||
|
builder.field(VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, returnDocuments);
|
||||||
|
} else if (taskSettings.getDoesReturnDocuments() != null) {
|
||||||
builder.field(VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, taskSettings.getDoesReturnDocuments());
|
builder.field(VoyageAIRerankTaskSettings.RETURN_DOCUMENTS, taskSettings.getDoesReturnDocuments());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (taskSettings.getTopKDocumentsOnly() != null) {
|
// prefer the root level top_n over task settings
|
||||||
|
if (topN != null) {
|
||||||
|
builder.field(VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, topN);
|
||||||
|
} else if (taskSettings.getTopKDocumentsOnly() != null) {
|
||||||
builder.field(VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, taskSettings.getTopKDocumentsOnly());
|
builder.field(VoyageAIRerankTaskSettings.TOP_K_DOCS_ONLY, taskSettings.getTopKDocumentsOnly());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -103,10 +103,6 @@ public class GoogleVertexAiRerankResponseEntity {
|
||||||
return parseList(parser, (listParser, index) -> {
|
return parseList(parser, (listParser, index) -> {
|
||||||
var parsedRankedDoc = RankedDoc.parse(parser);
|
var parsedRankedDoc = RankedDoc.parse(parser);
|
||||||
|
|
||||||
if (parsedRankedDoc.content == null) {
|
|
||||||
throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDoc.CONTENT.getPreferredName()));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (parsedRankedDoc.score == null) {
|
if (parsedRankedDoc.score == null) {
|
||||||
throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDoc.SCORE.getPreferredName()));
|
throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDoc.SCORE.getPreferredName()));
|
||||||
}
|
}
|
||||||
|
|
|
@ -232,6 +232,8 @@ public class SemanticQueryBuilder extends AbstractQueryBuilder<SemanticQueryBuil
|
||||||
TaskType.ANY,
|
TaskType.ANY,
|
||||||
inferenceId,
|
inferenceId,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(query),
|
List.of(query),
|
||||||
Map.of(),
|
Map.of(),
|
||||||
InputType.INTERNAL_SEARCH,
|
InputType.INTERNAL_SEARCH,
|
||||||
|
|
|
@ -153,6 +153,8 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContext extends RankFe
|
||||||
TaskType.RERANK,
|
TaskType.RERANK,
|
||||||
inferenceId,
|
inferenceId,
|
||||||
inferenceText,
|
inferenceText,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
docFeatures,
|
docFeatures,
|
||||||
Map.of(),
|
Map.of(),
|
||||||
InputType.INTERNAL_SEARCH,
|
InputType.INTERNAL_SEARCH,
|
||||||
|
|
|
@ -60,6 +60,8 @@ public abstract class SenderService implements InferenceService {
|
||||||
public void infer(
|
public void infer(
|
||||||
Model model,
|
Model model,
|
||||||
@Nullable String query,
|
@Nullable String query,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer topN,
|
||||||
List<String> input,
|
List<String> input,
|
||||||
boolean stream,
|
boolean stream,
|
||||||
Map<String, Object> taskSettings,
|
Map<String, Object> taskSettings,
|
||||||
|
@ -68,7 +70,7 @@ public abstract class SenderService implements InferenceService {
|
||||||
ActionListener<InferenceServiceResults> listener
|
ActionListener<InferenceServiceResults> listener
|
||||||
) {
|
) {
|
||||||
init();
|
init();
|
||||||
var inferenceInput = createInput(this, model, input, inputType, query, stream);
|
var inferenceInput = createInput(this, model, input, inputType, query, returnDocuments, topN, stream);
|
||||||
doInfer(model, inferenceInput, taskSettings, timeout, listener);
|
doInfer(model, inferenceInput, taskSettings, timeout, listener);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,11 +80,20 @@ public abstract class SenderService implements InferenceService {
|
||||||
List<String> input,
|
List<String> input,
|
||||||
InputType inputType,
|
InputType inputType,
|
||||||
@Nullable String query,
|
@Nullable String query,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer topN,
|
||||||
boolean stream
|
boolean stream
|
||||||
) {
|
) {
|
||||||
return switch (model.getTaskType()) {
|
return switch (model.getTaskType()) {
|
||||||
case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(input, stream);
|
case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(input, stream);
|
||||||
case RERANK -> new QueryAndDocsInputs(query, input, stream);
|
case RERANK -> {
|
||||||
|
ValidationException validationException = new ValidationException();
|
||||||
|
service.validateRerankParameters(returnDocuments, topN, validationException);
|
||||||
|
if (validationException.validationErrors().isEmpty() == false) {
|
||||||
|
throw validationException;
|
||||||
|
}
|
||||||
|
yield new QueryAndDocsInputs(query, input, returnDocuments, topN, stream);
|
||||||
|
}
|
||||||
case TEXT_EMBEDDING, SPARSE_EMBEDDING -> {
|
case TEXT_EMBEDDING, SPARSE_EMBEDDING -> {
|
||||||
ValidationException validationException = new ValidationException();
|
ValidationException validationException = new ValidationException();
|
||||||
service.validateInputType(inputType, model, validationException);
|
service.validateInputType(inputType, model, validationException);
|
||||||
|
@ -141,6 +152,8 @@ public abstract class SenderService implements InferenceService {
|
||||||
|
|
||||||
protected abstract void validateInputType(InputType inputType, Model model, ValidationException validationException);
|
protected abstract void validateInputType(InputType inputType, Model model, ValidationException validationException);
|
||||||
|
|
||||||
|
protected void validateRerankParameters(Boolean returnDocuments, Integer topN, ValidationException validationException) {}
|
||||||
|
|
||||||
protected abstract void doUnifiedCompletionInfer(
|
protected abstract void doUnifiedCompletionInfer(
|
||||||
Model model,
|
Model model,
|
||||||
UnifiedChatInput inputs,
|
UnifiedChatInput inputs,
|
||||||
|
|
|
@ -735,6 +735,8 @@ public final class ServiceUtils {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(TEST_EMBEDDING_INPUT),
|
List.of(TEST_EMBEDDING_INPUT),
|
||||||
false,
|
false,
|
||||||
Map.of(),
|
Map.of(),
|
||||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.action.ActionListener;
|
||||||
import org.elasticsearch.common.ValidationException;
|
import org.elasticsearch.common.ValidationException;
|
||||||
import org.elasticsearch.common.util.LazyInitializable;
|
import org.elasticsearch.common.util.LazyInitializable;
|
||||||
import org.elasticsearch.core.Nullable;
|
import org.elasticsearch.core.Nullable;
|
||||||
|
import org.elasticsearch.core.Strings;
|
||||||
import org.elasticsearch.core.TimeValue;
|
import org.elasticsearch.core.TimeValue;
|
||||||
import org.elasticsearch.inference.ChunkedInference;
|
import org.elasticsearch.inference.ChunkedInference;
|
||||||
import org.elasticsearch.inference.ChunkingSettings;
|
import org.elasticsearch.inference.ChunkingSettings;
|
||||||
|
@ -300,6 +301,24 @@ public class AlibabaCloudSearchService extends SenderService {
|
||||||
ServiceUtils.validateInputTypeAgainstAllowlist(inputType, VALID_INPUT_TYPE_VALUES, SERVICE_NAME, validationException);
|
ServiceUtils.validateInputTypeAgainstAllowlist(inputType, VALID_INPUT_TYPE_VALUES, SERVICE_NAME, validationException);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void validateRerankParameters(Boolean returnDocuments, Integer topN, ValidationException validationException) {
|
||||||
|
if (returnDocuments != null) {
|
||||||
|
validationException.addValidationError(
|
||||||
|
Strings.format(
|
||||||
|
"Invalid return_documents [%s]. The return_documents option is not supported by this service",
|
||||||
|
returnDocuments
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (topN != null) {
|
||||||
|
validationException.addValidationError(
|
||||||
|
Strings.format("Invalid top_n [%s]. The top_n option is not supported by this service", topN)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void doChunkedInfer(
|
protected void doChunkedInfer(
|
||||||
Model model,
|
Model model,
|
||||||
|
|
|
@ -620,6 +620,8 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
|
||||||
public void infer(
|
public void infer(
|
||||||
Model model,
|
Model model,
|
||||||
@Nullable String query,
|
@Nullable String query,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer topN,
|
||||||
List<String> input,
|
List<String> input,
|
||||||
boolean stream,
|
boolean stream,
|
||||||
Map<String, Object> taskSettings,
|
Map<String, Object> taskSettings,
|
||||||
|
@ -632,7 +634,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
|
||||||
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
|
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
|
||||||
inferTextEmbedding(esModel, input, inputType, timeout, listener);
|
inferTextEmbedding(esModel, input, inputType, timeout, listener);
|
||||||
} else if (TaskType.RERANK.equals(taskType)) {
|
} else if (TaskType.RERANK.equals(taskType)) {
|
||||||
inferRerank(esModel, query, input, inputType, timeout, taskSettings, listener);
|
inferRerank(esModel, query, input, returnDocuments, topN, inputType, timeout, taskSettings, listener);
|
||||||
} else if (TaskType.SPARSE_EMBEDDING.equals(taskType)) {
|
} else if (TaskType.SPARSE_EMBEDDING.equals(taskType)) {
|
||||||
inferSparseEmbedding(esModel, input, inputType, timeout, listener);
|
inferSparseEmbedding(esModel, input, inputType, timeout, listener);
|
||||||
} else {
|
} else {
|
||||||
|
@ -693,6 +695,8 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
|
||||||
ElasticsearchInternalModel model,
|
ElasticsearchInternalModel model,
|
||||||
String query,
|
String query,
|
||||||
List<String> inputs,
|
List<String> inputs,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer topN,
|
||||||
InputType inputType,
|
InputType inputType,
|
||||||
TimeValue timeout,
|
TimeValue timeout,
|
||||||
Map<String, Object> requestTaskSettings,
|
Map<String, Object> requestTaskSettings,
|
||||||
|
@ -701,7 +705,9 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
|
||||||
var request = buildInferenceRequest(model.mlNodeDeploymentId(), new TextSimilarityConfigUpdate(query), inputs, inputType, timeout);
|
var request = buildInferenceRequest(model.mlNodeDeploymentId(), new TextSimilarityConfigUpdate(query), inputs, inputType, timeout);
|
||||||
|
|
||||||
var returnDocs = Boolean.TRUE;
|
var returnDocs = Boolean.TRUE;
|
||||||
if (model.getTaskSettings() instanceof RerankTaskSettings modelSettings) {
|
if (returnDocuments != null) {
|
||||||
|
returnDocs = returnDocuments;
|
||||||
|
} else if (model.getTaskSettings() instanceof RerankTaskSettings modelSettings) {
|
||||||
var requestSettings = RerankTaskSettings.fromMap(requestTaskSettings);
|
var requestSettings = RerankTaskSettings.fromMap(requestTaskSettings);
|
||||||
returnDocs = RerankTaskSettings.of(modelSettings, requestSettings).returnDocuments();
|
returnDocs = RerankTaskSettings.of(modelSettings, requestSettings).returnDocuments();
|
||||||
}
|
}
|
||||||
|
@ -709,7 +715,9 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
|
||||||
Function<Integer, String> inputSupplier = returnDocs == Boolean.TRUE ? inputs::get : i -> null;
|
Function<Integer, String> inputSupplier = returnDocs == Boolean.TRUE ? inputs::get : i -> null;
|
||||||
|
|
||||||
ActionListener<InferModelAction.Response> mlResultsListener = listener.delegateFailureAndWrap(
|
ActionListener<InferModelAction.Response> mlResultsListener = listener.delegateFailureAndWrap(
|
||||||
(l, inferenceResult) -> l.onResponse(textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier))
|
(l, inferenceResult) -> l.onResponse(
|
||||||
|
textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier, topN)
|
||||||
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
var maybeDeployListener = mlResultsListener.delegateResponse(
|
var maybeDeployListener = mlResultsListener.delegateResponse(
|
||||||
|
@ -824,7 +832,8 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
|
||||||
|
|
||||||
private RankedDocsResults textSimilarityResultsToRankedDocs(
|
private RankedDocsResults textSimilarityResultsToRankedDocs(
|
||||||
List<? extends InferenceResults> results,
|
List<? extends InferenceResults> results,
|
||||||
Function<Integer, String> inputSupplier
|
Function<Integer, String> inputSupplier,
|
||||||
|
@Nullable Integer topN
|
||||||
) {
|
) {
|
||||||
List<RankedDocsResults.RankedDoc> rankings = new ArrayList<>(results.size());
|
List<RankedDocsResults.RankedDoc> rankings = new ArrayList<>(results.size());
|
||||||
for (int i = 0; i < results.size(); i++) {
|
for (int i = 0; i < results.size(); i++) {
|
||||||
|
@ -851,7 +860,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
|
||||||
}
|
}
|
||||||
|
|
||||||
Collections.sort(rankings);
|
Collections.sort(rankings);
|
||||||
return new RankedDocsResults(rankings);
|
return new RankedDocsResults(topN != null ? rankings.subList(0, topN) : rankings);
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<DefaultConfigId> defaultConfigIds() {
|
public List<DefaultConfigId> defaultConfigIds() {
|
||||||
|
|
|
@ -30,6 +30,8 @@ public class SimpleServiceIntegrationValidator implements ServiceIntegrationVali
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
model.getTaskType().equals(TaskType.RERANK) ? QUERY : null,
|
model.getTaskType().equals(TaskType.RERANK) ? QUERY : null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
TEST_INPUT,
|
TEST_INPUT,
|
||||||
false,
|
false,
|
||||||
Map.of(),
|
Map.of(),
|
||||||
|
|
|
@ -423,9 +423,9 @@ public abstract class BaseTransportInferenceActionTestCase<Request extends BaseI
|
||||||
when(service.canStream(any())).thenReturn(stream);
|
when(service.canStream(any())).thenReturn(stream);
|
||||||
when(service.supportedStreamingTasks()).thenReturn(supportedStreamingTasks);
|
when(service.supportedStreamingTasks()).thenReturn(supportedStreamingTasks);
|
||||||
doAnswer(ans -> {
|
doAnswer(ans -> {
|
||||||
listenerAction.accept(ans.getArgument(7));
|
listenerAction.accept(ans.getArgument(9));
|
||||||
return null;
|
return null;
|
||||||
}).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());
|
}).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any());
|
||||||
doAnswer(ans -> {
|
doAnswer(ans -> {
|
||||||
listenerAction.accept(ans.getArgument(3));
|
listenerAction.accept(ans.getArgument(3));
|
||||||
return null;
|
return null;
|
||||||
|
|
|
@ -23,7 +23,7 @@ public class InferenceInputsTests extends ESTestCase {
|
||||||
var emptyRequest = new UnifiedCompletionRequest(List.of(), null, null, null, null, null, null, null);
|
var emptyRequest = new UnifiedCompletionRequest(List.of(), null, null, null, null, null, null, null);
|
||||||
assertThat(new UnifiedChatInput(emptyRequest, false).castTo(UnifiedChatInput.class), Matchers.instanceOf(UnifiedChatInput.class));
|
assertThat(new UnifiedChatInput(emptyRequest, false).castTo(UnifiedChatInput.class), Matchers.instanceOf(UnifiedChatInput.class));
|
||||||
assertThat(
|
assertThat(
|
||||||
new QueryAndDocsInputs("hello", List.of(), false).castTo(QueryAndDocsInputs.class),
|
new QueryAndDocsInputs("hello", List.of(), Boolean.TRUE, 33, false).castTo(QueryAndDocsInputs.class),
|
||||||
Matchers.instanceOf(QueryAndDocsInputs.class)
|
Matchers.instanceOf(QueryAndDocsInputs.class)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,13 @@ import static org.hamcrest.CoreMatchers.is;
|
||||||
|
|
||||||
public class AlibabaCloudSearchRerankRequestEntityTests extends ESTestCase {
|
public class AlibabaCloudSearchRerankRequestEntityTests extends ESTestCase {
|
||||||
public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException {
|
public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException {
|
||||||
var entity = new AlibabaCloudSearchRerankRequestEntity("query", List.of("abc"), new AlibabaCloudSearchRerankTaskSettings());
|
var entity = new AlibabaCloudSearchRerankRequestEntity(
|
||||||
|
"query",
|
||||||
|
List.of("abc"),
|
||||||
|
Boolean.TRUE,
|
||||||
|
22,
|
||||||
|
new AlibabaCloudSearchRerankTaskSettings()
|
||||||
|
);
|
||||||
|
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
entity.toXContent(builder, null);
|
entity.toXContent(builder, null);
|
||||||
|
|
|
@ -0,0 +1,95 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.external.request.cohere;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.Strings;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xcontent.XContentFactory;
|
||||||
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
|
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
|
||||||
|
import org.hamcrest.MatcherAssert;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.hamcrest.CoreMatchers.is;
|
||||||
|
|
||||||
|
public class CohereRerankRequestEntityTests extends ESTestCase {
|
||||||
|
public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException {
|
||||||
|
var entity = new CohereRerankRequestEntity(
|
||||||
|
"query",
|
||||||
|
List.of("abc"),
|
||||||
|
Boolean.TRUE,
|
||||||
|
22,
|
||||||
|
new CohereRerankTaskSettings(null, null, 3),
|
||||||
|
"model"
|
||||||
|
);
|
||||||
|
|
||||||
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
|
entity.toXContent(builder, null);
|
||||||
|
String xContentResult = Strings.toString(builder);
|
||||||
|
|
||||||
|
MatcherAssert.assertThat(xContentResult, is("""
|
||||||
|
{"model":"model","query":"query","documents":["abc"],"return_documents":true,"top_n":22,"max_chunks_per_doc":3}"""));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testXContent_WritesMinimalFields() throws IOException {
|
||||||
|
var entity = new CohereRerankRequestEntity(
|
||||||
|
"query",
|
||||||
|
List.of("abc"),
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
new CohereRerankTaskSettings(null, null, null),
|
||||||
|
"model"
|
||||||
|
);
|
||||||
|
|
||||||
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
|
entity.toXContent(builder, null);
|
||||||
|
String xContentResult = Strings.toString(builder);
|
||||||
|
|
||||||
|
MatcherAssert.assertThat(xContentResult, is("""
|
||||||
|
{"model":"model","query":"query","documents":["abc"]}"""));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testXContent_PrefersRootLevelReturnDocumentsAndTopN() throws IOException {
|
||||||
|
var entity = new CohereRerankRequestEntity(
|
||||||
|
"query",
|
||||||
|
List.of("abc"),
|
||||||
|
Boolean.FALSE,
|
||||||
|
99,
|
||||||
|
new CohereRerankTaskSettings(33, Boolean.TRUE, null),
|
||||||
|
"model"
|
||||||
|
);
|
||||||
|
|
||||||
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
|
entity.toXContent(builder, null);
|
||||||
|
String xContentResult = Strings.toString(builder);
|
||||||
|
|
||||||
|
MatcherAssert.assertThat(xContentResult, is("""
|
||||||
|
{"model":"model","query":"query","documents":["abc"],"return_documents":false,"top_n":99}"""));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testXContent_UsesTaskSettingsIfNoRootOptionsDefined() throws IOException {
|
||||||
|
var entity = new CohereRerankRequestEntity(
|
||||||
|
"query",
|
||||||
|
List.of("abc"),
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
new CohereRerankTaskSettings(33, Boolean.TRUE, null),
|
||||||
|
"model"
|
||||||
|
);
|
||||||
|
|
||||||
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
|
entity.toXContent(builder, null);
|
||||||
|
String xContentResult = Strings.toString(builder);
|
||||||
|
|
||||||
|
MatcherAssert.assertThat(xContentResult, is("""
|
||||||
|
{"model":"model","query":"query","documents":["abc"],"return_documents":true,"top_n":33}"""));
|
||||||
|
}
|
||||||
|
}
|
|
@ -20,8 +20,8 @@ import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhi
|
||||||
import static org.hamcrest.MatcherAssert.assertThat;
|
import static org.hamcrest.MatcherAssert.assertThat;
|
||||||
|
|
||||||
public class GoogleVertexAiRerankRequestEntityTests extends ESTestCase {
|
public class GoogleVertexAiRerankRequestEntityTests extends ESTestCase {
|
||||||
public void testXContent_SingleRequest_WritesModelAndTopNIfDefined() throws IOException {
|
public void testXContent_SingleRequest_WritesAllFieldsIfDefined() throws IOException {
|
||||||
var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc"), "model", 8);
|
var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc"), Boolean.TRUE, 10, "model");
|
||||||
|
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
entity.toXContent(builder, null);
|
entity.toXContent(builder, null);
|
||||||
|
@ -37,13 +37,14 @@ public class GoogleVertexAiRerankRequestEntityTests extends ESTestCase {
|
||||||
"content": "abc"
|
"content": "abc"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"topN": 8
|
"topN": 10,
|
||||||
|
"ignoreRecordDetailsInResponse": false
|
||||||
}
|
}
|
||||||
"""));
|
"""));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testXContent_SingleRequest_DoesNotWriteModelAndTopNIfNull() throws IOException {
|
public void testXContent_SingleRequest_WritesMinimalFields() throws IOException {
|
||||||
var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc"), null, null);
|
var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc"), null, null, null);
|
||||||
|
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
entity.toXContent(builder, null);
|
entity.toXContent(builder, null);
|
||||||
|
@ -62,8 +63,8 @@ public class GoogleVertexAiRerankRequestEntityTests extends ESTestCase {
|
||||||
"""));
|
"""));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testXContent_MultipleRequests_WritesModelAndTopNIfDefined() throws IOException {
|
public void testXContent_MultipleRequests_WritesAllFieldsIfDefined() throws IOException {
|
||||||
var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc", "def"), "model", 8);
|
var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc", "def"), Boolean.FALSE, 12, "model");
|
||||||
|
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
entity.toXContent(builder, null);
|
entity.toXContent(builder, null);
|
||||||
|
@ -83,13 +84,14 @@ public class GoogleVertexAiRerankRequestEntityTests extends ESTestCase {
|
||||||
"content": "def"
|
"content": "def"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"topN": 8
|
"topN": 12,
|
||||||
|
"ignoreRecordDetailsInResponse": true
|
||||||
}
|
}
|
||||||
"""));
|
"""));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testXContent_MultipleRequests_DoesNotWriteModelAndTopNIfNull() throws IOException {
|
public void testXContent_MultipleRequests_WritesMinimalFields() throws IOException {
|
||||||
var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc", "def"), null, null);
|
var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc", "def"), null, null, null);
|
||||||
|
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
entity.toXContent(builder, null);
|
entity.toXContent(builder, null);
|
||||||
|
@ -111,5 +113,4 @@ public class GoogleVertexAiRerankRequestEntityTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
"""));
|
"""));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,11 +29,11 @@ public class GoogleVertexAiRerankRequestTests extends ESTestCase {
|
||||||
|
|
||||||
private static final String AUTH_HEADER_VALUE = "foo";
|
private static final String AUTH_HEADER_VALUE = "foo";
|
||||||
|
|
||||||
public void testCreateRequest_WithoutModelSet_And_WithoutTopNSet() throws IOException {
|
public void testCreateRequest_WithMinimalFieldsSet() throws IOException {
|
||||||
var input = "input";
|
var input = "input";
|
||||||
var query = "query";
|
var query = "query";
|
||||||
|
|
||||||
var request = createRequest(query, input, null, null);
|
var request = createRequest(query, input, null, null, null, null);
|
||||||
var httpRequest = request.createHttpRequest();
|
var httpRequest = request.createHttpRequest();
|
||||||
|
|
||||||
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||||
|
@ -53,8 +53,9 @@ public class GoogleVertexAiRerankRequestTests extends ESTestCase {
|
||||||
var input = "input";
|
var input = "input";
|
||||||
var query = "query";
|
var query = "query";
|
||||||
var topN = 1;
|
var topN = 1;
|
||||||
|
var taskSettingsTopN = 3;
|
||||||
|
|
||||||
var request = createRequest(query, input, null, topN);
|
var request = createRequest(query, input, null, topN, null, taskSettingsTopN);
|
||||||
var httpRequest = request.createHttpRequest();
|
var httpRequest = request.createHttpRequest();
|
||||||
|
|
||||||
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||||
|
@ -71,12 +72,55 @@ public class GoogleVertexAiRerankRequestTests extends ESTestCase {
|
||||||
assertThat(requestMap.get("topN"), is(topN));
|
assertThat(requestMap.get("topN"), is(topN));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testCreateRequest_UsesTaskSettingsTopNWhenRootLevelIsNull() throws IOException {
|
||||||
|
var input = "input";
|
||||||
|
var query = "query";
|
||||||
|
var topN = 1;
|
||||||
|
|
||||||
|
var request = createRequest(query, input, null, null, null, topN);
|
||||||
|
var httpRequest = request.createHttpRequest();
|
||||||
|
|
||||||
|
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||||
|
var httpPost = (HttpPost) httpRequest.httpRequestBase();
|
||||||
|
|
||||||
|
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
|
||||||
|
assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE));
|
||||||
|
|
||||||
|
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||||
|
|
||||||
|
assertThat(requestMap, aMapWithSize(3));
|
||||||
|
assertThat(requestMap.get("records"), is(List.of(Map.of("id", "0", "content", input))));
|
||||||
|
assertThat(requestMap.get("query"), is(query));
|
||||||
|
assertThat(requestMap.get("topN"), is(topN));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testCreateRequest_WithReturnDocumentsSet() throws IOException {
|
||||||
|
var input = "input";
|
||||||
|
var query = "query";
|
||||||
|
|
||||||
|
var request = createRequest(query, input, null, null, Boolean.TRUE, null);
|
||||||
|
var httpRequest = request.createHttpRequest();
|
||||||
|
|
||||||
|
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||||
|
var httpPost = (HttpPost) httpRequest.httpRequestBase();
|
||||||
|
|
||||||
|
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
|
||||||
|
assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE));
|
||||||
|
|
||||||
|
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||||
|
|
||||||
|
assertThat(requestMap, aMapWithSize(3));
|
||||||
|
assertThat(requestMap.get("records"), is(List.of(Map.of("id", "0", "content", input))));
|
||||||
|
assertThat(requestMap.get("query"), is(query));
|
||||||
|
assertThat(requestMap.get("ignoreRecordDetailsInResponse"), is(Boolean.FALSE));
|
||||||
|
}
|
||||||
|
|
||||||
public void testCreateRequest_WithModelSet() throws IOException {
|
public void testCreateRequest_WithModelSet() throws IOException {
|
||||||
var input = "input";
|
var input = "input";
|
||||||
var query = "query";
|
var query = "query";
|
||||||
var modelId = "model";
|
var modelId = "model";
|
||||||
|
|
||||||
var request = createRequest(query, input, modelId, null);
|
var request = createRequest(query, input, modelId, null, null, null);
|
||||||
var httpRequest = request.createHttpRequest();
|
var httpRequest = request.createHttpRequest();
|
||||||
|
|
||||||
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||||
|
@ -94,24 +138,37 @@ public class GoogleVertexAiRerankRequestTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testTruncate_DoesNotTruncate() {
|
public void testTruncate_DoesNotTruncate() {
|
||||||
var request = createRequest("query", "input", null, null);
|
var request = createRequest("query", "input", null, null, null, null);
|
||||||
var truncatedRequest = request.truncate();
|
var truncatedRequest = request.truncate();
|
||||||
|
|
||||||
assertThat(truncatedRequest, sameInstance(request));
|
assertThat(truncatedRequest, sameInstance(request));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static GoogleVertexAiRerankRequest createRequest(String query, String input, @Nullable String modelId, @Nullable Integer topN) {
|
private static GoogleVertexAiRerankRequest createRequest(
|
||||||
var rerankModel = GoogleVertexAiRerankModelTests.createModel(modelId, topN);
|
String query,
|
||||||
|
String input,
|
||||||
|
@Nullable String modelId,
|
||||||
|
@Nullable Integer topN,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer taskSettingsTopN
|
||||||
|
) {
|
||||||
|
var rerankModel = GoogleVertexAiRerankModelTests.createModel(modelId, taskSettingsTopN);
|
||||||
|
|
||||||
return new GoogleVertexAiRerankWithoutAuthRequest(query, List.of(input), rerankModel);
|
return new GoogleVertexAiRerankWithoutAuthRequest(query, List.of(input), rerankModel, topN, returnDocuments);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* We use this class to fake the auth implementation to avoid static mocking of {@link GoogleVertexAiRequest}
|
* We use this class to fake the auth implementation to avoid static mocking of {@link GoogleVertexAiRequest}
|
||||||
*/
|
*/
|
||||||
private static class GoogleVertexAiRerankWithoutAuthRequest extends GoogleVertexAiRerankRequest {
|
private static class GoogleVertexAiRerankWithoutAuthRequest extends GoogleVertexAiRerankRequest {
|
||||||
GoogleVertexAiRerankWithoutAuthRequest(String query, List<String> input, GoogleVertexAiRerankModel model) {
|
GoogleVertexAiRerankWithoutAuthRequest(
|
||||||
super(query, input, model);
|
String query,
|
||||||
|
List<String> input,
|
||||||
|
GoogleVertexAiRerankModel model,
|
||||||
|
@Nullable Integer topN,
|
||||||
|
@Nullable Boolean returnDocuments
|
||||||
|
) {
|
||||||
|
super(query, input, returnDocuments, topN, model);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -21,8 +21,15 @@ import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhi
|
||||||
import static org.hamcrest.MatcherAssert.assertThat;
|
import static org.hamcrest.MatcherAssert.assertThat;
|
||||||
|
|
||||||
public class JinaAIRerankRequestEntityTests extends ESTestCase {
|
public class JinaAIRerankRequestEntityTests extends ESTestCase {
|
||||||
public void testXContent_SingleRequest_WritesModelAndTopNIfDefined() throws IOException {
|
public void testXContent_SingleRequest_WritesAllFieldsIfDefined() throws IOException {
|
||||||
var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), new JinaAIRerankTaskSettings(8, null), "model");
|
var entity = new JinaAIRerankRequestEntity(
|
||||||
|
"query",
|
||||||
|
List.of("abc"),
|
||||||
|
Boolean.TRUE,
|
||||||
|
12,
|
||||||
|
new JinaAIRerankTaskSettings(8, Boolean.FALSE),
|
||||||
|
"model"
|
||||||
|
);
|
||||||
|
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
entity.toXContent(builder, null);
|
entity.toXContent(builder, null);
|
||||||
|
@ -35,33 +42,86 @@ public class JinaAIRerankRequestEntityTests extends ESTestCase {
|
||||||
"documents": [
|
"documents": [
|
||||||
"abc"
|
"abc"
|
||||||
],
|
],
|
||||||
"top_n": 8
|
"top_n": 12,
|
||||||
}
|
|
||||||
"""));
|
|
||||||
}
|
|
||||||
|
|
||||||
public void testXContent_SingleRequest_WritesModelAndTopNIfDefined_ReturnDocumentsTrue() throws IOException {
|
|
||||||
var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), new JinaAIRerankTaskSettings(8, true), "model");
|
|
||||||
|
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
|
||||||
entity.toXContent(builder, null);
|
|
||||||
String xContentResult = Strings.toString(builder);
|
|
||||||
|
|
||||||
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
|
|
||||||
{
|
|
||||||
"model": "model",
|
|
||||||
"query": "query",
|
|
||||||
"documents": [
|
|
||||||
"abc"
|
|
||||||
],
|
|
||||||
"top_n": 8,
|
|
||||||
"return_documents": true
|
"return_documents": true
|
||||||
}
|
}
|
||||||
"""));
|
"""));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testXContent_SingleRequest_WritesModelAndTopNIfDefined_ReturnDocumentsFalse() throws IOException {
|
public void testXContent_SingleRequest_WritesMinimalFields() throws IOException {
|
||||||
var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), new JinaAIRerankTaskSettings(8, false), "model");
|
var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), null, null, new JinaAIRerankTaskSettings(null, null), "model");
|
||||||
|
|
||||||
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
|
entity.toXContent(builder, null);
|
||||||
|
String xContentResult = Strings.toString(builder);
|
||||||
|
|
||||||
|
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
|
||||||
|
{
|
||||||
|
"model": "model",
|
||||||
|
"query": "query",
|
||||||
|
"documents": [
|
||||||
|
"abc"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testXContent_MultipleRequests_WritesAllFieldsIfDefined() throws IOException {
|
||||||
|
var entity = new JinaAIRerankRequestEntity(
|
||||||
|
"query",
|
||||||
|
List.of("abc", "def"),
|
||||||
|
Boolean.FALSE,
|
||||||
|
12,
|
||||||
|
new JinaAIRerankTaskSettings(8, Boolean.TRUE),
|
||||||
|
"model"
|
||||||
|
);
|
||||||
|
|
||||||
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
|
entity.toXContent(builder, null);
|
||||||
|
String xContentResult = Strings.toString(builder);
|
||||||
|
|
||||||
|
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
|
||||||
|
{
|
||||||
|
"model": "model",
|
||||||
|
"query": "query",
|
||||||
|
"documents": [
|
||||||
|
"abc",
|
||||||
|
"def"
|
||||||
|
],
|
||||||
|
"top_n": 12,
|
||||||
|
"return_documents": false
|
||||||
|
}
|
||||||
|
"""));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testXContent_MultipleRequests_WritesMinimalFields() throws IOException {
|
||||||
|
var entity = new JinaAIRerankRequestEntity("query", List.of("abc", "def"), null, null, null, "model");
|
||||||
|
|
||||||
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
|
entity.toXContent(builder, null);
|
||||||
|
String xContentResult = Strings.toString(builder);
|
||||||
|
|
||||||
|
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
|
||||||
|
{
|
||||||
|
"model": "model",
|
||||||
|
"query": "query",
|
||||||
|
"documents": [
|
||||||
|
"abc",
|
||||||
|
"def"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testXContent_SingleRequest_UsesTaskSettingsTopNIfRootIsNotDefined() throws IOException {
|
||||||
|
var entity = new JinaAIRerankRequestEntity(
|
||||||
|
"query",
|
||||||
|
List.of("abc"),
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
new JinaAIRerankTaskSettings(8, Boolean.FALSE),
|
||||||
|
"model"
|
||||||
|
);
|
||||||
|
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
entity.toXContent(builder, null);
|
entity.toXContent(builder, null);
|
||||||
|
@ -80,61 +140,4 @@ public class JinaAIRerankRequestEntityTests extends ESTestCase {
|
||||||
"""));
|
"""));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testXContent_SingleRequest_DoesNotWriteTopNIfNull() throws IOException {
|
|
||||||
var entity = new JinaAIRerankRequestEntity("query", List.of("abc"), null, "model");
|
|
||||||
|
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
|
||||||
entity.toXContent(builder, null);
|
|
||||||
String xContentResult = Strings.toString(builder);
|
|
||||||
|
|
||||||
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
|
|
||||||
{
|
|
||||||
"model": "model",
|
|
||||||
"query": "query",
|
|
||||||
"documents": [
|
|
||||||
"abc"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
"""));
|
|
||||||
}
|
|
||||||
|
|
||||||
public void testXContent_MultipleRequests_WritesModelAndTopNIfDefined() throws IOException {
|
|
||||||
var entity = new JinaAIRerankRequestEntity("query", List.of("abc", "def"), new JinaAIRerankTaskSettings(8, null), "model");
|
|
||||||
|
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
|
||||||
entity.toXContent(builder, null);
|
|
||||||
String xContentResult = Strings.toString(builder);
|
|
||||||
|
|
||||||
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
|
|
||||||
{
|
|
||||||
"model": "model",
|
|
||||||
"query": "query",
|
|
||||||
"documents": [
|
|
||||||
"abc",
|
|
||||||
"def"
|
|
||||||
],
|
|
||||||
"top_n": 8
|
|
||||||
}
|
|
||||||
"""));
|
|
||||||
}
|
|
||||||
|
|
||||||
public void testXContent_MultipleRequests_DoesNotWriteTopNIfNull() throws IOException {
|
|
||||||
var entity = new JinaAIRerankRequestEntity("query", List.of("abc", "def"), null, "model");
|
|
||||||
|
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
|
||||||
entity.toXContent(builder, null);
|
|
||||||
String xContentResult = Strings.toString(builder);
|
|
||||||
|
|
||||||
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
|
|
||||||
{
|
|
||||||
"model": "model",
|
|
||||||
"query": "query",
|
|
||||||
"documents": [
|
|
||||||
"abc",
|
|
||||||
"def"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
"""));
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,12 +27,12 @@ public class JinaAIRerankRequestTests extends ESTestCase {
|
||||||
|
|
||||||
private static final String API_KEY = "foo";
|
private static final String API_KEY = "foo";
|
||||||
|
|
||||||
public void testCreateRequest_WithoutModelSet_And_WithoutTopNSet() throws IOException {
|
public void testCreateRequest_WithMinimalFieldsSet() throws IOException {
|
||||||
var input = "input";
|
var input = "input";
|
||||||
var query = "query";
|
var query = "query";
|
||||||
var modelId = "model";
|
var modelId = "model";
|
||||||
|
|
||||||
var request = createRequest(query, input, modelId, null);
|
var request = createRequest(query, input, modelId, null, null, null);
|
||||||
var httpRequest = request.createHttpRequest();
|
var httpRequest = request.createHttpRequest();
|
||||||
|
|
||||||
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||||
|
@ -49,13 +49,14 @@ public class JinaAIRerankRequestTests extends ESTestCase {
|
||||||
assertThat(requestMap.get("model"), is(modelId));
|
assertThat(requestMap.get("model"), is(modelId));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testCreateRequest_WithTopNSet() throws IOException {
|
public void testCreateRequest_WithAllFieldsSet() throws IOException {
|
||||||
var input = "input";
|
var input = "input";
|
||||||
var query = "query";
|
var query = "query";
|
||||||
var topN = 1;
|
var topN = 1;
|
||||||
|
var taskSettingsTopN = 2;
|
||||||
var modelId = "model";
|
var modelId = "model";
|
||||||
|
|
||||||
var request = createRequest(query, input, modelId, topN);
|
var request = createRequest(query, input, modelId, topN, Boolean.FALSE, taskSettingsTopN);
|
||||||
var httpRequest = request.createHttpRequest();
|
var httpRequest = request.createHttpRequest();
|
||||||
|
|
||||||
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||||
|
@ -66,10 +67,11 @@ public class JinaAIRerankRequestTests extends ESTestCase {
|
||||||
|
|
||||||
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||||
|
|
||||||
assertThat(requestMap, aMapWithSize(4));
|
assertThat(requestMap, aMapWithSize(5));
|
||||||
assertThat(requestMap.get("documents"), is(List.of(input)));
|
assertThat(requestMap.get("documents"), is(List.of(input)));
|
||||||
assertThat(requestMap.get("query"), is(query));
|
assertThat(requestMap.get("query"), is(query));
|
||||||
assertThat(requestMap.get("top_n"), is(topN));
|
assertThat(requestMap.get("top_n"), is(topN));
|
||||||
|
assertThat(requestMap.get("return_documents"), is(Boolean.FALSE));
|
||||||
assertThat(requestMap.get("model"), is(modelId));
|
assertThat(requestMap.get("model"), is(modelId));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,7 +80,7 @@ public class JinaAIRerankRequestTests extends ESTestCase {
|
||||||
var query = "query";
|
var query = "query";
|
||||||
var modelId = "model";
|
var modelId = "model";
|
||||||
|
|
||||||
var request = createRequest(query, input, modelId, null);
|
var request = createRequest(query, input, modelId, null, null, null);
|
||||||
var httpRequest = request.createHttpRequest();
|
var httpRequest = request.createHttpRequest();
|
||||||
|
|
||||||
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||||
|
@ -96,15 +98,22 @@ public class JinaAIRerankRequestTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testTruncate_DoesNotTruncate() {
|
public void testTruncate_DoesNotTruncate() {
|
||||||
var request = createRequest("query", "input", "null", null);
|
var request = createRequest("query", "input", "null", null, null, null);
|
||||||
var truncatedRequest = request.truncate();
|
var truncatedRequest = request.truncate();
|
||||||
|
|
||||||
assertThat(truncatedRequest, sameInstance(request));
|
assertThat(truncatedRequest, sameInstance(request));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static JinaAIRerankRequest createRequest(String query, String input, @Nullable String modelId, @Nullable Integer topN) {
|
private static JinaAIRerankRequest createRequest(
|
||||||
var rerankModel = JinaAIRerankModelTests.createModel(API_KEY, modelId, topN);
|
String query,
|
||||||
return new JinaAIRerankRequest(query, List.of(input), rerankModel);
|
String input,
|
||||||
|
@Nullable String modelId,
|
||||||
|
@Nullable Integer topN,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer taskSettingsTopN
|
||||||
|
) {
|
||||||
|
var rerankModel = JinaAIRerankModelTests.createModel(API_KEY, modelId, taskSettingsTopN);
|
||||||
|
return new JinaAIRerankRequest(query, List.of(input), returnDocuments, topN, rerankModel);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,27 +20,15 @@ import java.util.List;
|
||||||
import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString;
|
import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString;
|
||||||
|
|
||||||
public class VoyageAIRerankRequestEntityTests extends ESTestCase {
|
public class VoyageAIRerankRequestEntityTests extends ESTestCase {
|
||||||
public void testXContent_SingleRequest_WritesModelAndTopKIfDefined() throws IOException {
|
public void testXContent_SingleRequest_WritesAllFieldsIfDefined() throws IOException {
|
||||||
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, null, null), "model");
|
var entity = new VoyageAIRerankRequestEntity(
|
||||||
|
"query",
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
List.of("abc"),
|
||||||
entity.toXContent(builder, null);
|
Boolean.TRUE,
|
||||||
String xContentResult = Strings.toString(builder);
|
12,
|
||||||
|
new VoyageAIRerankTaskSettings(8, Boolean.FALSE, null),
|
||||||
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
|
"model"
|
||||||
{
|
);
|
||||||
"model": "model",
|
|
||||||
"query": "query",
|
|
||||||
"documents": [
|
|
||||||
"abc"
|
|
||||||
],
|
|
||||||
"top_k": 8
|
|
||||||
}
|
|
||||||
"""));
|
|
||||||
}
|
|
||||||
|
|
||||||
public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_ReturnDocumentsTrue() throws IOException {
|
|
||||||
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, true, null), "model");
|
|
||||||
|
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
entity.toXContent(builder, null);
|
entity.toXContent(builder, null);
|
||||||
|
@ -54,13 +42,20 @@ public class VoyageAIRerankRequestEntityTests extends ESTestCase {
|
||||||
"abc"
|
"abc"
|
||||||
],
|
],
|
||||||
"return_documents": true,
|
"return_documents": true,
|
||||||
"top_k": 8
|
"top_k": 12
|
||||||
}
|
}
|
||||||
"""));
|
"""));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_ReturnDocumentsFalse() throws IOException {
|
public void testXContent_SingleRequest_WritesMinimalFields() throws IOException {
|
||||||
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, null), "model");
|
var entity = new VoyageAIRerankRequestEntity(
|
||||||
|
"query",
|
||||||
|
List.of("abc"),
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
new VoyageAIRerankTaskSettings(null, true, null),
|
||||||
|
"model"
|
||||||
|
);
|
||||||
|
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
entity.toXContent(builder, null);
|
entity.toXContent(builder, null);
|
||||||
|
@ -73,14 +68,20 @@ public class VoyageAIRerankRequestEntityTests extends ESTestCase {
|
||||||
"documents": [
|
"documents": [
|
||||||
"abc"
|
"abc"
|
||||||
],
|
],
|
||||||
"return_documents": false,
|
"return_documents": true
|
||||||
"top_k": 8
|
|
||||||
}
|
}
|
||||||
"""));
|
"""));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_TruncationTrue() throws IOException {
|
public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_TruncationTrue() throws IOException {
|
||||||
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, true), "model");
|
var entity = new VoyageAIRerankRequestEntity(
|
||||||
|
"query",
|
||||||
|
List.of("abc"),
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
new VoyageAIRerankTaskSettings(8, false, true),
|
||||||
|
"model"
|
||||||
|
);
|
||||||
|
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
entity.toXContent(builder, null);
|
entity.toXContent(builder, null);
|
||||||
|
@ -101,7 +102,14 @@ public class VoyageAIRerankRequestEntityTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_TruncationFalse() throws IOException {
|
public void testXContent_SingleRequest_WritesModelAndTopKIfDefined_TruncationFalse() throws IOException {
|
||||||
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), new VoyageAIRerankTaskSettings(8, false, false), "model");
|
var entity = new VoyageAIRerankRequestEntity(
|
||||||
|
"query",
|
||||||
|
List.of("abc"),
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
new VoyageAIRerankTaskSettings(8, false, false),
|
||||||
|
"model"
|
||||||
|
);
|
||||||
|
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
entity.toXContent(builder, null);
|
entity.toXContent(builder, null);
|
||||||
|
@ -121,28 +129,12 @@ public class VoyageAIRerankRequestEntityTests extends ESTestCase {
|
||||||
"""));
|
"""));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testXContent_SingleRequest_DoesNotWriteTopKIfNull() throws IOException {
|
public void testXContent_MultipleRequests_WritesAllFieldsIfDefined() throws IOException {
|
||||||
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc"), null, "model");
|
|
||||||
|
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
|
||||||
entity.toXContent(builder, null);
|
|
||||||
String xContentResult = Strings.toString(builder);
|
|
||||||
|
|
||||||
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
|
|
||||||
{
|
|
||||||
"model": "model",
|
|
||||||
"query": "query",
|
|
||||||
"documents": [
|
|
||||||
"abc"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
"""));
|
|
||||||
}
|
|
||||||
|
|
||||||
public void testXContent_MultipleRequests_WritesModelAndTopKIfDefined() throws IOException {
|
|
||||||
var entity = new VoyageAIRerankRequestEntity(
|
var entity = new VoyageAIRerankRequestEntity(
|
||||||
"query",
|
"query",
|
||||||
List.of("abc", "def"),
|
List.of("abc", "def"),
|
||||||
|
Boolean.FALSE,
|
||||||
|
11,
|
||||||
new VoyageAIRerankTaskSettings(8, null, null),
|
new VoyageAIRerankTaskSettings(8, null, null),
|
||||||
"model"
|
"model"
|
||||||
);
|
);
|
||||||
|
@ -159,13 +151,14 @@ public class VoyageAIRerankRequestEntityTests extends ESTestCase {
|
||||||
"abc",
|
"abc",
|
||||||
"def"
|
"def"
|
||||||
],
|
],
|
||||||
"top_k": 8
|
"return_documents": false,
|
||||||
|
"top_k": 11
|
||||||
}
|
}
|
||||||
"""));
|
"""));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testXContent_MultipleRequests_DoesNotWriteTopKIfNull() throws IOException {
|
public void testXContent_MultipleRequests_DoesNotWriteTopKIfNull() throws IOException {
|
||||||
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc", "def"), null, "model");
|
var entity = new VoyageAIRerankRequestEntity("query", List.of("abc", "def"), null, null, null, "model");
|
||||||
|
|
||||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
entity.toXContent(builder, null);
|
entity.toXContent(builder, null);
|
||||||
|
@ -183,4 +176,31 @@ public class VoyageAIRerankRequestEntityTests extends ESTestCase {
|
||||||
"""));
|
"""));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testXContent_UsesTaskSettingsTopNIfRootIsNotDefined() throws IOException {
|
||||||
|
var entity = new VoyageAIRerankRequestEntity(
|
||||||
|
"query",
|
||||||
|
List.of("abc"),
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
new VoyageAIRerankTaskSettings(8, Boolean.FALSE, null),
|
||||||
|
"model"
|
||||||
|
);
|
||||||
|
|
||||||
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
|
entity.toXContent(builder, null);
|
||||||
|
String xContentResult = Strings.toString(builder);
|
||||||
|
|
||||||
|
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
|
||||||
|
{
|
||||||
|
"model": "model",
|
||||||
|
"query": "query",
|
||||||
|
"documents": [
|
||||||
|
"abc"
|
||||||
|
],
|
||||||
|
"return_documents": false,
|
||||||
|
"top_k": 8
|
||||||
|
}
|
||||||
|
"""));
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,12 +27,12 @@ public class VoyageAIRerankRequestTests extends ESTestCase {
|
||||||
|
|
||||||
private static final String API_KEY = "foo";
|
private static final String API_KEY = "foo";
|
||||||
|
|
||||||
public void testCreateRequest_WithoutModelSet_And_WithoutTopNSet() throws IOException {
|
public void testCreateRequest_WithMinimalFields() throws IOException {
|
||||||
var input = "input";
|
var input = "input";
|
||||||
var query = "query";
|
var query = "query";
|
||||||
var modelId = "model";
|
var modelId = "model";
|
||||||
|
|
||||||
var request = createRequest(query, input, modelId, null);
|
var request = createRequest(query, input, modelId, null, null, null);
|
||||||
var httpRequest = request.createHttpRequest();
|
var httpRequest = request.createHttpRequest();
|
||||||
|
|
||||||
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||||
|
@ -49,13 +49,14 @@ public class VoyageAIRerankRequestTests extends ESTestCase {
|
||||||
assertThat(requestMap.get("model"), is(modelId));
|
assertThat(requestMap.get("model"), is(modelId));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testCreateRequest_WithTopNSet() throws IOException {
|
public void testCreateRequest_WithAllFieldsDefined() throws IOException {
|
||||||
var input = "input";
|
var input = "input";
|
||||||
var query = "query";
|
var query = "query";
|
||||||
var topK = 1;
|
var topK = 1;
|
||||||
|
var taskSettingsTopK = 2;
|
||||||
var modelId = "model";
|
var modelId = "model";
|
||||||
|
|
||||||
var request = createRequest(query, input, modelId, topK);
|
var request = createRequest(query, input, modelId, topK, Boolean.FALSE, taskSettingsTopK);
|
||||||
var httpRequest = request.createHttpRequest();
|
var httpRequest = request.createHttpRequest();
|
||||||
|
|
||||||
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||||
|
@ -66,11 +67,12 @@ public class VoyageAIRerankRequestTests extends ESTestCase {
|
||||||
|
|
||||||
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||||
|
|
||||||
assertThat(requestMap, aMapWithSize(4));
|
assertThat(requestMap, aMapWithSize(5));
|
||||||
assertThat(requestMap.get("documents"), is(List.of(input)));
|
assertThat(requestMap.get("documents"), is(List.of(input)));
|
||||||
assertThat(requestMap.get("query"), is(query));
|
assertThat(requestMap.get("query"), is(query));
|
||||||
assertThat(requestMap.get("top_k"), is(topK));
|
assertThat(requestMap.get("top_k"), is(topK));
|
||||||
assertThat(requestMap.get("model"), is(modelId));
|
assertThat(requestMap.get("model"), is(modelId));
|
||||||
|
assertThat(requestMap.get("return_documents"), is(Boolean.FALSE));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testCreateRequest_WithModelSet() throws IOException {
|
public void testCreateRequest_WithModelSet() throws IOException {
|
||||||
|
@ -78,7 +80,7 @@ public class VoyageAIRerankRequestTests extends ESTestCase {
|
||||||
var query = "query";
|
var query = "query";
|
||||||
var modelId = "model";
|
var modelId = "model";
|
||||||
|
|
||||||
var request = createRequest(query, input, modelId, null);
|
var request = createRequest(query, input, modelId, null, null, null);
|
||||||
var httpRequest = request.createHttpRequest();
|
var httpRequest = request.createHttpRequest();
|
||||||
|
|
||||||
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||||
|
@ -96,15 +98,22 @@ public class VoyageAIRerankRequestTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testTruncate_DoesNotTruncate() {
|
public void testTruncate_DoesNotTruncate() {
|
||||||
var request = createRequest("query", "input", "null", null);
|
var request = createRequest("query", "input", "null", null, null, null);
|
||||||
var truncatedRequest = request.truncate();
|
var truncatedRequest = request.truncate();
|
||||||
|
|
||||||
assertThat(truncatedRequest, sameInstance(request));
|
assertThat(truncatedRequest, sameInstance(request));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static VoyageAIRerankRequest createRequest(String query, String input, @Nullable String modelId, @Nullable Integer topK) {
|
private static VoyageAIRerankRequest createRequest(
|
||||||
var rerankModel = VoyageAIRerankModelTests.createModel(API_KEY, modelId, topK);
|
String query,
|
||||||
return new VoyageAIRerankRequest(query, List.of(input), rerankModel);
|
String input,
|
||||||
|
@Nullable String modelId,
|
||||||
|
@Nullable Integer topK,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer taskSettingsTopK
|
||||||
|
) {
|
||||||
|
var rerankModel = VoyageAIRerankModelTests.createModel(API_KEY, modelId, taskSettingsTopK);
|
||||||
|
return new VoyageAIRerankRequest(query, List.of(input), returnDocuments, topK, rerankModel);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,6 +42,26 @@ public class GoogleVertexAiRerankResponseEntityTests extends ESTestCase {
|
||||||
assertThat(parsedResults.getRankedDocs(), is(List.of(new RankedDocsResults.RankedDoc(2, 0.97F, "content 2"))));
|
assertThat(parsedResults.getRankedDocs(), is(List.of(new RankedDocsResults.RankedDoc(2, 0.97F, "content 2"))));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testFromResponse_CreatesResultsForASingleItem_NoContent() throws IOException {
|
||||||
|
String responseJson = """
|
||||||
|
{
|
||||||
|
"records": [
|
||||||
|
{
|
||||||
|
"id": "2",
|
||||||
|
"title": "title 2",
|
||||||
|
"score": 0.97
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
""";
|
||||||
|
|
||||||
|
RankedDocsResults parsedResults = GoogleVertexAiRerankResponseEntity.fromResponse(
|
||||||
|
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
||||||
|
);
|
||||||
|
|
||||||
|
assertThat(parsedResults.getRankedDocs(), is(List.of(new RankedDocsResults.RankedDoc(2, 0.97F, null))));
|
||||||
|
}
|
||||||
|
|
||||||
public void testFromResponse_CreatesResultsForMultipleItems() throws IOException {
|
public void testFromResponse_CreatesResultsForMultipleItems() throws IOException {
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
{
|
{
|
||||||
|
@ -72,6 +92,34 @@ public class GoogleVertexAiRerankResponseEntityTests extends ESTestCase {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testFromResponse_CreatesResultsForMultipleItems_NoContent() throws IOException {
|
||||||
|
String responseJson = """
|
||||||
|
{
|
||||||
|
"records": [
|
||||||
|
{
|
||||||
|
"id": "2",
|
||||||
|
"title": "title 2",
|
||||||
|
"score": 0.97
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "1",
|
||||||
|
"title": "title 1",
|
||||||
|
"score": 0.90
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
""";
|
||||||
|
|
||||||
|
RankedDocsResults parsedResults = GoogleVertexAiRerankResponseEntity.fromResponse(
|
||||||
|
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
||||||
|
);
|
||||||
|
|
||||||
|
assertThat(
|
||||||
|
parsedResults.getRankedDocs(),
|
||||||
|
is(List.of(new RankedDocsResults.RankedDoc(2, 0.97F, null), new RankedDocsResults.RankedDoc(1, 0.90F, null)))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
public void testFromResponse_FailsWhenRecordsFieldIsNotPresent() {
|
public void testFromResponse_FailsWhenRecordsFieldIsNotPresent() {
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
{
|
{
|
||||||
|
@ -102,36 +150,6 @@ public class GoogleVertexAiRerankResponseEntityTests extends ESTestCase {
|
||||||
assertThat(thrownException.getMessage(), is("Failed to find required field [records] in Google Vertex AI rerank response"));
|
assertThat(thrownException.getMessage(), is("Failed to find required field [records] in Google Vertex AI rerank response"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testFromResponse_FailsWhenContentFieldIsNotPresent() {
|
|
||||||
String responseJson = """
|
|
||||||
{
|
|
||||||
"records": [
|
|
||||||
{
|
|
||||||
"id": "2",
|
|
||||||
"title": "title 2",
|
|
||||||
"content": "content 2",
|
|
||||||
"score": 0.97
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "1",
|
|
||||||
"title": "title 1",
|
|
||||||
"not_content": "content 1",
|
|
||||||
"score": 0.97
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
""";
|
|
||||||
|
|
||||||
var thrownException = expectThrows(
|
|
||||||
IllegalStateException.class,
|
|
||||||
() -> GoogleVertexAiRerankResponseEntity.fromResponse(
|
|
||||||
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
|
||||||
)
|
|
||||||
);
|
|
||||||
|
|
||||||
assertThat(thrownException.getMessage(), is("Failed to find required field [content] in Google Vertex AI rerank response"));
|
|
||||||
}
|
|
||||||
|
|
||||||
public void testFromResponse_FailsWhenScoreFieldIsNotPresent() {
|
public void testFromResponse_FailsWhenScoreFieldIsNotPresent() {
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
{
|
{
|
||||||
|
|
|
@ -98,6 +98,8 @@ public class TextSimilarityRankTests extends ESSingleNodeTestCase {
|
||||||
TaskType.RERANK,
|
TaskType.RERANK,
|
||||||
this.inferenceId,
|
this.inferenceId,
|
||||||
inferenceText,
|
inferenceText,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
docFeatures,
|
docFeatures,
|
||||||
Map.of("inferenceResultCount", inferenceResultCount),
|
Map.of("inferenceResultCount", inferenceResultCount),
|
||||||
InputType.INTERNAL_SEARCH,
|
InputType.INTERNAL_SEARCH,
|
||||||
|
|
|
@ -225,6 +225,8 @@ public class TextSimilarityTestPlugin extends Plugin implements ActionPlugin {
|
||||||
TaskType.RERANK,
|
TaskType.RERANK,
|
||||||
inferenceId,
|
inferenceId,
|
||||||
inferenceText,
|
inferenceText,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
docFeatures,
|
docFeatures,
|
||||||
Map.of("throwing", true),
|
Map.of("throwing", true),
|
||||||
InputType.INTERNAL_SEARCH,
|
InputType.INTERNAL_SEARCH,
|
||||||
|
|
|
@ -910,11 +910,11 @@ public class ServiceUtilsTests extends ESTestCase {
|
||||||
when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
|
when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
|
||||||
|
|
||||||
doAnswer(invocation -> {
|
doAnswer(invocation -> {
|
||||||
ActionListener<InferenceServiceResults> listener = invocation.getArgument(7);
|
ActionListener<InferenceServiceResults> listener = invocation.getArgument(9);
|
||||||
listener.onResponse(new TextEmbeddingFloatResults(List.of()));
|
listener.onResponse(new TextEmbeddingFloatResults(List.of()));
|
||||||
|
|
||||||
return Void.TYPE;
|
return Void.TYPE;
|
||||||
}).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());
|
}).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any());
|
||||||
|
|
||||||
PlainActionFuture<Integer> listener = new PlainActionFuture<>();
|
PlainActionFuture<Integer> listener = new PlainActionFuture<>();
|
||||||
getEmbeddingSize(model, service, listener);
|
getEmbeddingSize(model, service, listener);
|
||||||
|
@ -932,11 +932,11 @@ public class ServiceUtilsTests extends ESTestCase {
|
||||||
when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
|
when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
|
||||||
|
|
||||||
doAnswer(invocation -> {
|
doAnswer(invocation -> {
|
||||||
ActionListener<InferenceServiceResults> listener = invocation.getArgument(7);
|
ActionListener<InferenceServiceResults> listener = invocation.getArgument(9);
|
||||||
listener.onResponse(new TextEmbeddingByteResults(List.of()));
|
listener.onResponse(new TextEmbeddingByteResults(List.of()));
|
||||||
|
|
||||||
return Void.TYPE;
|
return Void.TYPE;
|
||||||
}).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());
|
}).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any());
|
||||||
|
|
||||||
PlainActionFuture<Integer> listener = new PlainActionFuture<>();
|
PlainActionFuture<Integer> listener = new PlainActionFuture<>();
|
||||||
getEmbeddingSize(model, service, listener);
|
getEmbeddingSize(model, service, listener);
|
||||||
|
@ -956,11 +956,11 @@ public class ServiceUtilsTests extends ESTestCase {
|
||||||
var textEmbedding = TextEmbeddingFloatResultsTests.createRandomResults();
|
var textEmbedding = TextEmbeddingFloatResultsTests.createRandomResults();
|
||||||
|
|
||||||
doAnswer(invocation -> {
|
doAnswer(invocation -> {
|
||||||
ActionListener<InferenceServiceResults> listener = invocation.getArgument(7);
|
ActionListener<InferenceServiceResults> listener = invocation.getArgument(9);
|
||||||
listener.onResponse(textEmbedding);
|
listener.onResponse(textEmbedding);
|
||||||
|
|
||||||
return Void.TYPE;
|
return Void.TYPE;
|
||||||
}).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());
|
}).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any());
|
||||||
|
|
||||||
PlainActionFuture<Integer> listener = new PlainActionFuture<>();
|
PlainActionFuture<Integer> listener = new PlainActionFuture<>();
|
||||||
getEmbeddingSize(model, service, listener);
|
getEmbeddingSize(model, service, listener);
|
||||||
|
@ -979,11 +979,11 @@ public class ServiceUtilsTests extends ESTestCase {
|
||||||
var textEmbedding = TextEmbeddingByteResultsTests.createRandomResults();
|
var textEmbedding = TextEmbeddingByteResultsTests.createRandomResults();
|
||||||
|
|
||||||
doAnswer(invocation -> {
|
doAnswer(invocation -> {
|
||||||
ActionListener<InferenceServiceResults> listener = invocation.getArgument(7);
|
ActionListener<InferenceServiceResults> listener = invocation.getArgument(9);
|
||||||
listener.onResponse(textEmbedding);
|
listener.onResponse(textEmbedding);
|
||||||
|
|
||||||
return Void.TYPE;
|
return Void.TYPE;
|
||||||
}).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());
|
}).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any());
|
||||||
|
|
||||||
PlainActionFuture<Integer> listener = new PlainActionFuture<>();
|
PlainActionFuture<Integer> listener = new PlainActionFuture<>();
|
||||||
getEmbeddingSize(model, service, listener);
|
getEmbeddingSize(model, service, listener);
|
||||||
|
|
|
@ -389,6 +389,8 @@ public class AlibabaCloudSearchServiceTests extends ESTestCase {
|
||||||
() -> service.infer(
|
() -> service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(""),
|
List.of(""),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -431,6 +433,8 @@ public class AlibabaCloudSearchServiceTests extends ESTestCase {
|
||||||
() -> service.infer(
|
() -> service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(""),
|
List.of(""),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -446,6 +450,53 @@ public class AlibabaCloudSearchServiceTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testInfer_ThrowsValidationErrorForInvalidRerankParams() throws IOException {
|
||||||
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
Map<String, Object> serviceSettingsMap = new HashMap<>();
|
||||||
|
serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.SERVICE_ID, "service_id");
|
||||||
|
serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.HOST, "host");
|
||||||
|
serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, "default");
|
||||||
|
serviceSettingsMap.put(ServiceFields.DIMENSIONS, 1536);
|
||||||
|
|
||||||
|
Map<String, Object> taskSettingsMap = new HashMap<>();
|
||||||
|
|
||||||
|
Map<String, Object> secretSettingsMap = new HashMap<>();
|
||||||
|
secretSettingsMap.put("api_key", "secret");
|
||||||
|
|
||||||
|
var model = AlibabaCloudSearchEmbeddingsModelTests.createModel(
|
||||||
|
"service",
|
||||||
|
TaskType.RERANK,
|
||||||
|
serviceSettingsMap,
|
||||||
|
taskSettingsMap,
|
||||||
|
secretSettingsMap
|
||||||
|
);
|
||||||
|
try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool))) {
|
||||||
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
|
var thrownException = expectThrows(
|
||||||
|
ValidationException.class,
|
||||||
|
() -> service.infer(
|
||||||
|
model,
|
||||||
|
"hi",
|
||||||
|
Boolean.TRUE,
|
||||||
|
10,
|
||||||
|
List.of("a"),
|
||||||
|
false,
|
||||||
|
new HashMap<>(),
|
||||||
|
null,
|
||||||
|
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||||
|
listener
|
||||||
|
)
|
||||||
|
);
|
||||||
|
assertThat(
|
||||||
|
thrownException.getMessage(),
|
||||||
|
is(
|
||||||
|
"Validation Failed: 1: Invalid return_documents [true]. The return_documents option is not supported by this "
|
||||||
|
+ "service;2: Invalid top_n [10]. The top_n option is not supported by this service;"
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public void testChunkedInfer_TextEmbeddingChunkingSettingsSet() throws IOException {
|
public void testChunkedInfer_TextEmbeddingChunkingSettingsSet() throws IOException {
|
||||||
testChunkedInfer(TaskType.TEXT_EMBEDDING, ChunkingSettingsTests.createRandomChunkingSettings());
|
testChunkedInfer(TaskType.TEXT_EMBEDDING, ChunkingSettingsTests.createRandomChunkingSettings());
|
||||||
}
|
}
|
||||||
|
|
|
@ -932,6 +932,8 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
mockModel,
|
mockModel,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(""),
|
List.of(""),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -979,6 +981,8 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
() -> service.infer(
|
() -> service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(""),
|
List.of(""),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1029,6 +1033,8 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1071,6 +1077,8 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1414,6 +1422,8 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
|
|
@ -458,6 +458,8 @@ public class AnthropicServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
mockModel,
|
mockModel,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(""),
|
List.of(""),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -513,6 +515,8 @@ public class AnthropicServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("input"),
|
List.of("input"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -571,6 +575,8 @@ public class AnthropicServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
true,
|
true,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
|
|
@ -1096,6 +1096,8 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
mockModel,
|
mockModel,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(""),
|
List.of(""),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1134,6 +1136,8 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
||||||
() -> service.infer(
|
() -> service.infer(
|
||||||
mockModel,
|
mockModel,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(""),
|
List.of(""),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1296,6 +1300,8 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1347,6 +1353,8 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1403,6 +1411,8 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
true,
|
true,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
|
|
@ -766,6 +766,8 @@ public class AzureOpenAiServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
mockModel,
|
mockModel,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(""),
|
List.of(""),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -822,6 +824,8 @@ public class AzureOpenAiServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1286,6 +1290,8 @@ public class AzureOpenAiServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1453,6 +1459,8 @@ public class AzureOpenAiServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
true,
|
true,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
|
|
@ -788,6 +788,8 @@ public class CohereServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
mockModel,
|
mockModel,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(""),
|
List.of(""),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -856,6 +858,8 @@ public class CohereServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1147,6 +1151,8 @@ public class CohereServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1207,6 +1213,8 @@ public class CohereServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1281,6 +1289,8 @@ public class CohereServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH, null),
|
CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH, null),
|
||||||
|
@ -1353,6 +1363,8 @@ public class CohereServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1629,6 +1641,8 @@ public class CohereServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
true,
|
true,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
|
|
@ -232,7 +232,7 @@ public class DeepSeekServiceTests extends ESTestCase {
|
||||||
try (var service = createService()) {
|
try (var service = createService()) {
|
||||||
var model = createModel(service, TaskType.COMPLETION);
|
var model = createModel(service, TaskType.COMPLETION);
|
||||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
service.infer(model, null, List.of("hello"), false, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener);
|
service.infer(model, null, null, null, List.of("hello"), false, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener);
|
||||||
var result = listener.actionGet(TIMEOUT);
|
var result = listener.actionGet(TIMEOUT);
|
||||||
assertThat(result, isA(ChatCompletionResults.class));
|
assertThat(result, isA(ChatCompletionResults.class));
|
||||||
var completionResults = (ChatCompletionResults) result;
|
var completionResults = (ChatCompletionResults) result;
|
||||||
|
@ -255,7 +255,7 @@ public class DeepSeekServiceTests extends ESTestCase {
|
||||||
try (var service = createService()) {
|
try (var service = createService()) {
|
||||||
var model = createModel(service, TaskType.COMPLETION);
|
var model = createModel(service, TaskType.COMPLETION);
|
||||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
service.infer(model, null, List.of("hello"), true, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener);
|
service.infer(model, null, null, null, List.of("hello"), true, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener);
|
||||||
InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream().hasNoErrors().hasEvent("""
|
InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream().hasNoErrors().hasEvent("""
|
||||||
{"completion":[{"delta":"hello, world"}]}""");
|
{"completion":[{"delta":"hello, world"}]}""");
|
||||||
}
|
}
|
||||||
|
|
|
@ -368,6 +368,8 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
mockModel,
|
mockModel,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(""),
|
List.of(""),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -404,6 +406,8 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
mockModel,
|
mockModel,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(""),
|
List.of(""),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -443,6 +447,8 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
mockModel,
|
mockModel,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(""),
|
List.of(""),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -494,6 +500,8 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("input text"),
|
List.of("input text"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -551,6 +559,8 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("input text"),
|
List.of("input text"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
|
|
@ -662,6 +662,8 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
mockModel,
|
mockModel,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(""),
|
List.of(""),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -700,6 +702,8 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
|
||||||
() -> service.infer(
|
() -> service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(""),
|
List.of(""),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -775,6 +779,8 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("input"),
|
List.of("input"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -832,6 +838,8 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(input),
|
List.of(input),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1005,6 +1013,8 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
|
|
@ -65,6 +65,8 @@ public class HuggingFaceBaseServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
mockModel,
|
mockModel,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(""),
|
List.of(""),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
|
|
@ -556,6 +556,8 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -593,6 +595,8 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
||||||
() -> service.infer(
|
() -> service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -627,6 +631,8 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
|
|
@ -602,6 +602,8 @@ public class IbmWatsonxServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
mockModel,
|
mockModel,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(""),
|
List.of(""),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -641,6 +643,8 @@ public class IbmWatsonxServiceTests extends ESTestCase {
|
||||||
() -> service.infer(
|
() -> service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(""),
|
List.of(""),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -697,6 +701,8 @@ public class IbmWatsonxServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(input),
|
List.of(input),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -840,6 +846,8 @@ public class IbmWatsonxServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
|
|
@ -782,6 +782,8 @@ public class JinaAIServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
mockModel,
|
mockModel,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(""),
|
List.of(""),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1044,6 +1046,8 @@ public class JinaAIServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1076,6 +1080,8 @@ public class JinaAIServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
"query",
|
"query",
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("candidate1", "candidate2"),
|
List.of("candidate1", "candidate2"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1132,6 +1138,8 @@ public class JinaAIServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1201,6 +1209,8 @@ public class JinaAIServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1254,6 +1264,8 @@ public class JinaAIServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1320,7 +1332,18 @@ public class JinaAIServiceTests extends ESTestCase {
|
||||||
JinaAIEmbeddingType.FLOAT
|
JinaAIEmbeddingType.FLOAT
|
||||||
);
|
);
|
||||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
service.infer(model, null, List.of("abc"), false, new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
service.infer(
|
||||||
|
model,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
List.of("abc"),
|
||||||
|
false,
|
||||||
|
new HashMap<>(),
|
||||||
|
null,
|
||||||
|
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||||
|
listener
|
||||||
|
);
|
||||||
|
|
||||||
var result = listener.actionGet(TIMEOUT);
|
var result = listener.actionGet(TIMEOUT);
|
||||||
|
|
||||||
|
@ -1371,6 +1394,8 @@ public class JinaAIServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
"query",
|
"query",
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("candidate1", "candidate2", "candidate3"),
|
List.of("candidate1", "candidate2", "candidate3"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1454,6 +1479,8 @@ public class JinaAIServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
"query",
|
"query",
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("candidate1", "candidate2", "candidate3", "candidate4"),
|
List.of("candidate1", "candidate2", "candidate3", "candidate4"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1549,6 +1576,8 @@ public class JinaAIServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
"query",
|
"query",
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("candidate1", "candidate2", "candidate3"),
|
List.of("candidate1", "candidate2", "candidate3"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1630,6 +1659,8 @@ public class JinaAIServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
"query",
|
"query",
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("candidate1", "candidate2", "candidate3", "candidate4"),
|
List.of("candidate1", "candidate2", "candidate3", "candidate4"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1724,6 +1755,8 @@ public class JinaAIServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
|
|
@ -586,6 +586,8 @@ public class MistralServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
mockModel,
|
mockModel,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(""),
|
List.of(""),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -625,6 +627,8 @@ public class MistralServiceTests extends ESTestCase {
|
||||||
() -> service.infer(
|
() -> service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(""),
|
List.of(""),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -781,6 +785,8 @@ public class MistralServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
|
|
@ -852,6 +852,8 @@ public class OpenAiServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
mockModel,
|
mockModel,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(""),
|
List.of(""),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -890,6 +892,8 @@ public class OpenAiServiceTests extends ESTestCase {
|
||||||
() -> service.infer(
|
() -> service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(""),
|
List.of(""),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -925,6 +929,8 @@ public class OpenAiServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
mockModel,
|
mockModel,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(""),
|
List.of(""),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -964,6 +970,8 @@ public class OpenAiServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
mockModel,
|
mockModel,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(""),
|
List.of(""),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1024,6 +1032,8 @@ public class OpenAiServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1263,6 +1273,8 @@ public class OpenAiServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
true,
|
true,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1794,6 +1806,8 @@ public class OpenAiServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
|
|
@ -63,6 +63,8 @@ public class SimpleServiceIntegrationValidatorTests extends ESTestCase {
|
||||||
.infer(
|
.infer(
|
||||||
eq(mockModel),
|
eq(mockModel),
|
||||||
eq(null),
|
eq(null),
|
||||||
|
eq(null),
|
||||||
|
eq(null),
|
||||||
eq(TEST_INPUT),
|
eq(TEST_INPUT),
|
||||||
eq(false),
|
eq(false),
|
||||||
eq(Map.of()),
|
eq(Map.of()),
|
||||||
|
@ -97,13 +99,15 @@ public class SimpleServiceIntegrationValidatorTests extends ESTestCase {
|
||||||
|
|
||||||
private void mockSuccessfulCallToService(String query, InferenceServiceResults result) {
|
private void mockSuccessfulCallToService(String query, InferenceServiceResults result) {
|
||||||
doAnswer(ans -> {
|
doAnswer(ans -> {
|
||||||
ActionListener<InferenceServiceResults> responseListener = ans.getArgument(7);
|
ActionListener<InferenceServiceResults> responseListener = ans.getArgument(9);
|
||||||
responseListener.onResponse(result);
|
responseListener.onResponse(result);
|
||||||
return null;
|
return null;
|
||||||
}).when(mockInferenceService)
|
}).when(mockInferenceService)
|
||||||
.infer(
|
.infer(
|
||||||
eq(mockModel),
|
eq(mockModel),
|
||||||
eq(query),
|
eq(query),
|
||||||
|
eq(null),
|
||||||
|
eq(null),
|
||||||
eq(TEST_INPUT),
|
eq(TEST_INPUT),
|
||||||
eq(false),
|
eq(false),
|
||||||
eq(Map.of()),
|
eq(Map.of()),
|
||||||
|
@ -120,6 +124,8 @@ public class SimpleServiceIntegrationValidatorTests extends ESTestCase {
|
||||||
verify(mockInferenceService).infer(
|
verify(mockInferenceService).infer(
|
||||||
eq(mockModel),
|
eq(mockModel),
|
||||||
eq(withQuery ? TEST_QUERY : null),
|
eq(withQuery ? TEST_QUERY : null),
|
||||||
|
eq(null),
|
||||||
|
eq(null),
|
||||||
eq(TEST_INPUT),
|
eq(TEST_INPUT),
|
||||||
eq(false),
|
eq(false),
|
||||||
eq(Map.of()),
|
eq(Map.of()),
|
||||||
|
|
|
@ -722,6 +722,8 @@ public class VoyageAIServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
mockModel,
|
mockModel,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(""),
|
List.of(""),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -768,6 +770,8 @@ public class VoyageAIServiceTests extends ESTestCase {
|
||||||
() -> service.infer(
|
() -> service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of(""),
|
List.of(""),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1017,6 +1021,8 @@ public class VoyageAIServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1049,6 +1055,8 @@ public class VoyageAIServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
"query",
|
"query",
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("candidate1", "candidate2"),
|
List.of("candidate1", "candidate2"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1103,6 +1111,8 @@ public class VoyageAIServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1183,6 +1193,8 @@ public class VoyageAIServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1260,7 +1272,18 @@ public class VoyageAIServiceTests extends ESTestCase {
|
||||||
(SimilarityMeasure) null
|
(SimilarityMeasure) null
|
||||||
);
|
);
|
||||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
service.infer(model, null, List.of("abc"), false, new HashMap<>(), null, InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
service.infer(
|
||||||
|
model,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
List.of("abc"),
|
||||||
|
false,
|
||||||
|
new HashMap<>(),
|
||||||
|
null,
|
||||||
|
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||||
|
listener
|
||||||
|
);
|
||||||
|
|
||||||
var result = listener.actionGet(TIMEOUT);
|
var result = listener.actionGet(TIMEOUT);
|
||||||
|
|
||||||
|
@ -1315,6 +1338,8 @@ public class VoyageAIServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
"query",
|
"query",
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("candidate1", "candidate2", "candidate3"),
|
List.of("candidate1", "candidate2", "candidate3"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1401,6 +1426,8 @@ public class VoyageAIServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
"query",
|
"query",
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("candidate1", "candidate2", "candidate3", "candidate4"),
|
List.of("candidate1", "candidate2", "candidate3", "candidate4"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1493,6 +1520,8 @@ public class VoyageAIServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
"query",
|
"query",
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("candidate1", "candidate2", "candidate3"),
|
List.of("candidate1", "candidate2", "candidate3"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1569,6 +1598,8 @@ public class VoyageAIServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
"query",
|
"query",
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("candidate1", "candidate2", "candidate3", "candidate4"),
|
List.of("candidate1", "candidate2", "candidate3", "candidate4"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
@ -1663,6 +1694,8 @@ public class VoyageAIServiceTests extends ESTestCase {
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
List.of("abc"),
|
List.of("abc"),
|
||||||
false,
|
false,
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
|
|
|
@ -123,6 +123,8 @@ public class TransportCoordinatedInferenceAction extends HandledTransportAction<
|
||||||
TaskType.ANY,
|
TaskType.ANY,
|
||||||
request.getModelId(),
|
request.getModelId(),
|
||||||
null,
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
request.getInputs(),
|
request.getInputs(),
|
||||||
request.getTaskSettings(),
|
request.getTaskSettings(),
|
||||||
inputType,
|
inputType,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue