[ML] Integrate SageMaker with OpenAI Embeddings (#126856)

Integrating with SageMaker.

Current design:
- SageMaker accepts any byte payload, which can be text, csv, or json. `api` represents the structure of the payload that we will send, for example `openai`, `elastic`, `common`, probably `cohere` or `huggingface` as well.
- `api` implementations are extensions of `SageMakerSchemaPayload`, which supports:
  - "extra" service and task settings specific to the payload structure, so `cohere` would require `embedding_type` and `openai` would require `dimensions` in the `service_settings`
  - conversion logic from model, service settings, task settings, and input to `SdkBytes`
  - conversion logic from responding `SdkBytes` to `InferenceServiceResults`
- Everything else is tunneling, there are a number of base `service_settings` and `task_settings` that are independent of the api format that we will store and set
- We let the SDK do the bulk of the work in terms of connection details, rate limiting, retries, etc.
This commit is contained in:
Pat Whelan 2025-05-01 12:57:13 -04:00 committed by GitHub
parent 1b35cceacf
commit 245f5eebce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
39 changed files with 4304 additions and 139 deletions

View file

@ -0,0 +1,5 @@
pr: 126856
summary: "[ML] Integrate SageMaker with OpenAI Embeddings"
area: Machine Learning
type: enhancement
issues: []

View file

@ -4912,6 +4912,11 @@
<sha256 value="c83dd82a9d82ff8c7d2eb1bdb2ae9f9505b312dad9a6bf0b80bc0136653a3a24" origin="Generated by Gradle"/>
</artifact>
</component>
<component group="software.amazon.awssdk" name="sagemakerruntime" version="2.30.38">
<artifact name="sagemakerruntime-2.30.38.jar">
<sha256 value="b26ee73fa06d047eab9a174e49627972e646c0bbe909f479c18dbff193b561f5" origin="Generated by Gradle"/>
</artifact>
</component>
<component group="software.amazon.awssdk" name="sdk-core" version="2.30.38">
<artifact name="sdk-core-2.30.38.jar">
<sha256 value="556463b8c353408d93feab74719d141fcfda7fd3d7b7d1ad3a8a548b7cc2982d" origin="Generated by Gradle"/>

View file

@ -162,6 +162,7 @@ public class TransportVersions {
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17);
public static final TransportVersion BATCHED_QUERY_PHASE_VERSION_BACKPORT_8_X = def(8_841_0_19);
public static final TransportVersion SEARCH_INCREMENTAL_TOP_DOCS_NULL_BACKPORT_8_19 = def(8_841_0_20);
public static final TransportVersion ML_INFERENCE_SAGEMAKER_8_19 = def(8_841_0_21);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED = def(9_001_0_00);
@ -232,6 +233,7 @@ public class TransportVersions {
public static final TransportVersion PROJECT_METADATA_SETTINGS = def(9_066_00_0);
public static final TransportVersion AGGREGATE_METRIC_DOUBLE_BLOCK = def(9_067_00_0);
public static final TransportVersion PINNED_RETRIEVER = def(9_068_0_00);
public static final TransportVersion ML_INFERENCE_SAGEMAKER = def(9_069_0_00);
/*
* STOP! READ THIS FIRST! No, really,

View file

@ -53,6 +53,12 @@ public class ValidationException extends IllegalArgumentException {
return validationErrors;
}
public final void throwIfValidationErrorsExist() {
if (validationErrors().isEmpty() == false) {
throw this;
}
}
@Override
public final String getMessage() {
StringBuilder sb = new StringBuilder();

View file

@ -62,6 +62,7 @@ dependencies {
/* AWS SDK v2 */
implementation("software.amazon.awssdk:bedrockruntime:${versions.awsv2sdk}")
implementation("software.amazon.awssdk:sagemakerruntime:${versions.awsv2sdk}")
api "software.amazon.awssdk:protocol-core:${versions.awsv2sdk}"
api "software.amazon.awssdk:aws-json-protocol:${versions.awsv2sdk}"
api "software.amazon.awssdk:third-party-jackson-core:${versions.awsv2sdk}"
@ -142,6 +143,7 @@ tasks.named("dependencyLicenses").configure {
mapping from: /json-utils.*/, to: 'aws-sdk-2'
mapping from: /endpoints-spi.*/, to: 'aws-sdk-2'
mapping from: /bedrockruntime.*/, to: 'aws-sdk-2'
mapping from: /sagemakerruntime.*/, to: 'aws-sdk-2'
mapping from: /netty-nio-client/, to: 'aws-sdk-2'
/* Cannot use REGEX to match netty-* because netty-nio-client is an AWS package */
mapping from: /netty-buffer/, to: 'netty'

View file

@ -18,163 +18,161 @@ import java.util.List;
import java.util.Map;
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;
public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
@SuppressWarnings("unchecked")
public void testGetServicesWithoutTaskType() throws IOException {
List<Object> services = getAllServices();
assertThat(services.size(), equalTo(21));
assertThat(services.size(), equalTo(22));
String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
providers[i] = (String) serviceConfig.get("service");
}
var providers = providers(services);
assertArrayEquals(
List.of(
"alibabacloud-ai-search",
"amazonbedrock",
"anthropic",
"azureaistudio",
"azureopenai",
"cohere",
"deepseek",
"elastic",
"elasticsearch",
"googleaistudio",
"googlevertexai",
"hugging_face",
"jinaai",
"mistral",
"openai",
"streaming_completion_test_service",
"test_reranking_service",
"test_service",
"text_embedding_test_service",
"voyageai",
"watsonxai"
).toArray(),
providers
assertThat(
providers,
containsInAnyOrder(
List.of(
"alibabacloud-ai-search",
"amazonbedrock",
"anthropic",
"azureaistudio",
"azureopenai",
"cohere",
"deepseek",
"elastic",
"elasticsearch",
"googleaistudio",
"googlevertexai",
"hugging_face",
"jinaai",
"mistral",
"openai",
"streaming_completion_test_service",
"test_reranking_service",
"test_service",
"text_embedding_test_service",
"voyageai",
"watsonxai",
"sagemaker"
).toArray()
)
);
}
@SuppressWarnings("unchecked")
private Iterable<String> providers(List<Object> services) {
return services.stream().map(service -> {
var serviceConfig = (Map<String, Object>) service;
return (String) serviceConfig.get("service");
}).toList();
}
public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
assertThat(services.size(), equalTo(15));
assertThat(services.size(), equalTo(16));
String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
providers[i] = (String) serviceConfig.get("service");
}
var providers = providers(services);
assertArrayEquals(
List.of(
"alibabacloud-ai-search",
"amazonbedrock",
"azureaistudio",
"azureopenai",
"cohere",
"elasticsearch",
"googleaistudio",
"googlevertexai",
"hugging_face",
"jinaai",
"mistral",
"openai",
"text_embedding_test_service",
"voyageai",
"watsonxai"
).toArray(),
providers
assertThat(
providers,
containsInAnyOrder(
List.of(
"alibabacloud-ai-search",
"amazonbedrock",
"azureaistudio",
"azureopenai",
"cohere",
"elasticsearch",
"googleaistudio",
"googlevertexai",
"hugging_face",
"jinaai",
"mistral",
"openai",
"text_embedding_test_service",
"voyageai",
"watsonxai",
"sagemaker"
).toArray()
)
);
}
@SuppressWarnings("unchecked")
public void testGetServicesWithRerankTaskType() throws IOException {
List<Object> services = getServices(TaskType.RERANK);
assertThat(services.size(), equalTo(7));
String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
providers[i] = (String) serviceConfig.get("service");
}
var providers = providers(services);
assertArrayEquals(
List.of("alibabacloud-ai-search", "cohere", "elasticsearch", "googlevertexai", "jinaai", "test_reranking_service", "voyageai")
.toArray(),
providers
assertThat(
providers,
containsInAnyOrder(
List.of(
"alibabacloud-ai-search",
"cohere",
"elasticsearch",
"googlevertexai",
"jinaai",
"test_reranking_service",
"voyageai"
).toArray()
)
);
}
@SuppressWarnings("unchecked")
public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(10));
String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
providers[i] = (String) serviceConfig.get("service");
}
var providers = providers(services);
assertArrayEquals(
List.of(
"alibabacloud-ai-search",
"amazonbedrock",
"anthropic",
"azureaistudio",
"azureopenai",
"cohere",
"deepseek",
"googleaistudio",
"openai",
"streaming_completion_test_service"
).toArray(),
providers
assertThat(
providers,
containsInAnyOrder(
List.of(
"alibabacloud-ai-search",
"amazonbedrock",
"anthropic",
"azureaistudio",
"azureopenai",
"cohere",
"deepseek",
"googleaistudio",
"openai",
"streaming_completion_test_service"
).toArray()
)
);
}
@SuppressWarnings("unchecked")
public void testGetServicesWithChatCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
assertThat(services.size(), equalTo(4));
String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
providers[i] = (String) serviceConfig.get("service");
}
var providers = providers(services);
assertArrayEquals(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray(), providers);
assertThat(providers, containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray()));
}
@SuppressWarnings("unchecked")
public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.SPARSE_EMBEDDING);
assertThat(services.size(), equalTo(6));
String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
providers[i] = (String) serviceConfig.get("service");
}
var providers = providers(services);
assertArrayEquals(
List.of(
"alibabacloud-ai-search",
"elastic",
"elasticsearch",
"hugging_face",
"streaming_completion_test_service",
"test_service"
).toArray(),
providers
assertThat(
providers,
containsInAnyOrder(
List.of(
"alibabacloud-ai-search",
"elastic",
"elasticsearch",
"hugging_face",
"streaming_completion_test_service",
"test_service"
).toArray()
)
);
}

View file

@ -36,6 +36,7 @@ module org.elasticsearch.inference {
requires org.elasticsearch.logging;
requires org.elasticsearch.sslconfig;
requires org.apache.commons.text;
requires software.amazon.awssdk.services.sagemakerruntime;
exports org.elasticsearch.xpack.inference.action;
exports org.elasticsearch.xpack.inference.registry;

View file

@ -92,6 +92,8 @@ import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCo
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
@ -157,6 +159,8 @@ public class InferenceNamedWriteablesProvider {
namedWriteables.addAll(StreamingTaskManager.namedWriteables());
namedWriteables.addAll(DeepSeekChatCompletionModel.namedWriteables());
namedWriteables.addAll(SageMakerModel.namedWriteables());
namedWriteables.addAll(SageMakerSchemas.namedWriteables());
return namedWriteables;
}

View file

@ -21,6 +21,7 @@ import org.elasticsearch.common.settings.IndexScopedSettings;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.settings.SettingsFilter;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.features.NodeFeature;
@ -132,6 +133,11 @@ import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService;
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService;
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerService;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerConfiguration;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIService;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
@ -294,6 +300,8 @@ public class InferencePlugin extends Plugin
services.threadPool()
);
var sageMakerSchemas = new SageMakerSchemas();
var sageMakerConfigurations = new LazyInitializable<>(new SageMakerConfiguration(sageMakerSchemas));
inferenceServices.add(
() -> List.of(
context -> new ElasticInferenceService(
@ -302,6 +310,16 @@ public class InferencePlugin extends Plugin
inferenceServiceSettings,
modelRegistry.get(),
authorizationHandler
),
context -> new SageMakerService(
new SageMakerModelBuilder(sageMakerSchemas),
new SageMakerClient(
new SageMakerClient.Factory(new HttpSettings(settings, services.clusterService())),
services.threadPool()
),
sageMakerSchemas,
services.threadPool(),
sageMakerConfigurations::getOrCompute
)
)
);

View file

@ -23,11 +23,12 @@ import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredSecureString;
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.ACCESS_KEY_FIELD;
@ -134,33 +135,39 @@ public class AwsSecretSettings implements SecretSettings {
}
private static final LazyInitializable<Map<String, SettingsConfiguration>, RuntimeException> configuration =
new LazyInitializable<>(() -> {
var configurationMap = new HashMap<String, SettingsConfiguration>();
configurationMap.put(
ACCESS_KEY_FIELD,
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)).setDescription(
"A valid AWS access key that has permissions to use Amazon Bedrock."
)
.setLabel("Access Key")
.setRequired(true)
.setSensitive(true)
.setUpdatable(true)
.setType(SettingsConfigurationFieldType.STRING)
.build()
);
configurationMap.put(
SECRET_KEY_FIELD,
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)).setDescription(
"A valid AWS secret key that is paired with the access_key."
)
.setLabel("Secret Key")
.setRequired(true)
.setSensitive(true)
.setUpdatable(true)
.setType(SettingsConfigurationFieldType.STRING)
.build()
);
return Collections.unmodifiableMap(configurationMap);
});
new LazyInitializable<>(
() -> configuration(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)).collect(
Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)
)
);
}
public static Stream<Map.Entry<String, SettingsConfiguration>> configuration(EnumSet<TaskType> supportedTaskTypes) {
return Stream.of(
Map.entry(
ACCESS_KEY_FIELD,
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
"A valid AWS access key that has permissions to use Amazon Bedrock."
)
.setLabel("Access Key")
.setRequired(true)
.setSensitive(true)
.setUpdatable(true)
.setType(SettingsConfigurationFieldType.STRING)
.build()
),
Map.entry(
SECRET_KEY_FIELD,
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
"A valid AWS secret key that is paired with the access_key."
)
.setLabel("Secret Key")
.setRequired(true)
.setSensitive(true)
.setUpdatable(true)
.setType(SettingsConfigurationFieldType.STRING)
.build()
)
);
}
}

View file

@ -14,6 +14,7 @@ import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.TimeValue;
import java.time.Duration;
import java.util.List;
import java.util.Objects;
@ -55,6 +56,10 @@ public class HttpSettings {
return connectionTimeout;
}
public Duration connectionTimeoutDuration() {
return Duration.ofMillis(connectionTimeout);
}
private void setMaxResponseSize(ByteSizeValue maxResponseSize) {
this.maxResponseSize = maxResponseSize;
}

View file

