diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 00943d04275d..67e69c013607 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -157,6 +157,7 @@ public class TransportVersions { public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14); public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15); public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16); + public static final TransportVersion ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED_8_19 = def(8_841_0_17); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00); public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01); public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02); @@ -214,6 +215,7 @@ public class TransportVersions { public static final TransportVersion ESQL_REMOVE_AGGREGATE_TYPE = def(9_045_0_00); public static final TransportVersion ADD_PROJECT_ID_TO_DSL_ERROR_INFO = def(9_046_0_00); public static final TransportVersion SEMANTIC_TEXT_CHUNKING_CONFIG = def(9_047_00_0); + public static final TransportVersion ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED = def(9_048_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java index 42289c50864e..5c95150f1a88 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java @@ -26,7 +26,7 @@ public class InferenceGetModelsWithElasticInferenceServiceIT extends BaseMockEIS var allModels = getAllModels(); var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION); - assertThat(allModels, hasSize(5)); + assertThat(allModels, hasSize(6)); assertThat(chatCompletionModels, hasSize(1)); for (var model : chatCompletionModels) { @@ -35,6 +35,7 @@ public class InferenceGetModelsWithElasticInferenceServiceIT extends BaseMockEIS assertInferenceIdTaskType(allModels, ".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION); assertInferenceIdTaskType(allModels, ".elser-v2-elastic", TaskType.SPARSE_EMBEDDING); + assertInferenceIdTaskType(allModels, ".multilingual-embed-elastic", TaskType.TEXT_EMBEDDING); } private static void assertInferenceIdTaskType(List> models, String inferenceId, TaskType taskType) { diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index 6f9a55048104..f180b995eb8c 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -64,7 +64,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest { @SuppressWarnings("unchecked") public void testGetServicesWithTextEmbeddingTaskType() throws IOException { List services = getServices(TaskType.TEXT_EMBEDDING); - assertThat(services.size(), equalTo(15)); + assertThat(services.size(), equalTo(16)); String[] providers = new String[services.size()]; for (int i = 0; i < services.size(); i++) { @@ -79,6 +79,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest { "azureaistudio", "azureopenai", "cohere", + "elastic", "elasticsearch", "googleaistudio", "googlevertexai", diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java index 3ea011c1317c..a032a78a942e 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java @@ -36,6 +36,10 @@ public class MockElasticInferenceServiceAuthorizationServer implements TestRule { "model_name": "elser-v2", "task_types": ["embed/text/sparse"] + }, + { + "model_name": "multilingual-embed", + "task_types": ["embed/text/dense"] } ] } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java index 86c1b549d9de..6fabfdcae484 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java @@ -11,6 +11,7 @@ import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; @@ -197,6 +198,10 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase { { "model_name": "elser-v2", "task_types": ["embed/text/sparse"] + }, + { + "model_name": "multilingual-embed", + "task_types": ["embed/text/dense"] } ] } @@ -221,16 +226,33 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase { ".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), service + ), + new InferenceService.DefaultConfigId( + ".multilingual-embed-elastic", + MinimalServiceSettings.textEmbedding( + ElasticInferenceService.NAME, + ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, + ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), + DenseVectorFieldMapper.ElementType.FLOAT + ), + service ) ) ) ); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING))); + assertThat( + service.supportedTaskTypes(), + is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING)) + ); PlainActionFuture> listener = new PlainActionFuture<>(); service.defaultConfigs(listener); assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic")); assertThat(listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); + assertThat( + listener.actionGet(TIMEOUT).get(2).getConfigurations().getInferenceEntityId(), + is(".multilingual-embed-elastic") + ); var getModelListener = new PlainActionFuture(); // persists the default endpoints @@ -267,6 +289,16 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase { ".elser-v2-elastic", MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), service + ), + new InferenceService.DefaultConfigId( + ".multilingual-embed-elastic", + MinimalServiceSettings.textEmbedding( + ElasticInferenceService.NAME, + ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, + ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), + DenseVectorFieldMapper.ElementType.FLOAT + ), + service ) ) ) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceDenseTextEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceDenseTextEmbeddingsRequest.java new file mode 100644 index 000000000000..8c43f0e12530 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceDenseTextEmbeddingsRequest.java @@ -0,0 +1,93 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.elastic; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpRequestBase; +import org.apache.http.entity.ByteArrayEntity; +import org.apache.http.message.BasicHeader; +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; +import org.elasticsearch.xpack.inference.telemetry.TraceContext; +import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceSparseEmbeddingsRequest.inputTypeToUsageContext; + +public class ElasticInferenceServiceDenseTextEmbeddingsRequest extends ElasticInferenceServiceRequest { + + private final URI uri; + private final ElasticInferenceServiceDenseTextEmbeddingsModel model; + private final List inputs; + private final TraceContextHandler traceContextHandler; + private final InputType inputType; + + public ElasticInferenceServiceDenseTextEmbeddingsRequest( + ElasticInferenceServiceDenseTextEmbeddingsModel model, + List inputs, + TraceContext traceContext, + ElasticInferenceServiceRequestMetadata metadata, + InputType inputType + ) { + super(metadata); + this.inputs = inputs; + this.model = Objects.requireNonNull(model); + this.uri = model.uri(); + this.traceContextHandler = new TraceContextHandler(traceContext); + this.inputType = inputType; + } + + @Override + public HttpRequestBase createHttpRequestBase() { + var httpPost = new HttpPost(uri); + var usageContext = inputTypeToUsageContext(inputType); + var requestEntity = Strings.toString( + new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(inputs, model.getServiceSettings().modelId(), usageContext) + ); + + ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8)); + httpPost.setEntity(byteEntity); + + traceContextHandler.propagateTraceContext(httpPost); + httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType())); + + return httpPost; + } + + public TraceContext getTraceContext() { + return traceContextHandler.traceContext(); + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public URI getURI() { + return this.uri; + } + + @Override + public Request truncate() { + return this; + } + + @Override + public boolean[] getTruncationInfo() { + return null; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceDenseTextEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceDenseTextEmbeddingsRequestEntity.java new file mode 100644 index 000000000000..99ea5e8d6d3e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceDenseTextEmbeddingsRequestEntity.java @@ -0,0 +1,57 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.elastic; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceUsageContext; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public record ElasticInferenceServiceDenseTextEmbeddingsRequestEntity( + List inputs, + String modelId, + @Nullable ElasticInferenceServiceUsageContext usageContext +) implements ToXContentObject { + + private static final String INPUT_FIELD = "input"; + private static final String MODEL_FIELD = "model"; + private static final String USAGE_CONTEXT = "usage_context"; + + public ElasticInferenceServiceDenseTextEmbeddingsRequestEntity { + Objects.requireNonNull(inputs); + Objects.requireNonNull(modelId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.startArray(INPUT_FIELD); + + for (String input : inputs) { + builder.value(input); + } + + builder.endArray(); + + builder.field(MODEL_FIELD, modelId); + + // optional field + if ((usageContext == ElasticInferenceServiceUsageContext.UNSPECIFIED) == false) { + builder.field(USAGE_CONTEXT, usageContext); + } + + builder.endObject(); + + return builder; + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceAuthorizationResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceAuthorizationResponseEntity.java index 30c460d5d1ed..3891850ce122 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceAuthorizationResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceAuthorizationResponseEntity.java @@ -43,7 +43,9 @@ public class ElasticInferenceServiceAuthorizationResponseEntity implements Infer "embed/text/sparse", TaskType.SPARSE_EMBEDDING, "chat", - TaskType.CHAT_COMPLETION + TaskType.CHAT_COMPLETION, + "embed/text/dense", + TaskType.TEXT_EMBEDDING ); @SuppressWarnings("unchecked") diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.java new file mode 100644 index 000000000000..d80b9c74a6c6 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.java @@ -0,0 +1,107 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.elastic; + +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; + +public class ElasticInferenceServiceDenseTextEmbeddingsResponseEntity { + + private static final String FAILED_TO_FIND_FIELD_TEMPLATE = + "Failed to find required field [%s] in Elastic Inference Service dense text embeddings response"; + + /** + * Parses the Elastic Inference Service Dense Text Embeddings response. + * + * For a request like: + * + *
+     *     
+     *         {
+     *             "inputs": ["Embed this text", "Embed this text, too"]
+     *         }
+     *     
+     * 
+ * + * The response would look like: + * + *
+     *     
+     *         {
+     *             "data": [
+     *                  [
+     *                      2.1259406,
+     *                      1.7073475,
+     *                      0.9020516
+     *                  ],
+     *                  (...)
+     *             ],
+     *             "meta": {
+     *                  "usage": {...}
+     *             }
+     *         }
+     *     
+     * 
+ */ + + public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + moveToFirstToken(jsonParser); + + XContentParser.Token token = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + + positionParserAtTokenAfterField(jsonParser, "data", FAILED_TO_FIND_FIELD_TEMPLATE); + + List parsedEmbeddings = parseList( + jsonParser, + (parser, index) -> ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.parseTextEmbeddingObject(parser) + ); + + if (parsedEmbeddings.isEmpty()) { + return new TextEmbeddingFloatResults(Collections.emptyList()); + } + + return new TextEmbeddingFloatResults(parsedEmbeddings); + } + } + + private static TextEmbeddingFloatResults.Embedding parseTextEmbeddingObject(XContentParser parser) throws IOException { + List embeddingValueList = parseList( + parser, + ElasticInferenceServiceDenseTextEmbeddingsResponseEntity::parseEmbeddingFloatValueList + ); + return TextEmbeddingFloatResults.Embedding.of(embeddingValueList); + } + + private static float parseEmbeddingFloatValueList(XContentParser parser) throws IOException { + XContentParser.Token token = parser.currentToken(); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser); + return parser.floatValue(); + } + + private ElasticInferenceServiceDenseTextEmbeddingsResponseEntity() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceSparseEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceSparseEmbeddingsResponseEntity.java index 42ca45f75a9c..af71ccb6dc43 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceSparseEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceSparseEmbeddingsResponseEntity.java @@ -51,11 +51,11 @@ public class ElasticInferenceServiceSparseEmbeddingsResponseEntity { * * { * "data": [ - * { - * "Embed": 2.1259406, - * "this": 1.7073475, - * "text": 0.9020516 - * }, + * [ + * 2.1259406, + * 1.7073475, + * 0.9020516 + * ], * (...) * ], * "meta": { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 75a9e44d25b6..99f7287e785d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -16,7 +16,9 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -27,6 +29,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; @@ -36,6 +39,8 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; +import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -51,6 +56,8 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticI import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -70,6 +77,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; @@ -79,10 +87,18 @@ public class ElasticInferenceService extends SenderService { public static final String NAME = "elastic"; public static final String ELASTIC_INFERENCE_SERVICE_IDENTIFIER = "Elastic Inference Service"; + public static final Integer DENSE_TEXT_EMBEDDINGS_DIMENSIONS = 1024; - private static final EnumSet IMPLEMENTED_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION); + private static final EnumSet IMPLEMENTED_TASK_TYPES = EnumSet.of( + TaskType.SPARSE_EMBEDDING, + TaskType.CHAT_COMPLETION, + TaskType.TEXT_EMBEDDING + ); private static final String SERVICE_NAME = "Elastic"; + // TODO: check with team, what makes the most sense + private static final Integer DENSE_TEXT_EMBEDDINGS_MAX_BATCH_SIZE = 32; + // rainbow-sprinkles static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles"; static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); @@ -91,10 +107,17 @@ public class ElasticInferenceService extends SenderService { static final String DEFAULT_ELSER_MODEL_ID_V2 = "elser-v2"; static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId(DEFAULT_ELSER_MODEL_ID_V2); + // multilingual-text-embed + static final String DEFAULT_MULTILINGUAL_EMBED_MODEL_ID = "multilingual-embed"; + static final String DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID = defaultEndpointId(DEFAULT_MULTILINGUAL_EMBED_MODEL_ID); + /** * The task types that the {@link InferenceAction.Request} can accept. */ - private static final EnumSet SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING); + private static final EnumSet SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of( + TaskType.SPARSE_EMBEDDING, + TaskType.TEXT_EMBEDDING + ); public static String defaultEndpointId(String modelId) { return Strings.format(".%s-elastic", modelId); @@ -155,6 +178,31 @@ public class ElasticInferenceService extends SenderService { elasticInferenceServiceComponents ), MinimalServiceSettings.sparseEmbedding(NAME) + ), + DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, + new DefaultModelConfig( + new ElasticInferenceServiceDenseTextEmbeddingsModel( + DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, + TaskType.TEXT_EMBEDDING, + NAME, + new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, + defaultDenseTextEmbeddingsSimilarity(), + null, + null, + false, + ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS + ), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + elasticInferenceServiceComponents + ), + MinimalServiceSettings.textEmbedding( + NAME, + DENSE_TEXT_EMBEDDINGS_DIMENSIONS, + defaultDenseTextEmbeddingsSimilarity(), + DenseVectorFieldMapper.ElementType.FLOAT + ) ) ); } @@ -270,12 +318,26 @@ public class ElasticInferenceService extends SenderService { TimeValue timeout, ActionListener> listener ) { - // Pass-through without actually performing chunking (result will have a single chunk per input) - ActionListener inferListener = listener.delegateFailureAndWrap( - (delegate, response) -> delegate.onResponse(translateToChunkedResults(inputs, response)) - ); + // TODO: we probably want to allow chunked inference for both sparse and dense? + if (model instanceof ElasticInferenceServiceDenseTextEmbeddingsModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } - doInfer(model, inputs, taskSettings, timeout, inferListener); + ElasticInferenceServiceDenseTextEmbeddingsModel elasticInferenceServiceModel = + (ElasticInferenceServiceDenseTextEmbeddingsModel) model; + var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), getCurrentTraceInfo()); + + List batchedRequests = new EmbeddingRequestChunker<>( + inputs.getInputs(), + DENSE_TEXT_EMBEDDINGS_MAX_BATCH_SIZE, + elasticInferenceServiceModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + + for (var request : batchedRequests) { + var action = elasticInferenceServiceModel.accept(actionCreator, taskSettings); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + } } @Override @@ -294,11 +356,19 @@ public class ElasticInferenceService extends SenderService { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + ElasticInferenceServiceModel model = createModel( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, serviceSettingsMap, elasticInferenceServiceComponents, TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), @@ -335,6 +405,7 @@ public class ElasticInferenceService extends SenderService { TaskType taskType, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secretSettings, ElasticInferenceServiceComponents eisServiceComponents, String failureMessage, @@ -361,6 +432,16 @@ public class ElasticInferenceService extends SenderService { eisServiceComponents, context ); + case TEXT_EMBEDDING -> new ElasticInferenceServiceDenseTextEmbeddingsModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings, + eisServiceComponents, + context + ); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); }; } @@ -376,11 +457,17 @@ public class ElasticInferenceService extends SenderService { Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModelFromPersistent( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, secretSettingsMap, parsePersistedConfigErrorMsg(inferenceEntityId, NAME) ); @@ -391,11 +478,17 @@ public class ElasticInferenceService extends SenderService { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + return createModelFromPersistent( inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, null, parsePersistedConfigErrorMsg(inferenceEntityId, NAME) ); @@ -411,6 +504,7 @@ public class ElasticInferenceService extends SenderService { TaskType taskType, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secretSettings, String failureMessage ) { @@ -419,6 +513,7 @@ public class ElasticInferenceService extends SenderService { taskType, serviceSettings, taskSettings, + chunkingSettings, secretSettings, elasticInferenceServiceComponents, failureMessage, @@ -432,6 +527,36 @@ public class ElasticInferenceService extends SenderService { ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); } + @Override + public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { + if (model instanceof ElasticInferenceServiceDenseTextEmbeddingsModel embeddingsModel) { + var serviceSettings = embeddingsModel.getServiceSettings(); + var modelId = serviceSettings.modelId(); + var similarityFromModel = serviceSettings.similarity(); + var similarityToUse = similarityFromModel == null ? defaultDenseTextEmbeddingsSimilarity() : similarityFromModel; + var maxInputTokens = serviceSettings.maxInputTokens(); + var dimensionsSetByUser = serviceSettings.dimensionsSetByUser(); + + var updateServiceSettings = new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + modelId, + similarityToUse, + embeddingSize, + maxInputTokens, + dimensionsSetByUser, + serviceSettings.rateLimitSettings() + ); + + return new ElasticInferenceServiceDenseTextEmbeddingsModel(embeddingsModel, updateServiceSettings); + } else { + throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass()); + } + } + + public static SimilarityMeasure defaultDenseTextEmbeddingsSimilarity() { + // TODO: double-check + return SimilarityMeasure.COSINE; + } + private static List translateToChunkedResults(InferenceInputs inputs, InferenceServiceResults inferenceResults) { if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) { var inputsAsList = EmbeddingsInput.of(inputs).getStringInputs(); @@ -469,9 +594,9 @@ public class ElasticInferenceService extends SenderService { configurationMap.put( MODEL_ID, - new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)).setDescription( - "The name of the model to use for the inference task." - ) + new SettingsConfiguration.Builder( + EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING) + ).setDescription("The name of the model to use for the inference task.") .setLabel("Model ID") .setRequired(true) .setSensitive(false) @@ -482,7 +607,7 @@ public class ElasticInferenceService extends SenderService { configurationMap.put( MAX_INPUT_TOKENS, - new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING)).setDescription( + new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING)).setDescription( "Allows you to specify the maximum number of tokens per input." ) .setLabel("Maximum Input Tokens") @@ -494,7 +619,9 @@ public class ElasticInferenceService extends SenderService { ); configurationMap.putAll( - RateLimitSettings.toSettingsConfiguration(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)) + RateLimitSettings.toSettingsConfiguration( + EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING) + ) ); return new InferenceServiceConfiguration.Builder().setService(NAME) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java index e03cc36e6241..34a808611915 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java @@ -7,14 +7,15 @@ package org.elasticsearch.xpack.inference.services.elastic; -import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import java.util.Objects; -public abstract class ElasticInferenceServiceModel extends Model { +public abstract class ElasticInferenceServiceModel extends RateLimitGroupingModel { private final ElasticInferenceServiceRateLimitServiceSettings rateLimitServiceSettings; @@ -35,12 +36,18 @@ public abstract class ElasticInferenceServiceModel extends Model { public ElasticInferenceServiceModel(ElasticInferenceServiceModel model, ServiceSettings serviceSettings) { super(model, serviceSettings); - this.rateLimitServiceSettings = model.rateLimitServiceSettings(); + this.rateLimitServiceSettings = model.rateLimitServiceSettings; this.elasticInferenceServiceComponents = model.elasticInferenceServiceComponents(); } - public ElasticInferenceServiceRateLimitServiceSettings rateLimitServiceSettings() { - return rateLimitServiceSettings; + @Override + public int rateLimitGroupingHash() { + // We only have one model for rerank + return Objects.hash(this.getServiceSettings().modelId()); + } + + public RateLimitSettings rateLimitSettings() { + return rateLimitServiceSettings.rateLimitSettings(); } public ElasticInferenceServiceComponents elasticInferenceServiceComponents() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceRequestManager.java index 81231b83c767..bcb7d395b7e3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceRequestManager.java @@ -20,7 +20,7 @@ public abstract class ElasticInferenceServiceRequestManager extends BaseRequestM private final ElasticInferenceServiceRequestMetadata requestMetadata; protected ElasticInferenceServiceRequestManager(ThreadPool threadPool, ElasticInferenceServiceModel model) { - super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings()); + super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitSettings()); this.requestMetadata = extractRequestMetadataFromThreadContext(threadPool.getThreadContext()); } @@ -32,7 +32,7 @@ public abstract class ElasticInferenceServiceRequestManager extends BaseRequestM public static RateLimitGrouping of(ElasticInferenceServiceModel model) { Objects.requireNonNull(model); - return new RateLimitGrouping(model.rateLimitServiceSettings().modelId().hashCode()); + return new RateLimitGrouping(model.rateLimitGroupingHash()); } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java index 5bdae8582f37..18080d9c055e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java @@ -9,9 +9,16 @@ package org.elasticsearch.xpack.inference.services.elastic.action; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceDenseTextEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceDenseTextEmbeddingsResponseEntity; import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceResponseHandler; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.telemetry.TraceContext; @@ -19,10 +26,16 @@ import java.util.Locale; import java.util.Objects; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceRequest.extractRequestMetadataFromThreadContext; import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER; public class ElasticInferenceServiceActionCreator implements ElasticInferenceServiceActionVisitor { + public static final ResponseHandler DENSE_TEXT_EMBEDDINGS_HANDLER = new ElasticInferenceServiceResponseHandler( + "elastic dense text embedding", + ElasticInferenceServiceDenseTextEmbeddingsResponseEntity::fromResponse + ); + private final Sender sender; private final ServiceComponents serviceComponents; @@ -43,4 +56,26 @@ public class ElasticInferenceServiceActionCreator implements ElasticInferenceSer ); return new SenderExecutableAction(sender, requestManager, errorMessage); } + + @Override + public ExecutableAction create(ElasticInferenceServiceDenseTextEmbeddingsModel model) { + var threadPool = serviceComponents.threadPool(); + + var manager = new GenericRequestManager<>( + threadPool, + model, + DENSE_TEXT_EMBEDDINGS_HANDLER, + (embeddingsInput) -> new ElasticInferenceServiceDenseTextEmbeddingsRequest( + model, + embeddingsInput.getStringInputs(), + traceContext, + extractRequestMetadataFromThreadContext(threadPool.getThreadContext()), + embeddingsInput.getInputType() + ), + EmbeddingsInput.class + ); + + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Elastic dense text embeddings"); + return new SenderExecutableAction(sender, manager, failedToSendRequestErrorMessage); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java index 2fdd90b8169b..c19eeb0f3796 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java @@ -8,10 +8,12 @@ package org.elasticsearch.xpack.inference.services.elastic.action; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; public interface ElasticInferenceServiceActionVisitor { ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel model); + ExecutableAction create(ElasticInferenceServiceDenseTextEmbeddingsModel model); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModel.java new file mode 100644 index 000000000000..8c2066ba7046 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModel.java @@ -0,0 +1,114 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.densetextembeddings; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceExecutableActionModel; +import org.elasticsearch.xpack.inference.services.elastic.action.ElasticInferenceServiceActionVisitor; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; + +public class ElasticInferenceServiceDenseTextEmbeddingsModel extends ElasticInferenceServiceExecutableActionModel { + + private final URI uri; + + public ElasticInferenceServiceDenseTextEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + Map secrets, + ElasticInferenceServiceComponents elasticInferenceServiceComponents, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.fromMap(serviceSettings, context), + // TODO: we probably want dense embeddings task settings + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + elasticInferenceServiceComponents + ); + } + + public ElasticInferenceServiceDenseTextEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + ElasticInferenceServiceDenseTextEmbeddingsServiceSettings serviceSettings, + // TODO: we probably want dense embeddings task settings + @Nullable TaskSettings taskSettings, + @Nullable SecretSettings secretSettings, + ElasticInferenceServiceComponents elasticInferenceServiceComponents + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secretSettings), + serviceSettings, + elasticInferenceServiceComponents + ); + this.uri = createUri(); + } + + public ElasticInferenceServiceDenseTextEmbeddingsModel( + ElasticInferenceServiceDenseTextEmbeddingsModel model, + ElasticInferenceServiceDenseTextEmbeddingsServiceSettings serviceSettings + ) { + super(model, serviceSettings); + this.uri = createUri(); + } + + @Override + public ExecutableAction accept(ElasticInferenceServiceActionVisitor visitor, Map taskSettings) { + return visitor.create(this); + } + + @Override + public ElasticInferenceServiceDenseTextEmbeddingsServiceSettings getServiceSettings() { + return (ElasticInferenceServiceDenseTextEmbeddingsServiceSettings) super.getServiceSettings(); + } + + public URI uri() { + return uri; + } + + private URI createUri() throws ElasticsearchStatusException { + try { + // TODO, consider transforming the base URL into a URI for better error handling. + return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/embed/text/dense"); + } catch (URISyntaxException e) { + throw new ElasticsearchStatusException( + "Failed to create URI for service [" + + this.getConfigurations().getService() + + "] with taskType [" + + this.getTaskType() + + "]: " + + e.getMessage(), + RestStatus.BAD_REQUEST, + e + ); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java new file mode 100644 index 000000000000..93a5be16aeac --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java @@ -0,0 +1,263 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.densetextembeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.*; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; + +public class ElasticInferenceServiceDenseTextEmbeddingsServiceSettings extends FilteredXContentObject + implements + ServiceSettings, + ElasticInferenceServiceRateLimitServiceSettings { + + public static final String NAME = "elastic_inference_service_dense_embeddings_service_settings"; + static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user"; + + public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000); + + private final String modelId; + private final SimilarityMeasure similarity; + private final Integer dimensions; + private final Integer maxInputTokens; + private final boolean dimensionsSetByUser; + private final RateLimitSettings rateLimitSettings; + + public static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings fromMap( + Map map, + ConfigurationParseContext context + ) { + return switch (context) { + case REQUEST -> fromRequestMap(map, context); + case PERSISTENT -> fromPersistentMap(map, context); + }; + } + + private static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings fromRequestMap( + Map map, + ConfigurationParseContext context + ) { + ValidationException validationException = new ValidationException(); + + String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + ElasticInferenceService.NAME, + context + ); + + SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer dims = removeAsType(map, DIMENSIONS, Integer.class); + Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + var dimensionsSetByUser = dims != null; + + return new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + modelId, + similarity, + dims, + maxInputTokens, + dimensionsSetByUser, + rateLimitSettings + ); + } + + private static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings fromPersistentMap( + Map map, + ConfigurationParseContext context + ) { + ValidationException validationException = new ValidationException(); + + String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + ElasticInferenceService.NAME, + context + ); + + SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer dims = removeAsType(map, DIMENSIONS, Integer.class); + Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); + Boolean dimensionsSetByUser = removeAsType(map, DIMENSIONS_SET_BY_USER, Boolean.class); + + if (dimensionsSetByUser == null) { + dimensionsSetByUser = Boolean.FALSE; + } + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + modelId, + similarity, + dims, + maxInputTokens, + dimensionsSetByUser, + rateLimitSettings + ); + } + + public ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + String modelId, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens, + boolean dimensionsSetByUser, + RateLimitSettings rateLimitSettings + ) { + this.modelId = modelId; + this.similarity = similarity; + this.dimensions = dimensions; + this.maxInputTokens = maxInputTokens; + this.dimensionsSetByUser = dimensionsSetByUser; + this.rateLimitSettings = rateLimitSettings; + } + + public ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(StreamInput in) throws IOException { + this.modelId = in.readString(); + this.similarity = in.readOptionalEnum(SimilarityMeasure.class); + this.dimensions = in.readOptionalVInt(); + this.maxInputTokens = in.readOptionalVInt(); + this.dimensionsSetByUser = in.readBoolean(); + this.rateLimitSettings = new RateLimitSettings(in); + } + + @Override + public SimilarityMeasure similarity() { + return similarity; + } + + @Override + public Integer dimensions() { + return dimensions; + } + + public Integer maxInputTokens() { + return maxInputTokens; + } + + @Override + public String modelId() { + return modelId; + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + @Override + public Boolean dimensionsSetByUser() { + return dimensionsSetByUser; + } + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.FLOAT; + } + + public RateLimitSettings getRateLimitSettings() { + return rateLimitSettings; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_ID, modelId); + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + if (similarity != null) { + builder.field(SIMILARITY, similarity); + } + + if (dimensions != null) { + builder.field(DIMENSIONS, dimensions); + } + + if (maxInputTokens != null) { + builder.field(MAX_INPUT_TOKENS, maxInputTokens); + } + + toXContentFragmentOfExposedFields(builder, params); + builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser); + + builder.endObject(); + return builder; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion())); + out.writeOptionalVInt(dimensions); + out.writeOptionalVInt(maxInputTokens); + out.writeBoolean(dimensionsSetByUser); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ElasticInferenceServiceDenseTextEmbeddingsServiceSettings that = (ElasticInferenceServiceDenseTextEmbeddingsServiceSettings) o; + return Objects.equals(dimensionsSetByUser, that.dimensionsSetByUser) + && Objects.equals(similarity, that.similarity) + && Objects.equals(dimensions, that.dimensions) + && Objects.equals(maxInputTokens, that.maxInputTokens); + } + + @Override + public int hashCode() { + return Objects.hash(similarity, dimensions, maxInputTokens, dimensionsSetByUser); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 6a93b1cc19c8..5b65d96cfbfe 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySecretSettings; @@ -38,11 +39,9 @@ import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.EmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; -import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; @@ -59,6 +58,8 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticI import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -85,6 +86,7 @@ import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; @@ -394,47 +396,6 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { verifyNoMoreInteractions(sender); } - public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException { - var sender = mock(Sender.class); - - var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender()).thenReturn(sender); - - var mockModel = getInvalidModel("model_id", "service_name", TaskType.TEXT_EMBEDDING); - - try (var service = createService(factory)) { - PlainActionFuture listener = new PlainActionFuture<>(); - service.infer( - mockModel, - null, - null, - null, - List.of(""), - false, - new HashMap<>(), - InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); - - var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); - MatcherAssert.assertThat( - thrownException.getMessage(), - is( - "Inference entity [model_id] does not support task type [text_embedding] " - + "for inference, the task type must be one of [sparse_embedding]." - ) - ); - - verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); - } - - verify(sender, times(1)).close(); - verifyNoMoreInteractions(factory); - verifyNoMoreInteractions(sender); - } - public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws IOException { var sender = mock(Sender.class); @@ -463,7 +424,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { thrownException.getMessage(), is( "Inference entity [model_id] does not support task type [chat_completion] " - + "for inference, the task type must be one of [sparse_embedding]. " + + "for inference, the task type must be one of [text_embedding, sparse_embedding]. " + "The task type for the inference entity is chat_completion, " + "please use the _inference/chat_completion/model_id/_stream URL." ) @@ -604,82 +565,6 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { } } - public void testChunkedInfer_PropagatesProductUseCaseHeader() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var elasticInferenceServiceURL = getUrl(webServer); - - try (var service = createService(senderFactory, elasticInferenceServiceURL)) { - String responseJson = """ - { - "data": [ - { - "hello": 2.1259406, - "greet": 1.7073475 - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - // Set up the product use case in the thread context - String productUseCase = "test-product-use-case"; - threadPool.getThreadContext().putHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, productUseCase); - - var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(elasticInferenceServiceURL, "my-model-id"); - PlainActionFuture> listener = new PlainActionFuture<>(); - - try { - service.chunkedInfer( - model, - null, - List.of(new ChunkInferenceInput("input text")), - new HashMap<>(), - InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); - - var results = listener.actionGet(TIMEOUT); - - // Verify the response was processed correctly - ChunkedInference inferenceResult = results.getFirst(); - assertThat(inferenceResult, instanceOf(ChunkedInferenceEmbedding.class)); - var sparseResult = (ChunkedInferenceEmbedding) inferenceResult; - assertThat( - sparseResult.chunks(), - is( - List.of( - new EmbeddingResults.Chunk( - new SparseEmbeddingResults.Embedding( - List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)), - false - ), - new ChunkedInference.TextOffset(0, "input text".length()) - ) - ) - ) - ); - - // Verify the request was sent and contains expected headers - MatcherAssert.assertThat(webServer.requests(), hasSize(1)); - var request = webServer.requests().getFirst(); - assertNull(request.getUri().getQuery()); - MatcherAssert.assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); - - // Check that the product use case header was set correctly - assertThat(request.getHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER), is(productUseCase)); - - // Verify request body - var requestMap = entityAsMap(request.getBody()); - assertThat(requestMap, is(Map.of("input", List.of("input text"), "model", "my-model-id", "usage_context", "ingest"))); - } finally { - // Clean up the thread context - threadPool.getThreadContext().stashContext(); - } - } - } - public void testUnifiedCompletionInfer_PropagatesProductUseCaseHeader() throws IOException { var elasticInferenceServiceURL = getUrl(webServer); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); @@ -738,30 +623,45 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { } } - public void testChunkedInfer_PassesThrough() throws IOException { + public void testChunkedInfer_PropagatesProductUseCaseHeader() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var elasticInferenceServiceURL = getUrl(webServer); - try (var service = createService(senderFactory, elasticInferenceServiceURL)) { + try (var service = createService(senderFactory, getUrl(webServer))) { + + // Batching will call the service with 2 inputs String responseJson = """ { "data": [ - { - "hello": 2.1259406, - "greet": 1.7073475 + [ + 0.123, + -0.456, + 0.789 + ], + [ + 0.987, + -0.654, + 0.321 + ] + ], + "meta": { + "usage": { + "total_tokens": 10 } - ] + } } """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id", null); + + String productUseCase = "test-product-use-case"; + threadPool.getThreadContext().putHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, productUseCase); - var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(elasticInferenceServiceURL, "my-model-id"); PlainActionFuture> listener = new PlainActionFuture<>(); + // 2 inputs service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("input text")), + List.of(new ChunkInferenceInput("hello world"), new ChunkInferenceInput("dense embedding")), new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -769,32 +669,123 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { ); var results = listener.actionGet(TIMEOUT); - assertThat(results.get(0), instanceOf(ChunkedInferenceEmbedding.class)); - var sparseResult = (ChunkedInferenceEmbedding) results.get(0); - assertThat( - sparseResult.chunks(), - is( - List.of( - new EmbeddingResults.Chunk( - new SparseEmbeddingResults.Embedding( - List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)), - false - ), - new ChunkedInference.TextOffset(0, "input text".length()) - ) - ) - ) + assertThat(results, hasSize(2)); + + // Verify the response was processed correctly + ChunkedInference inferenceResult = results.getFirst(); + assertThat(inferenceResult, instanceOf(ChunkedInferenceEmbedding.class)); + + // Verify the request was sent and contains expected headers + assertThat(webServer.requests(), hasSize(1)); + var request = webServer.requests().getFirst(); + assertNull(request.getUri().getQuery()); + assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + + // Check that the product use case header was set correctly + assertThat(request.getHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER), is(productUseCase)); + + } finally { + // Clean up the thread context + threadPool.getThreadContext().stashContext(); + } + } + + public void testChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOException { + var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel( + getUrl(webServer), + "my-dense-model-id", + createRandomChunkingSettings() + ); + + testChunkedInfer_BatchesCalls(model); + } + + public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException { + var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id", null); + + testChunkedInfer_BatchesCalls(model); + } + + private void testChunkedInfer_BatchesCalls(ElasticInferenceServiceDenseTextEmbeddingsModel model) throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = createService(senderFactory, getUrl(webServer))) { + + // Batching will call the service with 2 inputs + String responseJson = """ + { + "data": [ + [ + 0.123, + -0.456, + 0.789 + ], + [ + 0.987, + -0.654, + 0.321 + ] + ], + "meta": { + "usage": { + "total_tokens": 10 + } + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture> listener = new PlainActionFuture<>(); + // 2 inputs + service.chunkedInfer( + model, + null, + List.of(new ChunkInferenceInput("hello world"), new ChunkInferenceInput("dense embedding")), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener ); + var results = listener.actionGet(TIMEOUT); + assertThat(results, hasSize(2)); + + // First result + { + assertThat(results.getFirst(), instanceOf(ChunkedInferenceEmbedding.class)); + var denseResult = (ChunkedInferenceEmbedding) results.getFirst(); + assertThat(denseResult.chunks(), hasSize(1)); + assertEquals(new ChunkedInference.TextOffset(0, "hello world".length()), denseResult.chunks().getFirst().offset()); + assertThat(denseResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + + var embedding = (TextEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding(); + assertArrayEquals(new float[] { 0.123f, -0.456f, 0.789f }, embedding.values(), 0.0f); + } + + // Second result + { + assertThat(results.get(1), instanceOf(ChunkedInferenceEmbedding.class)); + var denseResult = (ChunkedInferenceEmbedding) results.get(1); + assertThat(denseResult.chunks(), hasSize(1)); + assertEquals(new ChunkedInference.TextOffset(0, "dense embedding".length()), denseResult.chunks().getFirst().offset()); + assertThat(denseResult.chunks().getFirst().embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + + var embedding = (TextEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding(); + assertArrayEquals(new float[] { 0.987f, -0.654f, 0.321f }, embedding.values(), 0.0f); + } + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); + assertNull(webServer.requests().getFirst().getUri().getQuery()); MatcherAssert.assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()) ); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap, is(Map.of("input", List.of("input text"), "model", "my-model-id", "usage_context", "ingest"))); + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + MatcherAssert.assertThat( + requestMap, + is(Map.of("input", List.of("hello world", "dense embedding"), "model", "my-dense-model-id", "usage_context", "ingest")) + ); } } @@ -806,27 +797,6 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { } } - public void testHideFromConfigurationApi_ReturnsTrue_WithModelTaskTypesThatAreNotImplemented() throws Exception { - try ( - var service = createServiceWithMockSender( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-1", - EnumSet.of(TaskType.TEXT_EMBEDDING) - ) - ) - ) - ) - ) - ) { - ensureAuthorizationCallFinished(service); - - assertTrue(service.hideFromConfigurationApi()); - } - } - public void testHideFromConfigurationApi_ReturnsFalse_WithAvailableModels() throws Exception { try ( var service = createServiceWithMockSender( @@ -856,7 +826,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { List.of( new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( "model-1", - EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION) + EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING) ) ) ) @@ -869,7 +839,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { { "service": "elastic", "name": "Elastic", - "task_types": ["sparse_embedding", "chat_completion"], + "task_types": ["sparse_embedding", "chat_completion", "text_embedding"], "configurations": { "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", @@ -878,7 +848,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["sparse_embedding" , "chat_completion"] + "supported_task_types": ["text_embedding", "sparse_embedding" , "chat_completion"] }, "model_id": { "description": "The name of the model to use for the inference task.", @@ -887,7 +857,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["sparse_embedding" , "chat_completion"] + "supported_task_types": ["text_embedding", "sparse_embedding" , "chat_completion"] }, "max_input_tokens": { "description": "Allows you to specify the maximum number of tokens per input.", @@ -896,7 +866,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["sparse_embedding"] + "supported_task_types": ["text_embedding", "sparse_embedding"] } } } @@ -933,7 +903,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["sparse_embedding" , "chat_completion"] + "supported_task_types": ["text_embedding", "sparse_embedding" , "chat_completion"] }, "model_id": { "description": "The name of the model to use for the inference task.", @@ -942,7 +912,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["sparse_embedding" , "chat_completion"] + "supported_task_types": ["text_embedding", "sparse_embedding" , "chat_completion"] }, "max_input_tokens": { "description": "Allows you to specify the maximum number of tokens per input.", @@ -951,7 +921,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["sparse_embedding"] + "supported_task_types": ["text_embedding", "sparse_embedding"] } } } @@ -993,7 +963,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { { "service": "elastic", "name": "Elastic", - "task_types": [], + "task_types": ["text_embedding"], "configurations": { "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", @@ -1002,7 +972,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["sparse_embedding" , "chat_completion"] + "supported_task_types": ["text_embedding" , "sparse_embedding", "chat_completion"] }, "model_id": { "description": "The name of the model to use for the inference task.", @@ -1011,7 +981,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["sparse_embedding" , "chat_completion"] + "supported_task_types": ["text_embedding" , "sparse_embedding", "chat_completion"] }, "max_input_tokens": { "description": "Allows you to specify the maximum number of tokens per input.", @@ -1020,7 +990,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["sparse_embedding"] + "supported_task_types": ["text_embedding", "sparse_embedding"] } } } @@ -1197,6 +1167,10 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { { "model_name": "elser-v2", "task_types": ["embed/text/sparse"] + }, + { + "model_name": "multilingual-embed", + "task_types": ["embed/text/dense"] } ] } @@ -1218,6 +1192,16 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), service ), + new InferenceService.DefaultConfigId( + ".multilingual-embed-elastic", + MinimalServiceSettings.textEmbedding( + ElasticInferenceService.NAME, + ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, + ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), + DenseVectorFieldMapper.ElementType.FLOAT + ), + service + ), new InferenceService.DefaultConfigId( ".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), @@ -1226,14 +1210,18 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { ) ) ); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING))); + assertThat( + service.supportedTaskTypes(), + is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING)) + ); PlainActionFuture> listener = new PlainActionFuture<>(); service.defaultConfigs(listener); var models = listener.actionGet(TIMEOUT); - assertThat(models.size(), is(2)); + assertThat(models.size(), is(3)); assertThat(models.get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic")); - assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); + assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".multilingual-embed-elastic")); + assertThat(models.get(2).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModelTests.java new file mode 100644 index 000000000000..753e5bb17303 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModelTests.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.densetextembeddings; + +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +public class ElasticInferenceServiceDenseTextEmbeddingsModelTests { + + public static ElasticInferenceServiceDenseTextEmbeddingsModel createModel( + String url, + String modelId, + ChunkingSettings chunkingSettings + ) { + return new ElasticInferenceServiceDenseTextEmbeddingsModel( + "id", + TaskType.TEXT_EMBEDDING, + "elastic", + new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + modelId, + SimilarityMeasure.COSINE, + null, + null, + false, + new RateLimitSettings(1000L) + ), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + ElasticInferenceServiceComponents.of(url) + ); + } + +}