diff --git a/docs/changelog/126856.yaml b/docs/changelog/126856.yaml new file mode 100644 index 000000000000..5cc9bdc6946f --- /dev/null +++ b/docs/changelog/126856.yaml @@ -0,0 +1,5 @@ +pr: 126856 +summary: "[ML] Integrate SageMaker with OpenAI Embeddings" +area: Machine Learning +type: enhancement +issues: [] diff --git a/gradle/verification-metadata.xml b/gradle/verification-metadata.xml index 8d01a5d66bde..eb43190a68ba 100644 --- a/gradle/verification-metadata.xml +++ b/gradle/verification-metadata.xml @@ -4912,6 +4912,11 @@ + + + + + diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index d2d9fe3fea8a..2fea4d30ba82 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -162,6 +162,7 @@ public class TransportVersions { public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17); public static final TransportVersion BATCHED_QUERY_PHASE_VERSION_BACKPORT_8_X = def(8_841_0_19); public static final TransportVersion SEARCH_INCREMENTAL_TOP_DOCS_NULL_BACKPORT_8_19 = def(8_841_0_20); + public static final TransportVersion ML_INFERENCE_SAGEMAKER_8_19 = def(8_841_0_21); public static final TransportVersion V_9_0_0 = def(9_000_0_09); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10); public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED = def(9_001_0_00); @@ -232,6 +233,7 @@ public class TransportVersions { public static final TransportVersion PROJECT_METADATA_SETTINGS = def(9_066_00_0); public static final TransportVersion AGGREGATE_METRIC_DOUBLE_BLOCK = def(9_067_00_0); public static final TransportVersion PINNED_RETRIEVER = def(9_068_0_00); + public static final TransportVersion ML_INFERENCE_SAGEMAKER = def(9_069_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/common/ValidationException.java b/server/src/main/java/org/elasticsearch/common/ValidationException.java index 67ff85f0bae6..aad91dbac9b4 100644 --- a/server/src/main/java/org/elasticsearch/common/ValidationException.java +++ b/server/src/main/java/org/elasticsearch/common/ValidationException.java @@ -53,6 +53,12 @@ public class ValidationException extends IllegalArgumentException { return validationErrors; } + public final void throwIfValidationErrorsExist() { + if (validationErrors().isEmpty() == false) { + throw this; + } + } + @Override public final String getMessage() { StringBuilder sb = new StringBuilder(); diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index b0657968f00f..fba8d9e61f0c 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -62,6 +62,7 @@ dependencies { /* AWS SDK v2 */ implementation("software.amazon.awssdk:bedrockruntime:${versions.awsv2sdk}") + implementation("software.amazon.awssdk:sagemakerruntime:${versions.awsv2sdk}") api "software.amazon.awssdk:protocol-core:${versions.awsv2sdk}" api "software.amazon.awssdk:aws-json-protocol:${versions.awsv2sdk}" api "software.amazon.awssdk:third-party-jackson-core:${versions.awsv2sdk}" @@ -142,6 +143,7 @@ tasks.named("dependencyLicenses").configure { mapping from: /json-utils.*/, to: 'aws-sdk-2' mapping from: /endpoints-spi.*/, to: 'aws-sdk-2' mapping from: /bedrockruntime.*/, to: 'aws-sdk-2' + mapping from: /sagemakerruntime.*/, to: 'aws-sdk-2' mapping from: /netty-nio-client/, to: 'aws-sdk-2' /* Cannot use REGEX to match netty-* because netty-nio-client is an AWS package */ mapping from: /netty-buffer/, to: 'netty' diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index da65da368951..682eebd0fa69 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -18,163 +18,161 @@ import java.util.List; import java.util.Map; import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; public class InferenceGetServicesIT extends BaseMockEISAuthServerTest { - @SuppressWarnings("unchecked") public void testGetServicesWithoutTaskType() throws IOException { List services = getAllServices(); - assertThat(services.size(), equalTo(21)); + assertThat(services.size(), equalTo(22)); - String[] providers = new String[services.size()]; - for (int i = 0; i < services.size(); i++) { - Map serviceConfig = (Map) services.get(i); - providers[i] = (String) serviceConfig.get("service"); - } + var providers = providers(services); - assertArrayEquals( - List.of( - "alibabacloud-ai-search", - "amazonbedrock", - "anthropic", - "azureaistudio", - "azureopenai", - "cohere", - "deepseek", - "elastic", - "elasticsearch", - "googleaistudio", - "googlevertexai", - "hugging_face", - "jinaai", - "mistral", - "openai", - "streaming_completion_test_service", - "test_reranking_service", - "test_service", - "text_embedding_test_service", - "voyageai", - "watsonxai" - ).toArray(), - providers + assertThat( + providers, + containsInAnyOrder( + List.of( + "alibabacloud-ai-search", + "amazonbedrock", + "anthropic", + "azureaistudio", + "azureopenai", + "cohere", + "deepseek", + "elastic", + "elasticsearch", + "googleaistudio", + "googlevertexai", + "hugging_face", + "jinaai", + "mistral", + "openai", + "streaming_completion_test_service", + "test_reranking_service", + "test_service", + "text_embedding_test_service", + "voyageai", + "watsonxai", + "sagemaker" + ).toArray() + ) ); } @SuppressWarnings("unchecked") + private Iterable providers(List services) { + return services.stream().map(service -> { + var serviceConfig = (Map) service; + return (String) serviceConfig.get("service"); + }).toList(); + } + public void testGetServicesWithTextEmbeddingTaskType() throws IOException { List services = getServices(TaskType.TEXT_EMBEDDING); - assertThat(services.size(), equalTo(15)); + assertThat(services.size(), equalTo(16)); - String[] providers = new String[services.size()]; - for (int i = 0; i < services.size(); i++) { - Map serviceConfig = (Map) services.get(i); - providers[i] = (String) serviceConfig.get("service"); - } + var providers = providers(services); - assertArrayEquals( - List.of( - "alibabacloud-ai-search", - "amazonbedrock", - "azureaistudio", - "azureopenai", - "cohere", - "elasticsearch", - "googleaistudio", - "googlevertexai", - "hugging_face", - "jinaai", - "mistral", - "openai", - "text_embedding_test_service", - "voyageai", - "watsonxai" - ).toArray(), - providers + assertThat( + providers, + containsInAnyOrder( + List.of( + "alibabacloud-ai-search", + "amazonbedrock", + "azureaistudio", + "azureopenai", + "cohere", + "elasticsearch", + "googleaistudio", + "googlevertexai", + "hugging_face", + "jinaai", + "mistral", + "openai", + "text_embedding_test_service", + "voyageai", + "watsonxai", + "sagemaker" + ).toArray() + ) ); } - @SuppressWarnings("unchecked") public void testGetServicesWithRerankTaskType() throws IOException { List services = getServices(TaskType.RERANK); assertThat(services.size(), equalTo(7)); - String[] providers = new String[services.size()]; - for (int i = 0; i < services.size(); i++) { - Map serviceConfig = (Map) services.get(i); - providers[i] = (String) serviceConfig.get("service"); - } + var providers = providers(services); - assertArrayEquals( - List.of("alibabacloud-ai-search", "cohere", "elasticsearch", "googlevertexai", "jinaai", "test_reranking_service", "voyageai") - .toArray(), - providers + assertThat( + providers, + containsInAnyOrder( + List.of( + "alibabacloud-ai-search", + "cohere", + "elasticsearch", + "googlevertexai", + "jinaai", + "test_reranking_service", + "voyageai" + ).toArray() + ) ); } - @SuppressWarnings("unchecked") public void testGetServicesWithCompletionTaskType() throws IOException { List services = getServices(TaskType.COMPLETION); assertThat(services.size(), equalTo(10)); - String[] providers = new String[services.size()]; - for (int i = 0; i < services.size(); i++) { - Map serviceConfig = (Map) services.get(i); - providers[i] = (String) serviceConfig.get("service"); - } + var providers = providers(services); - assertArrayEquals( - List.of( - "alibabacloud-ai-search", - "amazonbedrock", - "anthropic", - "azureaistudio", - "azureopenai", - "cohere", - "deepseek", - "googleaistudio", - "openai", - "streaming_completion_test_service" - ).toArray(), - providers + assertThat( + providers, + containsInAnyOrder( + List.of( + "alibabacloud-ai-search", + "amazonbedrock", + "anthropic", + "azureaistudio", + "azureopenai", + "cohere", + "deepseek", + "googleaistudio", + "openai", + "streaming_completion_test_service" + ).toArray() + ) ); } - @SuppressWarnings("unchecked") public void testGetServicesWithChatCompletionTaskType() throws IOException { List services = getServices(TaskType.CHAT_COMPLETION); assertThat(services.size(), equalTo(4)); - String[] providers = new String[services.size()]; - for (int i = 0; i < services.size(); i++) { - Map serviceConfig = (Map) services.get(i); - providers[i] = (String) serviceConfig.get("service"); - } + var providers = providers(services); - assertArrayEquals(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray(), providers); + assertThat(providers, containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray())); } - @SuppressWarnings("unchecked") public void testGetServicesWithSparseEmbeddingTaskType() throws IOException { List services = getServices(TaskType.SPARSE_EMBEDDING); assertThat(services.size(), equalTo(6)); - String[] providers = new String[services.size()]; - for (int i = 0; i < services.size(); i++) { - Map serviceConfig = (Map) services.get(i); - providers[i] = (String) serviceConfig.get("service"); - } + var providers = providers(services); - assertArrayEquals( - List.of( - "alibabacloud-ai-search", - "elastic", - "elasticsearch", - "hugging_face", - "streaming_completion_test_service", - "test_service" - ).toArray(), - providers + assertThat( + providers, + containsInAnyOrder( + List.of( + "alibabacloud-ai-search", + "elastic", + "elasticsearch", + "hugging_face", + "streaming_completion_test_service", + "test_service" + ).toArray() + ) ); } diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java index 346be5632e47..6aae961d4504 100644 --- a/x-pack/plugin/inference/src/main/java/module-info.java +++ b/x-pack/plugin/inference/src/main/java/module-info.java @@ -36,6 +36,7 @@ module org.elasticsearch.inference { requires org.elasticsearch.logging; requires org.elasticsearch.sslconfig; requires org.apache.commons.text; + requires software.amazon.awssdk.services.sagemakerruntime; exports org.elasticsearch.xpack.inference.action; exports org.elasticsearch.xpack.inference.registry; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 63d9d8a3bd9d..5c719c08142a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -92,6 +92,8 @@ import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCo import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings; import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings; import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings; @@ -157,6 +159,8 @@ public class InferenceNamedWriteablesProvider { namedWriteables.addAll(StreamingTaskManager.namedWriteables()); namedWriteables.addAll(DeepSeekChatCompletionModel.namedWriteables()); + namedWriteables.addAll(SageMakerModel.namedWriteables()); + namedWriteables.addAll(SageMakerSchemas.namedWriteables()); return namedWriteables; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 7edb724132ac..587eafbf553c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -21,6 +21,7 @@ import org.elasticsearch.common.settings.IndexScopedSettings; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.SettingsFilter; +import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.TimeValue; import org.elasticsearch.features.NodeFeature; @@ -132,6 +133,11 @@ import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService; import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService; import org.elasticsearch.xpack.inference.services.mistral.MistralService; import org.elasticsearch.xpack.inference.services.openai.OpenAiService; +import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient; +import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerService; +import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerConfiguration; +import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas; import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIService; import org.elasticsearch.xpack.inference.telemetry.InferenceStats; @@ -294,6 +300,8 @@ public class InferencePlugin extends Plugin services.threadPool() ); + var sageMakerSchemas = new SageMakerSchemas(); + var sageMakerConfigurations = new LazyInitializable<>(new SageMakerConfiguration(sageMakerSchemas)); inferenceServices.add( () -> List.of( context -> new ElasticInferenceService( @@ -302,6 +310,16 @@ public class InferencePlugin extends Plugin inferenceServiceSettings, modelRegistry.get(), authorizationHandler + ), + context -> new SageMakerService( + new SageMakerModelBuilder(sageMakerSchemas), + new SageMakerClient( + new SageMakerClient.Factory(new HttpSettings(settings, services.clusterService())), + services.threadPool() + ), + sageMakerSchemas, + services.threadPool(), + sageMakerConfigurations::getOrCompute ) ) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/amazon/AwsSecretSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/amazon/AwsSecretSettings.java index bad98b64a43e..934c78f3e1b2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/amazon/AwsSecretSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/amazon/AwsSecretSettings.java @@ -23,11 +23,12 @@ import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; -import java.util.Collections; import java.util.EnumSet; import java.util.HashMap; import java.util.Map; import java.util.Objects; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredSecureString; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.ACCESS_KEY_FIELD; @@ -134,33 +135,39 @@ public class AwsSecretSettings implements SecretSettings { } private static final LazyInitializable, RuntimeException> configuration = - new LazyInitializable<>(() -> { - var configurationMap = new HashMap(); - configurationMap.put( - ACCESS_KEY_FIELD, - new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)).setDescription( - "A valid AWS access key that has permissions to use Amazon Bedrock." - ) - .setLabel("Access Key") - .setRequired(true) - .setSensitive(true) - .setUpdatable(true) - .setType(SettingsConfigurationFieldType.STRING) - .build() - ); - configurationMap.put( - SECRET_KEY_FIELD, - new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)).setDescription( - "A valid AWS secret key that is paired with the access_key." - ) - .setLabel("Secret Key") - .setRequired(true) - .setSensitive(true) - .setUpdatable(true) - .setType(SettingsConfigurationFieldType.STRING) - .build() - ); - return Collections.unmodifiableMap(configurationMap); - }); + new LazyInitializable<>( + () -> configuration(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)).collect( + Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue) + ) + ); + } + + public static Stream> configuration(EnumSet supportedTaskTypes) { + return Stream.of( + Map.entry( + ACCESS_KEY_FIELD, + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription( + "A valid AWS access key that has permissions to use Amazon Bedrock." + ) + .setLabel("Access Key") + .setRequired(true) + .setSensitive(true) + .setUpdatable(true) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ), + Map.entry( + SECRET_KEY_FIELD, + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription( + "A valid AWS secret key that is paired with the access_key." + ) + .setLabel("Secret Key") + .setRequired(true) + .setSensitive(true) + .setUpdatable(true) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ) + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpSettings.java index 72dfd7bcd1b1..6161b9372844 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpSettings.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.TimeValue; +import java.time.Duration; import java.util.List; import java.util.Objects; @@ -55,6 +56,10 @@ public class HttpSettings { return connectionTimeout; } + public Duration connectionTimeoutDuration() { + return Duration.ofMillis(connectionTimeout); + } + private void setMaxResponseSize(ByteSizeValue maxResponseSize) { this.maxResponseSize = maxResponseSize; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerClient.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerClient.java new file mode 100644 index 000000000000..854bb1098e42 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerClient.java @@ -0,0 +1,251 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker; + +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; +import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; +import software.amazon.awssdk.profiles.ProfileFile; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.sagemakerruntime.SageMakerRuntimeAsyncClient; +import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointRequest; +import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse; +import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamRequest; +import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponse; +import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponseHandler; +import software.amazon.awssdk.services.sagemakerruntime.model.ResponseStream; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.SpecialPermission; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ContextPreservingActionListener; +import org.elasticsearch.action.support.ListenerTimeouts; +import org.elasticsearch.common.cache.Cache; +import org.elasticsearch.common.cache.CacheBuilder; +import org.elasticsearch.common.cache.CacheLoader; +import org.elasticsearch.common.util.concurrent.FutureUtils; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings; +import org.elasticsearch.xpack.inference.external.http.HttpSettings; +import org.reactivestreams.FlowAdapters; + +import java.io.Closeable; +import java.security.AccessController; +import java.security.PrivilegedExceptionAction; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Flow; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; + +public class SageMakerClient implements Closeable { + private static final Logger log = LogManager.getLogger(SageMakerClient.class); + private final Cache existingClients = CacheBuilder.< + RegionAndSecrets, + SageMakerRuntimeAsyncClient>builder() + .removalListener(removal -> removal.getValue().close()) + .setExpireAfterAccess(TimeValue.timeValueMinutes(15)) + .build(); + + private final CacheLoader clientFactory; + private final ThreadPool threadPool; + + public SageMakerClient(CacheLoader clientFactory, ThreadPool threadPool) { + this.clientFactory = clientFactory; + this.threadPool = threadPool; + } + + public void invoke( + RegionAndSecrets regionAndSecrets, + InvokeEndpointRequest request, + TimeValue timeout, + ActionListener listener + ) { + SageMakerRuntimeAsyncClient asyncClient; + try { + asyncClient = existingClients.computeIfAbsent(regionAndSecrets, clientFactory); + } catch (ExecutionException e) { + listener.onFailure(clientFailure(regionAndSecrets, e)); + return; + } + + var contextPreservingListener = new ContextPreservingActionListener<>( + threadPool.getThreadContext().newRestorableContext(false), + listener + ); + + var awsFuture = asyncClient.invokeEndpoint(request); + var timeoutListener = ListenerTimeouts.wrapWithTimeout( + threadPool, + timeout, + threadPool.executor(UTILITY_THREAD_POOL_NAME), + contextPreservingListener, + ignored -> { + FutureUtils.cancel(awsFuture); + contextPreservingListener.onFailure( + new ElasticsearchStatusException("Request timed out after [{}]", RestStatus.REQUEST_TIMEOUT, timeout) + ); + } + ); + awsFuture.thenAcceptAsync(timeoutListener::onResponse, threadPool.executor(UTILITY_THREAD_POOL_NAME)) + .exceptionallyAsync(t -> failAndMaybeThrowError(t, timeoutListener), threadPool.executor(UTILITY_THREAD_POOL_NAME)); + } + + private static Exception clientFailure(RegionAndSecrets regionAndSecrets, Exception cause) { + return new ElasticsearchStatusException( + "failed to create SageMakerRuntime client for region [{}]", + RestStatus.INTERNAL_SERVER_ERROR, + cause, + regionAndSecrets.region() + ); + } + + private Void failAndMaybeThrowError(Throwable t, ActionListener listener) { + if (t instanceof CompletionException ce) { + t = ce.getCause(); + } + if (t instanceof Exception e) { + listener.onFailure(e); + } else { + ExceptionsHelper.maybeError(t).ifPresent(ExceptionsHelper::maybeDieOnAnotherThread); + log.atWarn().withThrowable(t).log("Unknown failure calling SageMaker."); + listener.onFailure(new RuntimeException("Unknown failure calling SageMaker.", t)); + } + return null; // Void + } + + public void invokeStream( + RegionAndSecrets regionAndSecrets, + InvokeEndpointWithResponseStreamRequest request, + TimeValue timeout, + ActionListener listener + ) { + SageMakerRuntimeAsyncClient asyncClient; + try { + asyncClient = existingClients.computeIfAbsent(regionAndSecrets, clientFactory); + } catch (ExecutionException e) { + listener.onFailure(clientFailure(regionAndSecrets, e)); + return; + } + + var contextPreservingListener = new ContextPreservingActionListener<>( + threadPool.getThreadContext().newRestorableContext(false), + listener + ); + + var responseStreamProcessor = new SageMakerStreamingResponseProcessor(); + var cancelAwsRequestListener = new AtomicReference>(); + var timeoutListener = ListenerTimeouts.wrapWithTimeout( + threadPool, + timeout, + threadPool.executor(UTILITY_THREAD_POOL_NAME), + contextPreservingListener, + ignored -> { + FutureUtils.cancel(cancelAwsRequestListener.get()); + contextPreservingListener.onFailure( + new ElasticsearchStatusException("Request timed out after [{}]", RestStatus.REQUEST_TIMEOUT, timeout) + ); + } + ); + // To stay consistent with HTTP providers, we cancel the TimeoutListener onResponse because we are measuring the time it takes to + // start receiving bytes. + var responseStreamListener = InvokeEndpointWithResponseStreamResponseHandler.builder() + .onResponse(response -> timeoutListener.onResponse(new SageMakerStream(response, responseStreamProcessor))) + .onEventStream(publisher -> responseStreamProcessor.setPublisher(FlowAdapters.toFlowPublisher(publisher))) + .build(); + var awsFuture = asyncClient.invokeEndpointWithResponseStream(request, responseStreamListener); + cancelAwsRequestListener.set(awsFuture); + awsFuture.exceptionallyAsync(t -> failAndMaybeThrowError(t, timeoutListener), threadPool.executor(UTILITY_THREAD_POOL_NAME)); + } + + @Override + public void close() { + existingClients.invalidateAll(); // will close each cached client + } + + public record RegionAndSecrets(String region, AwsSecretSettings secretSettings) {} + + public static class Factory implements CacheLoader { + private final HttpSettings httpSettings; + + public Factory(HttpSettings httpSettings) { + this.httpSettings = httpSettings; + } + + @Override + public SageMakerRuntimeAsyncClient load(RegionAndSecrets key) throws Exception { + SpecialPermission.check(); + // TODO migrate to entitlements + return AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + var credentials = AwsBasicCredentials.create( + key.secretSettings().accessKey().toString(), + key.secretSettings().secretKey().toString() + ); + var credentialsProvider = StaticCredentialsProvider.create(credentials); + var clientConfig = NettyNioAsyncHttpClient.builder().connectionTimeout(httpSettings.connectionTimeoutDuration()); + var override = ClientOverrideConfiguration.builder() + // disable profileFile, user credentials will always come from the configured Model Secrets + .defaultProfileFileSupplier(ProfileFile.aggregator()::build) + .defaultProfileFile(ProfileFile.aggregator().build()) + .retryPolicy(retryPolicy -> retryPolicy.numRetries(3)) + .retryStrategy(retryStrategy -> retryStrategy.maxAttempts(3)) + .build(); + return SageMakerRuntimeAsyncClient.builder() + .credentialsProvider(credentialsProvider) + .region(Region.of(key.region())) + .httpClientBuilder(clientConfig) + .overrideConfiguration(override) + .build(); + }); + } + } + + private static class SageMakerStreamingResponseProcessor implements Flow.Publisher { + private static final Logger log = LogManager.getLogger(SageMakerStreamingResponseProcessor.class); + private final AtomicReference, Flow.Subscriber>> holder = + new AtomicReference<>(null); + private final AtomicBoolean subscribeCalledOnce = new AtomicBoolean(false); + + @Override + public void subscribe(Flow.Subscriber subscriber) { + if (subscribeCalledOnce.compareAndSet(false, true) == false) { + subscriber.onError(new IllegalStateException("Subscriber already set.")); + return; + } + if (holder.compareAndSet(null, Tuple.tuple(null, subscriber)) == false) { + log.debug("Subscriber connecting to publisher."); + var publisher = holder.getAndSet(null).v1(); + publisher.subscribe(subscriber); + } else { + log.debug("Subscriber waiting for connection."); + } + } + + private void setPublisher(Flow.Publisher publisher) { + if (holder.compareAndSet(null, Tuple.tuple(publisher, null)) == false) { + log.debug("Publisher connecting to subscriber."); + var subscriber = holder.getAndSet(null).v2(); + publisher.subscribe(subscriber); + } else { + log.debug("Publisher waiting for connection."); + } + } + } + + public record SageMakerStream(InvokeEndpointWithResponseStreamResponse response, Flow.Publisher responseStream) {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerInferenceRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerInferenceRequest.java new file mode 100644 index 000000000000..60cf70edeea2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerInferenceRequest.java @@ -0,0 +1,28 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; + +import java.util.List; +import java.util.Objects; + +public record SageMakerInferenceRequest( + @Nullable String query, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + List input, + boolean stream, + InputType inputType +) { + public SageMakerInferenceRequest { + Objects.requireNonNull(input); + Objects.requireNonNull(inputType); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java new file mode 100644 index 000000000000..e4b699424aec --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerService.java @@ -0,0 +1,305 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.common.CheckedSupplier; +import org.elasticsearch.common.util.LazyInitializable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; +import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails; + +public class SageMakerService implements InferenceService { + public static final String NAME = "sagemaker"; + private static final int DEFAULT_BATCH_SIZE = 256; + private final SageMakerModelBuilder modelBuilder; + private final SageMakerClient client; + private final SageMakerSchemas schemas; + private final ThreadPool threadPool; + private final LazyInitializable configuration; + + public SageMakerService( + SageMakerModelBuilder modelBuilder, + SageMakerClient client, + SageMakerSchemas schemas, + ThreadPool threadPool, + CheckedSupplier, RuntimeException> configurationMap + ) { + this.modelBuilder = modelBuilder; + this.client = client; + this.schemas = schemas; + this.threadPool = threadPool; + this.configuration = new LazyInitializable<>( + () -> new InferenceServiceConfiguration.Builder().setService(NAME) + .setName("Amazon SageMaker") + .setTaskTypes(supportedTaskTypes()) + .setConfigurations(configurationMap.get()) + .build() + ); + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String modelId, + TaskType taskType, + Map config, + ActionListener parsedModelListener + ) { + ActionListener.completeWith(parsedModelListener, () -> modelBuilder.fromRequest(modelId, taskType, NAME, config)); + } + + @Override + public Model parsePersistedConfigWithSecrets( + String modelId, + TaskType taskType, + Map config, + Map secrets + ) { + return modelBuilder.fromStorage(modelId, taskType, NAME, config, secrets); + } + + @Override + public Model parsePersistedConfig(String modelId, TaskType taskType, Map config) { + return modelBuilder.fromStorage(modelId, taskType, NAME, config, null); + } + + @Override + public InferenceServiceConfiguration getConfiguration() { + return configuration.getOrCompute(); + } + + @Override + public EnumSet supportedTaskTypes() { + return schemas.supportedTaskTypes(); + } + + @Override + public Set supportedStreamingTasks() { + return schemas.supportedStreamingTasks(); + } + + @Override + public void infer( + Model model, + @Nullable String query, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + List input, + boolean stream, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof SageMakerModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + var inferenceRequest = new SageMakerInferenceRequest(query, returnDocuments, topN, input, stream, inputType); + + try { + var sageMakerModel = ((SageMakerModel) model).override(taskSettings); + var regionAndSecrets = regionAndSecrets(sageMakerModel); + + if (stream) { + var schema = schemas.streamSchemaFor(sageMakerModel); + var request = schema.streamRequest(sageMakerModel, inferenceRequest); + client.invokeStream( + regionAndSecrets, + request, + timeout, + ActionListener.wrap( + response -> listener.onResponse(schema.streamResponse(sageMakerModel, response)), + e -> listener.onFailure(schema.error(sageMakerModel, e)) + ) + ); + } else { + var schema = schemas.schemaFor(sageMakerModel); + var request = schema.request(sageMakerModel, inferenceRequest); + client.invoke( + regionAndSecrets, + request, + timeout, + ActionListener.wrap( + response -> listener.onResponse(schema.response(sageMakerModel, response, threadPool.getThreadContext())), + e -> listener.onFailure(schema.error(sageMakerModel, e)) + ) + ); + } + } catch (Exception e) { + listener.onFailure(internalFailure(model, e)); + } + } + + private SageMakerClient.RegionAndSecrets regionAndSecrets(SageMakerModel model) { + var secrets = model.awsSecretSettings(); + if (secrets.isEmpty()) { + assert false : "Cannot invoke a model without secrets"; + throw new ElasticsearchStatusException( + format("Attempting to infer using a model without API keys, inference id [%s]", model.getInferenceEntityId()), + RestStatus.INTERNAL_SERVER_ERROR + ); + } + return new SageMakerClient.RegionAndSecrets(model.region(), secrets.get()); + } + + private static ElasticsearchStatusException internalFailure(Model model, Exception cause) { + if (cause instanceof ElasticsearchStatusException ese) { + return ese; + } else { + return new ElasticsearchStatusException( + "Failed to call SageMaker for inference id [{}].", + RestStatus.INTERNAL_SERVER_ERROR, + cause, + model.getInferenceEntityId() + ); + } + } + + @Override + public void unifiedCompletionInfer( + Model model, + UnifiedCompletionRequest request, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof SageMakerModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + try { + var sageMakerModel = (SageMakerModel) model; + var regionAndSecrets = regionAndSecrets(sageMakerModel); + var schema = schemas.streamSchemaFor(sageMakerModel); + var sagemakerRequest = schema.chatCompletionStreamRequest(sageMakerModel, request); + client.invokeStream( + regionAndSecrets, + sagemakerRequest, + timeout, + ActionListener.wrap( + response -> listener.onResponse(schema.chatCompletionStreamResponse(sageMakerModel, response)), + e -> listener.onFailure(schema.chatCompletionError(sageMakerModel, e)) + ) + ); + } catch (Exception e) { + listener.onFailure(internalFailure(model, e)); + } + } + + @Override + public void chunkedInfer( + Model model, + String query, + List input, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener> listener + ) { + if (model instanceof SageMakerModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + try { + var sageMakerModel = ((SageMakerModel) model).override(taskSettings); + var batchedRequests = new EmbeddingRequestChunker<>( + input, + sageMakerModel.batchSize().orElse(DEFAULT_BATCH_SIZE), + sageMakerModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + + var subscribableListener = SubscribableListener.newSucceeded(null); + for (var request : batchedRequests) { + subscribableListener = subscribableListener.andThen( + threadPool.executor(UTILITY_THREAD_POOL_NAME), + threadPool.getThreadContext(), + (l, ignored) -> infer( + sageMakerModel, + query, + null, // no return docs while chunking? + null, // no topN while chunking? + request.batch().inputs().get(), + false, // we never stream when chunking + null, // since we pass sageMakerModel as the model, we already overwrote the model with the task settings + inputType, + timeout, + ActionListener.runAfter(request.listener(), () -> l.onResponse(null)) + ) + ); + } + // if there were any errors trying to create the SubscribableListener chain, then forward that to the listener + // otherwise, BatchRequestAndListener will handle forwarding errors from the infer method + subscribableListener.addListener( + ActionListener.noop().delegateResponse((l, e) -> listener.onFailure(internalFailure(model, e))) + ); + } catch (Exception e) { + listener.onFailure(internalFailure(model, e)); + } + } + + @Override + public void start(Model model, TimeValue timeout, ActionListener listener) { + listener.onResponse(true); + } + + @Override + public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { + if (model instanceof SageMakerModel sageMakerModel) { + return modelBuilder.updateModelWithEmbeddingDetails(sageMakerModel, embeddingSize); + } + + throw invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass()); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_SAGEMAKER; + } + + @Override + public void close() throws IOException { + client.close(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerConfiguration.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerConfiguration.java new file mode 100644 index 000000000000..d3eb74f56f41 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerConfiguration.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.model; + +import org.elasticsearch.common.CheckedSupplier; +import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas; + +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class SageMakerConfiguration implements CheckedSupplier, RuntimeException> { + private final SageMakerSchemas schemas; + + public SageMakerConfiguration(SageMakerSchemas schemas) { + this.schemas = schemas; + } + + @Override + public Map get() { + return Stream.of( + AwsSecretSettings.configuration(schemas.supportedTaskTypes()), + SageMakerServiceSettings.configuration(schemas.supportedTaskTypes()), + SageMakerTaskSettings.configuration(schemas.supportedTaskTypes()) + ).flatMap(Function.identity()).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModel.java new file mode 100644 index 000000000000..48e32c741a60 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModel.java @@ -0,0 +1,146 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.model; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema; + +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** + * This model represents all models in SageMaker. SageMaker maintains a base set of settings and configurations, and this model manages + * those. Any settings that are required for a specific model are stored in the {@link SageMakerStoredServiceSchema} and + * {@link SageMakerStoredTaskSchema}. + * Design: + * - Region is stored in ServiceSettings and is used to create the SageMaker client. + * - RateLimiting is based on AWS Service Quota, metered by account and region. The SDK client handles rate limiting internally. In order to + * rate limit appropriately, use the same AWS Access ID, Secret Key, and Region for each inference endpoint calling the same AWS account. + * - Within Elastic, you cannot change the model behind the endpoint, because we require the embedding shape remain consistent between + * invocations. So anything "model selection" related must remain consistent through the lifetime of the Elastic endpoint. SageMaker + * allows model changes via TargetModel, InferenceComponentName, and TargetContainerHostname, but these will remain static in Elastic + * within ServiceSettings. + * - CustomAttributes, EnableExplanations, InferenceId, SessionId, and TargetVariant are all request-time fields that can be saved. + * - SageMaker returns 4 headers, which Elastic will forward as-is: x-Amzn-Invoked-Production-Variant, X-Amzn-SageMaker-Custom-Attributes, + * X-Amzn-SageMaker-New-Session-Id, X-Amzn-SageMaker-Closed-Session-Id + */ +public class SageMakerModel extends Model { + private final SageMakerServiceSettings serviceSettings; + private final SageMakerTaskSettings taskSettings; + private final AwsSecretSettings awsSecretSettings; + + SageMakerModel( + ModelConfigurations configurations, + ModelSecrets secrets, + SageMakerServiceSettings serviceSettings, + SageMakerTaskSettings taskSettings, + AwsSecretSettings awsSecretSettings + ) { + super(configurations, secrets); + this.serviceSettings = serviceSettings; + this.taskSettings = taskSettings; + this.awsSecretSettings = awsSecretSettings; + } + + public Optional awsSecretSettings() { + return Optional.ofNullable(awsSecretSettings); + } + + public String region() { + return serviceSettings.region(); + } + + public String endpointName() { + return serviceSettings.endpointName(); + } + + public String api() { + return serviceSettings.api(); + } + + public Optional customAttributes() { + return Optional.ofNullable(taskSettings.customAttributes()); + } + + public Optional enableExplanations() { + return Optional.ofNullable(taskSettings.enableExplanations()); + } + + public Optional inferenceComponentName() { + return Optional.ofNullable(serviceSettings.inferenceComponentName()); + } + + public Optional inferenceIdForDataCapture() { + return Optional.ofNullable(taskSettings.inferenceIdForDataCapture()); + } + + public Optional sessionId() { + return Optional.ofNullable(taskSettings.sessionId()); + } + + public Optional targetContainerHostname() { + return Optional.ofNullable(serviceSettings.targetContainerHostname()); + } + + public Optional targetModel() { + return Optional.ofNullable(serviceSettings.targetModel()); + } + + public Optional targetVariant() { + return Optional.ofNullable(taskSettings.targetVariant()); + } + + public Optional batchSize() { + return Optional.ofNullable(serviceSettings.batchSize()); + } + + public SageMakerModel override(Map taskSettingsOverride) { + if (taskSettingsOverride == null || taskSettingsOverride.isEmpty()) { + return this; + } + + return new SageMakerModel( + getConfigurations(), + getSecrets(), + serviceSettings, + taskSettings.updatedTaskSettings(taskSettingsOverride), + awsSecretSettings + ); + } + + public static List namedWriteables() { + return List.of( + new NamedWriteableRegistry.Entry(ServiceSettings.class, SageMakerServiceSettings.NAME, SageMakerServiceSettings::new), + new NamedWriteableRegistry.Entry(TaskSettings.class, SageMakerTaskSettings.NAME, SageMakerTaskSettings::new) + ); + } + + public SageMakerStoredServiceSchema apiServiceSettings() { + return serviceSettings.apiServiceSettings(); + } + + public SageMakerStoredTaskSchema apiTaskSettings() { + return taskSettings.apiTaskSettings(); + } + + SageMakerServiceSettings serviceSettings() { + return serviceSettings; + } + + SageMakerTaskSettings taskSettings() { + return taskSettings; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModelBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModelBuilder.java new file mode 100644 index 000000000000..193f53cb3eb1 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModelBuilder.java @@ -0,0 +1,135 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.model; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas; + +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; + +public class SageMakerModelBuilder { + + private final SageMakerSchemas schemas; + + public SageMakerModelBuilder(SageMakerSchemas schemas) { + this.schemas = schemas; + } + + public SageMakerModel fromRequest(String inferenceEntityId, TaskType taskType, String service, Map requestMap) { + var validationException = new ValidationException(); + var serviceSettingsMap = removeFromMapOrThrowIfNull(requestMap, ModelConfigurations.SERVICE_SETTINGS); + var awsSecretSettings = AwsSecretSettings.fromMap(serviceSettingsMap); + var serviceSettings = SageMakerServiceSettings.fromMap(schemas, taskType, serviceSettingsMap); + + var schema = schemas.schemaFor(taskType, serviceSettings.api()); + + var taskSettingsMap = removeFromMapOrDefaultEmpty(requestMap, ModelConfigurations.TASK_SETTINGS); + var taskSettings = SageMakerTaskSettings.fromMap( + taskSettingsMap, + schema.apiTaskSettings(taskSettingsMap, validationException), + validationException + ); + + validationException.throwIfValidationErrorsExist(); + throwIfNotEmptyMap(serviceSettingsMap, service); + throwIfNotEmptyMap(taskSettingsMap, service); + + var modelConfigurations = new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings); + return new SageMakerModel( + modelConfigurations, + new ModelSecrets(awsSecretSettings), + serviceSettings, + taskSettings, + awsSecretSettings + ); + } + + public SageMakerModel fromStorage( + String inferenceEntityId, + TaskType taskType, + String service, + Map config, + Map secrets + ) { + var validationException = new ValidationException(); + var awsSecretSettings = secrets != null + ? AwsSecretSettings.fromMap(removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS)) + : null; + + var serviceSettings = SageMakerServiceSettings.fromMap( + schemas, + taskType, + removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS) + ); + + var schema = schemas.schemaFor(taskType, serviceSettings.api()); + var taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + + var taskSettings = SageMakerTaskSettings.fromMap( + taskSettingsMap, + schema.apiTaskSettings(taskSettingsMap, validationException), + validationException + ); + + validationException.throwIfValidationErrorsExist(); + throwIfNotEmptyMap(config, service); + throwIfNotEmptyMap(taskSettingsMap, service); + throwIfNotEmptyMap(secrets, service); + + var modelConfigurations = new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings); + return new SageMakerModel( + modelConfigurations, + new ModelSecrets(awsSecretSettings), + serviceSettings, + taskSettings, + awsSecretSettings + ); + } + + public SageMakerModel updateModelWithEmbeddingDetails(SageMakerModel model, int embeddingSize) { + var updatedApiServiceSettings = model.apiServiceSettings().updateModelWithEmbeddingDetails(embeddingSize); + + if (updatedApiServiceSettings == model.apiServiceSettings()) { + return model; + } + + var updatedServiceSettings = new SageMakerServiceSettings( + model.serviceSettings().endpointName(), + model.serviceSettings().region(), + model.serviceSettings().api(), + model.serviceSettings().targetModel(), + model.serviceSettings().targetContainerHostname(), + model.serviceSettings().inferenceComponentName(), + model.serviceSettings().batchSize(), + updatedApiServiceSettings + ); + + var modelConfigurations = new ModelConfigurations( + model.getInferenceEntityId(), + model.getTaskType(), + model.getConfigurations().getService(), + updatedServiceSettings, + model.taskSettings() + ); + return new SageMakerModel( + modelConfigurations, + model.getSecrets(), + updatedServiceSettings, + model.taskSettings(), + model.awsSecretSettings().orElse(null) + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerServiceSettings.java new file mode 100644 index 000000000000..2caf97bdd05b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerServiceSettings.java @@ -0,0 +1,285 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.model; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Stream; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; + +/** + * Maintains the settings for SageMaker that cannot be changed without impacting semantic search and AI assistants. + * Model-specific settings are stored in {@link SageMakerStoredServiceSchema}. + */ +record SageMakerServiceSettings( + String endpointName, + String region, + String api, + @Nullable String targetModel, + @Nullable String targetContainerHostname, + @Nullable String inferenceComponentName, + @Nullable Integer batchSize, + SageMakerStoredServiceSchema apiServiceSettings +) implements ServiceSettings { + + static final String NAME = "sage_maker_service_settings"; + private static final String API = "api"; + private static final String ENDPOINT_NAME = "endpoint_name"; + private static final String REGION = "region"; + private static final String TARGET_MODEL = "target_model"; + private static final String TARGET_CONTAINER_HOSTNAME = "target_container_hostname"; + private static final String INFERENCE_COMPONENT_NAME = "inference_component_name"; + private static final String BATCH_SIZE = "batch_size"; + + SageMakerServiceSettings { + Objects.requireNonNull(endpointName); + Objects.requireNonNull(region); + Objects.requireNonNull(api); + Objects.requireNonNull(apiServiceSettings); + } + + SageMakerServiceSettings(StreamInput in) throws IOException { + this( + in.readString(), + in.readString(), + in.readString(), + in.readOptionalString(), + in.readOptionalString(), + in.readOptionalString(), + in.readOptionalInt(), + in.readNamedWriteable(SageMakerStoredServiceSchema.class) + ); + } + + @Override + public String modelId() { + return apiServiceSettings.modelId(); + } + + @Override + public SimilarityMeasure similarity() { + return apiServiceSettings.similarity(); + } + + @Override + public Integer dimensions() { + return apiServiceSettings.dimensions(); + } + + @Override + public Boolean dimensionsSetByUser() { + return apiServiceSettings.dimensionsSetByUser(); + } + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return apiServiceSettings.elementType(); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_SAGEMAKER; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(endpointName()); + out.writeString(region()); + out.writeString(api()); + out.writeOptionalString(targetModel()); + out.writeOptionalString(targetContainerHostname()); + out.writeOptionalString(inferenceComponentName()); + out.writeOptionalInt(batchSize()); + out.writeNamedWriteable(apiServiceSettings); + } + + @Override + public ToXContentObject getFilteredXContentObject() { + return this; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.field(ENDPOINT_NAME, endpointName()); + builder.field(REGION, region()); + builder.field(API, api()); + optionalField(TARGET_MODEL, targetModel(), builder); + optionalField(TARGET_CONTAINER_HOSTNAME, targetContainerHostname(), builder); + optionalField(INFERENCE_COMPONENT_NAME, inferenceComponentName(), builder); + optionalField(BATCH_SIZE, batchSize(), builder); + apiServiceSettings.toXContent(builder, params); + + return builder.endObject(); + } + + private static void optionalField(String name, T value, XContentBuilder builder) throws IOException { + if (value != null) { + builder.field(name, value); + } + } + + static SageMakerServiceSettings fromMap(SageMakerSchemas schemas, TaskType taskType, Map serviceSettingsMap) { + ValidationException validationException = new ValidationException(); + + var endpointName = extractRequiredString( + serviceSettingsMap, + ENDPOINT_NAME, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + var region = extractRequiredString(serviceSettingsMap, REGION, ModelConfigurations.SERVICE_SETTINGS, validationException); + var api = extractRequiredString(serviceSettingsMap, API, ModelConfigurations.SERVICE_SETTINGS, validationException); + var targetModel = extractOptionalString( + serviceSettingsMap, + TARGET_MODEL, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + var targetContainerHostname = extractOptionalString( + serviceSettingsMap, + TARGET_CONTAINER_HOSTNAME, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + var inferenceComponentName = extractOptionalString( + serviceSettingsMap, + INFERENCE_COMPONENT_NAME, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + var batchSize = extractOptionalPositiveInteger( + serviceSettingsMap, + BATCH_SIZE, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + + validationException.throwIfValidationErrorsExist(); + + var schema = schemas.schemaFor(taskType, api); + var apiServiceSettings = schema.apiServiceSettings(serviceSettingsMap, validationException); + + validationException.throwIfValidationErrorsExist(); + + return new SageMakerServiceSettings( + endpointName, + region, + api, + targetModel, + targetContainerHostname, + inferenceComponentName, + batchSize, + apiServiceSettings + ); + } + + static Stream> configuration(EnumSet supportedTaskTypes) { + return Stream.of( + Map.entry( + API, + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("The API format that your SageMaker Endpoint expects.") + .setLabel("API") + .setRequired(true) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ), + Map.entry( + ENDPOINT_NAME, + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription( + "The name specified when creating the SageMaker Endpoint." + ) + .setLabel("Endpoint Name") + .setRequired(true) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ), + Map.entry( + REGION, + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription( + "The AWS region that your model or ARN is deployed in." + ) + .setLabel("Region") + .setRequired(true) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ), + Map.entry( + TARGET_MODEL, + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription( + "The model to request when calling a SageMaker multi-model Endpoint." + ) + .setLabel("Target Model") + .setRequired(false) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ), + Map.entry( + TARGET_CONTAINER_HOSTNAME, + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription( + "The hostname of the container when calling a SageMaker multi-container Endpoint." + ) + .setLabel("Target Container Hostname") + .setRequired(false) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ), + Map.entry( + BATCH_SIZE, + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription( + "The maximum size a single chunk of input can be when chunking input for semantic text." + ) + .setLabel("Batch Size") + .setRequired(false) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.INTEGER) + .build() + ) + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerTaskSettings.java new file mode 100644 index 000000000000..c1c244cc3705 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerTaskSettings.java @@ -0,0 +1,236 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.model; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.Map; +import java.util.stream.Stream; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; + +/** + * Maintains mutable settings for SageMaker. Model-specific settings are stored in {@link SageMakerStoredTaskSchema}. + */ +record SageMakerTaskSettings( + @Nullable String customAttributes, + @Nullable String enableExplanations, + @Nullable String inferenceIdForDataCapture, + @Nullable String sessionId, + @Nullable String targetVariant, + SageMakerStoredTaskSchema apiTaskSettings +) implements TaskSettings { + + static final String NAME = "sage_maker_task_settings"; + private static final String CUSTOM_ATTRIBUTES = "custom_attributes"; + private static final String ENABLE_EXPLANATIONS = "enable_explanations"; + private static final String INFERENCE_ID = "inference_id"; + private static final String SESSION_ID = "session_id"; + private static final String TARGET_VARIANT = "target_variant"; + + SageMakerTaskSettings(StreamInput in) throws IOException { + this( + in.readOptionalString(), + in.readOptionalString(), + in.readOptionalString(), + in.readOptionalString(), + in.readOptionalString(), + in.readNamedWriteable(SageMakerStoredTaskSchema.class) + ); + } + + @Override + public boolean isEmpty() { + return customAttributes == null + && enableExplanations == null + && inferenceIdForDataCapture == null + && sessionId == null + && targetVariant == null + && SageMakerStoredTaskSchema.NO_OP.equals(apiTaskSettings); + } + + @Override + public SageMakerTaskSettings updatedTaskSettings(Map newSettings) { + var validationException = new ValidationException(); + + var updateTaskSettings = fromMap(newSettings, apiTaskSettings.updatedTaskSettings(newSettings), validationException); + + validationException.throwIfValidationErrorsExist(); + + var updatedExtraTaskSettings = updateTaskSettings.apiTaskSettings().equals(SageMakerStoredTaskSchema.NO_OP) + ? apiTaskSettings + : updateTaskSettings.apiTaskSettings(); + + return new SageMakerTaskSettings( + firstNotNullOrNull(updateTaskSettings.customAttributes(), customAttributes), + firstNotNullOrNull(updateTaskSettings.enableExplanations(), enableExplanations), + firstNotNullOrNull(updateTaskSettings.inferenceIdForDataCapture(), inferenceIdForDataCapture), + firstNotNullOrNull(updateTaskSettings.sessionId(), sessionId), + firstNotNullOrNull(updateTaskSettings.targetVariant(), targetVariant), + updatedExtraTaskSettings + ); + } + + private static T firstNotNullOrNull(T first, T second) { + return first != null ? first : second; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_SAGEMAKER; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(customAttributes); + out.writeOptionalString(enableExplanations); + out.writeOptionalString(inferenceIdForDataCapture); + out.writeOptionalString(sessionId); + out.writeOptionalString(targetVariant); + out.writeNamedWriteable(apiTaskSettings); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + optionalField(CUSTOM_ATTRIBUTES, customAttributes, builder); + optionalField(ENABLE_EXPLANATIONS, enableExplanations, builder); + optionalField(INFERENCE_ID, inferenceIdForDataCapture, builder); + optionalField(SESSION_ID, sessionId, builder); + optionalField(TARGET_VARIANT, targetVariant, builder); + apiTaskSettings.toXContent(builder, params); + + return builder.endObject(); + } + + private static void optionalField(String name, T value, XContentBuilder builder) throws IOException { + if (value != null) { + builder.field(name, value); + } + } + + public static SageMakerTaskSettings fromMap( + Map taskSettingsMap, + SageMakerStoredTaskSchema apiTaskSettings, + ValidationException validationException + ) { + var customAttributes = extractOptionalString( + taskSettingsMap, + CUSTOM_ATTRIBUTES, + ModelConfigurations.TASK_SETTINGS, + validationException + ); + var enableExplanations = extractOptionalString( + taskSettingsMap, + ENABLE_EXPLANATIONS, + ModelConfigurations.TASK_SETTINGS, + validationException + ); + var inferenceIdForDataCapture = extractOptionalString( + taskSettingsMap, + INFERENCE_ID, + ModelConfigurations.TASK_SETTINGS, + validationException + ); + var sessionId = extractOptionalString(taskSettingsMap, SESSION_ID, ModelConfigurations.TASK_SETTINGS, validationException); + var targetVariant = extractOptionalString(taskSettingsMap, TARGET_VARIANT, ModelConfigurations.TASK_SETTINGS, validationException); + + return new SageMakerTaskSettings( + customAttributes, + enableExplanations, + inferenceIdForDataCapture, + sessionId, + targetVariant, + apiTaskSettings + ); + } + + static Stream> configuration(EnumSet supportedTaskTypes) { + return Stream.of( + Map.entry( + CUSTOM_ATTRIBUTES, + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription( + "An opaque informational value forwarded as-is to the model within SageMaker." + ) + .setLabel("Custom Attributes") + .setRequired(false) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ), + Map.entry( + ENABLE_EXPLANATIONS, + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription( + "JMESPath expression overriding the ClarifyingExplainerConfig in the SageMaker Endpoint Configuration." + ) + .setLabel("Enable Explanations") + .setRequired(false) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ), + Map.entry( + INFERENCE_ID, + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription( + "Informational identifying for auditing requests within the SageMaker Endpoint." + ) + .setLabel("Inference ID") + .setRequired(false) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ), + Map.entry( + SESSION_ID, + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription( + "Creates or reuses an existing Session for SageMaker stateful models." + ) + .setLabel("Session ID") + .setRequired(false) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ), + Map.entry( + TARGET_VARIANT, + new SettingsConfiguration.Builder(supportedTaskTypes).setDescription( + "The production variant when calling the SageMaker Endpoint" + ) + .setLabel("Target Variant") + .setRequired(false) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ) + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchema.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchema.java new file mode 100644 index 000000000000..3a39bb804e23 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchema.java @@ -0,0 +1,215 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema; + +import software.amazon.awssdk.services.sagemakerruntime.model.InternalDependencyException; +import software.amazon.awssdk.services.sagemakerruntime.model.InternalFailureException; +import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointRequest; +import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse; +import software.amazon.awssdk.services.sagemakerruntime.model.ModelErrorException; +import software.amazon.awssdk.services.sagemakerruntime.model.ModelNotReadyException; +import software.amazon.awssdk.services.sagemakerruntime.model.SageMakerRuntimeException; +import software.amazon.awssdk.services.sagemakerruntime.model.ServiceUnavailableException; +import software.amazon.awssdk.services.sagemakerruntime.model.ValidationErrorException; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest; +import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; + +import java.util.Map; +import java.util.stream.Stream; + +import static org.elasticsearch.core.Strings.format; + +/** + * All the logic that is required to call any SageMaker model is handled within this Schema class. + * Any model-specific logic is handled within the associated {@link SageMakerSchemaPayload}. + * This schema is specific for SageMaker's non-streaming API. For streaming, see {@link SageMakerStreamSchema}. + */ +public class SageMakerSchema { + private static final String CUSTOM_ATTRIBUTES_HEADER = "X-elastic-sagemaker-custom-attributes"; + private static final String NEW_SESSION_HEADER = "X-elastic-sagemaker-new-session-id"; + private static final String CLOSED_SESSION_HEADER = "X-elastic-sagemaker-closed-session-id"; + + private static final String ACCESS_DENIED_CODE = "AccessDeniedException"; + private static final String INCOMPLETE_SIGNATURE = "IncompleteSignature"; + private static final String INVALID_ACTION = "InvalidAction"; + private static final String INVALID_CLIENT_TOKEN = "InvalidClientTokenId"; + private static final String NOT_AUTHORIZED = "NotAuthorized"; + private static final String OPT_IN_REQUIRED = "OptInRequired"; + private static final String REQUEST_EXPIRED = "RequestExpired"; + private static final String THROTTLING_EXCEPTION = "ThrottlingException"; + + private final SageMakerSchemaPayload schemaPayload; + + public SageMakerSchema(SageMakerSchemaPayload schemaPayload) { + this.schemaPayload = schemaPayload; + } + + public InvokeEndpointRequest request(SageMakerModel model, SageMakerInferenceRequest request) { + try { + return createRequest(model).accept(schemaPayload.accept(model)) + .contentType(schemaPayload.contentType(model)) + .body(schemaPayload.requestBytes(model, request)) + .build(); + } catch (ElasticsearchStatusException e) { + throw e; + } catch (Exception e) { + throw new ElasticsearchStatusException( + "Failed to create SageMaker request for [%s]", + RestStatus.INTERNAL_SERVER_ERROR, + e, + model.getInferenceEntityId() + ); + } + } + + private InvokeEndpointRequest.Builder createRequest(SageMakerModel model) { + var request = InvokeEndpointRequest.builder(); + request.endpointName(model.endpointName()); + model.customAttributes().ifPresent(request::customAttributes); + model.enableExplanations().ifPresent(request::enableExplanations); + model.inferenceComponentName().ifPresent(request::inferenceComponentName); + model.inferenceIdForDataCapture().ifPresent(request::inferenceId); + model.sessionId().ifPresent(request::sessionId); + model.targetContainerHostname().ifPresent(request::targetContainerHostname); + model.targetModel().ifPresent(request::targetModel); + model.targetVariant().ifPresent(request::targetVariant); + return request; + } + + public InferenceServiceResults response(SageMakerModel model, InvokeEndpointResponse response, ThreadContext threadContext) + throws Exception { + try { + addHeaders(response, threadContext); + return schemaPayload.responseBody(model, response); + } catch (ElasticsearchStatusException e) { + throw e; + } catch (Exception e) { + throw new ElasticsearchStatusException( + "Failed to translate SageMaker response for [%s]", + RestStatus.INTERNAL_SERVER_ERROR, + e, + model.getInferenceEntityId() + ); + } + } + + private void addHeaders(InvokeEndpointResponse response, ThreadContext threadContext) { + if (response.customAttributes() != null) { + threadContext.addResponseHeader(CUSTOM_ATTRIBUTES_HEADER, response.customAttributes()); + } + if (response.newSessionId() != null) { + threadContext.addResponseHeader(NEW_SESSION_HEADER, response.newSessionId()); + } + if (response.closedSessionId() != null) { + threadContext.addResponseHeader(CLOSED_SESSION_HEADER, response.closedSessionId()); + } + } + + public Exception error(SageMakerModel model, Exception e) { + if (e instanceof ElasticsearchStatusException ee) { + return ee; + } + var error = errorMessageAndStatus(model, e); + return new ElasticsearchStatusException(error.v1(), error.v2(), e); + } + + /** + * Protected because {@link SageMakerStreamSchema} will reuse this to create a Chat Completion error message. + */ + protected Tuple errorMessageAndStatus(SageMakerModel model, Exception e) { + String errorMessage = null; + RestStatus restStatus = null; + if (e instanceof InternalDependencyException) { + errorMessage = format("Received an internal dependency error from SageMaker for [%s]", model.getInferenceEntityId()); + restStatus = RestStatus.INTERNAL_SERVER_ERROR; + } else if (e instanceof InternalFailureException) { + errorMessage = format("Received an internal failure from SageMaker for [%s]", model.getInferenceEntityId()); + restStatus = RestStatus.INTERNAL_SERVER_ERROR; + } else if (e instanceof ModelErrorException) { + errorMessage = format("Received a model error from SageMaker for [%s]", model.getInferenceEntityId()); + restStatus = RestStatus.FAILED_DEPENDENCY; + } else if (e instanceof ModelNotReadyException) { + errorMessage = format("Received a model not ready error from SageMaker for [%s]", model.getInferenceEntityId()); + restStatus = RestStatus.TOO_MANY_REQUESTS; + } else if (e instanceof ServiceUnavailableException) { + errorMessage = format("SageMaker is unavailable for [%s]", model.getInferenceEntityId()); + restStatus = RestStatus.SERVICE_UNAVAILABLE; + } else if (e instanceof ValidationErrorException) { + errorMessage = format("Received a validation error from SageMaker for [%s]", model.getInferenceEntityId()); + restStatus = RestStatus.BAD_REQUEST; + } + + // if we have a SageMakerRuntimeException that isn't one of the child exceptions, we can parse the error from AwsErrorDetails + // https://docs.aws.amazon.com/sagemaker/latest/APIReference/CommonErrors.html + if (errorMessage == null && e instanceof SageMakerRuntimeException re && re.awsErrorDetails() != null) { + switch (re.awsErrorDetails().errorCode()) { + case ACCESS_DENIED_CODE, NOT_AUTHORIZED -> { + errorMessage = format( + "Access and Secret key stored in [%s] do not have sufficient permissions.", + model.getInferenceEntityId() + ); + restStatus = RestStatus.BAD_REQUEST; + } + case INCOMPLETE_SIGNATURE -> { + errorMessage = format("The request signature does not conform to AWS standards [%s]", model.getInferenceEntityId()); + restStatus = RestStatus.INTERNAL_SERVER_ERROR; // this shouldn't happen and isn't anything the user can do about it + } + case INVALID_ACTION -> { + errorMessage = format("The requested action is not valid for [%s]", model.getInferenceEntityId()); + restStatus = RestStatus.BAD_REQUEST; + } + case INVALID_CLIENT_TOKEN -> { + errorMessage = format("Access key stored in [%s] does not exist in AWS", model.getInferenceEntityId()); + restStatus = RestStatus.FORBIDDEN; + } + case OPT_IN_REQUIRED -> { + errorMessage = format("Access key stored in [%s] needs a subscription for the service", model.getInferenceEntityId()); + restStatus = RestStatus.FORBIDDEN; + } + case REQUEST_EXPIRED -> { + errorMessage = format( + "The request reached SageMaker more than 15 minutes after the date stamp on the request for [%s]", + model.getInferenceEntityId() + ); + restStatus = RestStatus.BAD_REQUEST; + } + case THROTTLING_EXCEPTION -> { + errorMessage = format("SageMaker denied the request for [%s] due to request throttling", model.getInferenceEntityId()); + restStatus = RestStatus.BAD_REQUEST; + } + } + } + + if (errorMessage == null) { + errorMessage = format("Received an error from SageMaker for [%s]", model.getInferenceEntityId()); + restStatus = RestStatus.INTERNAL_SERVER_ERROR; + } + + return Tuple.tuple(errorMessage, restStatus); + } + + public SageMakerStoredServiceSchema apiServiceSettings(Map serviceSettings, ValidationException validationException) { + return schemaPayload.apiServiceSettings(serviceSettings, validationException); + } + + public SageMakerStoredTaskSchema apiTaskSettings(Map taskSettings, ValidationException validationException) { + return schemaPayload.apiTaskSettings(taskSettings, validationException); + } + + public Stream namedWriteables() { + return schemaPayload.namedWriteables(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemaPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemaPayload.java new file mode 100644 index 000000000000..146d27256a42 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemaPayload.java @@ -0,0 +1,98 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema; + +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest; +import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; + +import java.util.EnumSet; +import java.util.Map; +import java.util.stream.Stream; + +public interface SageMakerSchemaPayload { + + /** + * The model API keyword that users will supply in the service settings when creating the request. + * Automatically registered in {@link SageMakerSchemas}. + */ + String api(); + + /** + * The supported TaskTypes for this model API. + * Automatically registered in {@link SageMakerSchemas}. + */ + EnumSet supportedTasks(); + + /** + * Implement this if the model requires extra ServiceSettings that can be saved to the model index. + * This can be accessed via {@link SageMakerModel#apiServiceSettings()}. + */ + default SageMakerStoredServiceSchema apiServiceSettings(Map serviceSettings, ValidationException validationException) { + return SageMakerStoredServiceSchema.NO_OP; + } + + /** + * Implement this if the model requires extra TaskSettings that can be saved to the model index. + * This can be accessed via {@link SageMakerModel#apiTaskSettings()}. + */ + default SageMakerStoredTaskSchema apiTaskSettings(Map taskSettings, ValidationException validationException) { + return SageMakerStoredTaskSchema.NO_OP; + } + + /** + * This must be thrown if {@link SageMakerModel#apiServiceSettings()} or {@link SageMakerModel#apiTaskSettings()} return the wrong + * object types. + */ + default Exception createUnsupportedSchemaException(SageMakerModel model) { + return new IllegalArgumentException( + Strings.format( + "Unsupported SageMaker settings for api [%s] and task type [%s]: [%s] and [%s]", + model.api(), + model.getTaskType(), + model.apiServiceSettings().getWriteableName(), + model.apiTaskSettings().getWriteableName() + ) + ); + } + + /** + * Automatically register the required registry entries with {@link org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider}. + */ + default Stream namedWriteables() { + return Stream.of(); + } + + /** + * The MIME type of the response from SageMaker. + */ + String accept(SageMakerModel model); + + /** + * The MIME type of the request to SageMaker. + */ + String contentType(SageMakerModel model); + + /** + * Translate to the body of the request in the MIME type specified by {@link #contentType(SageMakerModel)}. + */ + SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest request) throws Exception; + + /** + * Translate from the body of the response in the MIME type specified by {@link #accept(SageMakerModel)}. + */ + InferenceServiceResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception; + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemas.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemas.java new file mode 100644 index 000000000000..cf3a17a7ae70 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemas.java @@ -0,0 +1,137 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiTextEmbeddingPayload; + +import java.util.Arrays; +import java.util.EnumSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.elasticsearch.core.Strings.format; + +/** + * The mapping and registry for all supported model API. + */ +public class SageMakerSchemas { + private static final Map schemas; + private static final Map streamSchemas; + private static final Map> tasksByApi; + private static final Map> streamingTasksByApi; + private static final Set supportedStreamingTasks; + private static final EnumSet supportedTaskTypes; + + static { + /* + * Add new model API to the register call. + */ + schemas = register(new OpenAiTextEmbeddingPayload()); + + streamSchemas = schemas.entrySet() + .stream() + .filter(e -> e.getValue() instanceof SageMakerStreamSchema) + .collect(Collectors.toMap(Map.Entry::getKey, e -> (SageMakerStreamSchema) e.getValue())); + + tasksByApi = schemas.keySet() + .stream() + .collect(Collectors.groupingBy(TaskAndApi::api, Collectors.mapping(TaskAndApi::taskType, Collectors.toSet()))); + streamingTasksByApi = streamSchemas.keySet() + .stream() + .collect(Collectors.groupingBy(TaskAndApi::api, Collectors.mapping(TaskAndApi::taskType, Collectors.toSet()))); + + supportedStreamingTasks = streamSchemas.keySet().stream().map(TaskAndApi::taskType).collect(Collectors.toSet()); + supportedTaskTypes = EnumSet.copyOf(schemas.keySet().stream().map(TaskAndApi::taskType).collect(Collectors.toSet())); + } + + private static Map register(SageMakerSchemaPayload... payloads) { + return Arrays.stream(payloads).flatMap(payload -> payload.supportedTasks().stream().map(taskType -> { + var key = new TaskAndApi(taskType, payload.api()); + SageMakerSchema value; + if (payload instanceof SageMakerStreamSchemaPayload streamPayload) { + value = new SageMakerStreamSchema(streamPayload); + } else { + value = new SageMakerSchema(payload); + } + return Map.entry(key, value); + })).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } + + /** + * Automatically register the stored Schema writeables with {@link org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider}. + */ + public static List namedWriteables() { + return Stream.concat( + Stream.of( + new NamedWriteableRegistry.Entry( + SageMakerStoredServiceSchema.class, + SageMakerStoredServiceSchema.NO_OP.getWriteableName(), + in -> SageMakerStoredServiceSchema.NO_OP + ), + new NamedWriteableRegistry.Entry( + SageMakerStoredTaskSchema.class, + SageMakerStoredTaskSchema.NO_OP.getWriteableName(), + in -> SageMakerStoredTaskSchema.NO_OP + ) + ), + schemas.values().stream().flatMap(SageMakerSchema::namedWriteables) + ).toList(); + } + + public SageMakerSchema schemaFor(SageMakerModel model) throws ElasticsearchStatusException { + return schemaFor(model.getTaskType(), model.api()); + } + + public SageMakerSchema schemaFor(TaskType taskType, String api) throws ElasticsearchStatusException { + var schema = schemas.get(new TaskAndApi(taskType, api)); + if (schema == null) { + throw new ElasticsearchStatusException( + format( + "Task [%s] is not compatible for service [sagemaker] and api [%s]. Supported tasks: [%s]", + api, + taskType.toString(), + tasksByApi.getOrDefault(api, Set.of()) + ), + RestStatus.METHOD_NOT_ALLOWED + ); + } + return schema; + } + + public SageMakerStreamSchema streamSchemaFor(SageMakerModel model) throws ElasticsearchStatusException { + var schema = streamSchemas.get(new TaskAndApi(model.getTaskType(), model.api())); + if (schema == null) { + throw new ElasticsearchStatusException( + format( + "Streaming is not allowed for service [sagemaker], api [%s], and task [%s]. Supported streaming tasks: [%s]", + model.api(), + model.getTaskType().toString(), + streamingTasksByApi.getOrDefault(model.api(), Set.of()) + ), + RestStatus.METHOD_NOT_ALLOWED + ); + } + return schema; + } + + public EnumSet supportedTaskTypes() { + return supportedTaskTypes; + } + + public Set supportedStreamingTasks() { + return supportedStreamingTasks; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredServiceSchema.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredServiceSchema.java new file mode 100644 index 000000000000..9fb320a2d364 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredServiceSchema.java @@ -0,0 +1,75 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +/** + * Contains any model-specific settings that are stored in SageMakerServiceSettings. + */ +public interface SageMakerStoredServiceSchema extends ServiceSettings { + SageMakerStoredServiceSchema NO_OP = new SageMakerStoredServiceSchema() { + + private static final String NAME = "noop_sagemaker_service_schema"; + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_SAGEMAKER; + } + + @Override + public void writeTo(StreamOutput out) { + + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) { + return builder; + } + }; + + /** + * The model is part of the SageMaker Endpoint definition and is not declared in the service settings. + */ + @Override + default String modelId() { + return null; + } + + @Override + default ToXContentObject getFilteredXContentObject() { + return this; + } + + /** + * These extra service settings serialize flatly alongside the overall SageMaker ServiceSettings. + */ + @Override + default boolean isFragment() { + return true; + } + + /** + * If this Schema supports Text Embeddings, then we need to implement this. + * {@link org.elasticsearch.xpack.inference.services.validation.TextEmbeddingModelValidator} will set the dimensions if the user + * does not do it, so we need to store the dimensions and flip the {@link #dimensionsSetByUser()} boolean. + */ + default SageMakerStoredServiceSchema updateModelWithEmbeddingDetails(Integer dimensions) { + return this; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredTaskSchema.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredTaskSchema.java new file mode 100644 index 000000000000..2aa2f9556d41 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredTaskSchema.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.util.Map; + +/** + * Contains any model-specific settings that are stored in SageMakerTaskSettings. + */ +public interface SageMakerStoredTaskSchema extends TaskSettings { + SageMakerStoredTaskSchema NO_OP = new SageMakerStoredTaskSchema() { + @Override + public boolean isEmpty() { + return true; + } + + @Override + public SageMakerStoredTaskSchema updatedTaskSettings(Map newSettings) { + return this; + } + + private static final String NAME = "noop_sagemaker_task_schema"; + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_SAGEMAKER; + } + + @Override + public void writeTo(StreamOutput out) {} + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) { + return builder; + } + }; + + /** + * These extra service settings serialize flatly alongside the overall SageMaker ServiceSettings. + */ + @Override + default boolean isFragment() { + return true; + } + + @Override + SageMakerStoredTaskSchema updatedTaskSettings(Map newSettings); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStreamSchema.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStreamSchema.java new file mode 100644 index 000000000000..1eb84ecede37 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStreamSchema.java @@ -0,0 +1,152 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema; + +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamRequest; +import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponseHandler; +import software.amazon.awssdk.services.sagemakerruntime.model.ResponseStream; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.common.CheckedBiFunction; +import org.elasticsearch.common.CheckedSupplier; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient; +import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest; +import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; + +import java.util.Locale; +import java.util.concurrent.Flow; +import java.util.function.BiFunction; + +/** + * All the logic that is required to call any SageMaker model is handled within this Schema class. + * Any model-specific logic is handled within the associated {@link SageMakerStreamSchemaPayload}. + * This schema is specific for SageMaker's streaming API. For non-streaming, see {@link SageMakerSchema}. + */ +public class SageMakerStreamSchema extends SageMakerSchema { + + private final SageMakerStreamSchemaPayload payload; + + public SageMakerStreamSchema(SageMakerStreamSchemaPayload payload) { + super(payload); + this.payload = payload; + } + + public InvokeEndpointWithResponseStreamRequest streamRequest(SageMakerModel model, SageMakerInferenceRequest request) { + return streamRequest(model, () -> payload.requestBytes(model, request)); + } + + private InvokeEndpointWithResponseStreamRequest streamRequest(SageMakerModel model, CheckedSupplier body) { + try { + return createStreamRequest(model).accept(payload.accept(model)) + .contentType(payload.contentType(model)) + .body(body.get()) + .build(); + } catch (ElasticsearchStatusException e) { + throw e; + } catch (Exception e) { + throw new ElasticsearchStatusException( + "Failed to create SageMaker request for [%s]", + RestStatus.INTERNAL_SERVER_ERROR, + e, + model.getInferenceEntityId() + ); + } + } + + public InferenceServiceResults streamResponse(SageMakerModel model, SageMakerClient.SageMakerStream response) { + return streamResponse(model, response, payload::streamResponseBody, this::error); + } + + private InferenceServiceResults streamResponse( + SageMakerModel model, + SageMakerClient.SageMakerStream response, + CheckedBiFunction parseFunction, + BiFunction errorFunction + ) { + return new StreamingChatCompletionResults(downstream -> { + response.responseStream().subscribe(new Flow.Subscriber<>() { + private volatile Flow.Subscription upstream; + + @Override + public void onSubscribe(Flow.Subscription subscription) { + this.upstream = subscription; + downstream.onSubscribe(subscription); + } + + @Override + public void onNext(ResponseStream item) { + if (item.sdkEventType() == ResponseStream.EventType.PAYLOAD_PART) { + item.accept(InvokeEndpointWithResponseStreamResponseHandler.Visitor.builder().onPayloadPart(payloadPart -> { + try { + downstream.onNext(parseFunction.apply(model, payloadPart.bytes())); + } catch (Exception e) { + downstream.onError(errorFunction.apply(model, e)); + } + }).build()); + } else { + assert upstream != null : "upstream is unset"; + upstream.request(1); + } + } + + @Override + public void onError(Throwable throwable) { + if (throwable instanceof Exception e) { + downstream.onError(errorFunction.apply(model, e)); + } else { + ExceptionsHelper.maybeError(throwable).ifPresent(ExceptionsHelper::maybeDieOnAnotherThread); + var e = new RuntimeException("Fatal while streaming SageMaker response for [" + model.getInferenceEntityId() + "]"); + downstream.onError(errorFunction.apply(model, e)); + } + + } + + @Override + public void onComplete() { + downstream.onComplete(); + } + }); + }); + } + + public InvokeEndpointWithResponseStreamRequest chatCompletionStreamRequest(SageMakerModel model, UnifiedCompletionRequest request) { + return streamRequest(model, () -> payload.chatCompletionRequestBytes(model, request)); + } + + public InferenceServiceResults chatCompletionStreamResponse(SageMakerModel model, SageMakerClient.SageMakerStream response) { + return streamResponse(model, response, payload::chatCompletionResponseBody, this::chatCompletionError); + } + + public UnifiedChatCompletionException chatCompletionError(SageMakerModel model, Exception e) { + if (e instanceof UnifiedChatCompletionException ucce) { + return ucce; + } + + var error = errorMessageAndStatus(model, e); + return new UnifiedChatCompletionException(error.v2(), error.v1(), "error", error.v2().name().toLowerCase(Locale.ROOT)); + } + + private InvokeEndpointWithResponseStreamRequest.Builder createStreamRequest(SageMakerModel model) { + var request = InvokeEndpointWithResponseStreamRequest.builder(); + request.endpointName(model.endpointName()); + model.customAttributes().ifPresent(request::customAttributes); + model.inferenceComponentName().ifPresent(request::inferenceComponentName); + model.inferenceIdForDataCapture().ifPresent(request::inferenceId); + model.sessionId().ifPresent(request::sessionId); + model.targetContainerHostname().ifPresent(request::targetContainerHostname); + model.targetVariant().ifPresent(request::targetVariant); + return request; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStreamSchemaPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStreamSchemaPayload.java new file mode 100644 index 000000000000..7867e16b8773 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStreamSchemaPayload.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema; + +import software.amazon.awssdk.core.SdkBytes; + +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest; +import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; + +import java.util.EnumSet; + +/** + * Implemented for models that support streaming. + * This is an extension of {@link SageMakerSchemaPayload} because Elastic expects Completion tasks to handle both streaming and + * non-streaming, and all models currently support toggling streaming on/off. + */ +public interface SageMakerStreamSchemaPayload extends SageMakerSchemaPayload { + /** + * We currently only support streaming for Completion and Chat Completion, and if we are going to implement one then we should implement + * the other, so this interface requires both streaming input and streaming unified input. + * If we ever allowed streaming for more than just Completion, then we'd probably break up this class so that Unified Chat Completion + * was its own interface. + */ + @Override + default EnumSet supportedTasks() { + return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); + } + + /** + * This API would only be called for Completion task types. {@link #requestBytes(SageMakerModel, SageMakerInferenceRequest)} would + * handle the request translation for both streaming and non-streaming. + */ + InferenceServiceResults.Result streamResponseBody(SageMakerModel model, SdkBytes response) throws Exception; + + SdkBytes chatCompletionRequestBytes(SageMakerModel model, UnifiedCompletionRequest request) throws Exception; + + InferenceServiceResults.Result chatCompletionResponseBody(SageMakerModel model, SdkBytes response) throws Exception; +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/TaskAndApi.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/TaskAndApi.java new file mode 100644 index 000000000000..fa258c327528 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/TaskAndApi.java @@ -0,0 +1,12 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema; + +import org.elasticsearch.inference.TaskType; + +record TaskAndApi(TaskType taskType, String api) {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java new file mode 100644 index 000000000000..7bd122a5922e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java @@ -0,0 +1,229 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema.openai; + +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.XContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiEmbeddingsResponseEntity; +import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest; +import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemaPayload; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.Map; +import java.util.stream.Stream; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; + +public class OpenAiTextEmbeddingPayload implements SageMakerSchemaPayload { + + private static final XContent jsonXContent = JsonXContent.jsonXContent; + private static final String APPLICATION_JSON = jsonXContent.type().mediaTypeWithoutParameters(); + + @Override + public String api() { + return "openai"; + } + + @Override + public EnumSet supportedTasks() { + return EnumSet.of(TaskType.TEXT_EMBEDDING); + } + + @Override + public SageMakerStoredServiceSchema apiServiceSettings(Map serviceSettings, ValidationException validationException) { + return ApiServiceSettings.fromMap(serviceSettings, validationException); + } + + @Override + public SageMakerStoredTaskSchema apiTaskSettings(Map taskSettings, ValidationException validationException) { + return ApiTaskSettings.fromMap(taskSettings, validationException); + } + + @Override + public Stream namedWriteables() { + return Stream.of( + new NamedWriteableRegistry.Entry(SageMakerStoredServiceSchema.class, ApiServiceSettings.NAME, ApiServiceSettings::new), + new NamedWriteableRegistry.Entry(SageMakerStoredTaskSchema.class, ApiTaskSettings.NAME, ApiTaskSettings::new) + ); + } + + @Override + public String accept(SageMakerModel model) { + return APPLICATION_JSON; + } + + @Override + public String contentType(SageMakerModel model) { + return APPLICATION_JSON; + } + + @Override + public SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest request) throws Exception { + if (model.apiServiceSettings() instanceof ApiServiceSettings apiServiceSettings + && model.apiTaskSettings() instanceof ApiTaskSettings apiTaskSettings) { + try (var builder = JsonXContent.contentBuilder()) { + builder.startObject(); + if (request.query() != null) { + builder.field("query", request.query()); + } + if (request.input().size() == 1) { + builder.field("input", request.input().get(0)); + } else { + builder.field("input", request.input()); + } + if (apiTaskSettings.user() != null) { + builder.field("user", apiTaskSettings.user()); + } + if (apiServiceSettings.dimensionsSetByUser() && apiServiceSettings.dimensions() != null) { + builder.field("dimensions", apiServiceSettings.dimensions()); + } + builder.endObject(); + return SdkBytes.fromUtf8String(Strings.toString(builder)); + } + } else { + throw createUnsupportedSchemaException(model); + } + } + + @Override + public TextEmbeddingFloatResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception { + try (var p = jsonXContent.createParser(XContentParserConfiguration.EMPTY, response.body().asInputStream())) { + return OpenAiEmbeddingsResponseEntity.EmbeddingFloatResult.PARSER.apply(p, null).toTextEmbeddingFloatResults(); + } + } + + record ApiServiceSettings(@Nullable Integer dimensions, Boolean dimensionsSetByUser) implements SageMakerStoredServiceSchema { + private static final String NAME = "sagemaker_openai_text_embeddings_service_settings"; + private static final String DIMENSIONS_FIELD = "dimensions"; + + ApiServiceSettings(StreamInput in) throws IOException { + this(in.readOptionalInt(), in.readBoolean()); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_SAGEMAKER; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalInt(dimensions); + out.writeBoolean(dimensionsSetByUser); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + if (dimensions != null) { + builder.field(DIMENSIONS_FIELD, dimensions); + } + return builder; + } + + static ApiServiceSettings fromMap(Map serviceSettings, ValidationException validationException) { + var dimensions = extractOptionalPositiveInteger( + serviceSettings, + DIMENSIONS_FIELD, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + + return new ApiServiceSettings(dimensions, dimensions != null); + } + + @Override + public SimilarityMeasure similarity() { + return SimilarityMeasure.DOT_PRODUCT; + } + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.FLOAT; + } + + @Override + public SageMakerStoredServiceSchema updateModelWithEmbeddingDetails(Integer dimensions) { + return new ApiServiceSettings(dimensions, false); + } + } + + record ApiTaskSettings(@Nullable String user) implements SageMakerStoredTaskSchema { + private static final String NAME = "sagemaker_openai_text_embeddings_task_settings"; + private static final String USER_FIELD = "user"; + + ApiTaskSettings(StreamInput in) throws IOException { + this(in.readOptionalString()); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_SAGEMAKER; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(user); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return user != null ? builder.field(USER_FIELD, user) : builder; + } + + @Override + public boolean isEmpty() { + return user == null; + } + + @Override + public ApiTaskSettings updatedTaskSettings(Map newSettings) { + var validationException = new ValidationException(); + var newTaskSettings = fromMap(newSettings, validationException); + validationException.throwIfValidationErrorsExist(); + + return new ApiTaskSettings(newTaskSettings.user() != null ? newTaskSettings.user() : user); + } + + static ApiTaskSettings fromMap(Map map, ValidationException exception) { + var user = extractOptionalString(map, USER_FIELD, ModelConfigurations.TASK_SETTINGS, exception); + return new ApiTaskSettings(user); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/InferenceSettingsTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/InferenceSettingsTestCase.java new file mode 100644 index 000000000000..806e0389533f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/InferenceSettingsTestCase.java @@ -0,0 +1,101 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.equalTo; + +/** + * All Inference Setting classes derived from {@link org.elasticsearch.inference.ServiceSettings}, + * {@link org.elasticsearch.inference.TaskSettings}, and {@link org.elasticsearch.inference.SecretSettings} + * must be able to read/write to an index via ToXContent as well as read/write between nodes via Writeable. + */ +public abstract class InferenceSettingsTestCase extends AbstractBWCWireSerializationTestCase { + + /** + * This method is final because {@link org.elasticsearch.inference.ModelConfigurations} settings must be registered in + * {@link InferenceNamedWriteablesProvider}. + */ + @Override + protected final NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(InferenceNamedWriteablesProvider.getNamedWriteables()); + } + + /** + * Helper implementation since most mutates can be handled by continuously randomizing until we have another test instance. + */ + @Override + protected T mutateInstance(T instance) throws IOException { + return randomValueOtherThan(instance, this::createTestInstance); + } + + /** + * Change this for BWC once the implementation requires difference objects depending on the transport version. + */ + @Override + protected T mutateInstanceForVersion(T instance, TransportVersion version) { + return instance; + } + + /** + * Verify that we can write to XContent and then read from XContent. + * This simulates saving the model to the index and then reading the model from the index. + */ + public final void testXContentRoundTrip() throws IOException { + var instance = createTestInstance(); + var instanceAsMap = toMap(instance); + var roundTripInstance = fromMutableMap(new HashMap<>(instanceAsMap)); + assertThat(roundTripInstance, equalTo(instance)); + } + + protected abstract T fromMutableMap(Map mutableMap); + + public static Map toMap(ToXContent instance) throws IOException { + try (var builder = JsonXContent.contentBuilder()) { + if (instance.isFragment()) { + builder.startObject(); + } + instance.toXContent(builder, ToXContent.EMPTY_PARAMS); + if (instance.isFragment()) { + builder.endObject(); + } + var taskSettingsBytes = Strings.toString(builder).getBytes(StandardCharsets.UTF_8); + try (var parser = JsonXContent.jsonXContent.createParser(XContentParserConfiguration.EMPTY, taskSettingsBytes)) { + return parser.map(); + } + } + } + + /** + * Helper, since most settings contain optional strings. + */ + protected static String randomOptionalString() { + return randomBoolean() ? randomString() : null; + } + + /** + * Helper, since most settings contain strings. + */ + protected static String randomString() { + return randomAlphaOfLength(randomIntBetween(4, 8)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerClientTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerClientTests.java new file mode 100644 index 000000000000..ff70e8adf8aa --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerClientTests.java @@ -0,0 +1,180 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker; + +import software.amazon.awssdk.core.async.SdkPublisher; +import software.amazon.awssdk.services.sagemakerruntime.SageMakerRuntimeAsyncClient; +import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointRequest; +import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse; +import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamRequest; +import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponse; +import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponseHandler; +import software.amazon.awssdk.services.sagemakerruntime.model.ResponseStream; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.cache.CacheLoader; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings; +import org.junit.After; +import org.junit.Before; +import org.mockito.ArgumentMatchers; +import org.reactivestreams.Subscriber; + +import java.io.IOException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.assertArg; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class SageMakerClientTests extends ESTestCase { + private static final InvokeEndpointRequest REQUEST = InvokeEndpointRequest.builder().build(); + private static final InvokeEndpointWithResponseStreamRequest STREAM_REQUEST = InvokeEndpointWithResponseStreamRequest.builder().build(); + private SageMakerRuntimeAsyncClient awsClient; + private CacheLoader clientFactory; + private SageMakerClient client; + private ThreadPool threadPool; + + @Before + public void setUp() throws Exception { + super.setUp(); + threadPool = createThreadPool(inferenceUtilityPool()); + + awsClient = mock(); + clientFactory = spy(new CacheLoader<>() { + public SageMakerRuntimeAsyncClient load(SageMakerClient.RegionAndSecrets key) { + return awsClient; + } + }); + client = new SageMakerClient(clientFactory, threadPool); + } + + @After + public void shutdown() throws IOException { + terminate(threadPool); + } + + public void testInvoke() throws Exception { + var expectedResponse = InvokeEndpointResponse.builder().build(); + when(awsClient.invokeEndpoint(any(InvokeEndpointRequest.class))).thenReturn(CompletableFuture.completedFuture(expectedResponse)); + + var listener = invoke(TimeValue.THIRTY_SECONDS); + verify(clientFactory, times(1)).load(any()); + verify(listener, times(1)).onResponse(eq(expectedResponse)); + } + + private static SageMakerClient.RegionAndSecrets regionAndSecrets() { + return new SageMakerClient.RegionAndSecrets( + "us-east-1", + new AwsSecretSettings(new SecureString("access"), new SecureString("secrets")) + ); + } + + private ActionListener invoke(TimeValue timeout) throws InterruptedException { + var latch = new CountDownLatch(1); + ActionListener listener = spy(ActionListener.noop()); + client.invoke(regionAndSecrets(), REQUEST, timeout, ActionListener.runAfter(listener, latch::countDown)); + assertTrue("Timed out waiting for invoke call", latch.await(5, TimeUnit.SECONDS)); + return listener; + } + + public void testInvokeCache() throws Exception { + when(awsClient.invokeEndpoint(any(InvokeEndpointRequest.class))).thenReturn( + CompletableFuture.completedFuture(InvokeEndpointResponse.builder().build()) + ); + + invoke(TimeValue.THIRTY_SECONDS); + invoke(TimeValue.THIRTY_SECONDS); + verify(clientFactory, times(1)).load(any()); + } + + public void testInvokeTimeout() throws Exception { + when(awsClient.invokeEndpoint(any(InvokeEndpointRequest.class))).thenReturn(new CompletableFuture<>()); + + var listener = invoke(TimeValue.timeValueMillis(10)); + + verify(clientFactory, times(1)).load(any()); + verifyTimeout(listener); + } + + private static void verifyTimeout(ActionListener listener) { + verify(listener, times(1)).onFailure(assertArg(e -> assertThat(e.getMessage(), equalTo("Request timed out after [10ms]")))); + } + + public void testInvokeStream() throws Exception { + SdkPublisher publisher = mockPublisher(); + + var listener = invokeStream(TimeValue.THIRTY_SECONDS); + + verify(publisher, never()).subscribe(ArgumentMatchers.>any()); + verify(listener, times(1)).onResponse(assertArg(stream -> stream.responseStream().subscribe(mock()))); + verify(publisher, times(1)).subscribe(ArgumentMatchers.>any()); + } + + private SdkPublisher mockPublisher() { + SdkPublisher publisher = mock(); + doAnswer(ans -> { + InvokeEndpointWithResponseStreamResponseHandler handler = ans.getArgument(1); + handler.responseReceived(InvokeEndpointWithResponseStreamResponse.builder().build()); + handler.onEventStream(publisher); + return CompletableFuture.completedFuture((Void) null); + }).when(awsClient).invokeEndpointWithResponseStream(any(InvokeEndpointWithResponseStreamRequest.class), any()); + return publisher; + } + + private ActionListener invokeStream(TimeValue timeout) throws Exception { + var latch = new CountDownLatch(1); + ActionListener listener = spy(ActionListener.noop()); + client.invokeStream(regionAndSecrets(), STREAM_REQUEST, timeout, ActionListener.runAfter(listener, latch::countDown)); + assertTrue("Timed out waiting for invoke call", latch.await(5, TimeUnit.SECONDS)); + return listener; + } + + public void testInvokeStreamCache() throws Exception { + mockPublisher(); + + invokeStream(TimeValue.THIRTY_SECONDS); + invokeStream(TimeValue.THIRTY_SECONDS); + + verify(clientFactory, times(1)).load(any()); + } + + public void testInvokeStreamTimeout() throws Exception { + when(awsClient.invokeEndpointWithResponseStream(any(InvokeEndpointWithResponseStreamRequest.class), any())).thenReturn( + new CompletableFuture<>() + ); + + var listener = invokeStream(TimeValue.timeValueMillis(10)); + + verify(clientFactory, times(1)).load(any()); + verifyTimeout(listener); + } + + public void testClose() throws Exception { + // load cache + mockPublisher(); + invokeStream(TimeValue.THIRTY_SECONDS); + // clear cache + client.close(); + verify(awsClient, times(1)).close(); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java new file mode 100644 index 000000000000..d7d9473f1808 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerServiceTests.java @@ -0,0 +1,463 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker; + +import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse; +import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponse; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests; +import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; +import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings; +import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; +import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchema; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStreamSchema; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Stream; + +import static org.elasticsearch.action.ActionListener.assertOnce; +import static org.elasticsearch.action.support.ActionTestUtils.assertNoFailureListener; +import static org.elasticsearch.action.support.ActionTestUtils.assertNoSuccessListener; +import static org.elasticsearch.core.TimeValue.THIRTY_SECONDS; +import static org.elasticsearch.xpack.core.inference.action.UnifiedCompletionRequestTests.randomUnifiedCompletionRequest; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.in; +import static org.hamcrest.Matchers.isA; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.assertArg; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.only; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class SageMakerServiceTests extends ESTestCase { + + private static final String QUERY = "query"; + private static final List INPUT = List.of("input"); + private static final InputType INPUT_TYPE = InputType.UNSPECIFIED; + + private SageMakerModelBuilder modelBuilder; + private SageMakerClient client; + private SageMakerSchemas schemas; + private SageMakerService sageMakerService; + + @Before + public void init() { + modelBuilder = mock(); + client = mock(); + schemas = mock(); + ThreadPool threadPool = mock(); + when(threadPool.executor(anyString())).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE); + when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); + sageMakerService = new SageMakerService(modelBuilder, client, schemas, threadPool, Map::of); + } + + public void testSupportedTaskTypes() { + sageMakerService.supportedTaskTypes(); + verify(schemas, only()).supportedTaskTypes(); + } + + public void testSupportedStreamingTasks() { + sageMakerService.supportedStreamingTasks(); + verify(schemas, only()).supportedStreamingTasks(); + } + + public void testParseRequestConfig() { + sageMakerService.parseRequestConfig("modelId", TaskType.ANY, Map.of(), assertNoFailureListener(model -> { + verify(modelBuilder, only()).fromRequest(eq("modelId"), eq(TaskType.ANY), eq(SageMakerService.NAME), eq(Map.of())); + })); + } + + public void testParsePersistedConfigWithSecrets() { + sageMakerService.parsePersistedConfigWithSecrets("modelId", TaskType.ANY, Map.of(), Map.of()); + verify(modelBuilder, only()).fromStorage(eq("modelId"), eq(TaskType.ANY), eq(SageMakerService.NAME), eq(Map.of()), eq(Map.of())); + } + + public void testParsePersistedConfig() { + sageMakerService.parsePersistedConfig("modelId", TaskType.ANY, Map.of()); + verify(modelBuilder, only()).fromStorage(eq("modelId"), eq(TaskType.ANY), eq(SageMakerService.NAME), eq(Map.of()), eq(null)); + } + + public void testInferWithWrongModel() { + sageMakerService.infer( + mockUnsupportedModel(), + QUERY, + false, + null, + INPUT, + false, + null, + INPUT_TYPE, + THIRTY_SECONDS, + assertUnsupportedModel() + ); + } + + private static Model mockUnsupportedModel() { + Model model = mock(); + ModelConfigurations modelConfigurations = mock(); + when(modelConfigurations.getService()).thenReturn("mockService"); + when(modelConfigurations.getInferenceEntityId()).thenReturn("mockInferenceId"); + when(model.getConfigurations()).thenReturn(modelConfigurations); + return model; + } + + private static ActionListener assertUnsupportedModel() { + return assertNoSuccessListener(e -> { + assertThat(e, isA(ElasticsearchStatusException.class)); + assertThat( + e.getMessage(), + equalTo( + "The internal model was invalid, please delete the service [mockService] " + + "with id [mockInferenceId] and add it again." + ) + ); + assertThat(((ElasticsearchStatusException) e).status(), equalTo(RestStatus.INTERNAL_SERVER_ERROR)); + }); + } + + public void testInfer() { + var model = mockModel(); + + SageMakerSchema schema = mock(); + when(schemas.schemaFor(model)).thenReturn(schema); + mockInvoke(); + + sageMakerService.infer( + model, + QUERY, + null, + null, + INPUT, + false, + null, + INPUT_TYPE, + THIRTY_SECONDS, + assertNoFailureListener(ignored -> { + verify(schemas, only()).schemaFor(eq(model)); + verify(schema, times(1)).request(eq(model), assertRequest()); + verify(schema, times(1)).response(eq(model), any(), any()); + }) + ); + verify(client, only()).invoke(any(), any(), any(), any()); + verifyNoMoreInteractions(client, schemas, schema); + } + + private SageMakerModel mockModel() { + SageMakerModel model = mock(); + when(model.override(null)).thenReturn(model); + when(model.awsSecretSettings()).thenReturn( + Optional.of(new AwsSecretSettings(new SecureString("test-accessKey"), new SecureString("test-secretKey"))) + ); + return model; + } + + private void mockInvoke() { + doAnswer(ans -> { + ActionListener responseListener = ans.getArgument(3); + responseListener.onResponse(InvokeEndpointResponse.builder().build()); + return null; // Void + }).when(client).invoke(any(), any(), any(), any()); + } + + private static SageMakerInferenceRequest assertRequest() { + return assertArg(request -> { + assertThat(request.query(), equalTo(QUERY)); + assertThat(request.input(), containsInAnyOrder(INPUT.toArray())); + assertThat(request.inputType(), equalTo(InputType.UNSPECIFIED)); + assertNull(request.returnDocuments()); + assertNull(request.topN()); + }); + } + + public void testInferStream() { + SageMakerModel model = mockModel(); + + SageMakerStreamSchema schema = mock(); + when(schemas.streamSchemaFor(model)).thenReturn(schema); + mockInvokeStream(); + + sageMakerService.infer(model, QUERY, null, null, INPUT, true, null, INPUT_TYPE, THIRTY_SECONDS, assertNoFailureListener(ignored -> { + verify(schemas, only()).streamSchemaFor(eq(model)); + verify(schema, times(1)).streamRequest(eq(model), assertRequest()); + verify(schema, times(1)).streamResponse(eq(model), any()); + })); + verify(client, only()).invokeStream(any(), any(), any(), any()); + verifyNoMoreInteractions(client, schemas, schema); + } + + private void mockInvokeStream() { + doAnswer(ans -> { + ActionListener responseListener = ans.getArgument(3); + responseListener.onResponse( + new SageMakerClient.SageMakerStream(InvokeEndpointWithResponseStreamResponse.builder().build(), mock()) + ); + return null; // Void + }).when(client).invokeStream(any(), any(), any(), any()); + } + + public void testInferError() { + SageMakerModel model = mockModel(); + + var expectedException = new IllegalArgumentException("hola"); + SageMakerSchema schema = mock(); + when(schemas.schemaFor(model)).thenReturn(schema); + mockInvokeFailure(expectedException); + + sageMakerService.infer( + model, + QUERY, + null, + null, + INPUT, + false, + null, + INPUT_TYPE, + THIRTY_SECONDS, + assertNoSuccessListener(ignored -> { + verify(schemas, only()).schemaFor(eq(model)); + verify(schema, times(1)).request(eq(model), assertRequest()); + verify(schema, times(1)).error(eq(model), assertArg(e -> assertThat(e, equalTo(expectedException)))); + }) + ); + verify(client, only()).invoke(any(), any(), any(), any()); + verifyNoMoreInteractions(client, schemas, schema); + } + + private void mockInvokeFailure(Exception e) { + doAnswer(ans -> { + ActionListener responseListener = ans.getArgument(3); + responseListener.onFailure(e); + return null; // Void + }).when(client).invoke(any(), any(), any(), any()); + } + + public void testInferException() { + SageMakerModel model = mockModel(); + when(model.getInferenceEntityId()).thenReturn("some id"); + + SageMakerStreamSchema schema = mock(); + when(schemas.streamSchemaFor(model)).thenReturn(schema); + doThrow(new IllegalArgumentException("wow, really?")).when(client).invokeStream(any(), any(), any(), any()); + + sageMakerService.infer(model, QUERY, null, null, INPUT, true, null, INPUT_TYPE, THIRTY_SECONDS, assertNoSuccessListener(e -> { + verify(schemas, only()).streamSchemaFor(eq(model)); + verify(schema, times(1)).streamRequest(eq(model), assertRequest()); + assertThat(e, isA(ElasticsearchStatusException.class)); + assertThat(((ElasticsearchStatusException) e).status(), equalTo(RestStatus.INTERNAL_SERVER_ERROR)); + assertThat(e.getMessage(), equalTo("Failed to call SageMaker for inference id [some id].")); + })); + verify(client, only()).invokeStream(any(), any(), any(), any()); + verifyNoMoreInteractions(client, schemas, schema); + } + + public void testUnifiedInferWithWrongModel() { + sageMakerService.unifiedCompletionInfer( + mockUnsupportedModel(), + randomUnifiedCompletionRequest(), + THIRTY_SECONDS, + assertUnsupportedModel() + ); + } + + public void testUnifiedInfer() { + var model = mockModel(); + + SageMakerStreamSchema schema = mock(); + when(schemas.streamSchemaFor(model)).thenReturn(schema); + mockInvokeStream(); + + sageMakerService.unifiedCompletionInfer( + model, + randomUnifiedCompletionRequest(), + THIRTY_SECONDS, + assertNoFailureListener(ignored -> { + verify(schemas, only()).streamSchemaFor(eq(model)); + verify(schema, times(1)).chatCompletionStreamRequest(eq(model), any()); + verify(schema, times(1)).chatCompletionStreamResponse(eq(model), any()); + }) + ); + verify(client, only()).invokeStream(any(), any(), any(), any()); + verifyNoMoreInteractions(client, schemas, schema); + } + + public void testUnifiedInferError() { + var model = mockModel(); + + var expectedException = new IllegalArgumentException("hola"); + SageMakerStreamSchema schema = mock(); + when(schemas.streamSchemaFor(model)).thenReturn(schema); + mockInvokeStreamFailure(expectedException); + + sageMakerService.unifiedCompletionInfer( + model, + randomUnifiedCompletionRequest(), + THIRTY_SECONDS, + assertNoSuccessListener(ignored -> { + verify(schemas, only()).streamSchemaFor(eq(model)); + verify(schema, times(1)).chatCompletionStreamRequest(eq(model), any()); + verify(schema, times(1)).chatCompletionError(eq(model), assertArg(e -> assertThat(e, equalTo(expectedException)))); + }) + ); + verify(client, only()).invokeStream(any(), any(), any(), any()); + verifyNoMoreInteractions(client, schemas, schema); + } + + private void mockInvokeStreamFailure(Exception e) { + doAnswer(ans -> { + ActionListener responseListener = ans.getArgument(3); + responseListener.onFailure(e); + return null; // Void + }).when(client).invokeStream(any(), any(), any(), any()); + } + + public void testUnifiedInferException() { + SageMakerModel model = mockModel(); + when(model.getInferenceEntityId()).thenReturn("some id"); + + SageMakerStreamSchema schema = mock(); + when(schemas.streamSchemaFor(model)).thenReturn(schema); + doThrow(new IllegalArgumentException("wow, rude")).when(client).invokeStream(any(), any(), any(), any()); + + sageMakerService.unifiedCompletionInfer(model, randomUnifiedCompletionRequest(), THIRTY_SECONDS, assertNoSuccessListener(e -> { + verify(schemas, only()).streamSchemaFor(eq(model)); + verify(schema, times(1)).chatCompletionStreamRequest(eq(model), any()); + assertThat(e, isA(ElasticsearchStatusException.class)); + assertThat(((ElasticsearchStatusException) e).status(), equalTo(RestStatus.INTERNAL_SERVER_ERROR)); + assertThat(e.getMessage(), equalTo("Failed to call SageMaker for inference id [some id].")); + })); + verify(client, only()).invokeStream(any(), any(), any(), any()); + verifyNoMoreInteractions(client, schemas, schema); + } + + public void testChunkedInferWithWrongModel() { + sageMakerService.chunkedInfer( + mockUnsupportedModel(), + QUERY, + INPUT.stream().map(ChunkInferenceInput::new).toList(), + null, + INPUT_TYPE, + THIRTY_SECONDS, + assertUnsupportedModel() + ); + } + + public void testChunkedInfer() throws Exception { + var model = mockModelForChunking(); + + SageMakerSchema schema = mock(); + when(schema.response(any(), any(), any())).thenReturn(TextEmbeddingFloatResultsTests.createRandomResults()); + when(schemas.schemaFor(model)).thenReturn(schema); + mockInvoke(); + + var expectedInput = Set.of("first", "second"); + + sageMakerService.chunkedInfer( + model, + QUERY, + expectedInput.stream().map(ChunkInferenceInput::new).toList(), + null, + INPUT_TYPE, + THIRTY_SECONDS, + assertOnce(assertNoFailureListener(chunkedInferences -> { + verify(schemas, times(2)).schemaFor(eq(model)); + verify(schema, times(2)).request(eq(model), assertChunkRequest(expectedInput)); + verify(schema, times(2)).response(eq(model), any(), any()); + })) + ); + verify(client, times(2)).invoke(any(), any(), any(), any()); + verifyNoMoreInteractions(client, schemas, schema); + } + + private SageMakerModel mockModelForChunking() { + var model = mockModel(); + when(model.batchSize()).thenReturn(Optional.of(1)); + ModelConfigurations modelConfigurations = mock(); + when(modelConfigurations.getChunkingSettings()).thenReturn(new WordBoundaryChunkingSettings(1, 0)); + when(model.getConfigurations()).thenReturn(modelConfigurations); + return model; + } + + private static SageMakerInferenceRequest assertChunkRequest(Set expectedInput) { + return assertArg(request -> { + assertThat(request.query(), equalTo(QUERY)); + assertThat(request.input(), hasSize(1)); + assertThat(request.input().get(0), in(expectedInput)); + assertThat(request.inputType(), equalTo(InputType.UNSPECIFIED)); + assertNull(request.returnDocuments()); + assertNull(request.topN()); + assertFalse(request.stream()); + }); + } + + public void testChunkedInferError() { + var model = mockModelForChunking(); + + var expectedException = new IllegalArgumentException("hola"); + SageMakerSchema schema = mock(); + when(schema.error(any(), any())).thenReturn(expectedException); + when(schemas.schemaFor(model)).thenReturn(schema); + mockInvokeFailure(expectedException); + + var expectedInput = Set.of("first", "second"); + var expectedOutput = Stream.of(expectedException, expectedException).map(ChunkedInferenceError::new).toArray(); + + sageMakerService.chunkedInfer( + model, + QUERY, + expectedInput.stream().map(ChunkInferenceInput::new).toList(), + null, + INPUT_TYPE, + THIRTY_SECONDS, + assertOnce(assertNoFailureListener(chunkedInferences -> { + verify(schemas, times(2)).schemaFor(eq(model)); + verify(schema, times(2)).request(eq(model), assertChunkRequest(expectedInput)); + verify(schema, times(2)).error(eq(model), assertArg(e -> assertThat(e, equalTo(expectedException)))); + assertThat(chunkedInferences, containsInAnyOrder(expectedOutput)); + })) + ); + verify(client, times(2)).invoke(any(), any(), any(), any()); + verifyNoMoreInteractions(client, schemas, schema); + } + + public void testClose() throws IOException { + sageMakerService.close(); + verify(client, only()).close(); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModelBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModelBuilderTests.java new file mode 100644 index 000000000000..4228cb781a49 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModelBuilderTests.java @@ -0,0 +1,315 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.model; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemasTests; +import org.junit.Before; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.Optional; + +import static org.elasticsearch.inference.ModelConfigurations.USE_ID_FOR_INDEX; +import static org.hamcrest.Matchers.equalTo; + +public class SageMakerModelBuilderTests extends ESTestCase { + private static final String inferenceId = "inferenceId"; + private static final TaskType taskType = TaskType.ANY; + private static final String service = "service"; + private SageMakerModelBuilder builder; + + @Before + public void setUp() throws Exception { + super.setUp(); + builder = new SageMakerModelBuilder(SageMakerSchemasTests.mockSchemas()); + } + + public void testFromRequestWithRequiredFields() { + var model = fromRequest(""" + { + "service_settings": { + "access_key": "test-access-key", + "secret_key": "test-secret-key", + "region": "us-east-1", + "api": "test-api", + "endpoint_name": "test-endpoint" + } + } + """); + + assertNotNull(model); + assertTrue(model.awsSecretSettings().isPresent()); + assertThat(model.awsSecretSettings().get().accessKey().toString(), equalTo("test-access-key")); + assertThat(model.awsSecretSettings().get().secretKey().toString(), equalTo("test-secret-key")); + assertThat(model.region(), equalTo("us-east-1")); + assertThat(model.api(), equalTo("test-api")); + assertThat(model.endpointName(), equalTo("test-endpoint")); + + assertTrue(model.customAttributes().isEmpty()); + assertTrue(model.enableExplanations().isEmpty()); + assertTrue(model.inferenceComponentName().isEmpty()); + assertTrue(model.inferenceIdForDataCapture().isEmpty()); + assertTrue(model.sessionId().isEmpty()); + assertTrue(model.targetContainerHostname().isEmpty()); + assertTrue(model.targetModel().isEmpty()); + assertTrue(model.targetVariant().isEmpty()); + assertTrue(model.batchSize().isEmpty()); + } + + public void testFromRequestWithOptionalFields() { + var model = fromRequest(""" + { + "service_settings": { + "access_key": "test-access-key", + "secret_key": "test-secret-key", + "region": "us-east-1", + "api": "test-api", + "endpoint_name": "test-endpoint", + "target_model": "test-target", + "target_container_hostname": "test-target-container", + "inference_component_name": "test-inference-component", + "batch_size": 1234 + }, + "task_settings": { + "custom_attributes": "test-custom-attributes", + "enable_explanations": "test-enable-explanations", + "inference_id": "test-inference-id", + "session_id": "test-session-id", + "target_variant": "test-target-variant" + } + } + """); + + assertNotNull(model); + assertTrue(model.awsSecretSettings().isPresent()); + assertThat(model.awsSecretSettings().get().accessKey().toString(), equalTo("test-access-key")); + assertThat(model.awsSecretSettings().get().secretKey().toString(), equalTo("test-secret-key")); + assertThat(model.region(), equalTo("us-east-1")); + assertThat(model.api(), equalTo("test-api")); + assertThat(model.endpointName(), equalTo("test-endpoint")); + + assertPresent(model.customAttributes(), "test-custom-attributes"); + assertPresent(model.enableExplanations(), "test-enable-explanations"); + assertPresent(model.inferenceComponentName(), "test-inference-component"); + assertPresent(model.inferenceIdForDataCapture(), "test-inference-id"); + assertPresent(model.sessionId(), "test-session-id"); + assertPresent(model.targetContainerHostname(), "test-target-container"); + assertPresent(model.targetModel(), "test-target"); + assertPresent(model.targetVariant(), "test-target-variant"); + assertPresent(model.batchSize(), 1234); + } + + public void testFromRequestWithoutAccessKey() { + testExceptionFromRequest(""" + { + "service_settings": { + "secret_key": "test-secret-key", + "region": "us-east-1", + "api": "test-api", + "endpoint_name": "test-endpoint" + } + } + """, ValidationException.class, "Validation Failed: 1: [secret_settings] does not contain the required setting [access_key];"); + } + + public void testFromRequestWithoutSecretKey() { + testExceptionFromRequest(""" + { + "service_settings": { + "access_key": "test-access-key", + "region": "us-east-1", + "api": "test-api", + "endpoint_name": "test-endpoint" + } + } + """, ValidationException.class, "Validation Failed: 1: [secret_settings] does not contain the required setting [secret_key];"); + } + + public void testFromRequestWithoutRegion() { + testExceptionFromRequest(""" + { + "service_settings": { + "access_key": "test-access-key", + "secret_key": "test-secret-key", + "api": "test-api", + "endpoint_name": "test-endpoint" + } + } + """, ValidationException.class, "Validation Failed: 1: [service_settings] does not contain the required setting [region];"); + } + + public void testFromRequestWithoutApi() { + testExceptionFromRequest(""" + { + "service_settings": { + "access_key": "test-access-key", + "secret_key": "test-secret-key", + "region": "us-east-1", + "endpoint_name": "test-endpoint" + } + } + """, ValidationException.class, "Validation Failed: 1: [service_settings] does not contain the required setting [api];"); + } + + public void testFromRequestWithoutEndpointName() { + testExceptionFromRequest( + """ + { + "service_settings": { + "access_key": "test-access-key", + "secret_key": "test-secret-key", + "region": "us-east-1", + "api": "test-api" + } + } + """, + ValidationException.class, + "Validation Failed: 1: [service_settings] does not contain the required setting [endpoint_name];" + ); + } + + public void testFromRequestWithExtraServiceKeys() { + testExceptionFromRequest( + """ + { + "service_settings": { + "access_key": "test-access-key", + "secret_key": "test-secret-key", + "region": "us-east-1", + "api": "test-api", + "endpoint_name": "test-endpoint", + "hello": "there" + } + } + """, + ElasticsearchStatusException.class, + "Model configuration contains settings [{hello=there}] unknown to the [service] service" + ); + } + + public void testFromRequestWithExtraTaskKeys() { + testExceptionFromRequest( + """ + { + "service_settings": { + "access_key": "test-access-key", + "secret_key": "test-secret-key", + "region": "us-east-1", + "api": "test-api", + "endpoint_name": "test-endpoint" + }, + "task_settings": { + "hello": "there" + } + } + """, + ElasticsearchStatusException.class, + "Model configuration contains settings [{hello=there}] unknown to the [service] service" + ); + } + + public void testRoundTrip() throws IOException { + var expectedModel = fromRequest(""" + { + "service_settings": { + "access_key": "test-access-key", + "secret_key": "test-secret-key", + "region": "us-east-1", + "api": "test-api", + "endpoint_name": "test-endpoint", + "target_model": "test-target", + "target_container_hostname": "test-target-container", + "inference_component_name": "test-inference-component", + "batch_size": 1234 + }, + "task_settings": { + "custom_attributes": "test-custom-attributes", + "enable_explanations": "test-enable-explanations", + "inference_id": "test-inference-id", + "session_id": "test-session-id", + "target_variant": "test-target-variant" + } + } + """); + + var unparsedModelWithSecrets = unparsedModel(expectedModel.getConfigurations(), expectedModel.getSecrets()); + var modelWithSecrets = builder.fromStorage( + unparsedModelWithSecrets.inferenceEntityId(), + unparsedModelWithSecrets.taskType(), + unparsedModelWithSecrets.service(), + unparsedModelWithSecrets.settings(), + unparsedModelWithSecrets.secrets() + ); + assertThat(modelWithSecrets, equalTo(expectedModel)); + assertNotNull(modelWithSecrets.getSecrets().getSecretSettings()); + + var unparsedModelWithoutSecrets = unparsedModel(expectedModel.getConfigurations(), null); + var modelWithoutSecrets = builder.fromStorage( + unparsedModelWithoutSecrets.inferenceEntityId(), + unparsedModelWithoutSecrets.taskType(), + unparsedModelWithoutSecrets.service(), + unparsedModelWithoutSecrets.settings(), + unparsedModelWithoutSecrets.secrets() + ); + assertThat(modelWithoutSecrets.getConfigurations(), equalTo(expectedModel.getConfigurations())); + assertNull(modelWithoutSecrets.getSecrets().getSecretSettings()); + } + + private SageMakerModel fromRequest(String json) { + return builder.fromRequest(inferenceId, taskType, service, map(json)); + } + + private void testExceptionFromRequest(String json, Class exceptionClass, String message) { + var exception = assertThrows(exceptionClass, () -> fromRequest(json)); + assertThat(exception.getMessage(), equalTo(message)); + } + + private static void assertPresent(Optional optional, T expectedValue) { + assertTrue(optional.isPresent()); + assertThat(optional.get(), equalTo(expectedValue)); + } + + private static Map map(String json) { + try ( + var parser = JsonXContent.jsonXContent.createParser(XContentParserConfiguration.EMPTY, json.getBytes(StandardCharsets.UTF_8)) + ) { + return parser.map(); + } catch (IOException e) { + throw new AssertionError(e); + } + } + + private static UnparsedModel unparsedModel(ModelConfigurations modelConfigurations, ModelSecrets modelSecrets) throws IOException { + var modelConfigMap = new ModelRegistry.ModelConfigMap( + toJsonMap(modelConfigurations), + modelSecrets != null ? toJsonMap(modelSecrets) : null + ); + + return ModelRegistry.unparsedModelFromMap(modelConfigMap); + } + + private static Map toJsonMap(ToXContent toXContent) throws IOException { + try (var builder = JsonXContent.contentBuilder()) { + toXContent.toXContent(builder, new ToXContent.MapParams(Map.of(USE_ID_FOR_INDEX, "true"))); + return map(Strings.toString(builder)); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerServiceSettingsTests.java new file mode 100644 index 000000000000..de08871f354d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerServiceSettingsTests.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.model; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemasTests; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema; + +import java.util.Map; + +public class SageMakerServiceSettingsTests extends InferenceSettingsTestCase { + + @Override + protected Writeable.Reader instanceReader() { + return SageMakerServiceSettings::new; + } + + @Override + protected SageMakerServiceSettings createTestInstance() { + return new SageMakerServiceSettings( + randomString(), + randomString(), + randomString(), + randomOptionalString(), + randomOptionalString(), + randomOptionalString(), + randomBoolean() ? randomIntBetween(1, 1000) : null, + SageMakerStoredServiceSchema.NO_OP + ); + } + + @Override + protected SageMakerServiceSettings fromMutableMap(Map mutableMap) { + return SageMakerServiceSettings.fromMap(SageMakerSchemasTests.mockSchemas(), randomFrom(TaskType.values()), mutableMap); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerTaskSettingsTests.java new file mode 100644 index 000000000000..e4f7118ed173 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerTaskSettingsTests.java @@ -0,0 +1,42 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.model; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema; + +import java.util.Map; + +public class SageMakerTaskSettingsTests extends InferenceSettingsTestCase { + @Override + protected Writeable.Reader instanceReader() { + return SageMakerTaskSettings::new; + } + + @Override + protected SageMakerTaskSettings createTestInstance() { + return new SageMakerTaskSettings( + randomOptionalString(), + randomOptionalString(), + randomOptionalString(), + randomOptionalString(), + randomOptionalString(), + SageMakerStoredTaskSchema.NO_OP + ); + } + + @Override + protected SageMakerTaskSettings fromMutableMap(Map mutableMap) { + var validationException = new ValidationException(); + var taskSettings = SageMakerTaskSettings.fromMap(mutableMap, SageMakerStoredTaskSchema.NO_OP, validationException); + validationException.throwIfValidationErrorsExist(); + return taskSettings; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemaPayloadTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemaPayloadTestCase.java new file mode 100644 index 000000000000..4e480ed4c17b --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemaPayloadTestCase.java @@ -0,0 +1,156 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema; + +import software.amazon.awssdk.core.SdkBytes; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.VersionedNamedWriteable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest; +import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; +import org.junit.Before; + +import java.io.IOException; +import java.util.Set; +import java.util.function.Predicate; +import java.util.stream.Stream; + +import static org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase.toMap; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.startsWith; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public abstract class SageMakerSchemaPayloadTestCase extends ESTestCase { + protected T payload; + + @Before + public void setUp() throws Exception { + super.setUp(); + payload = payload(); + } + + protected abstract T payload(); + + protected abstract String expectedApi(); + + protected abstract Set expectedSupportedTaskTypes(); + + protected abstract SageMakerStoredServiceSchema randomApiServiceSettings(); + + protected abstract SageMakerStoredTaskSchema randomApiTaskSettings(); + + public final void testApi() { + assertThat(payload.api(), equalTo(expectedApi())); + } + + public final void testSupportedTaskTypes() { + assertThat(payload.supportedTasks(), containsInAnyOrder(expectedSupportedTaskTypes().toArray())); + } + + public final void testApiServiceSettings() throws IOException { + var validationException = new ValidationException(); + var expectedApiServiceSettings = randomApiServiceSettings(); + var actualApiServiceSettings = payload.apiServiceSettings(toMap(expectedApiServiceSettings), validationException); + assertThat(actualApiServiceSettings, equalTo(expectedApiServiceSettings)); + validationException.throwIfValidationErrorsExist(); + } + + public final void testApiTaskSettings() throws IOException { + var validationException = new ValidationException(); + var expectedApiTaskSettings = randomApiTaskSettings(); + var actualApiTaskSettings = payload.apiTaskSettings(toMap(expectedApiTaskSettings), validationException); + assertThat(actualApiTaskSettings, equalTo(expectedApiTaskSettings)); + validationException.throwIfValidationErrorsExist(); + } + + public void testNamedWriteables() { + var namedWriteables = payload.namedWriteables().map(entry -> entry.name).toList(); + + var filteredNames = Set.of( + SageMakerStoredServiceSchema.NO_OP.getWriteableName(), + SageMakerStoredTaskSchema.NO_OP.getWriteableName() + ); + var expectedWriteables = Stream.of(randomApiServiceSettings(), randomApiTaskSettings()) + .map(VersionedNamedWriteable::getWriteableName) + .filter(Predicate.not(filteredNames::contains)) + .toArray(); + assertThat(namedWriteables, containsInAnyOrder(expectedWriteables)); + } + + public final void testWithUnknownApiServiceSettings() { + SageMakerModel model = mock(); + when(model.apiServiceSettings()).thenReturn(mock()); + when(model.apiTaskSettings()).thenReturn(randomApiTaskSettings()); + when(model.api()).thenReturn("serviceApi"); + when(model.getTaskType()).thenReturn(TaskType.ANY); + + var e = assertThrows(IllegalArgumentException.class, () -> payload.requestBytes(model, randomRequest())); + + assertThat(e.getMessage(), startsWith("Unsupported SageMaker settings for api [serviceApi] and task type [any]:")); + } + + public final void testWithUnknownApiTaskSettings() { + SageMakerModel model = mock(); + when(model.apiServiceSettings()).thenReturn(randomApiServiceSettings()); + when(model.apiTaskSettings()).thenReturn(mock()); + when(model.api()).thenReturn("taskApi"); + when(model.getTaskType()).thenReturn(TaskType.ANY); + + var e = assertThrows(IllegalArgumentException.class, () -> payload.requestBytes(model, randomRequest())); + + assertThat(e.getMessage(), startsWith("Unsupported SageMaker settings for api [taskApi] and task type [any]:")); + } + + public final void testUpdate() throws IOException { + var taskSettings = randomApiTaskSettings(); + if (taskSettings != SageMakerStoredTaskSchema.NO_OP) { + var otherTaskSettings = randomValueOtherThan(taskSettings, this::randomApiTaskSettings); + + var updatedSettings = toMap(taskSettings.updatedTaskSettings(toMap(otherTaskSettings))); + + var initialSettings = toMap(taskSettings); + var newSettings = toMap(otherTaskSettings); + + newSettings.forEach((key, value) -> { + assertThat("Value should have been updated for key " + key, value, equalTo(updatedSettings.remove(key))); + }); + initialSettings.forEach((key, value) -> { + if (updatedSettings.containsKey(key)) { + assertThat("Value should not have been updated for key " + key, value, equalTo(updatedSettings.remove(key))); + } + }); + assertTrue("Map should be empty now that we verified all updated keys and all initial keys", updatedSettings.isEmpty()); + } + if (payload instanceof SageMakerStoredTaskSchema taskSchema) { + var otherTaskSettings = randomValueOtherThan(randomApiTaskSettings(), this::randomApiTaskSettings); + var otherTaskSettingsAsMap = toMap(otherTaskSettings); + + taskSchema.updatedTaskSettings(otherTaskSettingsAsMap); + } + } + + protected static SageMakerInferenceRequest randomRequest() { + return new SageMakerInferenceRequest( + randomBoolean() ? randomAlphaOfLengthBetween(4, 8) : null, + randomOptionalBoolean(), + randomBoolean() ? randomInt() : null, + randomList(randomIntBetween(2, 4), () -> randomAlphaOfLengthBetween(2, 4)), + randomBoolean(), + randomFrom(InputType.values()) + ); + } + + protected static void assertSdkBytes(SdkBytes sdkBytes, String expectedValue) { + assertThat(sdkBytes.asUtf8String(), equalTo(expectedValue)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemasTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemasTests.java new file mode 100644 index 000000000000..8e3c30a95e36 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemasTests.java @@ -0,0 +1,119 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiTextEmbeddingPayload; + +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.empty; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class SageMakerSchemasTests extends ESTestCase { + public static SageMakerSchemas mockSchemas() { + SageMakerSchemas schemas = mock(); + var schema = mockSchema(); + when(schemas.schemaFor(any(TaskType.class), anyString())).thenReturn(schema); + return schemas; + } + + public static SageMakerSchema mockSchema() { + SageMakerSchema schema = mock(); + when(schema.apiServiceSettings(anyMap(), any())).thenReturn(SageMakerStoredServiceSchema.NO_OP); + when(schema.apiTaskSettings(anyMap(), any())).thenReturn(SageMakerStoredTaskSchema.NO_OP); + return schema; + } + + private static final SageMakerSchemas schemas = new SageMakerSchemas(); + + public void testSupportedTaskTypes() { + assertThat(schemas.supportedTaskTypes(), containsInAnyOrder(TaskType.TEXT_EMBEDDING)); + } + + public void testSupportedStreamingTasks() { + assertThat(schemas.supportedStreamingTasks(), empty()); + } + + public void testSchemaFor() { + var payloads = Stream.of(new OpenAiTextEmbeddingPayload()); + payloads.forEach(payload -> { + payload.supportedTasks().forEach(taskType -> { + var model = mockModel(taskType, payload.api()); + assertNotNull(schemas.schemaFor(model)); + }); + }); + } + + public void testStreamSchemaFor() { + var payloads = Stream.of(/* For when we add support for streaming payloads */); + payloads.forEach(payload -> { + payload.supportedTasks().forEach(taskType -> { + var model = mockModel(taskType, payload.api()); + assertNotNull(schemas.streamSchemaFor(model)); + }); + }); + } + + private SageMakerModel mockModel(TaskType taskType, String api) { + SageMakerModel model = mock(); + when(model.getTaskType()).thenReturn(taskType); + when(model.api()).thenReturn(api); + return model; + } + + public void testMissingTaskTypeThrowsException() { + var knownPayload = new OpenAiTextEmbeddingPayload(); + var unknownTaskType = TaskType.COMPLETION; + var knownModel = mockModel(unknownTaskType, knownPayload.api()); + assertThrows( + "Task [completion] is not compatible for service [sagemaker] and api [openai]. Supported tasks: [text_embedding]", + ElasticsearchStatusException.class, + () -> schemas.schemaFor(knownModel) + ); + } + + public void testMissingSchemaThrowsException() { + var unknownModel = mockModel(TaskType.ANY, "blah"); + assertThrows( + "Task [any] is not compatible for service [sagemaker] and api [blah]. Supported tasks: []", + ElasticsearchStatusException.class, + () -> schemas.schemaFor(unknownModel) + ); + } + + public void testMissingStreamSchemaThrowsException() { + var unknownModel = mockModel(TaskType.ANY, "blah"); + assertThrows( + "Streaming is not allowed for service [sagemaker], api [blah], and task [any]. Supported streaming tasks: []", + ElasticsearchStatusException.class, + () -> schemas.streamSchemaFor(unknownModel) + ); + } + + public void testNamedWriteables() { + var namedWriteables = Stream.of(new OpenAiTextEmbeddingPayload().namedWriteables()); + + var expectedNamedWriteables = Stream.concat( + namedWriteables.flatMap(names -> names.map(entry -> entry.name)), + Stream.of(SageMakerStoredServiceSchema.NO_OP.getWriteableName(), SageMakerStoredTaskSchema.NO_OP.getWriteableName()) + ).distinct().toArray(); + + var actualRegisteredNames = SageMakerSchemas.namedWriteables().stream().map(entry -> entry.name).toList(); + + assertThat(actualRegisteredNames, containsInAnyOrder(expectedNamedWriteables)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayloadTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayloadTests.java new file mode 100644 index 000000000000..7a85a5e05fab --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayloadTests.java @@ -0,0 +1,155 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema.openai; + +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse; + +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest; +import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemaPayloadTestCase; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema; +import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema; + +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Set; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class OpenAiTextEmbeddingPayloadTests extends SageMakerSchemaPayloadTestCase { + @Override + protected OpenAiTextEmbeddingPayload payload() { + return new OpenAiTextEmbeddingPayload(); + } + + @Override + protected String expectedApi() { + return "openai"; + } + + @Override + protected Set expectedSupportedTaskTypes() { + return Set.of(TaskType.TEXT_EMBEDDING); + } + + @Override + protected SageMakerStoredServiceSchema randomApiServiceSettings() { + return SageMakerOpenAiServiceSettingsTests.randomApiServiceSettings(); + } + + @Override + protected SageMakerStoredTaskSchema randomApiTaskSettings() { + return SageMakerOpenAiTaskSettingsTests.randomApiTaskSettings(); + } + + public void testAccept() { + assertThat(payload.accept(mock()), equalTo("application/json")); + } + + public void testContentType() { + assertThat(payload.contentType(mock()), equalTo("application/json")); + } + + public void testRequestWithSingleInput() throws Exception { + SageMakerModel model = mock(); + when(model.apiServiceSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiServiceSettings(null, false)); + when(model.apiTaskSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiTaskSettings((String) null)); + var request = new SageMakerInferenceRequest(null, null, null, List.of("hello"), randomBoolean(), randomFrom(InputType.values())); + + var sdkByes = payload.requestBytes(model, request); + assertSdkBytes(sdkByes, """ + {"input":"hello"}"""); + } + + public void testRequestWithArrayInput() throws Exception { + SageMakerModel model = mock(); + when(model.apiServiceSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiServiceSettings(null, false)); + when(model.apiTaskSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiTaskSettings((String) null)); + var request = new SageMakerInferenceRequest( + null, + null, + null, + List.of("hello", "there"), + randomBoolean(), + randomFrom(InputType.values()) + ); + + var sdkByes = payload.requestBytes(model, request); + assertSdkBytes(sdkByes, """ + {"input":["hello","there"]}"""); + } + + public void testRequestWithDimensionsNotSetByUserIgnoreDimensions() throws Exception { + SageMakerModel model = mock(); + when(model.apiServiceSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiServiceSettings(123, false)); + when(model.apiTaskSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiTaskSettings((String) null)); + var request = new SageMakerInferenceRequest( + null, + null, + null, + List.of("hello", "there"), + randomBoolean(), + randomFrom(InputType.values()) + ); + + var sdkByes = payload.requestBytes(model, request); + assertSdkBytes(sdkByes, """ + {"input":["hello","there"]}"""); + } + + public void testRequestWithOptionals() throws Exception { + SageMakerModel model = mock(); + when(model.apiServiceSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiServiceSettings(1234, true)); + when(model.apiTaskSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiTaskSettings("user")); + var request = new SageMakerInferenceRequest("query", null, null, List.of("hello"), randomBoolean(), randomFrom(InputType.values())); + + var sdkByes = payload.requestBytes(model, request); + assertSdkBytes(sdkByes, """ + {"query":"query","input":"hello","user":"user","dimensions":1234}"""); + } + + public void testResponse() throws Exception { + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.014539449, + -0.015288644 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + var invokeEndpointResponse = InvokeEndpointResponse.builder() + .body(SdkBytes.fromString(responseJson, StandardCharsets.UTF_8)) + .build(); + + var textEmbeddingFloatResults = payload.responseBody(mock(), invokeEndpointResponse); + + assertThat( + textEmbeddingFloatResults.embeddings(), + is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F }))) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/SageMakerOpenAiServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/SageMakerOpenAiServiceSettingsTests.java new file mode 100644 index 000000000000..ce8d27672249 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/SageMakerOpenAiServiceSettingsTests.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema.openai; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase; + +import java.util.Map; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.sameInstance; + +public class SageMakerOpenAiServiceSettingsTests extends InferenceSettingsTestCase { + @Override + protected OpenAiTextEmbeddingPayload.ApiServiceSettings fromMutableMap(Map mutableMap) { + var validationException = new ValidationException(); + var settings = OpenAiTextEmbeddingPayload.ApiServiceSettings.fromMap(mutableMap, validationException); + validationException.throwIfValidationErrorsExist(); + return settings; + } + + @Override + protected Writeable.Reader instanceReader() { + return OpenAiTextEmbeddingPayload.ApiServiceSettings::new; + } + + @Override + protected OpenAiTextEmbeddingPayload.ApiServiceSettings createTestInstance() { + return randomApiServiceSettings(); + } + + static OpenAiTextEmbeddingPayload.ApiServiceSettings randomApiServiceSettings() { + var dimensions = randomBoolean() ? randomIntBetween(1, 100) : null; + return new OpenAiTextEmbeddingPayload.ApiServiceSettings(dimensions, dimensions != null); + } + + public void testDimensionsSetByUser() { + var expectedDimensions = randomIntBetween(1, 100); + var dimensionlessSettings = new OpenAiTextEmbeddingPayload.ApiServiceSettings(null, false); + var updatedSettings = dimensionlessSettings.updateModelWithEmbeddingDetails(expectedDimensions); + assertThat(updatedSettings, not(sameInstance(dimensionlessSettings))); + assertThat(updatedSettings.dimensions(), equalTo(expectedDimensions)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/SageMakerOpenAiTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/SageMakerOpenAiTaskSettingsTests.java new file mode 100644 index 000000000000..1eaaf4100f5f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/openai/SageMakerOpenAiTaskSettingsTests.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.sagemaker.schema.openai; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase; + +import java.util.Map; + +public class SageMakerOpenAiTaskSettingsTests extends InferenceSettingsTestCase { + @Override + protected OpenAiTextEmbeddingPayload.ApiTaskSettings fromMutableMap(Map mutableMap) { + var validationException = new ValidationException(); + var settings = OpenAiTextEmbeddingPayload.ApiTaskSettings.fromMap(mutableMap, validationException); + validationException.throwIfValidationErrorsExist(); + return settings; + } + + @Override + protected Writeable.Reader instanceReader() { + return OpenAiTextEmbeddingPayload.ApiTaskSettings::new; + } + + @Override + protected OpenAiTextEmbeddingPayload.ApiTaskSettings createTestInstance() { + return randomApiTaskSettings(); + } + + static OpenAiTextEmbeddingPayload.ApiTaskSettings randomApiTaskSettings() { + return new OpenAiTextEmbeddingPayload.ApiTaskSettings(randomOptionalString()); + } +}