@ -0,0 +1,251 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
import software.amazon.awssdk.profiles.ProfileFile;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sagemakerruntime.SageMakerRuntimeAsyncClient;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointRequest;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamRequest;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponse;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponseHandler;
import software.amazon.awssdk.services.sagemakerruntime.model.ResponseStream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.SpecialPermission;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.action.support.ListenerTimeouts;
import org.elasticsearch.common.cache.Cache;
import org.elasticsearch.common.cache.CacheBuilder;
import org.elasticsearch.common.cache.CacheLoader;
import org.elasticsearch.common.util.concurrent.FutureUtils;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
import org.elasticsearch.xpack.inference.external.http.HttpSettings;
import org.reactivestreams.FlowAdapters;
import java.io.Closeable;
import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;
public class SageMakerClient implements Closeable {
private static final Logger log = LogManager.getLogger(SageMakerClient.class);
private final Cache<RegionAndSecrets, SageMakerRuntimeAsyncClient> existingClients = CacheBuilder.<
RegionAndSecrets,
SageMakerRuntimeAsyncClient>builder()
.removalListener(removal -> removal.getValue().close())
.setExpireAfterAccess(TimeValue.timeValueMinutes(15))
.build();
private final CacheLoader<RegionAndSecrets, SageMakerRuntimeAsyncClient> clientFactory;
private final ThreadPool threadPool;
public SageMakerClient(CacheLoader<RegionAndSecrets, SageMakerRuntimeAsyncClient> clientFactory, ThreadPool threadPool) {
this.clientFactory = clientFactory;
this.threadPool = threadPool;
}
public void invoke(
RegionAndSecrets regionAndSecrets,
InvokeEndpointRequest request,
TimeValue timeout,
ActionListener<InvokeEndpointResponse> listener
) {
SageMakerRuntimeAsyncClient asyncClient;
try {
asyncClient = existingClients.computeIfAbsent(regionAndSecrets, clientFactory);
} catch (ExecutionException e) {
listener.onFailure(clientFailure(regionAndSecrets, e));
return;
}
var contextPreservingListener = new ContextPreservingActionListener<>(
threadPool.getThreadContext().newRestorableContext(false),
listener
);
var awsFuture = asyncClient.invokeEndpoint(request);
var timeoutListener = ListenerTimeouts.wrapWithTimeout(
threadPool,
timeout,
threadPool.executor(UTILITY_THREAD_POOL_NAME),
contextPreservingListener,
ignored -> {
FutureUtils.cancel(awsFuture);
contextPreservingListener.onFailure(
new ElasticsearchStatusException("Request timed out after [{}]", RestStatus.REQUEST_TIMEOUT, timeout)
);
}
);
awsFuture.thenAcceptAsync(timeoutListener::onResponse, threadPool.executor(UTILITY_THREAD_POOL_NAME))
.exceptionallyAsync(t -> failAndMaybeThrowError(t, timeoutListener), threadPool.executor(UTILITY_THREAD_POOL_NAME));
}
private static Exception clientFailure(RegionAndSecrets regionAndSecrets, Exception cause) {
return new ElasticsearchStatusException(
"failed to create SageMakerRuntime client for region [{}]",
RestStatus.INTERNAL_SERVER_ERROR,
cause,
regionAndSecrets.region()
);
}
private Void failAndMaybeThrowError(Throwable t, ActionListener<?> listener) {
if (t instanceof CompletionException ce) {
t = ce.getCause();
}
if (t instanceof Exception e) {
listener.onFailure(e);
} else {
ExceptionsHelper.maybeError(t).ifPresent(ExceptionsHelper::maybeDieOnAnotherThread);
log.atWarn().withThrowable(t).log("Unknown failure calling SageMaker.");
listener.onFailure(new RuntimeException("Unknown failure calling SageMaker.", t));
}
return null; // Void
}
public void invokeStream(
RegionAndSecrets regionAndSecrets,
InvokeEndpointWithResponseStreamRequest request,
TimeValue timeout,
ActionListener<SageMakerStream> listener
) {
SageMakerRuntimeAsyncClient asyncClient;
try {
asyncClient = existingClients.computeIfAbsent(regionAndSecrets, clientFactory);
} catch (ExecutionException e) {
listener.onFailure(clientFailure(regionAndSecrets, e));
return;
}
var contextPreservingListener = new ContextPreservingActionListener<>(
threadPool.getThreadContext().newRestorableContext(false),
listener
);
var responseStreamProcessor = new SageMakerStreamingResponseProcessor();
var cancelAwsRequestListener = new AtomicReference<CompletableFuture<?>>();
var timeoutListener = ListenerTimeouts.wrapWithTimeout(
threadPool,
timeout,
threadPool.executor(UTILITY_THREAD_POOL_NAME),
contextPreservingListener,
ignored -> {
FutureUtils.cancel(cancelAwsRequestListener.get());
contextPreservingListener.onFailure(
new ElasticsearchStatusException("Request timed out after [{}]", RestStatus.REQUEST_TIMEOUT, timeout)
);
}
);
// To stay consistent with HTTP providers, we cancel the TimeoutListener onResponse because we are measuring the time it takes to
// start receiving bytes.
var responseStreamListener = InvokeEndpointWithResponseStreamResponseHandler.builder()
.onResponse(response -> timeoutListener.onResponse(new SageMakerStream(response, responseStreamProcessor)))
.onEventStream(publisher -> responseStreamProcessor.setPublisher(FlowAdapters.toFlowPublisher(publisher)))
.build();
var awsFuture = asyncClient.invokeEndpointWithResponseStream(request, responseStreamListener);
cancelAwsRequestListener.set(awsFuture);
awsFuture.exceptionallyAsync(t -> failAndMaybeThrowError(t, timeoutListener), threadPool.executor(UTILITY_THREAD_POOL_NAME));
}
@Override
public void close() {
existingClients.invalidateAll(); // will close each cached client
}
public record RegionAndSecrets(String region, AwsSecretSettings secretSettings) {}
public static class Factory implements CacheLoader<RegionAndSecrets, SageMakerRuntimeAsyncClient> {
private final HttpSettings httpSettings;
public Factory(HttpSettings httpSettings) {
this.httpSettings = httpSettings;
}
@Override
public SageMakerRuntimeAsyncClient load(RegionAndSecrets key) throws Exception {
SpecialPermission.check();
// TODO migrate to entitlements
return AccessController.doPrivileged((PrivilegedExceptionAction<SageMakerRuntimeAsyncClient>) () -> {
var credentials = AwsBasicCredentials.create(
key.secretSettings().accessKey().toString(),
key.secretSettings().secretKey().toString()
);
var credentialsProvider = StaticCredentialsProvider.create(credentials);
var clientConfig = NettyNioAsyncHttpClient.builder().connectionTimeout(httpSettings.connectionTimeoutDuration());
var override = ClientOverrideConfiguration.builder()
// disable profileFile, user credentials will always come from the configured Model Secrets
.defaultProfileFileSupplier(ProfileFile.aggregator()::build)
.defaultProfileFile(ProfileFile.aggregator().build())
.retryPolicy(retryPolicy -> retryPolicy.numRetries(3))
.retryStrategy(retryStrategy -> retryStrategy.maxAttempts(3))
.build();
return SageMakerRuntimeAsyncClient.builder()
.credentialsProvider(credentialsProvider)
.region(Region.of(key.region()))
.httpClientBuilder(clientConfig)
.overrideConfiguration(override)
.build();
});
}
}
private static class SageMakerStreamingResponseProcessor implements Flow.Publisher<ResponseStream> {
private static final Logger log = LogManager.getLogger(SageMakerStreamingResponseProcessor.class);
private final AtomicReference<Tuple<Flow.Publisher<ResponseStream>, Flow.Subscriber<? super ResponseStream>>> holder =
new AtomicReference<>(null);
private final AtomicBoolean subscribeCalledOnce = new AtomicBoolean(false);
@Override
public void subscribe(Flow.Subscriber<? super ResponseStream> subscriber) {
if (subscribeCalledOnce.compareAndSet(false, true) == false) {
subscriber.onError(new IllegalStateException("Subscriber already set."));
return;
}
if (holder.compareAndSet(null, Tuple.tuple(null, subscriber)) == false) {
log.debug("Subscriber connecting to publisher.");
var publisher = holder.getAndSet(null).v1();
publisher.subscribe(subscriber);
} else {
log.debug("Subscriber waiting for connection.");
}
}
private void setPublisher(Flow.Publisher<ResponseStream> publisher) {
if (holder.compareAndSet(null, Tuple.tuple(publisher, null)) == false) {
log.debug("Publisher connecting to subscriber.");
var subscriber = holder.getAndSet(null).v2();
publisher.subscribe(subscriber);
} else {
log.debug("Publisher waiting for connection.");
}
}
}
public record SageMakerStream(InvokeEndpointWithResponseStreamResponse response, Flow.Publisher<ResponseStream> responseStream) {}
}

View file

@ -0,0 +1,28 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InputType;
import java.util.List;
import java.util.Objects;
public record SageMakerInferenceRequest(
@Nullable String query,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
List<String> input,
boolean stream,
InputType inputType
) {
public SageMakerInferenceRequest {
Objects.requireNonNull(input);
Objects.requireNonNull(inputType);
}
}

View file

@ -0,0 +1,305 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.common.CheckedSupplier;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.SettingsConfiguration;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
import java.io.IOException;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails;
public class SageMakerService implements InferenceService {
public static final String NAME = "sagemaker";
private static final int DEFAULT_BATCH_SIZE = 256;
private final SageMakerModelBuilder modelBuilder;
private final SageMakerClient client;
private final SageMakerSchemas schemas;
private final ThreadPool threadPool;
private final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration;
public SageMakerService(
SageMakerModelBuilder modelBuilder,
SageMakerClient client,
SageMakerSchemas schemas,
ThreadPool threadPool,
CheckedSupplier<Map<String, SettingsConfiguration>, RuntimeException> configurationMap
) {
this.modelBuilder = modelBuilder;
this.client = client;
this.schemas = schemas;
this.threadPool = threadPool;
this.configuration = new LazyInitializable<>(
() -> new InferenceServiceConfiguration.Builder().setService(NAME)
.setName("Amazon SageMaker")
.setTaskTypes(supportedTaskTypes())
.setConfigurations(configurationMap.get())
.build()
);
}
@Override
public String name() {
return NAME;
}
@Override
public void parseRequestConfig(
String modelId,
TaskType taskType,
Map<String, Object> config,
ActionListener<Model> parsedModelListener
) {
ActionListener.completeWith(parsedModelListener, () -> modelBuilder.fromRequest(modelId, taskType, NAME, config));
}
@Override
public Model parsePersistedConfigWithSecrets(
String modelId,
TaskType taskType,
Map<String, Object> config,
Map<String, Object> secrets
) {
return modelBuilder.fromStorage(modelId, taskType, NAME, config, secrets);
}
@Override
public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config) {
return modelBuilder.fromStorage(modelId, taskType, NAME, config, null);
}
@Override
public InferenceServiceConfiguration getConfiguration() {
return configuration.getOrCompute();
}
@Override
public EnumSet<TaskType> supportedTaskTypes() {
return schemas.supportedTaskTypes();
}
@Override
public Set<TaskType> supportedStreamingTasks() {
return schemas.supportedStreamingTasks();
}
@Override
public void infer(
Model model,
@Nullable String query,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
List<String> input,
boolean stream,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
if (model instanceof SageMakerModel == false) {
listener.onFailure(createInvalidModelException(model));
return;
}
var inferenceRequest = new SageMakerInferenceRequest(query, returnDocuments, topN, input, stream, inputType);
try {
var sageMakerModel = ((SageMakerModel) model).override(taskSettings);
var regionAndSecrets = regionAndSecrets(sageMakerModel);
if (stream) {
var schema = schemas.streamSchemaFor(sageMakerModel);
var request = schema.streamRequest(sageMakerModel, inferenceRequest);
client.invokeStream(
regionAndSecrets,
request,
timeout,
ActionListener.wrap(
response -> listener.onResponse(schema.streamResponse(sageMakerModel, response)),
e -> listener.onFailure(schema.error(sageMakerModel, e))
)
);
} else {
var schema = schemas.schemaFor(sageMakerModel);
var request = schema.request(sageMakerModel, inferenceRequest);
client.invoke(
regionAndSecrets,
request,
timeout,
ActionListener.wrap(
response -> listener.onResponse(schema.response(sageMakerModel, response, threadPool.getThreadContext())),
e -> listener.onFailure(schema.error(sageMakerModel, e))
)
);
}
} catch (Exception e) {
listener.onFailure(internalFailure(model, e));
}
}
private SageMakerClient.RegionAndSecrets regionAndSecrets(SageMakerModel model) {
var secrets = model.awsSecretSettings();
if (secrets.isEmpty()) {
assert false : "Cannot invoke a model without secrets";
throw new ElasticsearchStatusException(
format("Attempting to infer using a model without API keys, inference id [%s]", model.getInferenceEntityId()),
RestStatus.INTERNAL_SERVER_ERROR
);
}
return new SageMakerClient.RegionAndSecrets(model.region(), secrets.get());
}
private static ElasticsearchStatusException internalFailure(Model model, Exception cause) {
if (cause instanceof ElasticsearchStatusException ese) {
return ese;
} else {
return new ElasticsearchStatusException(
"Failed to call SageMaker for inference id [{}].",
RestStatus.INTERNAL_SERVER_ERROR,
cause,
model.getInferenceEntityId()
);
}
}
@Override
public void unifiedCompletionInfer(
Model model,
UnifiedCompletionRequest request,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
if (model instanceof SageMakerModel == false) {
listener.onFailure(createInvalidModelException(model));
return;
}
try {
var sageMakerModel = (SageMakerModel) model;
var regionAndSecrets = regionAndSecrets(sageMakerModel);
var schema = schemas.streamSchemaFor(sageMakerModel);
var sagemakerRequest = schema.chatCompletionStreamRequest(sageMakerModel, request);
client.invokeStream(
regionAndSecrets,
sagemakerRequest,
timeout,
ActionListener.wrap(
response -> listener.onResponse(schema.chatCompletionStreamResponse(sageMakerModel, response)),
e -> listener.onFailure(schema.chatCompletionError(sageMakerModel, e))
)
);
} catch (Exception e) {
listener.onFailure(internalFailure(model, e));
}
}
@Override
public void chunkedInfer(
Model model,
String query,
List<ChunkInferenceInput> input,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
ActionListener<List<ChunkedInference>> listener
) {
if (model instanceof SageMakerModel == false) {
listener.onFailure(createInvalidModelException(model));
return;
}
try {
var sageMakerModel = ((SageMakerModel) model).override(taskSettings);
var batchedRequests = new EmbeddingRequestChunker<>(
input,
sageMakerModel.batchSize().orElse(DEFAULT_BATCH_SIZE),
sageMakerModel.getConfigurations().getChunkingSettings()
).batchRequestsWithListeners(listener);
var subscribableListener = SubscribableListener.newSucceeded(null);
for (var request : batchedRequests) {
subscribableListener = subscribableListener.andThen(
threadPool.executor(UTILITY_THREAD_POOL_NAME),
threadPool.getThreadContext(),
(l, ignored) -> infer(
sageMakerModel,
query,
null, // no return docs while chunking?
null, // no topN while chunking?
request.batch().inputs().get(),
false, // we never stream when chunking
null, // since we pass sageMakerModel as the model, we already overwrote the model with the task settings
inputType,
timeout,
ActionListener.runAfter(request.listener(), () -> l.onResponse(null))
)
);
}
// if there were any errors trying to create the SubscribableListener chain, then forward that to the listener
// otherwise, BatchRequestAndListener will handle forwarding errors from the infer method
subscribableListener.addListener(
ActionListener.noop().delegateResponse((l, e) -> listener.onFailure(internalFailure(model, e)))
);
} catch (Exception e) {
listener.onFailure(internalFailure(model, e));
}
}
@Override
public void start(Model model, TimeValue timeout, ActionListener<Boolean> listener) {
listener.onResponse(true);
}
@Override
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
if (model instanceof SageMakerModel sageMakerModel) {
return modelBuilder.updateModelWithEmbeddingDetails(sageMakerModel, embeddingSize);
}
throw invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_SAGEMAKER;
}
@Override
public void close() throws IOException {
client.close();
}
}

View file

@ -0,0 +1,35 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker.model;
import org.elasticsearch.common.CheckedSupplier;
import org.elasticsearch.inference.SettingsConfiguration;
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class SageMakerConfiguration implements CheckedSupplier<Map<String, SettingsConfiguration>, RuntimeException> {
private final SageMakerSchemas schemas;
public SageMakerConfiguration(SageMakerSchemas schemas) {
this.schemas = schemas;
}
@Override
public Map<String, SettingsConfiguration> get() {
return Stream.of(
AwsSecretSettings.configuration(schemas.supportedTaskTypes()),
SageMakerServiceSettings.configuration(schemas.supportedTaskTypes()),
SageMakerTaskSettings.configuration(schemas.supportedTaskTypes())
).flatMap(Function.identity()).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}
}

View file

