diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 0ec58df6f58d..cc63063ecf9d 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -206,6 +206,7 @@ public class TransportVersions { public static final TransportVersion RANDOM_SAMPLER_QUERY_BUILDER_8_19 = def(8_841_0_56); public static final TransportVersion ML_INFERENCE_SAGEMAKER_ELASTIC_8_19 = def(8_841_0_57); public static final TransportVersion SPARSE_VECTOR_FIELD_PRUNING_OPTIONS_8_19 = def(8_841_0_58); + public static final TransportVersion ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED_8_19 = def(8_841_0_59); public static final TransportVersion V_9_0_0 = def(9_000_0_09); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11); @@ -318,6 +319,7 @@ public class TransportVersions { public static final TransportVersion ML_INFERENCE_SAGEMAKER_ELASTIC = def(9_106_0_00); public static final TransportVersion SPARSE_VECTOR_FIELD_PRUNING_OPTIONS = def(9_107_0_00); public static final TransportVersion CLUSTER_STATE_PROJECTS_SETTINGS = def(9_108_0_00); + public static final TransportVersion ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED = def(9_109_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java index 42c1237cdcdb..e548fcc4f2eb 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java @@ -33,7 +33,7 @@ public class InferenceGetModelsWithElasticInferenceServiceIT extends BaseMockEIS var allModels = getAllModels(); var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION); - assertThat(allModels, hasSize(6)); + assertThat(allModels, hasSize(7)); assertThat(chatCompletionModels, hasSize(1)); for (var model : chatCompletionModels) { @@ -42,6 +42,7 @@ public class InferenceGetModelsWithElasticInferenceServiceIT extends BaseMockEIS assertInferenceIdTaskType(allModels, ".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION); assertInferenceIdTaskType(allModels, ".elser-v2-elastic", TaskType.SPARSE_EMBEDDING); + assertInferenceIdTaskType(allModels, ".multilingual-embed-v1-elastic", TaskType.TEXT_EMBEDDING); assertInferenceIdTaskType(allModels, ".rerank-v1-elastic", TaskType.RERANK); } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index 1f2047d500b5..b96c94db438a 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -20,6 +20,7 @@ import java.util.Map; import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated; import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.equalTo; public class InferenceGetServicesIT extends BaseMockEISAuthServerTest { @@ -76,16 +77,21 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest { } public void testGetServicesWithTextEmbeddingTaskType() throws IOException { + List services = getServices(TaskType.TEXT_EMBEDDING); + assertThat(services.size(), equalTo(18)); + assertThat( providersFor(TaskType.TEXT_EMBEDDING), containsInAnyOrder( List.of( "alibabacloud-ai-search", "amazonbedrock", + "amazon_sagemaker", "azureaistudio", "azureopenai", "cohere", "custom", + "elastic", "elasticsearch", "googleaistudio", "googlevertexai", @@ -95,8 +101,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest { "openai", "text_embedding_test_service", "voyageai", - "watsonxai", - "amazon_sagemaker" + "watsonxai" ).toArray() ) ); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java index 4bdf9aa40b2c..7f0212167f8a 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java @@ -43,6 +43,10 @@ public class MockElasticInferenceServiceAuthorizationServer implements TestRule "task_types": ["embed/text/sparse"] }, { + "model_name": "multilingual-embed-v1", + "task_types": ["embed/text/dense"] + }, + { "model_name": "rerank-v1", "task_types": ["rerank/text/text-similarity"] } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java index 3f6c66151cf4..4c200c6f2024 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java @@ -11,6 +11,7 @@ import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; @@ -43,6 +44,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.mockito.Mockito.mock; public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase { @@ -190,13 +192,17 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase { String responseJson = """ { "models": [ + { + "model_name": "elser-v2", + "task_types": ["embed/text/sparse"] + }, { "model_name": "rainbow-sprinkles", "task_types": ["chat"] }, { - "model_name": "elser-v2", - "task_types": ["embed/text/sparse"] + "model_name": "multilingual-embed-v1", + "task_types": ["embed/text/dense"] }, { "model_name": "rerank-v1", @@ -214,36 +220,48 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); assertThat( service.defaultConfigIds(), - is( - List.of( - new InferenceService.DefaultConfigId( - ".elser-v2-elastic", - MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), - service + containsInAnyOrder( + new InferenceService.DefaultConfigId( + ".elser-v2-elastic", + MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), + service + ), + new InferenceService.DefaultConfigId( + ".rainbow-sprinkles-elastic", + MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), + service + ), + new InferenceService.DefaultConfigId( + ".multilingual-embed-v1-elastic", + MinimalServiceSettings.textEmbedding( + ElasticInferenceService.NAME, + ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, + ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), + DenseVectorFieldMapper.ElementType.FLOAT ), - new InferenceService.DefaultConfigId( - ".rainbow-sprinkles-elastic", - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), - service - ), - new InferenceService.DefaultConfigId( - ".rerank-v1-elastic", - MinimalServiceSettings.rerank(ElasticInferenceService.NAME), - service - ) + service + ), + new InferenceService.DefaultConfigId( + ".rerank-v1-elastic", + MinimalServiceSettings.rerank(ElasticInferenceService.NAME), + service ) ) ); assertThat( service.supportedTaskTypes(), - is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK)) + is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK, TaskType.TEXT_EMBEDDING)) ); PlainActionFuture> listener = new PlainActionFuture<>(); service.defaultConfigs(listener); assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic")); - assertThat(listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); - assertThat(listener.actionGet(TIMEOUT).get(2).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic")); + assertThat( + listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), + is(".multilingual-embed-v1-elastic") + ); + assertThat(listener.actionGet(TIMEOUT).get(2).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); + assertThat(listener.actionGet(TIMEOUT).get(3).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic")); var getModelListener = new PlainActionFuture(); // persists the default endpoints @@ -265,6 +283,10 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase { { "model_name": "rerank-v1", "task_types": ["rerank/text/text-similarity"] + }, + { + "model_name": "multilingual-embed-v1", + "task_types": ["embed/text/dense"] } ] } @@ -278,22 +300,33 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase { assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); assertThat( service.defaultConfigIds(), - is( - List.of( - new InferenceService.DefaultConfigId( - ".elser-v2-elastic", - MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), - service + containsInAnyOrder( + new InferenceService.DefaultConfigId( + ".elser-v2-elastic", + MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), + service + ), + new InferenceService.DefaultConfigId( + ".multilingual-embed-v1-elastic", + MinimalServiceSettings.textEmbedding( + ElasticInferenceService.NAME, + ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, + ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), + DenseVectorFieldMapper.ElementType.FLOAT ), - new InferenceService.DefaultConfigId( - ".rerank-v1-elastic", - MinimalServiceSettings.rerank(ElasticInferenceService.NAME), - service - ) + service + ), + new InferenceService.DefaultConfigId( + ".rerank-v1-elastic", + MinimalServiceSettings.rerank(ElasticInferenceService.NAME), + service ) ) ); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.RERANK))); + assertThat( + service.supportedTaskTypes(), + is(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK)) + ); var getModelListener = new PlainActionFuture(); modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.java new file mode 100644 index 000000000000..a96ebc0048f7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/elastic/ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.java @@ -0,0 +1,103 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.elastic; + +import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; + +public class ElasticInferenceServiceDenseTextEmbeddingsResponseEntity { + + /** + * Parses the Elastic Inference Service Dense Text Embeddings response. + * + * For a request like: + * + *
+     *     
+     *         {
+     *             "inputs": ["Embed this text", "Embed this text, too"]
+     *         }
+     *     
+     * 
+ * + * The response would look like: + * + *
+     *     
+     *         {
+     *             "data": [
+     *                  [
+     *                      2.1259406,
+     *                      1.7073475,
+     *                      0.9020516
+     *                  ],
+     *                  (...)
+     *             ],
+     *             "meta": {
+     *                  "usage": {...}
+     *             }
+     *         }
+     *     
+     * 
+ */ + public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException { + try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) { + return EmbeddingFloatResult.PARSER.apply(p, null).toTextEmbeddingFloatResults(); + } + } + + public record EmbeddingFloatResult(List embeddingResults) { + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + EmbeddingFloatResult.class.getSimpleName(), + true, + args -> new EmbeddingFloatResult((List) args[0]) + ); + + static { + // Custom field declaration to handle array of arrays format + PARSER.declareField(constructorArg(), (parser, context) -> { + return XContentParserUtils.parseList(parser, (p, index) -> { + List embedding = XContentParserUtils.parseList(p, (innerParser, innerIndex) -> innerParser.floatValue()); + return EmbeddingFloatResultEntry.fromFloatArray(embedding); + }); + }, new ParseField("data"), org.elasticsearch.xcontent.ObjectParser.ValueType.OBJECT_ARRAY); + } + + public TextEmbeddingFloatResults toTextEmbeddingFloatResults() { + return new TextEmbeddingFloatResults( + embeddingResults.stream().map(entry -> TextEmbeddingFloatResults.Embedding.of(entry.embedding)).toList() + ); + } + } + + /** + * Represents a single embedding entry in the response. + * For the Elastic Inference Service, each entry is just an array of floats (no wrapper object). + * This is a simpler wrapper that just holds the float array. + */ + public record EmbeddingFloatResultEntry(List embedding) { + public static EmbeddingFloatResultEntry fromFloatArray(List floats) { + return new EmbeddingFloatResultEntry(floats); + } + } + + private ElasticInferenceServiceDenseTextEmbeddingsResponseEntity() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index e6b27b6a641c..640929b05876 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySecretSettings; @@ -28,6 +29,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; @@ -54,6 +56,8 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticI import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; @@ -84,15 +88,20 @@ public class ElasticInferenceService extends SenderService { public static final String NAME = "elastic"; public static final String ELASTIC_INFERENCE_SERVICE_IDENTIFIER = "Elastic Inference Service"; - public static final int SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE = 512; + public static final Integer DENSE_TEXT_EMBEDDINGS_DIMENSIONS = 1024; + public static final Integer SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE = 512; private static final EnumSet IMPLEMENTED_TASK_TYPES = EnumSet.of( TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, - TaskType.RERANK + TaskType.RERANK, + TaskType.TEXT_EMBEDDING ); private static final String SERVICE_NAME = "Elastic"; + // TODO: check with team, what makes the most sense + private static final Integer DENSE_TEXT_EMBEDDINGS_MAX_BATCH_SIZE = 32; + // rainbow-sprinkles static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles"; static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); @@ -101,6 +110,10 @@ public class ElasticInferenceService extends SenderService { static final String DEFAULT_ELSER_MODEL_ID_V2 = "elser-v2"; static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId(DEFAULT_ELSER_MODEL_ID_V2); + // multilingual-text-embed + static final String DEFAULT_MULTILINGUAL_EMBED_MODEL_ID = "multilingual-embed-v1"; + static final String DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID = defaultEndpointId(DEFAULT_MULTILINGUAL_EMBED_MODEL_ID); + // rerank-v1 static final String DEFAULT_RERANK_MODEL_ID_V1 = "rerank-v1"; static final String DEFAULT_RERANK_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_RERANK_MODEL_ID_V1); @@ -108,7 +121,11 @@ public class ElasticInferenceService extends SenderService { /** * The task types that the {@link InferenceAction.Request} can accept. */ - private static final EnumSet SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.RERANK); + private static final EnumSet SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of( + TaskType.SPARSE_EMBEDDING, + TaskType.RERANK, + TaskType.TEXT_EMBEDDING + ); public static String defaultEndpointId(String modelId) { return Strings.format(".%s-elastic", modelId); @@ -171,6 +188,32 @@ public class ElasticInferenceService extends SenderService { ), MinimalServiceSettings.sparseEmbedding(NAME) ), + DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, + new DefaultModelConfig( + new ElasticInferenceServiceDenseTextEmbeddingsModel( + DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, + TaskType.TEXT_EMBEDDING, + NAME, + new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, + defaultDenseTextEmbeddingsSimilarity(), + null, + null, + ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS + ), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + elasticInferenceServiceComponents, + ChunkingSettingsBuilder.DEFAULT_SETTINGS + ), + MinimalServiceSettings.textEmbedding( + NAME, + DENSE_TEXT_EMBEDDINGS_DIMENSIONS, + defaultDenseTextEmbeddingsSimilarity(), + DenseVectorFieldMapper.ElementType.FLOAT + ) + ), + DEFAULT_RERANK_MODEL_ID_V1, new DefaultModelConfig( new ElasticInferenceServiceRerankModel( @@ -310,6 +353,23 @@ public class ElasticInferenceService extends SenderService { TimeValue timeout, ActionListener> listener ) { + if (model instanceof ElasticInferenceServiceDenseTextEmbeddingsModel denseTextEmbeddingsModel) { + var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), getCurrentTraceInfo()); + + List batchedRequests = new EmbeddingRequestChunker<>( + inputs.getInputs(), + DENSE_TEXT_EMBEDDINGS_MAX_BATCH_SIZE, + denseTextEmbeddingsModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + + for (var request : batchedRequests) { + var action = denseTextEmbeddingsModel.accept(actionCreator, taskSettings); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + } + + return; + } + if (model instanceof ElasticInferenceServiceSparseEmbeddingsModel sparseTextEmbeddingsModel) { var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), getCurrentTraceInfo()); @@ -348,7 +408,7 @@ public class ElasticInferenceService extends SenderService { Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); ChunkingSettings chunkingSettings = null; - if (TaskType.SPARSE_EMBEDDING.equals(taskType)) { + if (TaskType.SPARSE_EMBEDDING.equals(taskType) || TaskType.TEXT_EMBEDDING.equals(taskType)) { chunkingSettings = ChunkingSettingsBuilder.fromMap( removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) ); @@ -359,11 +419,11 @@ public class ElasticInferenceService extends SenderService { taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, serviceSettingsMap, elasticInferenceServiceComponents, TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), - ConfigurationParseContext.REQUEST, - chunkingSettings + ConfigurationParseContext.REQUEST ); throwIfNotEmptyMap(config, NAME); @@ -396,11 +456,11 @@ public class ElasticInferenceService extends SenderService { TaskType taskType, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secretSettings, ElasticInferenceServiceComponents elasticInferenceServiceComponents, String failureMessage, - ConfigurationParseContext context, - ChunkingSettings chunkingSettings + ConfigurationParseContext context ) { return switch (taskType) { case SPARSE_EMBEDDING -> new ElasticInferenceServiceSparseEmbeddingsModel( @@ -434,6 +494,17 @@ public class ElasticInferenceService extends SenderService { elasticInferenceServiceComponents, context ); + case TEXT_EMBEDDING -> new ElasticInferenceServiceDenseTextEmbeddingsModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings, + elasticInferenceServiceComponents, + context, + chunkingSettings + ); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); }; } @@ -450,7 +521,7 @@ public class ElasticInferenceService extends SenderService { Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); ChunkingSettings chunkingSettings = null; - if (TaskType.SPARSE_EMBEDDING.equals(taskType)) { + if (TaskType.SPARSE_EMBEDDING.equals(taskType) || TaskType.TEXT_EMBEDDING.equals(taskType)) { chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); } @@ -459,9 +530,9 @@ public class ElasticInferenceService extends SenderService { taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, secretSettingsMap, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME), - chunkingSettings + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) ); } @@ -471,7 +542,7 @@ public class ElasticInferenceService extends SenderService { Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); ChunkingSettings chunkingSettings = null; - if (TaskType.SPARSE_EMBEDDING.equals(taskType)) { + if (TaskType.SPARSE_EMBEDDING.equals(taskType) || TaskType.TEXT_EMBEDDING.equals(taskType)) { chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); } @@ -480,9 +551,9 @@ public class ElasticInferenceService extends SenderService { taskType, serviceSettingsMap, taskSettingsMap, + chunkingSettings, null, - parsePersistedConfigErrorMsg(inferenceEntityId, NAME), - chunkingSettings + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) ); } @@ -496,23 +567,51 @@ public class ElasticInferenceService extends SenderService { TaskType taskType, Map serviceSettings, Map taskSettings, + ChunkingSettings chunkingSettings, @Nullable Map secretSettings, - String failureMessage, - ChunkingSettings chunkingSettings + String failureMessage ) { return createModel( inferenceEntityId, taskType, serviceSettings, taskSettings, + chunkingSettings, secretSettings, elasticInferenceServiceComponents, failureMessage, - ConfigurationParseContext.PERSISTENT, - chunkingSettings + ConfigurationParseContext.PERSISTENT ); } + @Override + public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { + if (model instanceof ElasticInferenceServiceDenseTextEmbeddingsModel embeddingsModel) { + var serviceSettings = embeddingsModel.getServiceSettings(); + var modelId = serviceSettings.modelId(); + var similarityFromModel = serviceSettings.similarity(); + var similarityToUse = similarityFromModel == null ? defaultDenseTextEmbeddingsSimilarity() : similarityFromModel; + var maxInputTokens = serviceSettings.maxInputTokens(); + + var updateServiceSettings = new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + modelId, + similarityToUse, + embeddingSize, + maxInputTokens, + serviceSettings.rateLimitSettings() + ); + + return new ElasticInferenceServiceDenseTextEmbeddingsModel(embeddingsModel, updateServiceSettings); + } else { + throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass()); + } + } + + public static SimilarityMeasure defaultDenseTextEmbeddingsSimilarity() { + // TODO: double-check + return SimilarityMeasure.COSINE; + } + private static List translateToChunkedResults(InferenceInputs inputs, InferenceServiceResults inferenceResults) { if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) { var inputsAsList = EmbeddingsInput.of(inputs).getStringInputs(); @@ -550,8 +649,9 @@ public class ElasticInferenceService extends SenderService { configurationMap.put( MODEL_ID, - new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK)) - .setDescription("The name of the model to use for the inference task.") + new SettingsConfiguration.Builder( + EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK, TaskType.TEXT_EMBEDDING) + ).setDescription("The name of the model to use for the inference task.") .setLabel("Model ID") .setRequired(true) .setSensitive(false) @@ -562,7 +662,7 @@ public class ElasticInferenceService extends SenderService { configurationMap.put( MAX_INPUT_TOKENS, - new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING)).setDescription( + new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING)).setDescription( "Allows you to specify the maximum number of tokens per input." ) .setLabel("Maximum Input Tokens") @@ -575,7 +675,7 @@ public class ElasticInferenceService extends SenderService { configurationMap.putAll( RateLimitSettings.toSettingsConfiguration( - EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK) + EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK, TaskType.TEXT_EMBEDDING) ) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java index c3b406266019..8b987cd53bc8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java @@ -11,14 +11,18 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; -import org.elasticsearch.xpack.inference.external.request.elastic.rerank.ElasticInferenceServiceRerankRequest; +import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceDenseTextEmbeddingsResponseEntity; import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceRerankResponseEntity; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceResponseHandler; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceDenseTextEmbeddingsRequest; +import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRerankRequest; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.telemetry.TraceContext; @@ -31,17 +35,22 @@ import static org.elasticsearch.xpack.inference.services.elastic.request.Elastic public class ElasticInferenceServiceActionCreator implements ElasticInferenceServiceActionVisitor { - private final Sender sender; - - private final ServiceComponents serviceComponents; - - private final TraceContext traceContext; + static final ResponseHandler DENSE_TEXT_EMBEDDINGS_HANDLER = new ElasticInferenceServiceResponseHandler( + "elastic dense text embedding", + ElasticInferenceServiceDenseTextEmbeddingsResponseEntity::fromResponse + ); static final ResponseHandler RERANK_HANDLER = new ElasticInferenceServiceResponseHandler( "elastic rerank", (request, response) -> ElasticInferenceServiceRerankResponseEntity.fromResponse(response) ); + private final Sender sender; + + private final ServiceComponents serviceComponents; + + private final TraceContext traceContext; + public ElasticInferenceServiceActionCreator(Sender sender, ServiceComponents serviceComponents, TraceContext traceContext) { this.sender = Objects.requireNonNull(sender); this.serviceComponents = Objects.requireNonNull(serviceComponents); @@ -77,4 +86,26 @@ public class ElasticInferenceServiceActionCreator implements ElasticInferenceSer var errorMessage = constructFailedToSendRequestMessage(Strings.format("%s rerank", ELASTIC_INFERENCE_SERVICE_IDENTIFIER)); return new SenderExecutableAction(sender, requestManager, errorMessage); } + + @Override + public ExecutableAction create(ElasticInferenceServiceDenseTextEmbeddingsModel model) { + var threadPool = serviceComponents.threadPool(); + + var manager = new GenericRequestManager<>( + threadPool, + model, + DENSE_TEXT_EMBEDDINGS_HANDLER, + (embeddingsInput) -> new ElasticInferenceServiceDenseTextEmbeddingsRequest( + model, + embeddingsInput.getStringInputs(), + traceContext, + extractRequestMetadataFromThreadContext(threadPool.getThreadContext()), + embeddingsInput.getInputType() + ), + EmbeddingsInput.class + ); + + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Elastic dense text embeddings"); + return new SenderExecutableAction(sender, manager, failedToSendRequestErrorMessage); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java index 2bfb6d9f1222..4f8a9c9ec20a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionVisitor.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.elastic.action; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; @@ -17,4 +18,5 @@ public interface ElasticInferenceServiceActionVisitor { ExecutableAction create(ElasticInferenceServiceRerankModel model); + ExecutableAction create(ElasticInferenceServiceDenseTextEmbeddingsModel model); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModel.java new file mode 100644 index 000000000000..dfbfaf47e2d2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModel.java @@ -0,0 +1,116 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.densetextembeddings; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceExecutableActionModel; +import org.elasticsearch.xpack.inference.services.elastic.action.ElasticInferenceServiceActionVisitor; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; + +public class ElasticInferenceServiceDenseTextEmbeddingsModel extends ElasticInferenceServiceExecutableActionModel { + + private final URI uri; + + public ElasticInferenceServiceDenseTextEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + Map secrets, + ElasticInferenceServiceComponents elasticInferenceServiceComponents, + ConfigurationParseContext context, + ChunkingSettings chunkingSettings + ) { + this( + inferenceEntityId, + taskType, + service, + ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.fromMap(serviceSettings, context), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + elasticInferenceServiceComponents, + chunkingSettings + ); + } + + public ElasticInferenceServiceDenseTextEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + ElasticInferenceServiceDenseTextEmbeddingsServiceSettings serviceSettings, + @Nullable TaskSettings taskSettings, + @Nullable SecretSettings secretSettings, + ElasticInferenceServiceComponents elasticInferenceServiceComponents, + ChunkingSettings chunkingSettings + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings), + new ModelSecrets(secretSettings), + serviceSettings, + elasticInferenceServiceComponents + ); + this.uri = createUri(); + } + + public ElasticInferenceServiceDenseTextEmbeddingsModel( + ElasticInferenceServiceDenseTextEmbeddingsModel model, + ElasticInferenceServiceDenseTextEmbeddingsServiceSettings serviceSettings + ) { + super(model, serviceSettings); + this.uri = createUri(); + } + + @Override + public ExecutableAction accept(ElasticInferenceServiceActionVisitor visitor, Map taskSettings) { + return visitor.create(this); + } + + @Override + public ElasticInferenceServiceDenseTextEmbeddingsServiceSettings getServiceSettings() { + return (ElasticInferenceServiceDenseTextEmbeddingsServiceSettings) super.getServiceSettings(); + } + + public URI uri() { + return uri; + } + + private URI createUri() throws ElasticsearchStatusException { + try { + // TODO, consider transforming the base URL into a URI for better error handling. + return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/embed/text/dense"); + } catch (URISyntaxException e) { + throw new ElasticsearchStatusException( + "Failed to create URI for service [" + + this.getConfigurations().getService() + + "] with taskType [" + + this.getTaskType() + + "]: " + + e.getMessage(), + RestStatus.BAD_REQUEST, + e + ); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java new file mode 100644 index 000000000000..5047f34a1b2e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java @@ -0,0 +1,236 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.densetextembeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; + +public class ElasticInferenceServiceDenseTextEmbeddingsServiceSettings extends FilteredXContentObject + implements + ServiceSettings, + ElasticInferenceServiceRateLimitServiceSettings { + + public static final String NAME = "elastic_inference_service_dense_embeddings_service_settings"; + + public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000); + + private final String modelId; + private final SimilarityMeasure similarity; + private final Integer dimensions; + private final Integer maxInputTokens; + private final RateLimitSettings rateLimitSettings; + + public static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings fromMap( + Map map, + ConfigurationParseContext context + ) { + return switch (context) { + case REQUEST -> fromRequestMap(map, context); + case PERSISTENT -> fromPersistentMap(map, context); + }; + } + + private static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings fromRequestMap( + Map map, + ConfigurationParseContext context + ) { + ValidationException validationException = new ValidationException(); + + String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + ElasticInferenceService.NAME, + context + ); + + SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer dims = removeAsType(map, DIMENSIONS, Integer.class); + Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(modelId, similarity, dims, maxInputTokens, rateLimitSettings); + } + + private static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings fromPersistentMap( + Map map, + ConfigurationParseContext context + ) { + ValidationException validationException = new ValidationException(); + + String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + ElasticInferenceService.NAME, + context + ); + + SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer dims = removeAsType(map, DIMENSIONS, Integer.class); + Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(modelId, similarity, dims, maxInputTokens, rateLimitSettings); + } + + public ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + String modelId, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens, + RateLimitSettings rateLimitSettings + ) { + this.modelId = modelId; + this.similarity = similarity; + this.dimensions = dimensions; + this.maxInputTokens = maxInputTokens; + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + public ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(StreamInput in) throws IOException { + this.modelId = in.readString(); + this.similarity = in.readOptionalEnum(SimilarityMeasure.class); + this.dimensions = in.readOptionalVInt(); + this.maxInputTokens = in.readOptionalVInt(); + this.rateLimitSettings = new RateLimitSettings(in); + } + + @Override + public SimilarityMeasure similarity() { + return similarity; + } + + @Override + public Integer dimensions() { + return dimensions; + } + + public Integer maxInputTokens() { + return maxInputTokens; + } + + @Override + public String modelId() { + return modelId; + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.FLOAT; + } + + public RateLimitSettings getRateLimitSettings() { + return rateLimitSettings; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_ID, modelId); + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + if (similarity != null) { + builder.field(SIMILARITY, similarity); + } + + if (dimensions != null) { + builder.field(DIMENSIONS, dimensions); + } + + if (maxInputTokens != null) { + builder.field(MAX_INPUT_TOKENS, maxInputTokens); + } + + toXContentFragmentOfExposedFields(builder, params); + + builder.endObject(); + return builder; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion())); + out.writeOptionalVInt(dimensions); + out.writeOptionalVInt(maxInputTokens); + rateLimitSettings.writeTo(out); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ElasticInferenceServiceDenseTextEmbeddingsServiceSettings that = (ElasticInferenceServiceDenseTextEmbeddingsServiceSettings) o; + return Objects.equals(modelId, that.modelId) + && similarity == that.similarity + && Objects.equals(dimensions, that.dimensions) + && Objects.equals(maxInputTokens, that.maxInputTokens) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, similarity, dimensions, maxInputTokens, rateLimitSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequest.java new file mode 100644 index 000000000000..8a873504ee12 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequest.java @@ -0,0 +1,93 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.request; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpRequestBase; +import org.apache.http.entity.ByteArrayEntity; +import org.apache.http.message.BasicHeader; +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; +import org.elasticsearch.xpack.inference.telemetry.TraceContext; +import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceSparseEmbeddingsRequest.inputTypeToUsageContext; + +public class ElasticInferenceServiceDenseTextEmbeddingsRequest extends ElasticInferenceServiceRequest { + + private final URI uri; + private final ElasticInferenceServiceDenseTextEmbeddingsModel model; + private final List inputs; + private final TraceContextHandler traceContextHandler; + private final InputType inputType; + + public ElasticInferenceServiceDenseTextEmbeddingsRequest( + ElasticInferenceServiceDenseTextEmbeddingsModel model, + List inputs, + TraceContext traceContext, + ElasticInferenceServiceRequestMetadata metadata, + InputType inputType + ) { + super(metadata); + this.inputs = inputs; + this.model = Objects.requireNonNull(model); + this.uri = model.uri(); + this.traceContextHandler = new TraceContextHandler(traceContext); + this.inputType = inputType; + } + + @Override + public HttpRequestBase createHttpRequestBase() { + var httpPost = new HttpPost(uri); + var usageContext = inputTypeToUsageContext(inputType); + var requestEntity = Strings.toString( + new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(inputs, model.getServiceSettings().modelId(), usageContext) + ); + + ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8)); + httpPost.setEntity(byteEntity); + + traceContextHandler.propagateTraceContext(httpPost); + httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType())); + + return httpPost; + } + + public TraceContext getTraceContext() { + return traceContextHandler.traceContext(); + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public URI getURI() { + return this.uri; + } + + @Override + public Request truncate() { + return this; + } + + @Override + public boolean[] getTruncationInfo() { + return null; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestEntity.java new file mode 100644 index 000000000000..6d7862f83cb6 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestEntity.java @@ -0,0 +1,57 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.request; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceUsageContext; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public record ElasticInferenceServiceDenseTextEmbeddingsRequestEntity( + List inputs, + String modelId, + @Nullable ElasticInferenceServiceUsageContext usageContext +) implements ToXContentObject { + + private static final String INPUT_FIELD = "input"; + private static final String MODEL_FIELD = "model"; + private static final String USAGE_CONTEXT = "usage_context"; + + public ElasticInferenceServiceDenseTextEmbeddingsRequestEntity { + Objects.requireNonNull(inputs); + Objects.requireNonNull(modelId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.startArray(INPUT_FIELD); + + for (String input : inputs) { + builder.value(input); + } + + builder.endArray(); + + builder.field(MODEL_FIELD, modelId); + + // optional field + if (Objects.nonNull(usageContext) && usageContext != ElasticInferenceServiceUsageContext.UNSPECIFIED) { + builder.field(USAGE_CONTEXT, usageContext); + } + + builder.endObject(); + + return builder; + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/rerank/ElasticInferenceServiceRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceRerankRequest.java similarity index 91% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/rerank/ElasticInferenceServiceRerankRequest.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceRerankRequest.java index 08b3fd238464..63b26f2a1223 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/rerank/ElasticInferenceServiceRerankRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceRerankRequest.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.external.request.elastic.rerank; +package org.elasticsearch.xpack.inference.services.elastic.request; import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; @@ -15,8 +15,6 @@ import org.apache.http.message.BasicHeader; import org.elasticsearch.common.Strings; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.external.request.Request; -import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequest; -import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestMetadata; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/rerank/ElasticInferenceServiceRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceRerankRequestEntity.java similarity index 95% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/rerank/ElasticInferenceServiceRerankRequestEntity.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceRerankRequestEntity.java index b542af93047f..1e21b6f7d8ee 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/rerank/ElasticInferenceServiceRerankRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceRerankRequestEntity.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.external.request.elastic.rerank; +package org.elasticsearch.xpack.inference.services.elastic.request; import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ToXContentObject; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java index bb34ac202bd5..451c601e7cc9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java @@ -44,6 +44,8 @@ public class ElasticInferenceServiceAuthorizationResponseEntity implements Infer TaskType.SPARSE_EMBEDDING, "chat", TaskType.CHAT_COMPLETION, + "embed/text/dense", + TaskType.TEXT_EMBEDDING, "rerank/text/text-similarity", TaskType.RERANK ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceSparseEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceSparseEmbeddingsResponseEntity.java index 8c6149b9e5a2..7eb2e50372c5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceSparseEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceSparseEmbeddingsResponseEntity.java @@ -51,11 +51,11 @@ public class ElasticInferenceServiceSparseEmbeddingsResponseEntity { * * { * "data": [ - * { - * "Embed": 2.1259406, - * "this": 1.7073475, - * "text": 0.9020516 - * }, + * [ + * 2.1259406, + * 1.7073475, + * 0.9020516 + * ], * (...) * ], * "meta": { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestEntityTests.java index 407d3e38b4da..a484c690b260 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestEntityTests.java @@ -12,7 +12,7 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.external.request.elastic.rerank.ElasticInferenceServiceRerankRequestEntity; +import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRerankRequestEntity; import java.io.IOException; import java.util.List; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestTests.java index 4e6efed6faa5..58a357684961 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestTests.java @@ -10,7 +10,7 @@ package org.elasticsearch.xpack.inference.external.request.elastic; import org.apache.http.client.methods.HttpPost; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.external.request.elastic.rerank.ElasticInferenceServiceRerankRequest; +import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRerankRequest; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModelTests; import org.elasticsearch.xpack.inference.telemetry.TraceContext; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 9ff4a04add8b..236a6be3d742 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -17,6 +17,7 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySecretSettings; @@ -29,7 +30,6 @@ import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; -import org.elasticsearch.inference.WeightedToken; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.test.ESSingleNodeTestCase; import org.elasticsearch.test.http.MockResponse; @@ -40,9 +40,8 @@ import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.EmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; @@ -59,6 +58,7 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticI import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModelTests; import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity; @@ -421,47 +421,6 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { } } - public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException { - var sender = mock(Sender.class); - - var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender()).thenReturn(sender); - - var mockModel = getInvalidModel("model_id", "service_name", TaskType.TEXT_EMBEDDING); - - try (var service = createService(factory)) { - PlainActionFuture listener = new PlainActionFuture<>(); - service.infer( - mockModel, - null, - null, - null, - List.of(""), - false, - new HashMap<>(), - InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); - - var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); - MatcherAssert.assertThat( - thrownException.getMessage(), - is( - "Inference entity [model_id] does not support task type [text_embedding] " - + "for inference, the task type must be one of [sparse_embedding, rerank]." - ) - ); - - verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); - } - - verify(sender, times(1)).close(); - verifyNoMoreInteractions(factory); - verifyNoMoreInteractions(sender); - } - public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws IOException { var sender = mock(Sender.class); @@ -490,7 +449,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { thrownException.getMessage(), is( "Inference entity [model_id] does not support task type [chat_completion] " - + "for inference, the task type must be one of [sparse_embedding, rerank]. " + + "for inference, the task type must be one of [text_embedding, sparse_embedding, rerank]. " + "The task type for the inference entity is chat_completion, " + "please use the _inference/chat_completion/model_id/_stream URL." ) @@ -701,82 +660,6 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { } } - public void testChunkedInfer_PropagatesProductUseCaseHeader() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var elasticInferenceServiceURL = getUrl(webServer); - - try (var service = createService(senderFactory, elasticInferenceServiceURL)) { - String responseJson = """ - { - "data": [ - { - "hello": 2.1259406, - "greet": 1.7073475 - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - // Set up the product use case in the thread context - String productUseCase = "test-product-use-case"; - threadPool.getThreadContext().putHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, productUseCase); - - var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(elasticInferenceServiceURL, "my-model-id"); - PlainActionFuture> listener = new PlainActionFuture<>(); - - try { - service.chunkedInfer( - model, - null, - List.of(new ChunkInferenceInput("input text")), - new HashMap<>(), - InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); - - var results = listener.actionGet(TIMEOUT); - - // Verify the response was processed correctly - ChunkedInference inferenceResult = results.getFirst(); - assertThat(inferenceResult, instanceOf(ChunkedInferenceEmbedding.class)); - var sparseResult = (ChunkedInferenceEmbedding) inferenceResult; - assertThat( - sparseResult.chunks(), - is( - List.of( - new EmbeddingResults.Chunk( - new SparseEmbeddingResults.Embedding( - List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)), - false - ), - new ChunkedInference.TextOffset(0, "input text".length()) - ) - ) - ) - ); - - // Verify the request was sent and contains expected headers - MatcherAssert.assertThat(webServer.requests(), hasSize(1)); - var request = webServer.requests().getFirst(); - assertNull(request.getUri().getQuery()); - MatcherAssert.assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); - - // Check that the product use case header was set correctly - assertThat(request.getHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER), is(productUseCase)); - - // Verify request body - var requestMap = entityAsMap(request.getBody()); - assertThat(requestMap, is(Map.of("input", List.of("input text"), "model", "my-model-id", "usage_context", "ingest"))); - } finally { - // Clean up the thread context - threadPool.getThreadContext().stashContext(); - } - } - } - public void testUnifiedCompletionInfer_PropagatesProductUseCaseHeader() throws IOException { var elasticInferenceServiceURL = getUrl(webServer); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); @@ -835,30 +718,45 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { } } - public void testChunkedInfer() throws IOException { + public void testChunkedInfer_PropagatesProductUseCaseHeader() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var elasticInferenceServiceURL = getUrl(webServer); - try (var service = createService(senderFactory, elasticInferenceServiceURL)) { + try (var service = createService(senderFactory, getUrl(webServer))) { + + // Batching will call the service with 2 inputs String responseJson = """ { "data": [ - { - "hello": 2.1259406, - "greet": 1.7073475 + [ + 0.123, + -0.456, + 0.789 + ], + [ + 0.987, + -0.654, + 0.321 + ] + ], + "meta": { + "usage": { + "total_tokens": 10 } - ] + } } """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id"); + + String productUseCase = "test-product-use-case"; + threadPool.getThreadContext().putHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, productUseCase); - var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(elasticInferenceServiceURL, "my-model-id"); PlainActionFuture> listener = new PlainActionFuture<>(); + // 2 inputs service.chunkedInfer( model, null, - List.of(new ChunkInferenceInput("input text")), + List.of(new ChunkInferenceInput("hello world"), new ChunkInferenceInput("dense embedding")), new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -866,32 +764,106 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { ); var results = listener.actionGet(TIMEOUT); - assertThat(results.get(0), instanceOf(ChunkedInferenceEmbedding.class)); - var sparseResult = (ChunkedInferenceEmbedding) results.get(0); - assertThat( - sparseResult.chunks(), - is( - List.of( - new EmbeddingResults.Chunk( - new SparseEmbeddingResults.Embedding( - List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)), - false - ), - new ChunkedInference.TextOffset(0, "input text".length()) - ) - ) - ) + assertThat(results, hasSize(2)); + + // Verify the response was processed correctly + ChunkedInference inferenceResult = results.getFirst(); + assertThat(inferenceResult, instanceOf(ChunkedInferenceEmbedding.class)); + + // Verify the request was sent and contains expected headers + assertThat(webServer.requests(), hasSize(1)); + var request = webServer.requests().getFirst(); + assertNull(request.getUri().getQuery()); + assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + + // Check that the product use case header was set correctly + assertThat(request.getHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER), is(productUseCase)); + + } finally { + // Clean up the thread context + threadPool.getThreadContext().stashContext(); + } + } + + public void testChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOException { + var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id"); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = createService(senderFactory, getUrl(webServer))) { + + // Batching will call the service with 2 inputs + String responseJson = """ + { + "data": [ + [ + 0.123, + -0.456, + 0.789 + ], + [ + 0.987, + -0.654, + 0.321 + ] + ], + "meta": { + "usage": { + "total_tokens": 10 + } + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture> listener = new PlainActionFuture<>(); + // 2 inputs + service.chunkedInfer( + model, + null, + List.of(new ChunkInferenceInput("hello world"), new ChunkInferenceInput("dense embedding")), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener ); - MatcherAssert.assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); + var results = listener.actionGet(TIMEOUT); + assertThat(results, hasSize(2)); + + // First result + { + assertThat(results.getFirst(), instanceOf(ChunkedInferenceEmbedding.class)); + var denseResult = (ChunkedInferenceEmbedding) results.getFirst(); + assertThat(denseResult.chunks(), hasSize(1)); + assertEquals(new ChunkedInference.TextOffset(0, "hello world".length()), denseResult.chunks().getFirst().offset()); + assertThat(denseResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + + var embedding = (TextEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding(); + assertArrayEquals(new float[] { 0.123f, -0.456f, 0.789f }, embedding.values(), 0.0f); + } + + // Second result + { + assertThat(results.get(1), instanceOf(ChunkedInferenceEmbedding.class)); + var denseResult = (ChunkedInferenceEmbedding) results.get(1); + assertThat(denseResult.chunks(), hasSize(1)); + assertEquals(new ChunkedInference.TextOffset(0, "dense embedding".length()), denseResult.chunks().getFirst().offset()); + assertThat(denseResult.chunks().getFirst().embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); + + var embedding = (TextEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding(); + assertArrayEquals(new float[] { 0.987f, -0.654f, 0.321f }, embedding.values(), 0.0f); + } + + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().getFirst().getUri().getQuery()); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); MatcherAssert.assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), - equalTo(XContentType.JSON.mediaType()) + requestMap, + is(Map.of("input", List.of("hello world", "dense embedding"), "model", "my-dense-model-id", "usage_context", "ingest")) ); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap, is(Map.of("input", List.of("input text"), "model", "my-model-id", "usage_context", "ingest"))); } } @@ -903,27 +875,6 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { } } - public void testHideFromConfigurationApi_ReturnsTrue_WithModelTaskTypesThatAreNotImplemented() throws Exception { - try ( - var service = createServiceWithMockSender( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-1", - EnumSet.of(TaskType.TEXT_EMBEDDING) - ) - ) - ) - ) - ) - ) { - ensureAuthorizationCallFinished(service); - - assertTrue(service.hideFromConfigurationApi()); - } - } - public void testHideFromConfigurationApi_ReturnsFalse_WithAvailableModels() throws Exception { try ( var service = createServiceWithMockSender( @@ -953,7 +904,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { List.of( new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( "model-1", - EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION) + EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING) ) ) ) @@ -966,7 +917,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { { "service": "elastic", "name": "Elastic", - "task_types": ["sparse_embedding", "chat_completion"], + "task_types": ["sparse_embedding", "chat_completion", "text_embedding"], "configurations": { "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", @@ -975,7 +926,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"] + "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"] }, "model_id": { "description": "The name of the model to use for the inference task.", @@ -984,7 +935,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"] + "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"] }, "max_input_tokens": { "description": "Allows you to specify the maximum number of tokens per input.", @@ -993,7 +944,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["sparse_embedding"] + "supported_task_types": ["text_embedding", "sparse_embedding"] } } } @@ -1030,7 +981,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"] + "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"] }, "model_id": { "description": "The name of the model to use for the inference task.", @@ -1039,7 +990,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"] + "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"] }, "max_input_tokens": { "description": "Allows you to specify the maximum number of tokens per input.", @@ -1048,7 +999,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["sparse_embedding"] + "supported_task_types": ["text_embedding", "sparse_embedding"] } } } @@ -1090,7 +1041,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { { "service": "elastic", "name": "Elastic", - "task_types": [], + "task_types": ["text_embedding"], "configurations": { "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", @@ -1099,7 +1050,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"] + "supported_task_types": ["text_embedding" , "sparse_embedding", "rerank", "chat_completion"] }, "model_id": { "description": "The name of the model to use for the inference task.", @@ -1108,7 +1059,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"] + "supported_task_types": ["text_embedding" , "sparse_embedding", "rerank", "chat_completion"] }, "max_input_tokens": { "description": "Allows you to specify the maximum number of tokens per input.", @@ -1117,7 +1068,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["sparse_embedding"] + "supported_task_types": ["text_embedding", "sparse_embedding"] } } } @@ -1296,6 +1247,10 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { "task_types": ["embed/text/sparse"] }, { + "model_name": "multilingual-embed-v1", + "task_types": ["embed/text/dense"] + }, + { "model_name": "rerank-v1", "task_types": ["rerank/text/text-similarity"] } @@ -1319,6 +1274,16 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), service ), + new InferenceService.DefaultConfigId( + ".multilingual-embed-v1-elastic", + MinimalServiceSettings.textEmbedding( + ElasticInferenceService.NAME, + ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, + ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), + DenseVectorFieldMapper.ElementType.FLOAT + ), + service + ), new InferenceService.DefaultConfigId( ".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), @@ -1332,16 +1297,19 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { ) ) ); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK))); + assertThat( + service.supportedTaskTypes(), + is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK, TaskType.TEXT_EMBEDDING)) + ); PlainActionFuture> listener = new PlainActionFuture<>(); service.defaultConfigs(listener); var models = listener.actionGet(TIMEOUT); - assertThat(models.size(), is(3)); + assertThat(models.size(), is(4)); assertThat(models.get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic")); - assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); - assertThat(models.get(2).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic")); - + assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".multilingual-embed-v1-elastic")); + assertThat(models.get(2).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); + assertThat(models.get(3).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java index 49957800f3a8..c8701b47a20b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java @@ -22,12 +22,14 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModelTests; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import org.junit.After; @@ -46,6 +48,7 @@ import static org.elasticsearch.xpack.inference.external.http.retry.RetrySetting import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; @@ -256,6 +259,206 @@ public class ElasticInferenceServiceActionCreatorTests extends ESTestCase { } } + @SuppressWarnings("unchecked") + public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "data": [ + [ + 2.1259406, + 1.7073475, + 0.9020516 + ], + [ + 1.8342123, + 2.3456789, + 0.7654321 + ] + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id"); + var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext()); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new EmbeddingsInput(List.of("hello world", "second text"), null, InputType.UNSPECIFIED), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result, instanceOf(TextEmbeddingFloatResults.class)); + var textEmbeddingResults = (TextEmbeddingFloatResults) result; + assertThat(textEmbeddingResults.embeddings(), hasSize(2)); + + var firstEmbedding = textEmbeddingResults.embeddings().get(0); + assertThat(firstEmbedding.values(), is(new float[] { 2.1259406f, 1.7073475f, 0.9020516f })); + + var secondEmbedding = textEmbeddingResults.embeddings().get(1); + assertThat(secondEmbedding.values(), is(new float[] { 1.8342123f, 2.3456789f, 0.7654321f })); + + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(2)); + assertThat(requestMap.get("input"), instanceOf(List.class)); + var inputList = (List) requestMap.get("input"); + assertThat(inputList, contains("hello world", "second text")); + assertThat(requestMap.get("model"), is("my-dense-model-id")); + } + } + + @SuppressWarnings("unchecked") + public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_WithUsageContext() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "data": [ + [ + 0.1234567, + 0.9876543 + ] + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id"); + var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext()); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new EmbeddingsInput(List.of("search query"), null, InputType.SEARCH), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result, instanceOf(TextEmbeddingFloatResults.class)); + var textEmbeddingResults = (TextEmbeddingFloatResults) result; + assertThat(textEmbeddingResults.embeddings(), hasSize(1)); + + var embedding = textEmbeddingResults.embeddings().get(0); + assertThat(embedding.values(), is(new float[] { 0.1234567f, 0.9876543f })); + + assertThat(webServer.requests(), hasSize(1)); + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(3)); + assertThat(requestMap.get("input"), instanceOf(List.class)); + var inputList = (List) requestMap.get("input"); + assertThat(inputList, contains("search query")); + assertThat(requestMap.get("model"), is("my-dense-model-id")); + assertThat(requestMap.get("usage_context"), is("search")); + } + } + + @SuppressWarnings("unchecked") + public void testSend_FailsFromInvalidResponseFormat_ForDenseTextEmbeddingsAction() throws IOException { + // timeout as zero for no retries + var settings = buildSettingsWithRetryFields( + TimeValue.timeValueMillis(1), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueSeconds(0) + ); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + // This will fail because the expected output is {"data": [[...]]} + String responseJson = """ + { + "data": { + "embedding": [2.1259406, 1.7073475] + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id"); + var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext()); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(thrownException.getMessage(), containsString("[EmbeddingFloatResult] failed to parse field [data]")); + + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(2)); + assertThat(requestMap.get("input"), instanceOf(List.class)); + var inputList = (List) requestMap.get("input"); + assertThat(inputList, contains("hello world")); + assertThat(requestMap.get("model"), is("my-dense-model-id")); + } + } + + @SuppressWarnings("unchecked") + public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_EmptyEmbeddings() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "data": [] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id"); + var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext()); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new EmbeddingsInput(List.of(), null, InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result, instanceOf(TextEmbeddingFloatResults.class)); + var textEmbeddingResults = (TextEmbeddingFloatResults) result; + assertThat(textEmbeddingResults.embeddings(), hasSize(0)); + + assertThat(webServer.requests(), hasSize(1)); + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.get("input"), instanceOf(List.class)); + var inputList = (List) requestMap.get("input"); + assertThat(inputList, hasSize(0)); + } + } + public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModelTests.java new file mode 100644 index 000000000000..fe0e4efc85a5 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModelTests.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.densetextembeddings; + +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +public class ElasticInferenceServiceDenseTextEmbeddingsModelTests { + + public static ElasticInferenceServiceDenseTextEmbeddingsModel createModel(String url, String modelId) { + return new ElasticInferenceServiceDenseTextEmbeddingsModel( + "id", + TaskType.TEXT_EMBEDDING, + "elastic", + new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + modelId, + SimilarityMeasure.COSINE, + null, + null, + new RateLimitSettings(1000L) + ), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + ElasticInferenceServiceComponents.of(url), + ChunkingSettingsBuilder.DEFAULT_SETTINGS + ); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java new file mode 100644 index 000000000000..a9263d5624dc --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java @@ -0,0 +1,165 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.densetextembeddings; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase< + ElasticInferenceServiceDenseTextEmbeddingsServiceSettings> { + + @Override + protected Writeable.Reader instanceReader() { + return ElasticInferenceServiceDenseTextEmbeddingsServiceSettings::new; + } + + @Override + protected ElasticInferenceServiceDenseTextEmbeddingsServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected ElasticInferenceServiceDenseTextEmbeddingsServiceSettings mutateInstance( + ElasticInferenceServiceDenseTextEmbeddingsServiceSettings instance + ) throws IOException { + return randomValueOtherThan(instance, ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests::createRandom); + } + + public void testFromMap_Request_WithAllSettings() { + var modelId = "my-dense-model-id"; + var similarity = SimilarityMeasure.COSINE; + var dimensions = 384; + var maxInputTokens = 512; + + var serviceSettings = ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + modelId, + ServiceFields.SIMILARITY, + similarity.toString(), + ServiceFields.DIMENSIONS, + dimensions, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens + ) + ), + ConfigurationParseContext.REQUEST + ); + + assertThat(serviceSettings.modelId(), is(modelId)); + assertThat(serviceSettings.similarity(), is(similarity)); + assertThat(serviceSettings.dimensions(), is(dimensions)); + assertThat(serviceSettings.maxInputTokens(), is(maxInputTokens)); + } + + public void testToXContent_WritesAllFields() throws IOException { + var modelId = "my-dense-model"; + var similarity = SimilarityMeasure.DOT_PRODUCT; + var dimensions = 1024; + var maxInputTokens = 256; + var rateLimitSettings = new RateLimitSettings(5000); + + var serviceSettings = new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + modelId, + similarity, + dimensions, + maxInputTokens, + rateLimitSettings + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + String expectedResult = Strings.format( + """ + {"similarity":"%s","dimensions":%d,"max_input_tokens":%d,"model_id":"%s","rate_limit":{"requests_per_minute":%d}}""", + similarity, + dimensions, + maxInputTokens, + modelId, + rateLimitSettings.requestsPerTimeUnit() + ); + + assertThat(xContentResult, is(expectedResult)); + } + + public void testToXContent_WritesOnlyNonNullFields() throws IOException { + var modelId = "my-dense-model"; + var rateLimitSettings = new RateLimitSettings(2000); + + var serviceSettings = new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + modelId, + null, // similarity + null, // dimensions + null, // maxInputTokens + rateLimitSettings + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(Strings.format(""" + {"model_id":"%s","rate_limit":{"requests_per_minute":%d}}""", modelId, rateLimitSettings.requestsPerTimeUnit()))); + } + + public void testToXContentFragmentOfExposedFields() throws IOException { + var modelId = "my-dense-model"; + var rateLimitSettings = new RateLimitSettings(1500); + + var serviceSettings = new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + modelId, + SimilarityMeasure.COSINE, + 512, + 128, + rateLimitSettings + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + builder.startObject(); + serviceSettings.toXContentFragmentOfExposedFields(builder, null); + builder.endObject(); + String xContentResult = Strings.toString(builder); + + // Only model_id and rate_limit should be in exposed fields + assertThat(xContentResult, is(Strings.format(""" + {"model_id":"%s","rate_limit":{"requests_per_minute":%d}}""", modelId, rateLimitSettings.requestsPerTimeUnit()))); + } + + public static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings createRandom() { + var modelId = randomAlphaOfLength(10); + var similarity = SimilarityMeasure.COSINE; + var dimensions = randomBoolean() ? randomIntBetween(1, 1024) : null; + var maxInputTokens = randomBoolean() ? randomIntBetween(128, 256) : null; + var rateLimitSettings = randomBoolean() ? new RateLimitSettings(randomIntBetween(1, 10000)) : null; + + return new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + modelId, + similarity, + dimensions, + maxInputTokens, + rateLimitSettings + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestEntityTests.java new file mode 100644 index 000000000000..f0ac37174f15 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestEntityTests.java @@ -0,0 +1,147 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.request; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceUsageContext; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; + +public class ElasticInferenceServiceDenseTextEmbeddingsRequestEntityTests extends ESTestCase { + + public void testToXContent_SingleInput_UnspecifiedUsageContext() throws IOException { + var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity( + List.of("abc"), + "my-model-id", + ElasticInferenceServiceUsageContext.UNSPECIFIED + ); + String xContentString = xContentEntityToString(entity); + assertThat(xContentString, equalToIgnoringWhitespaceInJsonString(""" + { + "input": ["abc"], + "model": "my-model-id" + }""")); + } + + public void testToXContent_MultipleInputs_UnspecifiedUsageContext() throws IOException { + var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity( + List.of("abc", "def"), + "my-model-id", + ElasticInferenceServiceUsageContext.UNSPECIFIED + ); + String xContentString = xContentEntityToString(entity); + assertThat(xContentString, equalToIgnoringWhitespaceInJsonString(""" + { + "input": [ + "abc", + "def" + ], + "model": "my-model-id" + } + """)); + } + + public void testToXContent_SingleInput_SearchUsageContext() throws IOException { + var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity( + List.of("abc"), + "my-model-id", + ElasticInferenceServiceUsageContext.SEARCH + ); + String xContentString = xContentEntityToString(entity); + assertThat(xContentString, equalToIgnoringWhitespaceInJsonString(""" + { + "input": ["abc"], + "model": "my-model-id", + "usage_context": "search" + } + """)); + } + + public void testToXContent_SingleInput_IngestUsageContext() throws IOException { + var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity( + List.of("abc"), + "my-model-id", + ElasticInferenceServiceUsageContext.INGEST + ); + String xContentString = xContentEntityToString(entity); + assertThat(xContentString, equalToIgnoringWhitespaceInJsonString(""" + { + "input": ["abc"], + "model": "my-model-id", + "usage_context": "ingest" + } + """)); + } + + public void testToXContent_MultipleInputs_SearchUsageContext() throws IOException { + var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity( + List.of("first input", "second input", "third input"), + "my-dense-model", + ElasticInferenceServiceUsageContext.SEARCH + ); + String xContentString = xContentEntityToString(entity); + assertThat(xContentString, equalToIgnoringWhitespaceInJsonString(""" + { + "input": [ + "first input", + "second input", + "third input" + ], + "model": "my-dense-model", + "usage_context": "search" + } + """)); + } + + public void testToXContent_MultipleInputs_IngestUsageContext() throws IOException { + var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity( + List.of("document one", "document two"), + "embedding-model-v2", + ElasticInferenceServiceUsageContext.INGEST + ); + String xContentString = xContentEntityToString(entity); + assertThat(xContentString, equalToIgnoringWhitespaceInJsonString(""" + { + "input": [ + "document one", + "document two" + ], + "model": "embedding-model-v2", + "usage_context": "ingest" + } + """)); + } + + public void testToXContent_EmptyInput_UnspecifiedUsageContext() throws IOException { + var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity( + List.of(""), + "my-model-id", + ElasticInferenceServiceUsageContext.UNSPECIFIED + ); + String xContentString = xContentEntityToString(entity); + assertThat(xContentString, equalToIgnoringWhitespaceInJsonString(""" + { + "input": [""], + "model": "my-model-id" + } + """)); + } + + private String xContentEntityToString(ElasticInferenceServiceDenseTextEmbeddingsRequestEntity entity) throws IOException { + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + return Strings.toString(builder); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestTests.java new file mode 100644 index 000000000000..86687980acdf --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceDenseTextEmbeddingsRequestTests.java @@ -0,0 +1,165 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.request; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.telemetry.TraceContext; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; + +public class ElasticInferenceServiceDenseTextEmbeddingsRequestTests extends ESTestCase { + + public void testCreateHttpRequest_UsageContextSearch() throws IOException { + var url = "http://eis-gateway.com"; + var input = List.of("input text"); + var modelId = "my-dense-model-id"; + + var request = createRequest(url, modelId, input, InputType.SEARCH); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap.size(), equalTo(3)); + assertThat(requestMap.get("input"), is(input)); + assertThat(requestMap.get("model"), is(modelId)); + assertThat(requestMap.get("usage_context"), equalTo("search")); + } + + public void testCreateHttpRequest_UsageContextIngest() throws IOException { + var url = "http://eis-gateway.com"; + var input = List.of("ingest text"); + var modelId = "my-dense-model-id"; + + var request = createRequest(url, modelId, input, InputType.INGEST); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap.size(), equalTo(3)); + assertThat(requestMap.get("input"), is(input)); + assertThat(requestMap.get("model"), is(modelId)); + assertThat(requestMap.get("usage_context"), equalTo("ingest")); + } + + public void testCreateHttpRequest_UsageContextUnspecified() throws IOException { + var url = "http://eis-gateway.com"; + var input = List.of("unspecified text"); + var modelId = "my-dense-model-id"; + + var request = createRequest(url, modelId, input, InputType.UNSPECIFIED); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(2)); + assertThat(requestMap.get("input"), is(input)); + assertThat(requestMap.get("model"), is(modelId)); + // usage_context should not be present for UNSPECIFIED + } + + public void testCreateHttpRequest_MultipleInputs() throws IOException { + var url = "http://eis-gateway.com"; + var inputs = List.of("first input", "second input", "third input"); + var modelId = "my-dense-model-id"; + + var request = createRequest(url, modelId, inputs, InputType.SEARCH); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap.size(), equalTo(3)); + assertThat(requestMap.get("input"), is(inputs)); + assertThat(requestMap.get("model"), is(modelId)); + assertThat(requestMap.get("usage_context"), equalTo("search")); + } + + public void testTraceContextPropagatedThroughHTTPHeaders() { + var url = "http://eis-gateway.com"; + var input = List.of("input text"); + var modelId = "my-dense-model-id"; + + var request = createRequest(url, modelId, input, InputType.UNSPECIFIED); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var traceParent = request.getTraceContext().traceParent(); + var traceState = request.getTraceContext().traceState(); + + assertThat(httpPost.getLastHeader(Task.TRACE_PARENT_HTTP_HEADER).getValue(), is(traceParent)); + assertThat(httpPost.getLastHeader(Task.TRACE_STATE).getValue(), is(traceState)); + } + + public void testTruncate_ReturnsSameInstance() { + var url = "http://eis-gateway.com"; + var input = List.of("input text"); + var modelId = "my-dense-model-id"; + + var request = createRequest(url, modelId, input, InputType.UNSPECIFIED); + var truncatedRequest = request.truncate(); + + // Dense text embeddings request doesn't support truncation, should return same instance + assertThat(truncatedRequest, is(request)); + } + + public void testGetTruncationInfo_ReturnsNull() { + var url = "http://eis-gateway.com"; + var input = List.of("input text"); + var modelId = "my-dense-model-id"; + + var request = createRequest(url, modelId, input, InputType.UNSPECIFIED); + + // Dense text embeddings request doesn't support truncation info + assertThat(request.getTruncationInfo(), is(nullValue())); + } + + private ElasticInferenceServiceDenseTextEmbeddingsRequest createRequest( + String url, + String modelId, + List inputs, + InputType inputType + ) { + var embeddingsModel = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(url, modelId); + + return new ElasticInferenceServiceDenseTextEmbeddingsRequest( + embeddingsModel, + inputs, + new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10)), + randomElasticInferenceServiceRequestMetadata(), + inputType + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests.java new file mode 100644 index 000000000000..2883a1ab73c2 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests.java @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.response; + +import org.apache.http.HttpResponse; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceDenseTextEmbeddingsResponseEntity; + +import java.nio.charset.StandardCharsets; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.hasSize; +import static org.mockito.Mockito.mock; + +public class ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests extends ESTestCase { + + public void testDenseTextEmbeddingsResponse_SingleEmbeddingInData_NoMeta() throws Exception { + String responseJson = """ + { + "data": [ + [ + 1.23, + 4.56, + 7.89 + ] + ] + } + """; + + TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(parsedResults.embeddings(), hasSize(1)); + + var embedding = parsedResults.embeddings().get(0); + assertThat(embedding.values(), is(new float[] { 1.23f, 4.56f, 7.89f })); + } + + public void testDenseTextEmbeddingsResponse_MultipleEmbeddingsInData_NoMeta() throws Exception { + String responseJson = """ + { + "data": [ + [ + 1.23, + 4.56, + 7.89 + ], + [ + 0.12, + 0.34, + 0.56 + ] + ] + } + """; + + TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(parsedResults.embeddings(), hasSize(2)); + + var firstEmbedding = parsedResults.embeddings().get(0); + assertThat(firstEmbedding.values(), is(new float[] { 1.23f, 4.56f, 7.89f })); + + var secondEmbedding = parsedResults.embeddings().get(1); + assertThat(secondEmbedding.values(), is(new float[] { 0.12f, 0.34f, 0.56f })); + } + + public void testDenseTextEmbeddingsResponse_EmptyData() throws Exception { + String responseJson = """ + { + "data": [] + } + """; + + TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(parsedResults.embeddings(), hasSize(0)); + } + + public void testDenseTextEmbeddingsResponse_SingleEmbeddingInData_IgnoresMeta() throws Exception { + String responseJson = """ + { + "data": [ + [ + -1.0, + 0.0, + 1.0 + ] + ], + "meta": { + "usage": { + "total_tokens": 5 + } + } + } + """; + + TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(parsedResults.embeddings(), hasSize(1)); + + var embedding = parsedResults.embeddings().get(0); + assertThat(embedding.values(), is(new float[] { -1.0f, 0.0f, 1.0f })); + } +}