@ -0,0 +1,146 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker.model;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema;
import java.util.List;
import java.util.Map;
import java.util.Optional;
/**
* This model represents all models in SageMaker. SageMaker maintains a base set of settings and configurations, and this model manages
* those. Any settings that are required for a specific model are stored in the {@link SageMakerStoredServiceSchema} and
* {@link SageMakerStoredTaskSchema}.
* Design:
* - Region is stored in ServiceSettings and is used to create the SageMaker client.
* - RateLimiting is based on AWS Service Quota, metered by account and region. The SDK client handles rate limiting internally. In order to
* rate limit appropriately, use the same AWS Access ID, Secret Key, and Region for each inference endpoint calling the same AWS account.
* - Within Elastic, you cannot change the model behind the endpoint, because we require the embedding shape remain consistent between
* invocations. So anything "model selection" related must remain consistent through the lifetime of the Elastic endpoint. SageMaker
* allows model changes via TargetModel, InferenceComponentName, and TargetContainerHostname, but these will remain static in Elastic
* within ServiceSettings.
* - CustomAttributes, EnableExplanations, InferenceId, SessionId, and TargetVariant are all request-time fields that can be saved.
* - SageMaker returns 4 headers, which Elastic will forward as-is: x-Amzn-Invoked-Production-Variant, X-Amzn-SageMaker-Custom-Attributes,
* X-Amzn-SageMaker-New-Session-Id, X-Amzn-SageMaker-Closed-Session-Id
*/
public class SageMakerModel extends Model {
private final SageMakerServiceSettings serviceSettings;
private final SageMakerTaskSettings taskSettings;
private final AwsSecretSettings awsSecretSettings;
SageMakerModel(
ModelConfigurations configurations,
ModelSecrets secrets,
SageMakerServiceSettings serviceSettings,
SageMakerTaskSettings taskSettings,
AwsSecretSettings awsSecretSettings
) {
super(configurations, secrets);
this.serviceSettings = serviceSettings;
this.taskSettings = taskSettings;
this.awsSecretSettings = awsSecretSettings;
}
public Optional<AwsSecretSettings> awsSecretSettings() {
return Optional.ofNullable(awsSecretSettings);
}
public String region() {
return serviceSettings.region();
}
public String endpointName() {
return serviceSettings.endpointName();
}
public String api() {
return serviceSettings.api();
}
public Optional<String> customAttributes() {
return Optional.ofNullable(taskSettings.customAttributes());
}
public Optional<String> enableExplanations() {
return Optional.ofNullable(taskSettings.enableExplanations());
}
public Optional<String> inferenceComponentName() {
return Optional.ofNullable(serviceSettings.inferenceComponentName());
}
public Optional<String> inferenceIdForDataCapture() {
return Optional.ofNullable(taskSettings.inferenceIdForDataCapture());
}
public Optional<String> sessionId() {
return Optional.ofNullable(taskSettings.sessionId());
}
public Optional<String> targetContainerHostname() {
return Optional.ofNullable(serviceSettings.targetContainerHostname());
}
public Optional<String> targetModel() {
return Optional.ofNullable(serviceSettings.targetModel());
}
public Optional<String> targetVariant() {
return Optional.ofNullable(taskSettings.targetVariant());
}
public Optional<Integer> batchSize() {
return Optional.ofNullable(serviceSettings.batchSize());
}
public SageMakerModel override(Map<String, Object> taskSettingsOverride) {
if (taskSettingsOverride == null || taskSettingsOverride.isEmpty()) {
return this;
}
return new SageMakerModel(
getConfigurations(),
getSecrets(),
serviceSettings,
taskSettings.updatedTaskSettings(taskSettingsOverride),
awsSecretSettings
);
}
public static List<NamedWriteableRegistry.Entry> namedWriteables() {
return List.of(
new NamedWriteableRegistry.Entry(ServiceSettings.class, SageMakerServiceSettings.NAME, SageMakerServiceSettings::new),
new NamedWriteableRegistry.Entry(TaskSettings.class, SageMakerTaskSettings.NAME, SageMakerTaskSettings::new)
);
}
public SageMakerStoredServiceSchema apiServiceSettings() {
return serviceSettings.apiServiceSettings();
}
public SageMakerStoredTaskSchema apiTaskSettings() {
return taskSettings.apiTaskSettings();
}
SageMakerServiceSettings serviceSettings() {
return serviceSettings;
}
SageMakerTaskSettings taskSettings() {
return taskSettings;
}
}

View file

@ -0,0 +1,135 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker.model;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
import java.util.Map;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
public class SageMakerModelBuilder {
private final SageMakerSchemas schemas;
public SageMakerModelBuilder(SageMakerSchemas schemas) {
this.schemas = schemas;
}
public SageMakerModel fromRequest(String inferenceEntityId, TaskType taskType, String service, Map<String, Object> requestMap) {
var validationException = new ValidationException();
var serviceSettingsMap = removeFromMapOrThrowIfNull(requestMap, ModelConfigurations.SERVICE_SETTINGS);
var awsSecretSettings = AwsSecretSettings.fromMap(serviceSettingsMap);
var serviceSettings = SageMakerServiceSettings.fromMap(schemas, taskType, serviceSettingsMap);
var schema = schemas.schemaFor(taskType, serviceSettings.api());
var taskSettingsMap = removeFromMapOrDefaultEmpty(requestMap, ModelConfigurations.TASK_SETTINGS);
var taskSettings = SageMakerTaskSettings.fromMap(
taskSettingsMap,
schema.apiTaskSettings(taskSettingsMap, validationException),
validationException
);
validationException.throwIfValidationErrorsExist();
throwIfNotEmptyMap(serviceSettingsMap, service);
throwIfNotEmptyMap(taskSettingsMap, service);
var modelConfigurations = new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings);
return new SageMakerModel(
modelConfigurations,
new ModelSecrets(awsSecretSettings),
serviceSettings,
taskSettings,
awsSecretSettings
);
}
public SageMakerModel fromStorage(
String inferenceEntityId,
TaskType taskType,
String service,
Map<String, Object> config,
Map<String, Object> secrets
) {
var validationException = new ValidationException();
var awsSecretSettings = secrets != null
? AwsSecretSettings.fromMap(removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS))
: null;
var serviceSettings = SageMakerServiceSettings.fromMap(
schemas,
taskType,
removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS)
);
var schema = schemas.schemaFor(taskType, serviceSettings.api());
var taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
var taskSettings = SageMakerTaskSettings.fromMap(
taskSettingsMap,
schema.apiTaskSettings(taskSettingsMap, validationException),
validationException
);
validationException.throwIfValidationErrorsExist();
throwIfNotEmptyMap(config, service);
throwIfNotEmptyMap(taskSettingsMap, service);
throwIfNotEmptyMap(secrets, service);
var modelConfigurations = new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings);
return new SageMakerModel(
modelConfigurations,
new ModelSecrets(awsSecretSettings),
serviceSettings,
taskSettings,
awsSecretSettings
);
}
public SageMakerModel updateModelWithEmbeddingDetails(SageMakerModel model, int embeddingSize) {
var updatedApiServiceSettings = model.apiServiceSettings().updateModelWithEmbeddingDetails(embeddingSize);
if (updatedApiServiceSettings == model.apiServiceSettings()) {
return model;
}
var updatedServiceSettings = new SageMakerServiceSettings(
model.serviceSettings().endpointName(),
model.serviceSettings().region(),
model.serviceSettings().api(),
model.serviceSettings().targetModel(),
model.serviceSettings().targetContainerHostname(),
model.serviceSettings().inferenceComponentName(),
model.serviceSettings().batchSize(),
updatedApiServiceSettings
);
var modelConfigurations = new ModelConfigurations(
model.getInferenceEntityId(),
model.getTaskType(),
model.getConfigurations().getService(),
updatedServiceSettings,
model.taskSettings()
);
return new SageMakerModel(
modelConfigurations,
model.getSecrets(),
updatedServiceSettings,
model.taskSettings(),
model.awsSecretSettings().orElse(null)
);
}
}

View file

@ -0,0 +1,285 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker.model;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.SettingsConfiguration;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema;
import java.io.IOException;
import java.util.EnumSet;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Stream;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
/**
* Maintains the settings for SageMaker that cannot be changed without impacting semantic search and AI assistants.
* Model-specific settings are stored in {@link SageMakerStoredServiceSchema}.
*/
record SageMakerServiceSettings(
String endpointName,
String region,
String api,
@Nullable String targetModel,
@Nullable String targetContainerHostname,
@Nullable String inferenceComponentName,
@Nullable Integer batchSize,
SageMakerStoredServiceSchema apiServiceSettings
) implements ServiceSettings {
static final String NAME = "sage_maker_service_settings";
private static final String API = "api";
private static final String ENDPOINT_NAME = "endpoint_name";
private static final String REGION = "region";
private static final String TARGET_MODEL = "target_model";
private static final String TARGET_CONTAINER_HOSTNAME = "target_container_hostname";
private static final String INFERENCE_COMPONENT_NAME = "inference_component_name";
private static final String BATCH_SIZE = "batch_size";
SageMakerServiceSettings {
Objects.requireNonNull(endpointName);
Objects.requireNonNull(region);
Objects.requireNonNull(api);
Objects.requireNonNull(apiServiceSettings);
}
SageMakerServiceSettings(StreamInput in) throws IOException {
this(
in.readString(),
in.readString(),
in.readString(),
in.readOptionalString(),
in.readOptionalString(),
in.readOptionalString(),
in.readOptionalInt(),
in.readNamedWriteable(SageMakerStoredServiceSchema.class)
);
}
@Override
public String modelId() {
return apiServiceSettings.modelId();
}
@Override
public SimilarityMeasure similarity() {
return apiServiceSettings.similarity();
}
@Override
public Integer dimensions() {
return apiServiceSettings.dimensions();
}
@Override
public Boolean dimensionsSetByUser() {
return apiServiceSettings.dimensionsSetByUser();
}
@Override
public DenseVectorFieldMapper.ElementType elementType() {
return apiServiceSettings.elementType();
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_SAGEMAKER;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(endpointName());
out.writeString(region());
out.writeString(api());
out.writeOptionalString(targetModel());
out.writeOptionalString(targetContainerHostname());
out.writeOptionalString(inferenceComponentName());
out.writeOptionalInt(batchSize());
out.writeNamedWriteable(apiServiceSettings);
}
@Override
public ToXContentObject getFilteredXContentObject() {
return this;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ENDPOINT_NAME, endpointName());
builder.field(REGION, region());
builder.field(API, api());
optionalField(TARGET_MODEL, targetModel(), builder);
optionalField(TARGET_CONTAINER_HOSTNAME, targetContainerHostname(), builder);
optionalField(INFERENCE_COMPONENT_NAME, inferenceComponentName(), builder);
optionalField(BATCH_SIZE, batchSize(), builder);
apiServiceSettings.toXContent(builder, params);
return builder.endObject();
}
private static <T> void optionalField(String name, T value, XContentBuilder builder) throws IOException {
if (value != null) {
builder.field(name, value);
}
}
static SageMakerServiceSettings fromMap(SageMakerSchemas schemas, TaskType taskType, Map<String, Object> serviceSettingsMap) {
ValidationException validationException = new ValidationException();
var endpointName = extractRequiredString(
serviceSettingsMap,
ENDPOINT_NAME,
ModelConfigurations.SERVICE_SETTINGS,
validationException
);
var region = extractRequiredString(serviceSettingsMap, REGION, ModelConfigurations.SERVICE_SETTINGS, validationException);
var api = extractRequiredString(serviceSettingsMap, API, ModelConfigurations.SERVICE_SETTINGS, validationException);
var targetModel = extractOptionalString(
serviceSettingsMap,
TARGET_MODEL,
ModelConfigurations.SERVICE_SETTINGS,
validationException
);
var targetContainerHostname = extractOptionalString(
serviceSettingsMap,
TARGET_CONTAINER_HOSTNAME,
ModelConfigurations.SERVICE_SETTINGS,
validationException
);
var inferenceComponentName = extractOptionalString(
serviceSettingsMap,
INFERENCE_COMPONENT_NAME,
ModelConfigurations.SERVICE_SETTINGS,
validationException
);
var batchSize = extractOptionalPositiveInteger(
serviceSettingsMap,
BATCH_SIZE,
ModelConfigurations.SERVICE_SETTINGS,
validationException
);
validationException.throwIfValidationErrorsExist();
var schema = schemas.schemaFor(taskType, api);
var apiServiceSettings = schema.apiServiceSettings(serviceSettingsMap, validationException);
validationException.throwIfValidationErrorsExist();
return new SageMakerServiceSettings(
endpointName,
region,
api,
targetModel,
targetContainerHostname,
inferenceComponentName,
batchSize,
apiServiceSettings
);
}
static Stream<Map.Entry<String, SettingsConfiguration>> configuration(EnumSet<TaskType> supportedTaskTypes) {
return Stream.of(
Map.entry(
API,
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("The API format that your SageMaker Endpoint expects.")
.setLabel("API")
.setRequired(true)
.setSensitive(false)
.setUpdatable(false)
.setType(SettingsConfigurationFieldType.STRING)
.build()
),
Map.entry(
ENDPOINT_NAME,
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
"The name specified when creating the SageMaker Endpoint."
)
.setLabel("Endpoint Name")
.setRequired(true)
.setSensitive(false)
.setUpdatable(false)
.setType(SettingsConfigurationFieldType.STRING)
.build()
),
Map.entry(
REGION,
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
"The AWS region that your model or ARN is deployed in."
)
.setLabel("Region")
.setRequired(true)
.setSensitive(false)
.setUpdatable(false)
.setType(SettingsConfigurationFieldType.STRING)
.build()
),
Map.entry(
TARGET_MODEL,
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
"The model to request when calling a SageMaker multi-model Endpoint."
)
.setLabel("Target Model")
.setRequired(false)
.setSensitive(false)
.setUpdatable(false)
.setType(SettingsConfigurationFieldType.STRING)
.build()
),
Map.entry(
TARGET_CONTAINER_HOSTNAME,
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
"The hostname of the container when calling a SageMaker multi-container Endpoint."
)
.setLabel("Target Container Hostname")
.setRequired(false)
.setSensitive(false)
.setUpdatable(false)
.setType(SettingsConfigurationFieldType.STRING)
.build()
),
Map.entry(
BATCH_SIZE,
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
"The maximum size a single chunk of input can be when chunking input for semantic text."
)
.setLabel("Batch Size")
.setRequired(false)
.setSensitive(false)
.setUpdatable(false)
.setType(SettingsConfigurationFieldType.INTEGER)
.build()
)
);
}
}

View file

@ -0,0 +1,236 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker.model;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.SettingsConfiguration;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema;
import java.io.IOException;
import java.util.EnumSet;
import java.util.Map;
import java.util.stream.Stream;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
/**
* Maintains mutable settings for SageMaker. Model-specific settings are stored in {@link SageMakerStoredTaskSchema}.
*/
record SageMakerTaskSettings(
@Nullable String customAttributes,
@Nullable String enableExplanations,
@Nullable String inferenceIdForDataCapture,
@Nullable String sessionId,
@Nullable String targetVariant,
SageMakerStoredTaskSchema apiTaskSettings
) implements TaskSettings {
static final String NAME = "sage_maker_task_settings";
private static final String CUSTOM_ATTRIBUTES = "custom_attributes";
private static final String ENABLE_EXPLANATIONS = "enable_explanations";
private static final String INFERENCE_ID = "inference_id";
private static final String SESSION_ID = "session_id";
private static final String TARGET_VARIANT = "target_variant";
SageMakerTaskSettings(StreamInput in) throws IOException {
this(
in.readOptionalString(),
in.readOptionalString(),
in.readOptionalString(),
in.readOptionalString(),
in.readOptionalString(),
in.readNamedWriteable(SageMakerStoredTaskSchema.class)
);
}
@Override
public boolean isEmpty() {
return customAttributes == null
&& enableExplanations == null
&& inferenceIdForDataCapture == null
&& sessionId == null
&& targetVariant == null
&& SageMakerStoredTaskSchema.NO_OP.equals(apiTaskSettings);
}
@Override
public SageMakerTaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
var validationException = new ValidationException();
var updateTaskSettings = fromMap(newSettings, apiTaskSettings.updatedTaskSettings(newSettings), validationException);
validationException.throwIfValidationErrorsExist();
var updatedExtraTaskSettings = updateTaskSettings.apiTaskSettings().equals(SageMakerStoredTaskSchema.NO_OP)
? apiTaskSettings
: updateTaskSettings.apiTaskSettings();
return new SageMakerTaskSettings(
firstNotNullOrNull(updateTaskSettings.customAttributes(), customAttributes),
firstNotNullOrNull(updateTaskSettings.enableExplanations(), enableExplanations),
firstNotNullOrNull(updateTaskSettings.inferenceIdForDataCapture(), inferenceIdForDataCapture),
firstNotNullOrNull(updateTaskSettings.sessionId(), sessionId),
firstNotNullOrNull(updateTaskSettings.targetVariant(), targetVariant),
updatedExtraTaskSettings
);
}
private static <T> T firstNotNullOrNull(T first, T second) {
return first != null ? first : second;
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_SAGEMAKER;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(customAttributes);
out.writeOptionalString(enableExplanations);
out.writeOptionalString(inferenceIdForDataCapture);
out.writeOptionalString(sessionId);
out.writeOptionalString(targetVariant);
out.writeNamedWriteable(apiTaskSettings);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
optionalField(CUSTOM_ATTRIBUTES, customAttributes, builder);
optionalField(ENABLE_EXPLANATIONS, enableExplanations, builder);
optionalField(INFERENCE_ID, inferenceIdForDataCapture, builder);
optionalField(SESSION_ID, sessionId, builder);
optionalField(TARGET_VARIANT, targetVariant, builder);
apiTaskSettings.toXContent(builder, params);
return builder.endObject();
}
private static <T> void optionalField(String name, T value, XContentBuilder builder) throws IOException {
if (value != null) {
builder.field(name, value);
}
}
public static SageMakerTaskSettings fromMap(
Map<String, Object> taskSettingsMap,
SageMakerStoredTaskSchema apiTaskSettings,
ValidationException validationException
) {
var customAttributes = extractOptionalString(
taskSettingsMap,
CUSTOM_ATTRIBUTES,
ModelConfigurations.TASK_SETTINGS,
validationException
);
var enableExplanations = extractOptionalString(
taskSettingsMap,
ENABLE_EXPLANATIONS,
ModelConfigurations.TASK_SETTINGS,
validationException
);
var inferenceIdForDataCapture = extractOptionalString(
taskSettingsMap,
INFERENCE_ID,
ModelConfigurations.TASK_SETTINGS,
validationException
);
var sessionId = extractOptionalString(taskSettingsMap, SESSION_ID, ModelConfigurations.TASK_SETTINGS, validationException);
var targetVariant = extractOptionalString(taskSettingsMap, TARGET_VARIANT, ModelConfigurations.TASK_SETTINGS, validationException);
return new SageMakerTaskSettings(
customAttributes,
enableExplanations,
inferenceIdForDataCapture,
sessionId,
targetVariant,
apiTaskSettings
);
}
static Stream<Map.Entry<String, SettingsConfiguration>> configuration(EnumSet<TaskType> supportedTaskTypes) {
return Stream.of(
Map.entry(
CUSTOM_ATTRIBUTES,
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
"An opaque informational value forwarded as-is to the model within SageMaker."
)
.setLabel("Custom Attributes")
.setRequired(false)
.setSensitive(false)
.setUpdatable(false)
.setType(SettingsConfigurationFieldType.STRING)
.build()
),
Map.entry(
ENABLE_EXPLANATIONS,
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
"JMESPath expression overriding the ClarifyingExplainerConfig in the SageMaker Endpoint Configuration."
)
.setLabel("Enable Explanations")
.setRequired(false)
.setSensitive(false)
.setUpdatable(false)
.setType(SettingsConfigurationFieldType.STRING)
.build()
),
Map.entry(
INFERENCE_ID,
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
"Informational identifying for auditing requests within the SageMaker Endpoint."
)
.setLabel("Inference ID")
.setRequired(false)
.setSensitive(false)
.setUpdatable(false)
.setType(SettingsConfigurationFieldType.STRING)
.build()
),
Map.entry(
SESSION_ID,
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
"Creates or reuses an existing Session for SageMaker stateful models."
)
.setLabel("Session ID")
.setRequired(false)
.setSensitive(false)
.setUpdatable(false)
.setType(SettingsConfigurationFieldType.STRING)
.build()
),
Map.entry(
TARGET_VARIANT,
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
"The production variant when calling the SageMaker Endpoint"
)
.setLabel("Target Variant")
.setRequired(false)
.setSensitive(false)
.setUpdatable(false)
.setType(SettingsConfigurationFieldType.STRING)
.build()
)
);
}
}

View file

@ -0,0 +1,215 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
import software.amazon.awssdk.services.sagemakerruntime.model.InternalDependencyException;
import software.amazon.awssdk.services.sagemakerruntime.model.InternalFailureException;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointRequest;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
import software.amazon.awssdk.services.sagemakerruntime.model.ModelErrorException;
import software.amazon.awssdk.services.sagemakerruntime.model.ModelNotReadyException;
import software.amazon.awssdk.services.sagemakerruntime.model.SageMakerRuntimeException;
import software.amazon.awssdk.services.sagemakerruntime.model.ServiceUnavailableException;
import software.amazon.awssdk.services.sagemakerruntime.model.ValidationErrorException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import java.util.Map;
import java.util.stream.Stream;
import static org.elasticsearch.core.Strings.format;
/**
* All the logic that is required to call any SageMaker model is handled within this Schema class.
* Any model-specific logic is handled within the associated {@link SageMakerSchemaPayload}.
* This schema is specific for SageMaker's non-streaming API. For streaming, see {@link SageMakerStreamSchema}.
*/
public class SageMakerSchema {
private static final String CUSTOM_ATTRIBUTES_HEADER = "X-elastic-sagemaker-custom-attributes";
private static final String NEW_SESSION_HEADER = "X-elastic-sagemaker-new-session-id";
private static final String CLOSED_SESSION_HEADER = "X-elastic-sagemaker-closed-session-id";
private static final String ACCESS_DENIED_CODE = "AccessDeniedException";
private static final String INCOMPLETE_SIGNATURE = "IncompleteSignature";
private static final String INVALID_ACTION = "InvalidAction";
private static final String INVALID_CLIENT_TOKEN = "InvalidClientTokenId";
private static final String NOT_AUTHORIZED = "NotAuthorized";
private static final String OPT_IN_REQUIRED = "OptInRequired";
private static final String REQUEST_EXPIRED = "RequestExpired";
private static final String THROTTLING_EXCEPTION = "ThrottlingException";
private final SageMakerSchemaPayload schemaPayload;
public SageMakerSchema(SageMakerSchemaPayload schemaPayload) {
this.schemaPayload = schemaPayload;
}
public InvokeEndpointRequest request(SageMakerModel model, SageMakerInferenceRequest request) {
try {
return createRequest(model).accept(schemaPayload.accept(model))
.contentType(schemaPayload.contentType(model))
.body(schemaPayload.requestBytes(model, request))
.build();
} catch (ElasticsearchStatusException e) {
throw e;
} catch (Exception e) {
throw new ElasticsearchStatusException(
"Failed to create SageMaker request for [%s]",
RestStatus.INTERNAL_SERVER_ERROR,
e,
model.getInferenceEntityId()
);
}
}
private InvokeEndpointRequest.Builder createRequest(SageMakerModel model) {
var request = InvokeEndpointRequest.builder();
request.endpointName(model.endpointName());
model.customAttributes().ifPresent(request::customAttributes);
model.enableExplanations().ifPresent(request::enableExplanations);
model.inferenceComponentName().ifPresent(request::inferenceComponentName);
model.inferenceIdForDataCapture().ifPresent(request::inferenceId);
model.sessionId().ifPresent(request::sessionId);
model.targetContainerHostname().ifPresent(request::targetContainerHostname);
model.targetModel().ifPresent(request::targetModel);
model.targetVariant().ifPresent(request::targetVariant);
return request;
}
public InferenceServiceResults response(SageMakerModel model, InvokeEndpointResponse response, ThreadContext threadContext)
throws Exception {
try {
addHeaders(response, threadContext);
return schemaPayload.responseBody(model, response);
} catch (ElasticsearchStatusException e) {
throw e;
} catch (Exception e) {
throw new ElasticsearchStatusException(
"Failed to translate SageMaker response for [%s]",
RestStatus.INTERNAL_SERVER_ERROR,
e,
model.getInferenceEntityId()
);
}
}
private void addHeaders(InvokeEndpointResponse response, ThreadContext threadContext) {
if (response.customAttributes() != null) {
threadContext.addResponseHeader(CUSTOM_ATTRIBUTES_HEADER, response.customAttributes());
}
if (response.newSessionId() != null) {
threadContext.addResponseHeader(NEW_SESSION_HEADER, response.newSessionId());
}
if (response.closedSessionId() != null) {
threadContext.addResponseHeader(CLOSED_SESSION_HEADER, response.closedSessionId());
}
}
public Exception error(SageMakerModel model, Exception e) {
if (e instanceof ElasticsearchStatusException ee) {
return ee;
}
var error = errorMessageAndStatus(model, e);
return new ElasticsearchStatusException(error.v1(), error.v2(), e);
}
/**
* Protected because {@link SageMakerStreamSchema} will reuse this to create a Chat Completion error message.
*/
protected Tuple<String, RestStatus> errorMessageAndStatus(SageMakerModel model, Exception e) {
String errorMessage = null;
RestStatus restStatus = null;
if (e instanceof InternalDependencyException) {
errorMessage = format("Received an internal dependency error from SageMaker for [%s]", model.getInferenceEntityId());
restStatus = RestStatus.INTERNAL_SERVER_ERROR;
} else if (e instanceof InternalFailureException) {
errorMessage = format("Received an internal failure from SageMaker for [%s]", model.getInferenceEntityId());
restStatus = RestStatus.INTERNAL_SERVER_ERROR;
} else if (e instanceof ModelErrorException) {
errorMessage = format("Received a model error from SageMaker for [%s]", model.getInferenceEntityId());
restStatus = RestStatus.FAILED_DEPENDENCY;
} else if (e instanceof ModelNotReadyException) {
errorMessage = format("Received a model not ready error from SageMaker for [%s]", model.getInferenceEntityId());
restStatus = RestStatus.TOO_MANY_REQUESTS;
} else if (e instanceof ServiceUnavailableException) {
errorMessage = format("SageMaker is unavailable for [%s]", model.getInferenceEntityId());
restStatus = RestStatus.SERVICE_UNAVAILABLE;
} else if (e instanceof ValidationErrorException) {
errorMessage = format("Received a validation error from SageMaker for [%s]", model.getInferenceEntityId());
restStatus = RestStatus.BAD_REQUEST;
}
// if we have a SageMakerRuntimeException that isn't one of the child exceptions, we can parse the error from AwsErrorDetails
// https://docs.aws.amazon.com/sagemaker/latest/APIReference/CommonErrors.html
if (errorMessage == null && e instanceof SageMakerRuntimeException re && re.awsErrorDetails() != null) {
switch (re.awsErrorDetails().errorCode()) {
case ACCESS_DENIED_CODE, NOT_AUTHORIZED -> {
errorMessage = format(
"Access and Secret key stored in [%s] do not have sufficient permissions.",
model.getInferenceEntityId()
);
restStatus = RestStatus.BAD_REQUEST;
}
case INCOMPLETE_SIGNATURE -> {
errorMessage = format("The request signature does not conform to AWS standards [%s]", model.getInferenceEntityId());
restStatus = RestStatus.INTERNAL_SERVER_ERROR; // this shouldn't happen and isn't anything the user can do about it
}
case INVALID_ACTION -> {
errorMessage = format("The requested action is not valid for [%s]", model.getInferenceEntityId());
restStatus = RestStatus.BAD_REQUEST;
}
case INVALID_CLIENT_TOKEN -> {
errorMessage = format("Access key stored in [%s] does not exist in AWS", model.getInferenceEntityId());
restStatus = RestStatus.FORBIDDEN;
}
case OPT_IN_REQUIRED -> {
errorMessage = format("Access key stored in [%s] needs a subscription for the service", model.getInferenceEntityId());
restStatus = RestStatus.FORBIDDEN;
}
case REQUEST_EXPIRED -> {
errorMessage = format(
"The request reached SageMaker more than 15 minutes after the date stamp on the request for [%s]",
model.getInferenceEntityId()
);
restStatus = RestStatus.BAD_REQUEST;
}
case THROTTLING_EXCEPTION -> {
errorMessage = format("SageMaker denied the request for [%s] due to request throttling", model.getInferenceEntityId());
restStatus = RestStatus.BAD_REQUEST;
}
}
}
if (errorMessage == null) {
errorMessage = format("Received an error from SageMaker for [%s]", model.getInferenceEntityId());
restStatus = RestStatus.INTERNAL_SERVER_ERROR;
}
return Tuple.tuple(errorMessage, restStatus);
}
public SageMakerStoredServiceSchema apiServiceSettings(Map<String, Object> serviceSettings, ValidationException validationException) {
return schemaPayload.apiServiceSettings(serviceSettings, validationException);
}
public SageMakerStoredTaskSchema apiTaskSettings(Map<String, Object> taskSettings, ValidationException validationException) {
return schemaPayload.apiTaskSettings(taskSettings, validationException);
}
public Stream<NamedWriteableRegistry.Entry> namedWriteables() {
return schemaPayload.namedWriteables();
}
}

View file

@ -0,0 +1,98 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import java.util.EnumSet;
import java.util.Map;
import java.util.stream.Stream;
public interface SageMakerSchemaPayload {
/**
* The model API keyword that users will supply in the service settings when creating the request.
* Automatically registered in {@link SageMakerSchemas}.
*/
String api();
/**
* The supported TaskTypes for this model API.
* Automatically registered in {@link SageMakerSchemas}.
*/
EnumSet<TaskType> supportedTasks();
/**
* Implement this if the model requires extra ServiceSettings that can be saved to the model index.
* This can be accessed via {@link SageMakerModel#apiServiceSettings()}.
*/
default SageMakerStoredServiceSchema apiServiceSettings(Map<String, Object> serviceSettings, ValidationException validationException) {
return SageMakerStoredServiceSchema.NO_OP;
}
/**
* Implement this if the model requires extra TaskSettings that can be saved to the model index.
* This can be accessed via {@link SageMakerModel#apiTaskSettings()}.
*/
default SageMakerStoredTaskSchema apiTaskSettings(Map<String, Object> taskSettings, ValidationException validationException) {
return SageMakerStoredTaskSchema.NO_OP;
}
/**
* This must be thrown if {@link SageMakerModel#apiServiceSettings()} or {@link SageMakerModel#apiTaskSettings()} return the wrong
* object types.
*/
default Exception createUnsupportedSchemaException(SageMakerModel model) {
return new IllegalArgumentException(
Strings.format(
"Unsupported SageMaker settings for api [%s] and task type [%s]: [%s] and [%s]",
model.api(),
model.getTaskType(),
model.apiServiceSettings().getWriteableName(),
model.apiTaskSettings().getWriteableName()
)
);
}
/**
* Automatically register the required registry entries with {@link org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider}.
*/
default Stream<NamedWriteableRegistry.Entry> namedWriteables() {
return Stream.of();
}
/**
* The MIME type of the response from SageMaker.
*/
String accept(SageMakerModel model);
/**
* The MIME type of the request to SageMaker.
*/
String contentType(SageMakerModel model);
/**
* Translate to the body of the request in the MIME type specified by {@link #contentType(SageMakerModel)}.
*/
SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest request) throws Exception;
/**
* Translate from the body of the response in the MIME type specified by {@link #accept(SageMakerModel)}.
*/
InferenceServiceResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception;
}

View file

@ -0,0 +1,137 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiTextEmbeddingPayload;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.elasticsearch.core.Strings.format;
/**
* The mapping and registry for all supported model API.
*/
public class SageMakerSchemas {
private static final Map<TaskAndApi, SageMakerSchema> schemas;
private static final Map<TaskAndApi, SageMakerStreamSchema> streamSchemas;
private static final Map<String, Set<TaskType>> tasksByApi;
private static final Map<String, Set<TaskType>> streamingTasksByApi;
private static final Set<TaskType> supportedStreamingTasks;
private static final EnumSet<TaskType> supportedTaskTypes;
static {
/*
* Add new model API to the register call.
*/
schemas = register(new OpenAiTextEmbeddingPayload());
streamSchemas = schemas.entrySet()
.stream()
.filter(e -> e.getValue() instanceof SageMakerStreamSchema)
.collect(Collectors.toMap(Map.Entry::getKey, e -> (SageMakerStreamSchema) e.getValue()));
tasksByApi = schemas.keySet()
.stream()
.collect(Collectors.groupingBy(TaskAndApi::api, Collectors.mapping(TaskAndApi::taskType, Collectors.toSet())));
streamingTasksByApi = streamSchemas.keySet()
.stream()
.collect(Collectors.groupingBy(TaskAndApi::api, Collectors.mapping(TaskAndApi::taskType, Collectors.toSet())));
supportedStreamingTasks = streamSchemas.keySet().stream().map(TaskAndApi::taskType).collect(Collectors.toSet());
supportedTaskTypes = EnumSet.copyOf(schemas.keySet().stream().map(TaskAndApi::taskType).collect(Collectors.toSet()));
}
private static Map<TaskAndApi, SageMakerSchema> register(SageMakerSchemaPayload... payloads) {
return Arrays.stream(payloads).flatMap(payload -> payload.supportedTasks().stream().map(taskType -> {
var key = new TaskAndApi(taskType, payload.api());
SageMakerSchema value;
if (payload instanceof SageMakerStreamSchemaPayload streamPayload) {
value = new SageMakerStreamSchema(streamPayload);
} else {
value = new SageMakerSchema(payload);
}
return Map.entry(key, value);
})).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}
/**
* Automatically register the stored Schema writeables with {@link org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider}.
*/
public static List<NamedWriteableRegistry.Entry> namedWriteables() {
return Stream.concat(
Stream.of(
new NamedWriteableRegistry.Entry(
SageMakerStoredServiceSchema.class,
SageMakerStoredServiceSchema.NO_OP.getWriteableName(),
in -> SageMakerStoredServiceSchema.NO_OP
),
new NamedWriteableRegistry.Entry(
SageMakerStoredTaskSchema.class,
SageMakerStoredTaskSchema.NO_OP.getWriteableName(),
in -> SageMakerStoredTaskSchema.NO_OP
)
),
schemas.values().stream().flatMap(SageMakerSchema::namedWriteables)
).toList();
}
public SageMakerSchema schemaFor(SageMakerModel model) throws ElasticsearchStatusException {
return schemaFor(model.getTaskType(), model.api());
}
public SageMakerSchema schemaFor(TaskType taskType, String api) throws ElasticsearchStatusException {
var schema = schemas.get(new TaskAndApi(taskType, api));
if (schema == null) {
throw new ElasticsearchStatusException(
format(
"Task [%s] is not compatible for service [sagemaker] and api [%s]. Supported tasks: [%s]",
api,
taskType.toString(),
tasksByApi.getOrDefault(api, Set.of())
),
RestStatus.METHOD_NOT_ALLOWED
);
}
return schema;
}
public SageMakerStreamSchema streamSchemaFor(SageMakerModel model) throws ElasticsearchStatusException {
var schema = streamSchemas.get(new TaskAndApi(model.getTaskType(), model.api()));
if (schema == null) {
throw new ElasticsearchStatusException(
format(
"Streaming is not allowed for service [sagemaker], api [%s], and task [%s]. Supported streaming tasks: [%s]",
model.api(),
model.getTaskType().toString(),
streamingTasksByApi.getOrDefault(model.api(), Set.of())
),
RestStatus.METHOD_NOT_ALLOWED
);
}
return schema;
}
public EnumSet<TaskType> supportedTaskTypes() {
return supportedTaskTypes;
}
public Set<TaskType> supportedStreamingTasks() {
return supportedStreamingTasks;
}
}

View file

@ -0,0 +1,75 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
/**
* Contains any model-specific settings that are stored in SageMakerServiceSettings.
*/
public interface SageMakerStoredServiceSchema extends ServiceSettings {
SageMakerStoredServiceSchema NO_OP = new SageMakerStoredServiceSchema() {
private static final String NAME = "noop_sagemaker_service_schema";
@Override
public String getWriteableName() {
return NAME;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_SAGEMAKER;
}
@Override
public void writeTo(StreamOutput out) {
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) {
return builder;
}
};
/**
* The model is part of the SageMaker Endpoint definition and is not declared in the service settings.
*/
@Override
default String modelId() {
return null;
}
@Override
default ToXContentObject getFilteredXContentObject() {
return this;
}
/**
* These extra service settings serialize flatly alongside the overall SageMaker ServiceSettings.
*/
@Override
default boolean isFragment() {
return true;
}
/**
* If this Schema supports Text Embeddings, then we need to implement this.
* {@link org.elasticsearch.xpack.inference.services.validation.TextEmbeddingModelValidator} will set the dimensions if the user
* does not do it, so we need to store the dimensions and flip the {@link #dimensionsSetByUser()} boolean.
*/
default SageMakerStoredServiceSchema updateModelWithEmbeddingDetails(Integer dimensions) {
return this;
}
}

View file

@ -0,0 +1,64 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xcontent.XContentBuilder;
import java.util.Map;
/**
* Contains any model-specific settings that are stored in SageMakerTaskSettings.
*/
public interface SageMakerStoredTaskSchema extends TaskSettings {
SageMakerStoredTaskSchema NO_OP = new SageMakerStoredTaskSchema() {
@Override
public boolean isEmpty() {
return true;
}
@Override
public SageMakerStoredTaskSchema updatedTaskSettings(Map<String, Object> newSettings) {
return this;
}
private static final String NAME = "noop_sagemaker_task_schema";
@Override
public String getWriteableName() {
return NAME;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_SAGEMAKER;
}
@Override
public void writeTo(StreamOutput out) {}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) {
return builder;
}
};
/**
* These extra service settings serialize flatly alongside the overall SageMaker ServiceSettings.
*/
@Override
default boolean isFragment() {
return true;
}
@Override
SageMakerStoredTaskSchema updatedTaskSettings(Map<String, Object> newSettings);
}

View file

@ -0,0 +1,152 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamRequest;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponseHandler;
import software.amazon.awssdk.services.sagemakerruntime.model.ResponseStream;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.common.CheckedBiFunction;
import org.elasticsearch.common.CheckedSupplier;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import java.util.Locale;
import java.util.concurrent.Flow;
import java.util.function.BiFunction;
/**
* All the logic that is required to call any SageMaker model is handled within this Schema class.
* Any model-specific logic is handled within the associated {@link SageMakerStreamSchemaPayload}.
* This schema is specific for SageMaker's streaming API. For non-streaming, see {@link SageMakerSchema}.
*/
public class SageMakerStreamSchema extends SageMakerSchema {
private final SageMakerStreamSchemaPayload payload;
public SageMakerStreamSchema(SageMakerStreamSchemaPayload payload) {
super(payload);
this.payload = payload;
}
public InvokeEndpointWithResponseStreamRequest streamRequest(SageMakerModel model, SageMakerInferenceRequest request) {
return streamRequest(model, () -> payload.requestBytes(model, request));
}
private InvokeEndpointWithResponseStreamRequest streamRequest(SageMakerModel model, CheckedSupplier<SdkBytes, Exception> body) {
try {
return createStreamRequest(model).accept(payload.accept(model))
.contentType(payload.contentType(model))
.body(body.get())
.build();
} catch (ElasticsearchStatusException e) {
throw e;
} catch (Exception e) {
throw new ElasticsearchStatusException(
"Failed to create SageMaker request for [%s]",
RestStatus.INTERNAL_SERVER_ERROR,
e,
model.getInferenceEntityId()
);
}
}
public InferenceServiceResults streamResponse(SageMakerModel model, SageMakerClient.SageMakerStream response) {
return streamResponse(model, response, payload::streamResponseBody, this::error);
}
private InferenceServiceResults streamResponse(
SageMakerModel model,
SageMakerClient.SageMakerStream response,
CheckedBiFunction<SageMakerModel, SdkBytes, InferenceServiceResults.Result, Exception> parseFunction,
BiFunction<SageMakerModel, Exception, Exception> errorFunction
) {
return new StreamingChatCompletionResults(downstream -> {
response.responseStream().subscribe(new Flow.Subscriber<>() {
private volatile Flow.Subscription upstream;
@Override
public void onSubscribe(Flow.Subscription subscription) {
this.upstream = subscription;
downstream.onSubscribe(subscription);
}
@Override
public void onNext(ResponseStream item) {
if (item.sdkEventType() == ResponseStream.EventType.PAYLOAD_PART) {
item.accept(InvokeEndpointWithResponseStreamResponseHandler.Visitor.builder().onPayloadPart(payloadPart -> {
try {
downstream.onNext(parseFunction.apply(model, payloadPart.bytes()));
} catch (Exception e) {
downstream.onError(errorFunction.apply(model, e));
}
}).build());
} else {
assert upstream != null : "upstream is unset";
upstream.request(1);
}
}
@Override
public void onError(Throwable throwable) {
if (throwable instanceof Exception e) {
downstream.onError(errorFunction.apply(model, e));
} else {
ExceptionsHelper.maybeError(throwable).ifPresent(ExceptionsHelper::maybeDieOnAnotherThread);
var e = new RuntimeException("Fatal while streaming SageMaker response for [" + model.getInferenceEntityId() + "]");
downstream.onError(errorFunction.apply(model, e));
}
}
@Override
public void onComplete() {
downstream.onComplete();
}
});
});
}
public InvokeEndpointWithResponseStreamRequest chatCompletionStreamRequest(SageMakerModel model, UnifiedCompletionRequest request) {
return streamRequest(model, () -> payload.chatCompletionRequestBytes(model, request));
}
public InferenceServiceResults chatCompletionStreamResponse(SageMakerModel model, SageMakerClient.SageMakerStream response) {
return streamResponse(model, response, payload::chatCompletionResponseBody, this::chatCompletionError);
}
public UnifiedChatCompletionException chatCompletionError(SageMakerModel model, Exception e) {
if (e instanceof UnifiedChatCompletionException ucce) {
return ucce;
}
var error = errorMessageAndStatus(model, e);
return new UnifiedChatCompletionException(error.v2(), error.v1(), "error", error.v2().name().toLowerCase(Locale.ROOT));
}
private InvokeEndpointWithResponseStreamRequest.Builder createStreamRequest(SageMakerModel model) {
var request = InvokeEndpointWithResponseStreamRequest.builder();
request.endpointName(model.endpointName());
model.customAttributes().ifPresent(request::customAttributes);
model.inferenceComponentName().ifPresent(request::inferenceComponentName);
model.inferenceIdForDataCapture().ifPresent(request::inferenceId);
model.sessionId().ifPresent(request::sessionId);
model.targetContainerHostname().ifPresent(request::targetContainerHostname);
model.targetVariant().ifPresent(request::targetVariant);
return request;
}
}

View file

@ -0,0 +1,46 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
import software.amazon.awssdk.core.SdkBytes;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import java.util.EnumSet;
/**
* Implemented for models that support streaming.
* This is an extension of {@link SageMakerSchemaPayload} because Elastic expects Completion tasks to handle both streaming and
* non-streaming, and all models currently support toggling streaming on/off.
*/
public interface SageMakerStreamSchemaPayload extends SageMakerSchemaPayload {
/**
* We currently only support streaming for Completion and Chat Completion, and if we are going to implement one then we should implement
* the other, so this interface requires both streaming input and streaming unified input.
* If we ever allowed streaming for more than just Completion, then we'd probably break up this class so that Unified Chat Completion
* was its own interface.
*/
@Override
default EnumSet<TaskType> supportedTasks() {
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
}
/**
* This API would only be called for Completion task types. {@link #requestBytes(SageMakerModel, SageMakerInferenceRequest)} would
* handle the request translation for both streaming and non-streaming.
*/
InferenceServiceResults.Result streamResponseBody(SageMakerModel model, SdkBytes response) throws Exception;
SdkBytes chatCompletionRequestBytes(SageMakerModel model, UnifiedCompletionRequest request) throws Exception;
InferenceServiceResults.Result chatCompletionResponseBody(SageMakerModel model, SdkBytes response) throws Exception;
}

View file

@ -0,0 +1,12 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
import org.elasticsearch.inference.TaskType;
record TaskAndApi(TaskType taskType, String api) {}

View file

@ -0,0 +1,229 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker.schema.openai;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiEmbeddingsResponseEntity;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemaPayload;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema;
import java.io.IOException;
import java.util.EnumSet;
import java.util.Map;
import java.util.stream.Stream;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
public class OpenAiTextEmbeddingPayload implements SageMakerSchemaPayload {
private static final XContent jsonXContent = JsonXContent.jsonXContent;
private static final String APPLICATION_JSON = jsonXContent.type().mediaTypeWithoutParameters();
@Override
public String api() {
return "openai";
}
@Override
public EnumSet<TaskType> supportedTasks() {
return EnumSet.of(TaskType.TEXT_EMBEDDING);
}
@Override
public SageMakerStoredServiceSchema apiServiceSettings(Map<String, Object> serviceSettings, ValidationException validationException) {
return ApiServiceSettings.fromMap(serviceSettings, validationException);
}
@Override
public SageMakerStoredTaskSchema apiTaskSettings(Map<String, Object> taskSettings, ValidationException validationException) {
return ApiTaskSettings.fromMap(taskSettings, validationException);
}
@Override
public Stream<NamedWriteableRegistry.Entry> namedWriteables() {
return Stream.of(
new NamedWriteableRegistry.Entry(SageMakerStoredServiceSchema.class, ApiServiceSettings.NAME, ApiServiceSettings::new),
new NamedWriteableRegistry.Entry(SageMakerStoredTaskSchema.class, ApiTaskSettings.NAME, ApiTaskSettings::new)
);
}
@Override
public String accept(SageMakerModel model) {
return APPLICATION_JSON;
}
@Override
public String contentType(SageMakerModel model) {
return APPLICATION_JSON;
}
@Override
public SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest request) throws Exception {
if (model.apiServiceSettings() instanceof ApiServiceSettings apiServiceSettings
&& model.apiTaskSettings() instanceof ApiTaskSettings apiTaskSettings) {
try (var builder = JsonXContent.contentBuilder()) {
builder.startObject();
if (request.query() != null) {
builder.field("query", request.query());
}
if (request.input().size() == 1) {
builder.field("input", request.input().get(0));
} else {
builder.field("input", request.input());
}
if (apiTaskSettings.user() != null) {
builder.field("user", apiTaskSettings.user());
}
if (apiServiceSettings.dimensionsSetByUser() && apiServiceSettings.dimensions() != null) {
builder.field("dimensions", apiServiceSettings.dimensions());
}
builder.endObject();
return SdkBytes.fromUtf8String(Strings.toString(builder));
}
} else {
throw createUnsupportedSchemaException(model);
}
}
@Override
public TextEmbeddingFloatResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception {
try (var p = jsonXContent.createParser(XContentParserConfiguration.EMPTY, response.body().asInputStream())) {
return OpenAiEmbeddingsResponseEntity.EmbeddingFloatResult.PARSER.apply(p, null).toTextEmbeddingFloatResults();
}
}
record ApiServiceSettings(@Nullable Integer dimensions, Boolean dimensionsSetByUser) implements SageMakerStoredServiceSchema {
private static final String NAME = "sagemaker_openai_text_embeddings_service_settings";
private static final String DIMENSIONS_FIELD = "dimensions";
ApiServiceSettings(StreamInput in) throws IOException {
this(in.readOptionalInt(), in.readBoolean());
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_SAGEMAKER;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalInt(dimensions);
out.writeBoolean(dimensionsSetByUser);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
if (dimensions != null) {
builder.field(DIMENSIONS_FIELD, dimensions);
}
return builder;
}
static ApiServiceSettings fromMap(Map<String, Object> serviceSettings, ValidationException validationException) {
var dimensions = extractOptionalPositiveInteger(
serviceSettings,
DIMENSIONS_FIELD,
ModelConfigurations.SERVICE_SETTINGS,
validationException
);
return new ApiServiceSettings(dimensions, dimensions != null);
}
@Override
public SimilarityMeasure similarity() {
return SimilarityMeasure.DOT_PRODUCT;
}
@Override
public DenseVectorFieldMapper.ElementType elementType() {
return DenseVectorFieldMapper.ElementType.FLOAT;
}
@Override
public SageMakerStoredServiceSchema updateModelWithEmbeddingDetails(Integer dimensions) {
return new ApiServiceSettings(dimensions, false);
}
}
record ApiTaskSettings(@Nullable String user) implements SageMakerStoredTaskSchema {
private static final String NAME = "sagemaker_openai_text_embeddings_task_settings";
private static final String USER_FIELD = "user";
ApiTaskSettings(StreamInput in) throws IOException {
this(in.readOptionalString());
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_SAGEMAKER;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(user);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return user != null ? builder.field(USER_FIELD, user) : builder;
}
@Override
public boolean isEmpty() {
return user == null;
}
@Override
public ApiTaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
var validationException = new ValidationException();
var newTaskSettings = fromMap(newSettings, validationException);
validationException.throwIfValidationErrorsExist();
return new ApiTaskSettings(newTaskSettings.user() != null ? newTaskSettings.user() : user);
}
static ApiTaskSettings fromMap(Map<String, Object> map, ValidationException exception) {
var user = extractOptionalString(map, USER_FIELD, ModelConfigurations.TASK_SETTINGS, exception);
return new ApiTaskSettings(user);
}
}
}

View file

@ -0,0 +1,101 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import static org.hamcrest.Matchers.equalTo;
/**
* All Inference Setting classes derived from {@link org.elasticsearch.inference.ServiceSettings},
* {@link org.elasticsearch.inference.TaskSettings}, and {@link org.elasticsearch.inference.SecretSettings}
* must be able to read/write to an index via ToXContent as well as read/write between nodes via Writeable.
*/
public abstract class InferenceSettingsTestCase<T extends Writeable & ToXContent> extends AbstractBWCWireSerializationTestCase<T> {
/**
* This method is final because {@link org.elasticsearch.inference.ModelConfigurations} settings must be registered in
* {@link InferenceNamedWriteablesProvider}.
*/
@Override
protected final NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(InferenceNamedWriteablesProvider.getNamedWriteables());
}
/**
* Helper implementation since most mutates can be handled by continuously randomizing until we have another test instance.
*/
@Override
protected T mutateInstance(T instance) throws IOException {
return randomValueOtherThan(instance, this::createTestInstance);
}
/**
* Change this for BWC once the implementation requires difference objects depending on the transport version.
*/
@Override
protected T mutateInstanceForVersion(T instance, TransportVersion version) {
return instance;
}
/**
* Verify that we can write to XContent and then read from XContent.
* This simulates saving the model to the index and then reading the model from the index.
*/
public final void testXContentRoundTrip() throws IOException {
var instance = createTestInstance();
var instanceAsMap = toMap(instance);
var roundTripInstance = fromMutableMap(new HashMap<>(instanceAsMap));
assertThat(roundTripInstance, equalTo(instance));
}
protected abstract T fromMutableMap(Map<String, Object> mutableMap);
public static Map<String, Object> toMap(ToXContent instance) throws IOException {
try (var builder = JsonXContent.contentBuilder()) {
if (instance.isFragment()) {
builder.startObject();
}
instance.toXContent(builder, ToXContent.EMPTY_PARAMS);
if (instance.isFragment()) {
builder.endObject();
}
var taskSettingsBytes = Strings.toString(builder).getBytes(StandardCharsets.UTF_8);
try (var parser = JsonXContent.jsonXContent.createParser(XContentParserConfiguration.EMPTY, taskSettingsBytes)) {
return parser.map();
}
}
}
/**
* Helper, since most settings contain optional strings.
*/
protected static String randomOptionalString() {
return randomBoolean() ? randomString() : null;
}
/**
* Helper, since most settings contain strings.
*/
protected static String randomString() {
return randomAlphaOfLength(randomIntBetween(4, 8));
}
}

View file

@ -0,0 +1,180 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker;
import software.amazon.awssdk.core.async.SdkPublisher;
import software.amazon.awssdk.services.sagemakerruntime.SageMakerRuntimeAsyncClient;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointRequest;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamRequest;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponse;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponseHandler;
import software.amazon.awssdk.services.sagemakerruntime.model.ResponseStream;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.cache.CacheLoader;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
import org.junit.After;
import org.junit.Before;
import org.mockito.ArgumentMatchers;
import org.reactivestreams.Subscriber;
import java.io.IOException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.assertArg;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class SageMakerClientTests extends ESTestCase {
private static final InvokeEndpointRequest REQUEST = InvokeEndpointRequest.builder().build();
private static final InvokeEndpointWithResponseStreamRequest STREAM_REQUEST = InvokeEndpointWithResponseStreamRequest.builder().build();
private SageMakerRuntimeAsyncClient awsClient;
private CacheLoader<SageMakerClient.RegionAndSecrets, SageMakerRuntimeAsyncClient> clientFactory;
private SageMakerClient client;
private ThreadPool threadPool;
@Before
public void setUp() throws Exception {
super.setUp();
threadPool = createThreadPool(inferenceUtilityPool());
awsClient = mock();
clientFactory = spy(new CacheLoader<>() {
public SageMakerRuntimeAsyncClient load(SageMakerClient.RegionAndSecrets key) {
return awsClient;
}
});
client = new SageMakerClient(clientFactory, threadPool);
}
@After
public void shutdown() throws IOException {
terminate(threadPool);
}
public void testInvoke() throws Exception {
var expectedResponse = InvokeEndpointResponse.builder().build();
when(awsClient.invokeEndpoint(any(InvokeEndpointRequest.class))).thenReturn(CompletableFuture.completedFuture(expectedResponse));
var listener = invoke(TimeValue.THIRTY_SECONDS);
verify(clientFactory, times(1)).load(any());
verify(listener, times(1)).onResponse(eq(expectedResponse));
}
private static SageMakerClient.RegionAndSecrets regionAndSecrets() {
return new SageMakerClient.RegionAndSecrets(
"us-east-1",
new AwsSecretSettings(new SecureString("access"), new SecureString("secrets"))
);
}
private ActionListener<InvokeEndpointResponse> invoke(TimeValue timeout) throws InterruptedException {
var latch = new CountDownLatch(1);
ActionListener<InvokeEndpointResponse> listener = spy(ActionListener.noop());
client.invoke(regionAndSecrets(), REQUEST, timeout, ActionListener.runAfter(listener, latch::countDown));
assertTrue("Timed out waiting for invoke call", latch.await(5, TimeUnit.SECONDS));
return listener;
}
public void testInvokeCache() throws Exception {
when(awsClient.invokeEndpoint(any(InvokeEndpointRequest.class))).thenReturn(
CompletableFuture.completedFuture(InvokeEndpointResponse.builder().build())
);
invoke(TimeValue.THIRTY_SECONDS);
invoke(TimeValue.THIRTY_SECONDS);
verify(clientFactory, times(1)).load(any());
}
public void testInvokeTimeout() throws Exception {
when(awsClient.invokeEndpoint(any(InvokeEndpointRequest.class))).thenReturn(new CompletableFuture<>());
var listener = invoke(TimeValue.timeValueMillis(10));
verify(clientFactory, times(1)).load(any());
verifyTimeout(listener);
}
private static void verifyTimeout(ActionListener<?> listener) {
verify(listener, times(1)).onFailure(assertArg(e -> assertThat(e.getMessage(), equalTo("Request timed out after [10ms]"))));
}
public void testInvokeStream() throws Exception {
SdkPublisher<ResponseStream> publisher = mockPublisher();
var listener = invokeStream(TimeValue.THIRTY_SECONDS);
verify(publisher, never()).subscribe(ArgumentMatchers.<Subscriber<ResponseStream>>any());
verify(listener, times(1)).onResponse(assertArg(stream -> stream.responseStream().subscribe(mock())));
verify(publisher, times(1)).subscribe(ArgumentMatchers.<Subscriber<ResponseStream>>any());
}
private SdkPublisher<ResponseStream> mockPublisher() {
SdkPublisher<ResponseStream> publisher = mock();
doAnswer(ans -> {
InvokeEndpointWithResponseStreamResponseHandler handler = ans.getArgument(1);
handler.responseReceived(InvokeEndpointWithResponseStreamResponse.builder().build());
handler.onEventStream(publisher);
return CompletableFuture.completedFuture((Void) null);
}).when(awsClient).invokeEndpointWithResponseStream(any(InvokeEndpointWithResponseStreamRequest.class), any());
return publisher;
}
private ActionListener<SageMakerClient.SageMakerStream> invokeStream(TimeValue timeout) throws Exception {
var latch = new CountDownLatch(1);
ActionListener<SageMakerClient.SageMakerStream> listener = spy(ActionListener.noop());
client.invokeStream(regionAndSecrets(), STREAM_REQUEST, timeout, ActionListener.runAfter(listener, latch::countDown));
assertTrue("Timed out waiting for invoke call", latch.await(5, TimeUnit.SECONDS));
return listener;
}
public void testInvokeStreamCache() throws Exception {
mockPublisher();
invokeStream(TimeValue.THIRTY_SECONDS);
invokeStream(TimeValue.THIRTY_SECONDS);
verify(clientFactory, times(1)).load(any());
}
public void testInvokeStreamTimeout() throws Exception {
when(awsClient.invokeEndpointWithResponseStream(any(InvokeEndpointWithResponseStreamRequest.class), any())).thenReturn(
new CompletableFuture<>()
);
var listener = invokeStream(TimeValue.timeValueMillis(10));
verify(clientFactory, times(1)).load(any());
verifyTimeout(listener);
}
public void testClose() throws Exception {
// load cache
mockPublisher();
invokeStream(TimeValue.THIRTY_SECONDS);
// clear cache
client.close();
verify(awsClient, times(1)).close();
}
}

View file

@ -0,0 +1,463 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponse;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests;
import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchema;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStreamSchema;
import org.junit.Before;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
import static org.elasticsearch.action.ActionListener.assertOnce;
import static org.elasticsearch.action.support.ActionTestUtils.assertNoFailureListener;
import static org.elasticsearch.action.support.ActionTestUtils.assertNoSuccessListener;
import static org.elasticsearch.core.TimeValue.THIRTY_SECONDS;
import static org.elasticsearch.xpack.core.inference.action.UnifiedCompletionRequestTests.randomUnifiedCompletionRequest;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.in;
import static org.hamcrest.Matchers.isA;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.assertArg;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.only;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
public class SageMakerServiceTests extends ESTestCase {
private static final String QUERY = "query";
private static final List<String> INPUT = List.of("input");
private static final InputType INPUT_TYPE = InputType.UNSPECIFIED;
private SageMakerModelBuilder modelBuilder;
private SageMakerClient client;
private SageMakerSchemas schemas;
private SageMakerService sageMakerService;
@Before
public void init() {
modelBuilder = mock();
client = mock();
schemas = mock();
ThreadPool threadPool = mock();
when(threadPool.executor(anyString())).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
sageMakerService = new SageMakerService(modelBuilder, client, schemas, threadPool, Map::of);
}
public void testSupportedTaskTypes() {
sageMakerService.supportedTaskTypes();
verify(schemas, only()).supportedTaskTypes();
}
public void testSupportedStreamingTasks() {
sageMakerService.supportedStreamingTasks();
verify(schemas, only()).supportedStreamingTasks();
}
public void testParseRequestConfig() {
sageMakerService.parseRequestConfig("modelId", TaskType.ANY, Map.of(), assertNoFailureListener(model -> {
verify(modelBuilder, only()).fromRequest(eq("modelId"), eq(TaskType.ANY), eq(SageMakerService.NAME), eq(Map.of()));
}));
}
public void testParsePersistedConfigWithSecrets() {
sageMakerService.parsePersistedConfigWithSecrets("modelId", TaskType.ANY, Map.of(), Map.of());
verify(modelBuilder, only()).fromStorage(eq("modelId"), eq(TaskType.ANY), eq(SageMakerService.NAME), eq(Map.of()), eq(Map.of()));
}
public void testParsePersistedConfig() {
sageMakerService.parsePersistedConfig("modelId", TaskType.ANY, Map.of());
verify(modelBuilder, only()).fromStorage(eq("modelId"), eq(TaskType.ANY), eq(SageMakerService.NAME), eq(Map.of()), eq(null));
}
public void testInferWithWrongModel() {
sageMakerService.infer(
mockUnsupportedModel(),
QUERY,
false,
null,
INPUT,
false,
null,
INPUT_TYPE,
THIRTY_SECONDS,
assertUnsupportedModel()
);
}
private static Model mockUnsupportedModel() {
Model model = mock();
ModelConfigurations modelConfigurations = mock();
when(modelConfigurations.getService()).thenReturn("mockService");
when(modelConfigurations.getInferenceEntityId()).thenReturn("mockInferenceId");
when(model.getConfigurations()).thenReturn(modelConfigurations);
return model;
}
private static <T> ActionListener<T> assertUnsupportedModel() {
return assertNoSuccessListener(e -> {
assertThat(e, isA(ElasticsearchStatusException.class));
assertThat(
e.getMessage(),
equalTo(
"The internal model was invalid, please delete the service [mockService] "
+ "with id [mockInferenceId] and add it again."
)
);
assertThat(((ElasticsearchStatusException) e).status(), equalTo(RestStatus.INTERNAL_SERVER_ERROR));
});
}
public void testInfer() {
var model = mockModel();
SageMakerSchema schema = mock();
when(schemas.schemaFor(model)).thenReturn(schema);
mockInvoke();
sageMakerService.infer(
model,
QUERY,
null,
null,
INPUT,
false,
null,
INPUT_TYPE,
THIRTY_SECONDS,
assertNoFailureListener(ignored -> {
verify(schemas, only()).schemaFor(eq(model));
verify(schema, times(1)).request(eq(model), assertRequest());
verify(schema, times(1)).response(eq(model), any(), any());
})
);
verify(client, only()).invoke(any(), any(), any(), any());
verifyNoMoreInteractions(client, schemas, schema);
}
private SageMakerModel mockModel() {
SageMakerModel model = mock();
when(model.override(null)).thenReturn(model);
when(model.awsSecretSettings()).thenReturn(
Optional.of(new AwsSecretSettings(new SecureString("test-accessKey"), new SecureString("test-secretKey")))
);
return model;
}
private void mockInvoke() {
doAnswer(ans -> {
ActionListener<InvokeEndpointResponse> responseListener = ans.getArgument(3);
responseListener.onResponse(InvokeEndpointResponse.builder().build());
return null; // Void
}).when(client).invoke(any(), any(), any(), any());
}
private static SageMakerInferenceRequest assertRequest() {
return assertArg(request -> {
assertThat(request.query(), equalTo(QUERY));
assertThat(request.input(), containsInAnyOrder(INPUT.toArray()));
assertThat(request.inputType(), equalTo(InputType.UNSPECIFIED));
assertNull(request.returnDocuments());
assertNull(request.topN());
});
}
public void testInferStream() {
SageMakerModel model = mockModel();
SageMakerStreamSchema schema = mock();
when(schemas.streamSchemaFor(model)).thenReturn(schema);
mockInvokeStream();
sageMakerService.infer(model, QUERY, null, null, INPUT, true, null, INPUT_TYPE, THIRTY_SECONDS, assertNoFailureListener(ignored -> {
verify(schemas, only()).streamSchemaFor(eq(model));
verify(schema, times(1)).streamRequest(eq(model), assertRequest());
verify(schema, times(1)).streamResponse(eq(model), any());
}));
verify(client, only()).invokeStream(any(), any(), any(), any());
verifyNoMoreInteractions(client, schemas, schema);
}
private void mockInvokeStream() {
doAnswer(ans -> {
ActionListener<SageMakerClient.SageMakerStream> responseListener = ans.getArgument(3);
responseListener.onResponse(
new SageMakerClient.SageMakerStream(InvokeEndpointWithResponseStreamResponse.builder().build(), mock())
);
return null; // Void
}).when(client).invokeStream(any(), any(), any(), any());
}
public void testInferError() {
SageMakerModel model = mockModel();
var expectedException = new IllegalArgumentException("hola");
SageMakerSchema schema = mock();
when(schemas.schemaFor(model)).thenReturn(schema);
mockInvokeFailure(expectedException);
sageMakerService.infer(
model,
QUERY,
null,
null,
INPUT,
false,
null,
INPUT_TYPE,
THIRTY_SECONDS,
assertNoSuccessListener(ignored -> {
verify(schemas, only()).schemaFor(eq(model));
verify(schema, times(1)).request(eq(model), assertRequest());
verify(schema, times(1)).error(eq(model), assertArg(e -> assertThat(e, equalTo(expectedException))));
})
);
verify(client, only()).invoke(any(), any(), any(), any());
verifyNoMoreInteractions(client, schemas, schema);
}
private void mockInvokeFailure(Exception e) {
doAnswer(ans -> {
ActionListener<?> responseListener = ans.getArgument(3);
responseListener.onFailure(e);
return null; // Void
}).when(client).invoke(any(), any(), any(), any());
}
public void testInferException() {
SageMakerModel model = mockModel();
when(model.getInferenceEntityId()).thenReturn("some id");
SageMakerStreamSchema schema = mock();
when(schemas.streamSchemaFor(model)).thenReturn(schema);
doThrow(new IllegalArgumentException("wow, really?")).when(client).invokeStream(any(), any(), any(), any());
sageMakerService.infer(model, QUERY, null, null, INPUT, true, null, INPUT_TYPE, THIRTY_SECONDS, assertNoSuccessListener(e -> {
verify(schemas, only()).streamSchemaFor(eq(model));
verify(schema, times(1)).streamRequest(eq(model), assertRequest());
assertThat(e, isA(ElasticsearchStatusException.class));
assertThat(((ElasticsearchStatusException) e).status(), equalTo(RestStatus.INTERNAL_SERVER_ERROR));
assertThat(e.getMessage(), equalTo("Failed to call SageMaker for inference id [some id]."));
}));
verify(client, only()).invokeStream(any(), any(), any(), any());
verifyNoMoreInteractions(client, schemas, schema);
}
public void testUnifiedInferWithWrongModel() {
sageMakerService.unifiedCompletionInfer(
mockUnsupportedModel(),
randomUnifiedCompletionRequest(),
THIRTY_SECONDS,
assertUnsupportedModel()
);
}
public void testUnifiedInfer() {
var model = mockModel();
SageMakerStreamSchema schema = mock();
when(schemas.streamSchemaFor(model)).thenReturn(schema);
mockInvokeStream();
sageMakerService.unifiedCompletionInfer(
model,
randomUnifiedCompletionRequest(),
THIRTY_SECONDS,
assertNoFailureListener(ignored -> {
verify(schemas, only()).streamSchemaFor(eq(model));
verify(schema, times(1)).chatCompletionStreamRequest(eq(model), any());
verify(schema, times(1)).chatCompletionStreamResponse(eq(model), any());
})
);
verify(client, only()).invokeStream(any(), any(), any(), any());
verifyNoMoreInteractions(client, schemas, schema);
}
public void testUnifiedInferError() {
var model = mockModel();
var expectedException = new IllegalArgumentException("hola");
SageMakerStreamSchema schema = mock();
when(schemas.streamSchemaFor(model)).thenReturn(schema);
mockInvokeStreamFailure(expectedException);
sageMakerService.unifiedCompletionInfer(
model,
randomUnifiedCompletionRequest(),
THIRTY_SECONDS,
assertNoSuccessListener(ignored -> {
verify(schemas, only()).streamSchemaFor(eq(model));
verify(schema, times(1)).chatCompletionStreamRequest(eq(model), any());
verify(schema, times(1)).chatCompletionError(eq(model), assertArg(e -> assertThat(e, equalTo(expectedException))));
})
);
verify(client, only()).invokeStream(any(), any(), any(), any());
verifyNoMoreInteractions(client, schemas, schema);
}
private void mockInvokeStreamFailure(Exception e) {
doAnswer(ans -> {
ActionListener<?> responseListener = ans.getArgument(3);
responseListener.onFailure(e);
return null; // Void
}).when(client).invokeStream(any(), any(), any(), any());
}
public void testUnifiedInferException() {
SageMakerModel model = mockModel();
when(model.getInferenceEntityId()).thenReturn("some id");
SageMakerStreamSchema schema = mock();
when(schemas.streamSchemaFor(model)).thenReturn(schema);
doThrow(new IllegalArgumentException("wow, rude")).when(client).invokeStream(any(), any(), any(), any());
sageMakerService.unifiedCompletionInfer(model, randomUnifiedCompletionRequest(), THIRTY_SECONDS, assertNoSuccessListener(e -> {
verify(schemas, only()).streamSchemaFor(eq(model));
verify(schema, times(1)).chatCompletionStreamRequest(eq(model), any());
assertThat(e, isA(ElasticsearchStatusException.class));
assertThat(((ElasticsearchStatusException) e).status(), equalTo(RestStatus.INTERNAL_SERVER_ERROR));
assertThat(e.getMessage(), equalTo("Failed to call SageMaker for inference id [some id]."));
}));
verify(client, only()).invokeStream(any(), any(), any(), any());
verifyNoMoreInteractions(client, schemas, schema);
}
public void testChunkedInferWithWrongModel() {
sageMakerService.chunkedInfer(
mockUnsupportedModel(),
QUERY,
INPUT.stream().map(ChunkInferenceInput::new).toList(),
null,
INPUT_TYPE,
THIRTY_SECONDS,
assertUnsupportedModel()
);
}
public void testChunkedInfer() throws Exception {
var model = mockModelForChunking();
SageMakerSchema schema = mock();
when(schema.response(any(), any(), any())).thenReturn(TextEmbeddingFloatResultsTests.createRandomResults());
when(schemas.schemaFor(model)).thenReturn(schema);
mockInvoke();
var expectedInput = Set.of("first", "second");
sageMakerService.chunkedInfer(
model,
QUERY,
expectedInput.stream().map(ChunkInferenceInput::new).toList(),
null,
INPUT_TYPE,
THIRTY_SECONDS,
assertOnce(assertNoFailureListener(chunkedInferences -> {
verify(schemas, times(2)).schemaFor(eq(model));
verify(schema, times(2)).request(eq(model), assertChunkRequest(expectedInput));
verify(schema, times(2)).response(eq(model), any(), any());
}))
);
verify(client, times(2)).invoke(any(), any(), any(), any());
verifyNoMoreInteractions(client, schemas, schema);
}
private SageMakerModel mockModelForChunking() {
var model = mockModel();
when(model.batchSize()).thenReturn(Optional.of(1));
ModelConfigurations modelConfigurations = mock();
when(modelConfigurations.getChunkingSettings()).thenReturn(new WordBoundaryChunkingSettings(1, 0));
when(model.getConfigurations()).thenReturn(modelConfigurations);
return model;
}
private static SageMakerInferenceRequest assertChunkRequest(Set<String> expectedInput) {
return assertArg(request -> {
assertThat(request.query(), equalTo(QUERY));
assertThat(request.input(), hasSize(1));
assertThat(request.input().get(0), in(expectedInput));
assertThat(request.inputType(), equalTo(InputType.UNSPECIFIED));
assertNull(request.returnDocuments());
assertNull(request.topN());
assertFalse(request.stream());
});
}
public void testChunkedInferError() {
var model = mockModelForChunking();
var expectedException = new IllegalArgumentException("hola");
SageMakerSchema schema = mock();
when(schema.error(any(), any())).thenReturn(expectedException);
when(schemas.schemaFor(model)).thenReturn(schema);
mockInvokeFailure(expectedException);
var expectedInput = Set.of("first", "second");
var expectedOutput = Stream.of(expectedException, expectedException).map(ChunkedInferenceError::new).toArray();
sageMakerService.chunkedInfer(
model,
QUERY,
expectedInput.stream().map(ChunkInferenceInput::new).toList(),
null,
INPUT_TYPE,
THIRTY_SECONDS,
assertOnce(assertNoFailureListener(chunkedInferences -> {
verify(schemas, times(2)).schemaFor(eq(model));
verify(schema, times(2)).request(eq(model), assertChunkRequest(expectedInput));
verify(schema, times(2)).error(eq(model), assertArg(e -> assertThat(e, equalTo(expectedException))));
assertThat(chunkedInferences, containsInAnyOrder(expectedOutput));
}))
);
verify(client, times(2)).invoke(any(), any(), any(), any());
verifyNoMoreInteractions(client, schemas, schema);
}
public void testClose() throws IOException {
sageMakerService.close();
verify(client, only()).close();
}
}

View file

@ -0,0 +1,315 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker.model;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemasTests;
import org.junit.Before;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.Optional;
import static org.elasticsearch.inference.ModelConfigurations.USE_ID_FOR_INDEX;
import static org.hamcrest.Matchers.equalTo;
public class SageMakerModelBuilderTests extends ESTestCase {
private static final String inferenceId = "inferenceId";
private static final TaskType taskType = TaskType.ANY;
private static final String service = "service";
private SageMakerModelBuilder builder;
@Before
public void setUp() throws Exception {
super.setUp();
builder = new SageMakerModelBuilder(SageMakerSchemasTests.mockSchemas());
}
public void testFromRequestWithRequiredFields() {
var model = fromRequest("""
{
"service_settings": {
"access_key": "test-access-key",
"secret_key": "test-secret-key",
"region": "us-east-1",
"api": "test-api",
"endpoint_name": "test-endpoint"
}
}
""");
assertNotNull(model);
assertTrue(model.awsSecretSettings().isPresent());
assertThat(model.awsSecretSettings().get().accessKey().toString(), equalTo("test-access-key"));
assertThat(model.awsSecretSettings().get().secretKey().toString(), equalTo("test-secret-key"));
assertThat(model.region(), equalTo("us-east-1"));
assertThat(model.api(), equalTo("test-api"));
assertThat(model.endpointName(), equalTo("test-endpoint"));
assertTrue(model.customAttributes().isEmpty());
assertTrue(model.enableExplanations().isEmpty());
assertTrue(model.inferenceComponentName().isEmpty());
assertTrue(model.inferenceIdForDataCapture().isEmpty());
assertTrue(model.sessionId().isEmpty());
assertTrue(model.targetContainerHostname().isEmpty());
assertTrue(model.targetModel().isEmpty());
assertTrue(model.targetVariant().isEmpty());
assertTrue(model.batchSize().isEmpty());
}
public void testFromRequestWithOptionalFields() {
var model = fromRequest("""
{
"service_settings": {
"access_key": "test-access-key",
"secret_key": "test-secret-key",
"region": "us-east-1",
"api": "test-api",
"endpoint_name": "test-endpoint",
"target_model": "test-target",
"target_container_hostname": "test-target-container",
"inference_component_name": "test-inference-component",
"batch_size": 1234
},
"task_settings": {
"custom_attributes": "test-custom-attributes",
"enable_explanations": "test-enable-explanations",
"inference_id": "test-inference-id",
"session_id": "test-session-id",
"target_variant": "test-target-variant"
}
}
""");
assertNotNull(model);
assertTrue(model.awsSecretSettings().isPresent());
assertThat(model.awsSecretSettings().get().accessKey().toString(), equalTo("test-access-key"));
assertThat(model.awsSecretSettings().get().secretKey().toString(), equalTo("test-secret-key"));
assertThat(model.region(), equalTo("us-east-1"));
assertThat(model.api(), equalTo("test-api"));
assertThat(model.endpointName(), equalTo("test-endpoint"));
assertPresent(model.customAttributes(), "test-custom-attributes");
assertPresent(model.enableExplanations(), "test-enable-explanations");
assertPresent(model.inferenceComponentName(), "test-inference-component");
assertPresent(model.inferenceIdForDataCapture(), "test-inference-id");
assertPresent(model.sessionId(), "test-session-id");
assertPresent(model.targetContainerHostname(), "test-target-container");
assertPresent(model.targetModel(), "test-target");
assertPresent(model.targetVariant(), "test-target-variant");
assertPresent(model.batchSize(), 1234);
}
public void testFromRequestWithoutAccessKey() {
testExceptionFromRequest("""
{
"service_settings": {
"secret_key": "test-secret-key",
"region": "us-east-1",
"api": "test-api",
"endpoint_name": "test-endpoint"
}
}
""", ValidationException.class, "Validation Failed: 1: [secret_settings] does not contain the required setting [access_key];");
}
public void testFromRequestWithoutSecretKey() {
testExceptionFromRequest("""
{
"service_settings": {
"access_key": "test-access-key",
"region": "us-east-1",
"api": "test-api",
"endpoint_name": "test-endpoint"
}
}
""", ValidationException.class, "Validation Failed: 1: [secret_settings] does not contain the required setting [secret_key];");
}
public void testFromRequestWithoutRegion() {
testExceptionFromRequest("""
{
"service_settings": {
"access_key": "test-access-key",
"secret_key": "test-secret-key",
"api": "test-api",
"endpoint_name": "test-endpoint"
}
}
""", ValidationException.class, "Validation Failed: 1: [service_settings] does not contain the required setting [region];");
}
public void testFromRequestWithoutApi() {
testExceptionFromRequest("""
{
"service_settings": {
"access_key": "test-access-key",
"secret_key": "test-secret-key",
"region": "us-east-1",
"endpoint_name": "test-endpoint"
}
}
""", ValidationException.class, "Validation Failed: 1: [service_settings] does not contain the required setting [api];");
}
public void testFromRequestWithoutEndpointName() {
testExceptionFromRequest(
"""
{
"service_settings": {
"access_key": "test-access-key",
"secret_key": "test-secret-key",
"region": "us-east-1",
"api": "test-api"
}
}
""",
ValidationException.class,
"Validation Failed: 1: [service_settings] does not contain the required setting [endpoint_name];"
);
}
public void testFromRequestWithExtraServiceKeys() {
testExceptionFromRequest(
"""
{
"service_settings": {
"access_key": "test-access-key",
"secret_key": "test-secret-key",
"region": "us-east-1",
"api": "test-api",
"endpoint_name": "test-endpoint",
"hello": "there"
}
}
""",
ElasticsearchStatusException.class,
"Model configuration contains settings [{hello=there}] unknown to the [service] service"
);
}
public void testFromRequestWithExtraTaskKeys() {
testExceptionFromRequest(
"""
{
"service_settings": {
"access_key": "test-access-key",
"secret_key": "test-secret-key",
"region": "us-east-1",
"api": "test-api",
"endpoint_name": "test-endpoint"
},
"task_settings": {
"hello": "there"
}
}
""",
ElasticsearchStatusException.class,
"Model configuration contains settings [{hello=there}] unknown to the [service] service"
);
}
public void testRoundTrip() throws IOException {
var expectedModel = fromRequest("""
{
"service_settings": {
"access_key": "test-access-key",
"secret_key": "test-secret-key",
"region": "us-east-1",
"api": "test-api",
"endpoint_name": "test-endpoint",
"target_model": "test-target",
"target_container_hostname": "test-target-container",
"inference_component_name": "test-inference-component",
"batch_size": 1234
},
"task_settings": {
"custom_attributes": "test-custom-attributes",
"enable_explanations": "test-enable-explanations",
"inference_id": "test-inference-id",
"session_id": "test-session-id",
"target_variant": "test-target-variant"
}
}
""");
var unparsedModelWithSecrets = unparsedModel(expectedModel.getConfigurations(), expectedModel.getSecrets());
var modelWithSecrets = builder.fromStorage(
unparsedModelWithSecrets.inferenceEntityId(),
unparsedModelWithSecrets.taskType(),
unparsedModelWithSecrets.service(),
unparsedModelWithSecrets.settings(),
unparsedModelWithSecrets.secrets()
);
assertThat(modelWithSecrets, equalTo(expectedModel));
assertNotNull(modelWithSecrets.getSecrets().getSecretSettings());
var unparsedModelWithoutSecrets = unparsedModel(expectedModel.getConfigurations(), null);
var modelWithoutSecrets = builder.fromStorage(
unparsedModelWithoutSecrets.inferenceEntityId(),
unparsedModelWithoutSecrets.taskType(),
unparsedModelWithoutSecrets.service(),
unparsedModelWithoutSecrets.settings(),
unparsedModelWithoutSecrets.secrets()
);
assertThat(modelWithoutSecrets.getConfigurations(), equalTo(expectedModel.getConfigurations()));
assertNull(modelWithoutSecrets.getSecrets().getSecretSettings());
}
private SageMakerModel fromRequest(String json) {
return builder.fromRequest(inferenceId, taskType, service, map(json));
}
private void testExceptionFromRequest(String json, Class<? extends Exception> exceptionClass, String message) {
var exception = assertThrows(exceptionClass, () -> fromRequest(json));
assertThat(exception.getMessage(), equalTo(message));
}
private static <T> void assertPresent(Optional<T> optional, T expectedValue) {
assertTrue(optional.isPresent());
assertThat(optional.get(), equalTo(expectedValue));
}
private static Map<String, Object> map(String json) {
try (
var parser = JsonXContent.jsonXContent.createParser(XContentParserConfiguration.EMPTY, json.getBytes(StandardCharsets.UTF_8))
) {
return parser.map();
} catch (IOException e) {
throw new AssertionError(e);
}
}
private static UnparsedModel unparsedModel(ModelConfigurations modelConfigurations, ModelSecrets modelSecrets) throws IOException {
var modelConfigMap = new ModelRegistry.ModelConfigMap(
toJsonMap(modelConfigurations),
modelSecrets != null ? toJsonMap(modelSecrets) : null
);
return ModelRegistry.unparsedModelFromMap(modelConfigMap);
}
private static Map<String, Object> toJsonMap(ToXContent toXContent) throws IOException {
try (var builder = JsonXContent.contentBuilder()) {
toXContent.toXContent(builder, new ToXContent.MapParams(Map.of(USE_ID_FOR_INDEX, "true")));
return map(Strings.toString(builder));
}
}
}

View file

@ -0,0 +1,43 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker.model;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemasTests;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema;
import java.util.Map;
public class SageMakerServiceSettingsTests extends InferenceSettingsTestCase<SageMakerServiceSettings> {
@Override
protected Writeable.Reader<SageMakerServiceSettings> instanceReader() {
return SageMakerServiceSettings::new;
}
@Override
protected SageMakerServiceSettings createTestInstance() {
return new SageMakerServiceSettings(
randomString(),
randomString(),
randomString(),
randomOptionalString(),
randomOptionalString(),
randomOptionalString(),
randomBoolean() ? randomIntBetween(1, 1000) : null,
SageMakerStoredServiceSchema.NO_OP
);
}
@Override
protected SageMakerServiceSettings fromMutableMap(Map<String, Object> mutableMap) {
return SageMakerServiceSettings.fromMap(SageMakerSchemasTests.mockSchemas(), randomFrom(TaskType.values()), mutableMap);
}
}

View file

@ -0,0 +1,42 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker.model;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema;
import java.util.Map;
public class SageMakerTaskSettingsTests extends InferenceSettingsTestCase<SageMakerTaskSettings> {
@Override
protected Writeable.Reader<SageMakerTaskSettings> instanceReader() {
return SageMakerTaskSettings::new;
}
@Override
protected SageMakerTaskSettings createTestInstance() {
return new SageMakerTaskSettings(
randomOptionalString(),
randomOptionalString(),
randomOptionalString(),
randomOptionalString(),
randomOptionalString(),
SageMakerStoredTaskSchema.NO_OP
);
}
@Override
protected SageMakerTaskSettings fromMutableMap(Map<String, Object> mutableMap) {
var validationException = new ValidationException();
var taskSettings = SageMakerTaskSettings.fromMap(mutableMap, SageMakerStoredTaskSchema.NO_OP, validationException);
validationException.throwIfValidationErrorsExist();
return taskSettings;
}
}

View file

@ -0,0 +1,156 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
import software.amazon.awssdk.core.SdkBytes;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.junit.Before;
import java.io.IOException;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Stream;
import static org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase.toMap;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.startsWith;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public abstract class SageMakerSchemaPayloadTestCase<T extends SageMakerSchemaPayload> extends ESTestCase {
protected T payload;
@Before
public void setUp() throws Exception {
super.setUp();
payload = payload();
}
protected abstract T payload();
protected abstract String expectedApi();
protected abstract Set<TaskType> expectedSupportedTaskTypes();
protected abstract SageMakerStoredServiceSchema randomApiServiceSettings();
protected abstract SageMakerStoredTaskSchema randomApiTaskSettings();
public final void testApi() {
assertThat(payload.api(), equalTo(expectedApi()));
}
public final void testSupportedTaskTypes() {
assertThat(payload.supportedTasks(), containsInAnyOrder(expectedSupportedTaskTypes().toArray()));
}
public final void testApiServiceSettings() throws IOException {
var validationException = new ValidationException();
var expectedApiServiceSettings = randomApiServiceSettings();
var actualApiServiceSettings = payload.apiServiceSettings(toMap(expectedApiServiceSettings), validationException);
assertThat(actualApiServiceSettings, equalTo(expectedApiServiceSettings));
validationException.throwIfValidationErrorsExist();
}
public final void testApiTaskSettings() throws IOException {
var validationException = new ValidationException();
var expectedApiTaskSettings = randomApiTaskSettings();
var actualApiTaskSettings = payload.apiTaskSettings(toMap(expectedApiTaskSettings), validationException);
assertThat(actualApiTaskSettings, equalTo(expectedApiTaskSettings));
validationException.throwIfValidationErrorsExist();
}
public void testNamedWriteables() {
var namedWriteables = payload.namedWriteables().map(entry -> entry.name).toList();
var filteredNames = Set.of(
SageMakerStoredServiceSchema.NO_OP.getWriteableName(),
SageMakerStoredTaskSchema.NO_OP.getWriteableName()
);
var expectedWriteables = Stream.of(randomApiServiceSettings(), randomApiTaskSettings())
.map(VersionedNamedWriteable::getWriteableName)
.filter(Predicate.not(filteredNames::contains))
.toArray();
assertThat(namedWriteables, containsInAnyOrder(expectedWriteables));
}
public final void testWithUnknownApiServiceSettings() {
SageMakerModel model = mock();
when(model.apiServiceSettings()).thenReturn(mock());
when(model.apiTaskSettings()).thenReturn(randomApiTaskSettings());
when(model.api()).thenReturn("serviceApi");
when(model.getTaskType()).thenReturn(TaskType.ANY);
var e = assertThrows(IllegalArgumentException.class, () -> payload.requestBytes(model, randomRequest()));
assertThat(e.getMessage(), startsWith("Unsupported SageMaker settings for api [serviceApi] and task type [any]:"));
}
public final void testWithUnknownApiTaskSettings() {
SageMakerModel model = mock();
when(model.apiServiceSettings()).thenReturn(randomApiServiceSettings());
when(model.apiTaskSettings()).thenReturn(mock());
when(model.api()).thenReturn("taskApi");
when(model.getTaskType()).thenReturn(TaskType.ANY);
var e = assertThrows(IllegalArgumentException.class, () -> payload.requestBytes(model, randomRequest()));
assertThat(e.getMessage(), startsWith("Unsupported SageMaker settings for api [taskApi] and task type [any]:"));
}
public final void testUpdate() throws IOException {
var taskSettings = randomApiTaskSettings();
if (taskSettings != SageMakerStoredTaskSchema.NO_OP) {
var otherTaskSettings = randomValueOtherThan(taskSettings, this::randomApiTaskSettings);
var updatedSettings = toMap(taskSettings.updatedTaskSettings(toMap(otherTaskSettings)));
var initialSettings = toMap(taskSettings);
var newSettings = toMap(otherTaskSettings);
newSettings.forEach((key, value) -> {
assertThat("Value should have been updated for key " + key, value, equalTo(updatedSettings.remove(key)));
});
initialSettings.forEach((key, value) -> {
if (updatedSettings.containsKey(key)) {
assertThat("Value should not have been updated for key " + key, value, equalTo(updatedSettings.remove(key)));
}
});
assertTrue("Map should be empty now that we verified all updated keys and all initial keys", updatedSettings.isEmpty());
}
if (payload instanceof SageMakerStoredTaskSchema taskSchema) {
var otherTaskSettings = randomValueOtherThan(randomApiTaskSettings(), this::randomApiTaskSettings);
var otherTaskSettingsAsMap = toMap(otherTaskSettings);
taskSchema.updatedTaskSettings(otherTaskSettingsAsMap);
}
}
protected static SageMakerInferenceRequest randomRequest() {
return new SageMakerInferenceRequest(
randomBoolean() ? randomAlphaOfLengthBetween(4, 8) : null,
randomOptionalBoolean(),
randomBoolean() ? randomInt() : null,
randomList(randomIntBetween(2, 4), () -> randomAlphaOfLengthBetween(2, 4)),
randomBoolean(),
randomFrom(InputType.values())
);
}
protected static void assertSdkBytes(SdkBytes sdkBytes, String expectedValue) {
assertThat(sdkBytes.asUtf8String(), equalTo(expectedValue));
}
}

View file

@ -0,0 +1,119 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiTextEmbeddingPayload;
import java.util.stream.Stream;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.empty;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyMap;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class SageMakerSchemasTests extends ESTestCase {
public static SageMakerSchemas mockSchemas() {
SageMakerSchemas schemas = mock();
var schema = mockSchema();
when(schemas.schemaFor(any(TaskType.class), anyString())).thenReturn(schema);
return schemas;
}
public static SageMakerSchema mockSchema() {
SageMakerSchema schema = mock();
when(schema.apiServiceSettings(anyMap(), any())).thenReturn(SageMakerStoredServiceSchema.NO_OP);
when(schema.apiTaskSettings(anyMap(), any())).thenReturn(SageMakerStoredTaskSchema.NO_OP);
return schema;
}
private static final SageMakerSchemas schemas = new SageMakerSchemas();
public void testSupportedTaskTypes() {
assertThat(schemas.supportedTaskTypes(), containsInAnyOrder(TaskType.TEXT_EMBEDDING));
}
public void testSupportedStreamingTasks() {
assertThat(schemas.supportedStreamingTasks(), empty());
}
public void testSchemaFor() {
var payloads = Stream.of(new OpenAiTextEmbeddingPayload());
payloads.forEach(payload -> {
payload.supportedTasks().forEach(taskType -> {
var model = mockModel(taskType, payload.api());
assertNotNull(schemas.schemaFor(model));
});
});
}
public void testStreamSchemaFor() {
var payloads = Stream.<SageMakerStreamSchemaPayload>of(/* For when we add support for streaming payloads */);
payloads.forEach(payload -> {
payload.supportedTasks().forEach(taskType -> {
var model = mockModel(taskType, payload.api());
assertNotNull(schemas.streamSchemaFor(model));
});
});
}
private SageMakerModel mockModel(TaskType taskType, String api) {
SageMakerModel model = mock();
when(model.getTaskType()).thenReturn(taskType);
when(model.api()).thenReturn(api);
return model;
}
public void testMissingTaskTypeThrowsException() {
var knownPayload = new OpenAiTextEmbeddingPayload();
var unknownTaskType = TaskType.COMPLETION;
var knownModel = mockModel(unknownTaskType, knownPayload.api());
assertThrows(
"Task [completion] is not compatible for service [sagemaker] and api [openai]. Supported tasks: [text_embedding]",
ElasticsearchStatusException.class,
() -> schemas.schemaFor(knownModel)
);
}
public void testMissingSchemaThrowsException() {
var unknownModel = mockModel(TaskType.ANY, "blah");
assertThrows(
"Task [any] is not compatible for service [sagemaker] and api [blah]. Supported tasks: []",
ElasticsearchStatusException.class,
() -> schemas.schemaFor(unknownModel)
);
}
public void testMissingStreamSchemaThrowsException() {
var unknownModel = mockModel(TaskType.ANY, "blah");
assertThrows(
"Streaming is not allowed for service [sagemaker], api [blah], and task [any]. Supported streaming tasks: []",
ElasticsearchStatusException.class,
() -> schemas.streamSchemaFor(unknownModel)
);
}
public void testNamedWriteables() {
var namedWriteables = Stream.of(new OpenAiTextEmbeddingPayload().namedWriteables());
var expectedNamedWriteables = Stream.concat(
namedWriteables.flatMap(names -> names.map(entry -> entry.name)),
Stream.of(SageMakerStoredServiceSchema.NO_OP.getWriteableName(), SageMakerStoredTaskSchema.NO_OP.getWriteableName())
).distinct().toArray();
var actualRegisteredNames = SageMakerSchemas.namedWriteables().stream().map(entry -> entry.name).toList();
assertThat(actualRegisteredNames, containsInAnyOrder(expectedNamedWriteables));
}
}

View file

@ -0,0 +1,155 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker.schema.openai;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemaPayloadTestCase;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Set;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class OpenAiTextEmbeddingPayloadTests extends SageMakerSchemaPayloadTestCase<OpenAiTextEmbeddingPayload> {
@Override
protected OpenAiTextEmbeddingPayload payload() {
return new OpenAiTextEmbeddingPayload();
}
@Override
protected String expectedApi() {
return "openai";
}
@Override
protected Set<TaskType> expectedSupportedTaskTypes() {
return Set.of(TaskType.TEXT_EMBEDDING);
}
@Override
protected SageMakerStoredServiceSchema randomApiServiceSettings() {
return SageMakerOpenAiServiceSettingsTests.randomApiServiceSettings();
}
@Override
protected SageMakerStoredTaskSchema randomApiTaskSettings() {
return SageMakerOpenAiTaskSettingsTests.randomApiTaskSettings();
}
public void testAccept() {
assertThat(payload.accept(mock()), equalTo("application/json"));
}
public void testContentType() {
assertThat(payload.contentType(mock()), equalTo("application/json"));
}
public void testRequestWithSingleInput() throws Exception {
SageMakerModel model = mock();
when(model.apiServiceSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiServiceSettings(null, false));
when(model.apiTaskSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiTaskSettings((String) null));
var request = new SageMakerInferenceRequest(null, null, null, List.of("hello"), randomBoolean(), randomFrom(InputType.values()));
var sdkByes = payload.requestBytes(model, request);
assertSdkBytes(sdkByes, """
{"input":"hello"}""");
}
public void testRequestWithArrayInput() throws Exception {
SageMakerModel model = mock();
when(model.apiServiceSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiServiceSettings(null, false));
when(model.apiTaskSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiTaskSettings((String) null));
var request = new SageMakerInferenceRequest(
null,
null,
null,
List.of("hello", "there"),
randomBoolean(),
randomFrom(InputType.values())
);
var sdkByes = payload.requestBytes(model, request);
assertSdkBytes(sdkByes, """
{"input":["hello","there"]}""");
}
public void testRequestWithDimensionsNotSetByUserIgnoreDimensions() throws Exception {
SageMakerModel model = mock();
when(model.apiServiceSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiServiceSettings(123, false));
when(model.apiTaskSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiTaskSettings((String) null));
var request = new SageMakerInferenceRequest(
null,
null,
null,
List.of("hello", "there"),
randomBoolean(),
randomFrom(InputType.values())
);
var sdkByes = payload.requestBytes(model, request);
assertSdkBytes(sdkByes, """
{"input":["hello","there"]}""");
}
public void testRequestWithOptionals() throws Exception {
SageMakerModel model = mock();
when(model.apiServiceSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiServiceSettings(1234, true));
when(model.apiTaskSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiTaskSettings("user"));
var request = new SageMakerInferenceRequest("query", null, null, List.of("hello"), randomBoolean(), randomFrom(InputType.values()));
var sdkByes = payload.requestBytes(model, request);
assertSdkBytes(sdkByes, """
{"query":"query","input":"hello","user":"user","dimensions":1234}""");
}
public void testResponse() throws Exception {
String responseJson = """
{
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": [
0.014539449,
-0.015288644
]
}
],
"model": "text-embedding-ada-002-v2",
"usage": {
"prompt_tokens": 8,
"total_tokens": 8
}
}
""";
var invokeEndpointResponse = InvokeEndpointResponse.builder()
.body(SdkBytes.fromString(responseJson, StandardCharsets.UTF_8))
.build();
var textEmbeddingFloatResults = payload.responseBody(mock(), invokeEndpointResponse);
assertThat(
textEmbeddingFloatResults.embeddings(),
is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))
);
}
}

View file

@ -0,0 +1,51 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker.schema.openai;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase;
import java.util.Map;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.sameInstance;
public class SageMakerOpenAiServiceSettingsTests extends InferenceSettingsTestCase<OpenAiTextEmbeddingPayload.ApiServiceSettings> {
@Override
protected OpenAiTextEmbeddingPayload.ApiServiceSettings fromMutableMap(Map<String, Object> mutableMap) {
var validationException = new ValidationException();
var settings = OpenAiTextEmbeddingPayload.ApiServiceSettings.fromMap(mutableMap, validationException);
validationException.throwIfValidationErrorsExist();
return settings;
}
@Override
protected Writeable.Reader<OpenAiTextEmbeddingPayload.ApiServiceSettings> instanceReader() {
return OpenAiTextEmbeddingPayload.ApiServiceSettings::new;
}
@Override
protected OpenAiTextEmbeddingPayload.ApiServiceSettings createTestInstance() {
return randomApiServiceSettings();
}
static OpenAiTextEmbeddingPayload.ApiServiceSettings randomApiServiceSettings() {
var dimensions = randomBoolean() ? randomIntBetween(1, 100) : null;
return new OpenAiTextEmbeddingPayload.ApiServiceSettings(dimensions, dimensions != null);
}
public void testDimensionsSetByUser() {
var expectedDimensions = randomIntBetween(1, 100);
var dimensionlessSettings = new OpenAiTextEmbeddingPayload.ApiServiceSettings(null, false);
var updatedSettings = dimensionlessSettings.updateModelWithEmbeddingDetails(expectedDimensions);
assertThat(updatedSettings, not(sameInstance(dimensionlessSettings)));
assertThat(updatedSettings.dimensions(), equalTo(expectedDimensions));
}
}

View file

@ -0,0 +1,38 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.sagemaker.schema.openai;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase;
import java.util.Map;
public class SageMakerOpenAiTaskSettingsTests extends InferenceSettingsTestCase<OpenAiTextEmbeddingPayload.ApiTaskSettings> {
@Override
protected OpenAiTextEmbeddingPayload.ApiTaskSettings fromMutableMap(Map<String, Object> mutableMap) {
var validationException = new ValidationException();
var settings = OpenAiTextEmbeddingPayload.ApiTaskSettings.fromMap(mutableMap, validationException);
validationException.throwIfValidationErrorsExist();
return settings;
}
@Override
protected Writeable.Reader<OpenAiTextEmbeddingPayload.ApiTaskSettings> instanceReader() {
return OpenAiTextEmbeddingPayload.ApiTaskSettings::new;
}
@Override
protected OpenAiTextEmbeddingPayload.ApiTaskSettings createTestInstance() {
return randomApiTaskSettings();
}
static OpenAiTextEmbeddingPayload.ApiTaskSettings randomApiTaskSettings() {
return new OpenAiTextEmbeddingPayload.ApiTaskSettings(randomOptionalString());
}
}