mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 09:28:55 -04:00
[ML] Integrate SageMaker with OpenAI Embeddings (#126856)
Integrating with SageMaker. Current design: - SageMaker accepts any byte payload, which can be text, csv, or json. `api` represents the structure of the payload that we will send, for example `openai`, `elastic`, `common`, probably `cohere` or `huggingface` as well. - `api` implementations are extensions of `SageMakerSchemaPayload`, which supports: - "extra" service and task settings specific to the payload structure, so `cohere` would require `embedding_type` and `openai` would require `dimensions` in the `service_settings` - conversion logic from model, service settings, task settings, and input to `SdkBytes` - conversion logic from responding `SdkBytes` to `InferenceServiceResults` - Everything else is tunneling, there are a number of base `service_settings` and `task_settings` that are independent of the api format that we will store and set - We let the SDK do the bulk of the work in terms of connection details, rate limiting, retries, etc.
This commit is contained in:
parent
1b35cceacf
commit
245f5eebce
39 changed files with 4304 additions and 139 deletions
5
docs/changelog/126856.yaml
Normal file
5
docs/changelog/126856.yaml
Normal file
|
@ -0,0 +1,5 @@
|
|||
pr: 126856
|
||||
summary: "[ML] Integrate SageMaker with OpenAI Embeddings"
|
||||
area: Machine Learning
|
||||
type: enhancement
|
||||
issues: []
|
|
@ -4912,6 +4912,11 @@
|
|||
<sha256 value="c83dd82a9d82ff8c7d2eb1bdb2ae9f9505b312dad9a6bf0b80bc0136653a3a24" origin="Generated by Gradle"/>
|
||||
</artifact>
|
||||
</component>
|
||||
<component group="software.amazon.awssdk" name="sagemakerruntime" version="2.30.38">
|
||||
<artifact name="sagemakerruntime-2.30.38.jar">
|
||||
<sha256 value="b26ee73fa06d047eab9a174e49627972e646c0bbe909f479c18dbff193b561f5" origin="Generated by Gradle"/>
|
||||
</artifact>
|
||||
</component>
|
||||
<component group="software.amazon.awssdk" name="sdk-core" version="2.30.38">
|
||||
<artifact name="sdk-core-2.30.38.jar">
|
||||
<sha256 value="556463b8c353408d93feab74719d141fcfda7fd3d7b7d1ad3a8a548b7cc2982d" origin="Generated by Gradle"/>
|
||||
|
|
|
@ -162,6 +162,7 @@ public class TransportVersions {
|
|||
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17);
|
||||
public static final TransportVersion BATCHED_QUERY_PHASE_VERSION_BACKPORT_8_X = def(8_841_0_19);
|
||||
public static final TransportVersion SEARCH_INCREMENTAL_TOP_DOCS_NULL_BACKPORT_8_19 = def(8_841_0_20);
|
||||
public static final TransportVersion ML_INFERENCE_SAGEMAKER_8_19 = def(8_841_0_21);
|
||||
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
|
||||
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
|
||||
public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED = def(9_001_0_00);
|
||||
|
@ -232,6 +233,7 @@ public class TransportVersions {
|
|||
public static final TransportVersion PROJECT_METADATA_SETTINGS = def(9_066_00_0);
|
||||
public static final TransportVersion AGGREGATE_METRIC_DOUBLE_BLOCK = def(9_067_00_0);
|
||||
public static final TransportVersion PINNED_RETRIEVER = def(9_068_0_00);
|
||||
public static final TransportVersion ML_INFERENCE_SAGEMAKER = def(9_069_0_00);
|
||||
|
||||
/*
|
||||
* STOP! READ THIS FIRST! No, really,
|
||||
|
|
|
@ -53,6 +53,12 @@ public class ValidationException extends IllegalArgumentException {
|
|||
return validationErrors;
|
||||
}
|
||||
|
||||
public final void throwIfValidationErrorsExist() {
|
||||
if (validationErrors().isEmpty() == false) {
|
||||
throw this;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public final String getMessage() {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
|
|
|
@ -62,6 +62,7 @@ dependencies {
|
|||
|
||||
/* AWS SDK v2 */
|
||||
implementation("software.amazon.awssdk:bedrockruntime:${versions.awsv2sdk}")
|
||||
implementation("software.amazon.awssdk:sagemakerruntime:${versions.awsv2sdk}")
|
||||
api "software.amazon.awssdk:protocol-core:${versions.awsv2sdk}"
|
||||
api "software.amazon.awssdk:aws-json-protocol:${versions.awsv2sdk}"
|
||||
api "software.amazon.awssdk:third-party-jackson-core:${versions.awsv2sdk}"
|
||||
|
@ -142,6 +143,7 @@ tasks.named("dependencyLicenses").configure {
|
|||
mapping from: /json-utils.*/, to: 'aws-sdk-2'
|
||||
mapping from: /endpoints-spi.*/, to: 'aws-sdk-2'
|
||||
mapping from: /bedrockruntime.*/, to: 'aws-sdk-2'
|
||||
mapping from: /sagemakerruntime.*/, to: 'aws-sdk-2'
|
||||
mapping from: /netty-nio-client/, to: 'aws-sdk-2'
|
||||
/* Cannot use REGEX to match netty-* because netty-nio-client is an AWS package */
|
||||
mapping from: /netty-buffer/, to: 'netty'
|
||||
|
|
|
@ -18,163 +18,161 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated;
|
||||
import static org.hamcrest.Matchers.containsInAnyOrder;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testGetServicesWithoutTaskType() throws IOException {
|
||||
List<Object> services = getAllServices();
|
||||
assertThat(services.size(), equalTo(21));
|
||||
assertThat(services.size(), equalTo(22));
|
||||
|
||||
String[] providers = new String[services.size()];
|
||||
for (int i = 0; i < services.size(); i++) {
|
||||
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
|
||||
providers[i] = (String) serviceConfig.get("service");
|
||||
}
|
||||
var providers = providers(services);
|
||||
|
||||
assertArrayEquals(
|
||||
List.of(
|
||||
"alibabacloud-ai-search",
|
||||
"amazonbedrock",
|
||||
"anthropic",
|
||||
"azureaistudio",
|
||||
"azureopenai",
|
||||
"cohere",
|
||||
"deepseek",
|
||||
"elastic",
|
||||
"elasticsearch",
|
||||
"googleaistudio",
|
||||
"googlevertexai",
|
||||
"hugging_face",
|
||||
"jinaai",
|
||||
"mistral",
|
||||
"openai",
|
||||
"streaming_completion_test_service",
|
||||
"test_reranking_service",
|
||||
"test_service",
|
||||
"text_embedding_test_service",
|
||||
"voyageai",
|
||||
"watsonxai"
|
||||
).toArray(),
|
||||
providers
|
||||
assertThat(
|
||||
providers,
|
||||
containsInAnyOrder(
|
||||
List.of(
|
||||
"alibabacloud-ai-search",
|
||||
"amazonbedrock",
|
||||
"anthropic",
|
||||
"azureaistudio",
|
||||
"azureopenai",
|
||||
"cohere",
|
||||
"deepseek",
|
||||
"elastic",
|
||||
"elasticsearch",
|
||||
"googleaistudio",
|
||||
"googlevertexai",
|
||||
"hugging_face",
|
||||
"jinaai",
|
||||
"mistral",
|
||||
"openai",
|
||||
"streaming_completion_test_service",
|
||||
"test_reranking_service",
|
||||
"test_service",
|
||||
"text_embedding_test_service",
|
||||
"voyageai",
|
||||
"watsonxai",
|
||||
"sagemaker"
|
||||
).toArray()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private Iterable<String> providers(List<Object> services) {
|
||||
return services.stream().map(service -> {
|
||||
var serviceConfig = (Map<String, Object>) service;
|
||||
return (String) serviceConfig.get("service");
|
||||
}).toList();
|
||||
}
|
||||
|
||||
public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
|
||||
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
|
||||
assertThat(services.size(), equalTo(15));
|
||||
assertThat(services.size(), equalTo(16));
|
||||
|
||||
String[] providers = new String[services.size()];
|
||||
for (int i = 0; i < services.size(); i++) {
|
||||
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
|
||||
providers[i] = (String) serviceConfig.get("service");
|
||||
}
|
||||
var providers = providers(services);
|
||||
|
||||
assertArrayEquals(
|
||||
List.of(
|
||||
"alibabacloud-ai-search",
|
||||
"amazonbedrock",
|
||||
"azureaistudio",
|
||||
"azureopenai",
|
||||
"cohere",
|
||||
"elasticsearch",
|
||||
"googleaistudio",
|
||||
"googlevertexai",
|
||||
"hugging_face",
|
||||
"jinaai",
|
||||
"mistral",
|
||||
"openai",
|
||||
"text_embedding_test_service",
|
||||
"voyageai",
|
||||
"watsonxai"
|
||||
).toArray(),
|
||||
providers
|
||||
assertThat(
|
||||
providers,
|
||||
containsInAnyOrder(
|
||||
List.of(
|
||||
"alibabacloud-ai-search",
|
||||
"amazonbedrock",
|
||||
"azureaistudio",
|
||||
"azureopenai",
|
||||
"cohere",
|
||||
"elasticsearch",
|
||||
"googleaistudio",
|
||||
"googlevertexai",
|
||||
"hugging_face",
|
||||
"jinaai",
|
||||
"mistral",
|
||||
"openai",
|
||||
"text_embedding_test_service",
|
||||
"voyageai",
|
||||
"watsonxai",
|
||||
"sagemaker"
|
||||
).toArray()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testGetServicesWithRerankTaskType() throws IOException {
|
||||
List<Object> services = getServices(TaskType.RERANK);
|
||||
assertThat(services.size(), equalTo(7));
|
||||
|
||||
String[] providers = new String[services.size()];
|
||||
for (int i = 0; i < services.size(); i++) {
|
||||
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
|
||||
providers[i] = (String) serviceConfig.get("service");
|
||||
}
|
||||
var providers = providers(services);
|
||||
|
||||
assertArrayEquals(
|
||||
List.of("alibabacloud-ai-search", "cohere", "elasticsearch", "googlevertexai", "jinaai", "test_reranking_service", "voyageai")
|
||||
.toArray(),
|
||||
providers
|
||||
assertThat(
|
||||
providers,
|
||||
containsInAnyOrder(
|
||||
List.of(
|
||||
"alibabacloud-ai-search",
|
||||
"cohere",
|
||||
"elasticsearch",
|
||||
"googlevertexai",
|
||||
"jinaai",
|
||||
"test_reranking_service",
|
||||
"voyageai"
|
||||
).toArray()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testGetServicesWithCompletionTaskType() throws IOException {
|
||||
List<Object> services = getServices(TaskType.COMPLETION);
|
||||
assertThat(services.size(), equalTo(10));
|
||||
|
||||
String[] providers = new String[services.size()];
|
||||
for (int i = 0; i < services.size(); i++) {
|
||||
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
|
||||
providers[i] = (String) serviceConfig.get("service");
|
||||
}
|
||||
var providers = providers(services);
|
||||
|
||||
assertArrayEquals(
|
||||
List.of(
|
||||
"alibabacloud-ai-search",
|
||||
"amazonbedrock",
|
||||
"anthropic",
|
||||
"azureaistudio",
|
||||
"azureopenai",
|
||||
"cohere",
|
||||
"deepseek",
|
||||
"googleaistudio",
|
||||
"openai",
|
||||
"streaming_completion_test_service"
|
||||
).toArray(),
|
||||
providers
|
||||
assertThat(
|
||||
providers,
|
||||
containsInAnyOrder(
|
||||
List.of(
|
||||
"alibabacloud-ai-search",
|
||||
"amazonbedrock",
|
||||
"anthropic",
|
||||
"azureaistudio",
|
||||
"azureopenai",
|
||||
"cohere",
|
||||
"deepseek",
|
||||
"googleaistudio",
|
||||
"openai",
|
||||
"streaming_completion_test_service"
|
||||
).toArray()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testGetServicesWithChatCompletionTaskType() throws IOException {
|
||||
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
|
||||
assertThat(services.size(), equalTo(4));
|
||||
|
||||
String[] providers = new String[services.size()];
|
||||
for (int i = 0; i < services.size(); i++) {
|
||||
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
|
||||
providers[i] = (String) serviceConfig.get("service");
|
||||
}
|
||||
var providers = providers(services);
|
||||
|
||||
assertArrayEquals(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray(), providers);
|
||||
assertThat(providers, containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray()));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
|
||||
List<Object> services = getServices(TaskType.SPARSE_EMBEDDING);
|
||||
assertThat(services.size(), equalTo(6));
|
||||
|
||||
String[] providers = new String[services.size()];
|
||||
for (int i = 0; i < services.size(); i++) {
|
||||
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
|
||||
providers[i] = (String) serviceConfig.get("service");
|
||||
}
|
||||
var providers = providers(services);
|
||||
|
||||
assertArrayEquals(
|
||||
List.of(
|
||||
"alibabacloud-ai-search",
|
||||
"elastic",
|
||||
"elasticsearch",
|
||||
"hugging_face",
|
||||
"streaming_completion_test_service",
|
||||
"test_service"
|
||||
).toArray(),
|
||||
providers
|
||||
assertThat(
|
||||
providers,
|
||||
containsInAnyOrder(
|
||||
List.of(
|
||||
"alibabacloud-ai-search",
|
||||
"elastic",
|
||||
"elasticsearch",
|
||||
"hugging_face",
|
||||
"streaming_completion_test_service",
|
||||
"test_service"
|
||||
).toArray()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -36,6 +36,7 @@ module org.elasticsearch.inference {
|
|||
requires org.elasticsearch.logging;
|
||||
requires org.elasticsearch.sslconfig;
|
||||
requires org.apache.commons.text;
|
||||
requires software.amazon.awssdk.services.sagemakerruntime;
|
||||
|
||||
exports org.elasticsearch.xpack.inference.action;
|
||||
exports org.elasticsearch.xpack.inference.registry;
|
||||
|
|
|
@ -92,6 +92,8 @@ import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCo
|
|||
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
|
||||
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
|
||||
|
@ -157,6 +159,8 @@ public class InferenceNamedWriteablesProvider {
|
|||
|
||||
namedWriteables.addAll(StreamingTaskManager.namedWriteables());
|
||||
namedWriteables.addAll(DeepSeekChatCompletionModel.namedWriteables());
|
||||
namedWriteables.addAll(SageMakerModel.namedWriteables());
|
||||
namedWriteables.addAll(SageMakerSchemas.namedWriteables());
|
||||
|
||||
return namedWriteables;
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ import org.elasticsearch.common.settings.IndexScopedSettings;
|
|||
import org.elasticsearch.common.settings.Setting;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.settings.SettingsFilter;
|
||||
import org.elasticsearch.common.util.LazyInitializable;
|
||||
import org.elasticsearch.core.IOUtils;
|
||||
import org.elasticsearch.core.TimeValue;
|
||||
import org.elasticsearch.features.NodeFeature;
|
||||
|
@ -132,6 +133,11 @@ import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService;
|
|||
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
|
||||
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerService;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerConfiguration;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
|
||||
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIService;
|
||||
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
|
||||
|
||||
|
@ -294,6 +300,8 @@ public class InferencePlugin extends Plugin
|
|||
services.threadPool()
|
||||
);
|
||||
|
||||
var sageMakerSchemas = new SageMakerSchemas();
|
||||
var sageMakerConfigurations = new LazyInitializable<>(new SageMakerConfiguration(sageMakerSchemas));
|
||||
inferenceServices.add(
|
||||
() -> List.of(
|
||||
context -> new ElasticInferenceService(
|
||||
|
@ -302,6 +310,16 @@ public class InferencePlugin extends Plugin
|
|||
inferenceServiceSettings,
|
||||
modelRegistry.get(),
|
||||
authorizationHandler
|
||||
),
|
||||
context -> new SageMakerService(
|
||||
new SageMakerModelBuilder(sageMakerSchemas),
|
||||
new SageMakerClient(
|
||||
new SageMakerClient.Factory(new HttpSettings(settings, services.clusterService())),
|
||||
services.threadPool()
|
||||
),
|
||||
sageMakerSchemas,
|
||||
services.threadPool(),
|
||||
sageMakerConfigurations::getOrCompute
|
||||
)
|
||||
)
|
||||
);
|
||||
|
|
|
@ -23,11 +23,12 @@ import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
|
|||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.EnumSet;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredSecureString;
|
||||
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.ACCESS_KEY_FIELD;
|
||||
|
@ -134,33 +135,39 @@ public class AwsSecretSettings implements SecretSettings {
|
|||
}
|
||||
|
||||
private static final LazyInitializable<Map<String, SettingsConfiguration>, RuntimeException> configuration =
|
||||
new LazyInitializable<>(() -> {
|
||||
var configurationMap = new HashMap<String, SettingsConfiguration>();
|
||||
configurationMap.put(
|
||||
ACCESS_KEY_FIELD,
|
||||
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)).setDescription(
|
||||
"A valid AWS access key that has permissions to use Amazon Bedrock."
|
||||
)
|
||||
.setLabel("Access Key")
|
||||
.setRequired(true)
|
||||
.setSensitive(true)
|
||||
.setUpdatable(true)
|
||||
.setType(SettingsConfigurationFieldType.STRING)
|
||||
.build()
|
||||
);
|
||||
configurationMap.put(
|
||||
SECRET_KEY_FIELD,
|
||||
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)).setDescription(
|
||||
"A valid AWS secret key that is paired with the access_key."
|
||||
)
|
||||
.setLabel("Secret Key")
|
||||
.setRequired(true)
|
||||
.setSensitive(true)
|
||||
.setUpdatable(true)
|
||||
.setType(SettingsConfigurationFieldType.STRING)
|
||||
.build()
|
||||
);
|
||||
return Collections.unmodifiableMap(configurationMap);
|
||||
});
|
||||
new LazyInitializable<>(
|
||||
() -> configuration(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)).collect(
|
||||
Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
public static Stream<Map.Entry<String, SettingsConfiguration>> configuration(EnumSet<TaskType> supportedTaskTypes) {
|
||||
return Stream.of(
|
||||
Map.entry(
|
||||
ACCESS_KEY_FIELD,
|
||||
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
|
||||
"A valid AWS access key that has permissions to use Amazon Bedrock."
|
||||
)
|
||||
.setLabel("Access Key")
|
||||
.setRequired(true)
|
||||
.setSensitive(true)
|
||||
.setUpdatable(true)
|
||||
.setType(SettingsConfigurationFieldType.STRING)
|
||||
.build()
|
||||
),
|
||||
Map.entry(
|
||||
SECRET_KEY_FIELD,
|
||||
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
|
||||
"A valid AWS secret key that is paired with the access_key."
|
||||
)
|
||||
.setLabel("Secret Key")
|
||||
.setRequired(true)
|
||||
.setSensitive(true)
|
||||
.setUpdatable(true)
|
||||
.setType(SettingsConfigurationFieldType.STRING)
|
||||
.build()
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.common.unit.ByteSizeUnit;
|
|||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.core.TimeValue;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
|
@ -55,6 +56,10 @@ public class HttpSettings {
|
|||
return connectionTimeout;
|
||||
}
|
||||
|
||||
public Duration connectionTimeoutDuration() {
|
||||
return Duration.ofMillis(connectionTimeout);
|
||||
}
|
||||
|
||||
private void setMaxResponseSize(ByteSizeValue maxResponseSize) {
|
||||
this.maxResponseSize = maxResponseSize;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,251 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.sagemaker;
|
||||
|
||||
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
|
||||
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
|
||||
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
|
||||
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
|
||||
import software.amazon.awssdk.profiles.ProfileFile;
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.SageMakerRuntimeAsyncClient;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointRequest;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamRequest;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponse;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponseHandler;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.ResponseStream;
|
||||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.ExceptionsHelper;
|
||||
import org.elasticsearch.SpecialPermission;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.support.ContextPreservingActionListener;
|
||||
import org.elasticsearch.action.support.ListenerTimeouts;
|
||||
import org.elasticsearch.common.cache.Cache;
|
||||
import org.elasticsearch.common.cache.CacheBuilder;
|
||||
import org.elasticsearch.common.cache.CacheLoader;
|
||||
import org.elasticsearch.common.util.concurrent.FutureUtils;
|
||||
import org.elasticsearch.core.TimeValue;
|
||||
import org.elasticsearch.core.Tuple;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpSettings;
|
||||
import org.reactivestreams.FlowAdapters;
|
||||
|
||||
import java.io.Closeable;
|
||||
import java.security.AccessController;
|
||||
import java.security.PrivilegedExceptionAction;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.CompletionException;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.Flow;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;
|
||||
|
||||
public class SageMakerClient implements Closeable {
|
||||
private static final Logger log = LogManager.getLogger(SageMakerClient.class);
|
||||
private final Cache<RegionAndSecrets, SageMakerRuntimeAsyncClient> existingClients = CacheBuilder.<
|
||||
RegionAndSecrets,
|
||||
SageMakerRuntimeAsyncClient>builder()
|
||||
.removalListener(removal -> removal.getValue().close())
|
||||
.setExpireAfterAccess(TimeValue.timeValueMinutes(15))
|
||||
.build();
|
||||
|
||||
private final CacheLoader<RegionAndSecrets, SageMakerRuntimeAsyncClient> clientFactory;
|
||||
private final ThreadPool threadPool;
|
||||
|
||||
public SageMakerClient(CacheLoader<RegionAndSecrets, SageMakerRuntimeAsyncClient> clientFactory, ThreadPool threadPool) {
|
||||
this.clientFactory = clientFactory;
|
||||
this.threadPool = threadPool;
|
||||
}
|
||||
|
||||
public void invoke(
|
||||
RegionAndSecrets regionAndSecrets,
|
||||
InvokeEndpointRequest request,
|
||||
TimeValue timeout,
|
||||
ActionListener<InvokeEndpointResponse> listener
|
||||
) {
|
||||
SageMakerRuntimeAsyncClient asyncClient;
|
||||
try {
|
||||
asyncClient = existingClients.computeIfAbsent(regionAndSecrets, clientFactory);
|
||||
} catch (ExecutionException e) {
|
||||
listener.onFailure(clientFailure(regionAndSecrets, e));
|
||||
return;
|
||||
}
|
||||
|
||||
var contextPreservingListener = new ContextPreservingActionListener<>(
|
||||
threadPool.getThreadContext().newRestorableContext(false),
|
||||
listener
|
||||
);
|
||||
|
||||
var awsFuture = asyncClient.invokeEndpoint(request);
|
||||
var timeoutListener = ListenerTimeouts.wrapWithTimeout(
|
||||
threadPool,
|
||||
timeout,
|
||||
threadPool.executor(UTILITY_THREAD_POOL_NAME),
|
||||
contextPreservingListener,
|
||||
ignored -> {
|
||||
FutureUtils.cancel(awsFuture);
|
||||
contextPreservingListener.onFailure(
|
||||
new ElasticsearchStatusException("Request timed out after [{}]", RestStatus.REQUEST_TIMEOUT, timeout)
|
||||
);
|
||||
}
|
||||
);
|
||||
awsFuture.thenAcceptAsync(timeoutListener::onResponse, threadPool.executor(UTILITY_THREAD_POOL_NAME))
|
||||
.exceptionallyAsync(t -> failAndMaybeThrowError(t, timeoutListener), threadPool.executor(UTILITY_THREAD_POOL_NAME));
|
||||
}
|
||||
|
||||
private static Exception clientFailure(RegionAndSecrets regionAndSecrets, Exception cause) {
|
||||
return new ElasticsearchStatusException(
|
||||
"failed to create SageMakerRuntime client for region [{}]",
|
||||
RestStatus.INTERNAL_SERVER_ERROR,
|
||||
cause,
|
||||
regionAndSecrets.region()
|
||||
);
|
||||
}
|
||||
|
||||
private Void failAndMaybeThrowError(Throwable t, ActionListener<?> listener) {
|
||||
if (t instanceof CompletionException ce) {
|
||||
t = ce.getCause();
|
||||
}
|
||||
if (t instanceof Exception e) {
|
||||
listener.onFailure(e);
|
||||
} else {
|
||||
ExceptionsHelper.maybeError(t).ifPresent(ExceptionsHelper::maybeDieOnAnotherThread);
|
||||
log.atWarn().withThrowable(t).log("Unknown failure calling SageMaker.");
|
||||
listener.onFailure(new RuntimeException("Unknown failure calling SageMaker.", t));
|
||||
}
|
||||
return null; // Void
|
||||
}
|
||||
|
||||
public void invokeStream(
|
||||
RegionAndSecrets regionAndSecrets,
|
||||
InvokeEndpointWithResponseStreamRequest request,
|
||||
TimeValue timeout,
|
||||
ActionListener<SageMakerStream> listener
|
||||
) {
|
||||
SageMakerRuntimeAsyncClient asyncClient;
|
||||
try {
|
||||
asyncClient = existingClients.computeIfAbsent(regionAndSecrets, clientFactory);
|
||||
} catch (ExecutionException e) {
|
||||
listener.onFailure(clientFailure(regionAndSecrets, e));
|
||||
return;
|
||||
}
|
||||
|
||||
var contextPreservingListener = new ContextPreservingActionListener<>(
|
||||
threadPool.getThreadContext().newRestorableContext(false),
|
||||
listener
|
||||
);
|
||||
|
||||
var responseStreamProcessor = new SageMakerStreamingResponseProcessor();
|
||||
var cancelAwsRequestListener = new AtomicReference<CompletableFuture<?>>();
|
||||
var timeoutListener = ListenerTimeouts.wrapWithTimeout(
|
||||
threadPool,
|
||||
timeout,
|
||||
threadPool.executor(UTILITY_THREAD_POOL_NAME),
|
||||
contextPreservingListener,
|
||||
ignored -> {
|
||||
FutureUtils.cancel(cancelAwsRequestListener.get());
|
||||
contextPreservingListener.onFailure(
|
||||
new ElasticsearchStatusException("Request timed out after [{}]", RestStatus.REQUEST_TIMEOUT, timeout)
|
||||
);
|
||||
}
|
||||
);
|
||||
// To stay consistent with HTTP providers, we cancel the TimeoutListener onResponse because we are measuring the time it takes to
|
||||
// start receiving bytes.
|
||||
var responseStreamListener = InvokeEndpointWithResponseStreamResponseHandler.builder()
|
||||
.onResponse(response -> timeoutListener.onResponse(new SageMakerStream(response, responseStreamProcessor)))
|
||||
.onEventStream(publisher -> responseStreamProcessor.setPublisher(FlowAdapters.toFlowPublisher(publisher)))
|
||||
.build();
|
||||
var awsFuture = asyncClient.invokeEndpointWithResponseStream(request, responseStreamListener);
|
||||
cancelAwsRequestListener.set(awsFuture);
|
||||
awsFuture.exceptionallyAsync(t -> failAndMaybeThrowError(t, timeoutListener), threadPool.executor(UTILITY_THREAD_POOL_NAME));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
existingClients.invalidateAll(); // will close each cached client
|
||||
}
|
||||
|
||||
public record RegionAndSecrets(String region, AwsSecretSettings secretSettings) {}
|
||||
|
||||
public static class Factory implements CacheLoader<RegionAndSecrets, SageMakerRuntimeAsyncClient> {
|
||||
private final HttpSettings httpSettings;
|
||||
|
||||
public Factory(HttpSettings httpSettings) {
|
||||
this.httpSettings = httpSettings;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SageMakerRuntimeAsyncClient load(RegionAndSecrets key) throws Exception {
|
||||
SpecialPermission.check();
|
||||
// TODO migrate to entitlements
|
||||
return AccessController.doPrivileged((PrivilegedExceptionAction<SageMakerRuntimeAsyncClient>) () -> {
|
||||
var credentials = AwsBasicCredentials.create(
|
||||
key.secretSettings().accessKey().toString(),
|
||||
key.secretSettings().secretKey().toString()
|
||||
);
|
||||
var credentialsProvider = StaticCredentialsProvider.create(credentials);
|
||||
var clientConfig = NettyNioAsyncHttpClient.builder().connectionTimeout(httpSettings.connectionTimeoutDuration());
|
||||
var override = ClientOverrideConfiguration.builder()
|
||||
// disable profileFile, user credentials will always come from the configured Model Secrets
|
||||
.defaultProfileFileSupplier(ProfileFile.aggregator()::build)
|
||||
.defaultProfileFile(ProfileFile.aggregator().build())
|
||||
.retryPolicy(retryPolicy -> retryPolicy.numRetries(3))
|
||||
.retryStrategy(retryStrategy -> retryStrategy.maxAttempts(3))
|
||||
.build();
|
||||
return SageMakerRuntimeAsyncClient.builder()
|
||||
.credentialsProvider(credentialsProvider)
|
||||
.region(Region.of(key.region()))
|
||||
.httpClientBuilder(clientConfig)
|
||||
.overrideConfiguration(override)
|
||||
.build();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private static class SageMakerStreamingResponseProcessor implements Flow.Publisher<ResponseStream> {
|
||||
private static final Logger log = LogManager.getLogger(SageMakerStreamingResponseProcessor.class);
|
||||
private final AtomicReference<Tuple<Flow.Publisher<ResponseStream>, Flow.Subscriber<? super ResponseStream>>> holder =
|
||||
new AtomicReference<>(null);
|
||||
private final AtomicBoolean subscribeCalledOnce = new AtomicBoolean(false);
|
||||
|
||||
@Override
|
||||
public void subscribe(Flow.Subscriber<? super ResponseStream> subscriber) {
|
||||
if (subscribeCalledOnce.compareAndSet(false, true) == false) {
|
||||
subscriber.onError(new IllegalStateException("Subscriber already set."));
|
||||
return;
|
||||
}
|
||||
if (holder.compareAndSet(null, Tuple.tuple(null, subscriber)) == false) {
|
||||
log.debug("Subscriber connecting to publisher.");
|
||||
var publisher = holder.getAndSet(null).v1();
|
||||
publisher.subscribe(subscriber);
|
||||
} else {
|
||||
log.debug("Subscriber waiting for connection.");
|
||||
}
|
||||
}
|
||||
|
||||
private void setPublisher(Flow.Publisher<ResponseStream> publisher) {
|
||||
if (holder.compareAndSet(null, Tuple.tuple(publisher, null)) == false) {
|
||||
log.debug("Publisher connecting to subscriber.");
|
||||
var subscriber = holder.getAndSet(null).v2();
|
||||
publisher.subscribe(subscriber);
|
||||
} else {
|
||||
log.debug("Publisher waiting for connection.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public record SageMakerStream(InvokeEndpointWithResponseStreamResponse response, Flow.Publisher<ResponseStream> responseStream) {}
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.sagemaker;
|
||||
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.InputType;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
public record SageMakerInferenceRequest(
|
||||
@Nullable String query,
|
||||
@Nullable Boolean returnDocuments,
|
||||
@Nullable Integer topN,
|
||||
List<String> input,
|
||||
boolean stream,
|
||||
InputType inputType
|
||||
) {
|
||||
public SageMakerInferenceRequest {
|
||||
Objects.requireNonNull(input);
|
||||
Objects.requireNonNull(inputType);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,305 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.sagemaker;
|
||||
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.TransportVersions;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.support.SubscribableListener;
|
||||
import org.elasticsearch.common.CheckedSupplier;
|
||||
import org.elasticsearch.common.util.LazyInitializable;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.core.TimeValue;
|
||||
import org.elasticsearch.inference.ChunkInferenceInput;
|
||||
import org.elasticsearch.inference.ChunkedInference;
|
||||
import org.elasticsearch.inference.InferenceService;
|
||||
import org.elasticsearch.inference.InferenceServiceConfiguration;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.inference.InputType;
|
||||
import org.elasticsearch.inference.Model;
|
||||
import org.elasticsearch.inference.SettingsConfiguration;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.EnumSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.elasticsearch.core.Strings.format;
|
||||
import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails;
|
||||
|
||||
public class SageMakerService implements InferenceService {
|
||||
public static final String NAME = "sagemaker";
|
||||
private static final int DEFAULT_BATCH_SIZE = 256;
|
||||
private final SageMakerModelBuilder modelBuilder;
|
||||
private final SageMakerClient client;
|
||||
private final SageMakerSchemas schemas;
|
||||
private final ThreadPool threadPool;
|
||||
private final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration;
|
||||
|
||||
public SageMakerService(
|
||||
SageMakerModelBuilder modelBuilder,
|
||||
SageMakerClient client,
|
||||
SageMakerSchemas schemas,
|
||||
ThreadPool threadPool,
|
||||
CheckedSupplier<Map<String, SettingsConfiguration>, RuntimeException> configurationMap
|
||||
) {
|
||||
this.modelBuilder = modelBuilder;
|
||||
this.client = client;
|
||||
this.schemas = schemas;
|
||||
this.threadPool = threadPool;
|
||||
this.configuration = new LazyInitializable<>(
|
||||
() -> new InferenceServiceConfiguration.Builder().setService(NAME)
|
||||
.setName("Amazon SageMaker")
|
||||
.setTaskTypes(supportedTaskTypes())
|
||||
.setConfigurations(configurationMap.get())
|
||||
.build()
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String name() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void parseRequestConfig(
|
||||
String modelId,
|
||||
TaskType taskType,
|
||||
Map<String, Object> config,
|
||||
ActionListener<Model> parsedModelListener
|
||||
) {
|
||||
ActionListener.completeWith(parsedModelListener, () -> modelBuilder.fromRequest(modelId, taskType, NAME, config));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Model parsePersistedConfigWithSecrets(
|
||||
String modelId,
|
||||
TaskType taskType,
|
||||
Map<String, Object> config,
|
||||
Map<String, Object> secrets
|
||||
) {
|
||||
return modelBuilder.fromStorage(modelId, taskType, NAME, config, secrets);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config) {
|
||||
return modelBuilder.fromStorage(modelId, taskType, NAME, config, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public InferenceServiceConfiguration getConfiguration() {
|
||||
return configuration.getOrCompute();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EnumSet<TaskType> supportedTaskTypes() {
|
||||
return schemas.supportedTaskTypes();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<TaskType> supportedStreamingTasks() {
|
||||
return schemas.supportedStreamingTasks();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void infer(
|
||||
Model model,
|
||||
@Nullable String query,
|
||||
@Nullable Boolean returnDocuments,
|
||||
@Nullable Integer topN,
|
||||
List<String> input,
|
||||
boolean stream,
|
||||
Map<String, Object> taskSettings,
|
||||
InputType inputType,
|
||||
TimeValue timeout,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
if (model instanceof SageMakerModel == false) {
|
||||
listener.onFailure(createInvalidModelException(model));
|
||||
return;
|
||||
}
|
||||
|
||||
var inferenceRequest = new SageMakerInferenceRequest(query, returnDocuments, topN, input, stream, inputType);
|
||||
|
||||
try {
|
||||
var sageMakerModel = ((SageMakerModel) model).override(taskSettings);
|
||||
var regionAndSecrets = regionAndSecrets(sageMakerModel);
|
||||
|
||||
if (stream) {
|
||||
var schema = schemas.streamSchemaFor(sageMakerModel);
|
||||
var request = schema.streamRequest(sageMakerModel, inferenceRequest);
|
||||
client.invokeStream(
|
||||
regionAndSecrets,
|
||||
request,
|
||||
timeout,
|
||||
ActionListener.wrap(
|
||||
response -> listener.onResponse(schema.streamResponse(sageMakerModel, response)),
|
||||
e -> listener.onFailure(schema.error(sageMakerModel, e))
|
||||
)
|
||||
);
|
||||
} else {
|
||||
var schema = schemas.schemaFor(sageMakerModel);
|
||||
var request = schema.request(sageMakerModel, inferenceRequest);
|
||||
client.invoke(
|
||||
regionAndSecrets,
|
||||
request,
|
||||
timeout,
|
||||
ActionListener.wrap(
|
||||
response -> listener.onResponse(schema.response(sageMakerModel, response, threadPool.getThreadContext())),
|
||||
e -> listener.onFailure(schema.error(sageMakerModel, e))
|
||||
)
|
||||
);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
listener.onFailure(internalFailure(model, e));
|
||||
}
|
||||
}
|
||||
|
||||
private SageMakerClient.RegionAndSecrets regionAndSecrets(SageMakerModel model) {
|
||||
var secrets = model.awsSecretSettings();
|
||||
if (secrets.isEmpty()) {
|
||||
assert false : "Cannot invoke a model without secrets";
|
||||
throw new ElasticsearchStatusException(
|
||||
format("Attempting to infer using a model without API keys, inference id [%s]", model.getInferenceEntityId()),
|
||||
RestStatus.INTERNAL_SERVER_ERROR
|
||||
);
|
||||
}
|
||||
return new SageMakerClient.RegionAndSecrets(model.region(), secrets.get());
|
||||
}
|
||||
|
||||
private static ElasticsearchStatusException internalFailure(Model model, Exception cause) {
|
||||
if (cause instanceof ElasticsearchStatusException ese) {
|
||||
return ese;
|
||||
} else {
|
||||
return new ElasticsearchStatusException(
|
||||
"Failed to call SageMaker for inference id [{}].",
|
||||
RestStatus.INTERNAL_SERVER_ERROR,
|
||||
cause,
|
||||
model.getInferenceEntityId()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void unifiedCompletionInfer(
|
||||
Model model,
|
||||
UnifiedCompletionRequest request,
|
||||
TimeValue timeout,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
if (model instanceof SageMakerModel == false) {
|
||||
listener.onFailure(createInvalidModelException(model));
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
var sageMakerModel = (SageMakerModel) model;
|
||||
var regionAndSecrets = regionAndSecrets(sageMakerModel);
|
||||
var schema = schemas.streamSchemaFor(sageMakerModel);
|
||||
var sagemakerRequest = schema.chatCompletionStreamRequest(sageMakerModel, request);
|
||||
client.invokeStream(
|
||||
regionAndSecrets,
|
||||
sagemakerRequest,
|
||||
timeout,
|
||||
ActionListener.wrap(
|
||||
response -> listener.onResponse(schema.chatCompletionStreamResponse(sageMakerModel, response)),
|
||||
e -> listener.onFailure(schema.chatCompletionError(sageMakerModel, e))
|
||||
)
|
||||
);
|
||||
} catch (Exception e) {
|
||||
listener.onFailure(internalFailure(model, e));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void chunkedInfer(
|
||||
Model model,
|
||||
String query,
|
||||
List<ChunkInferenceInput> input,
|
||||
Map<String, Object> taskSettings,
|
||||
InputType inputType,
|
||||
TimeValue timeout,
|
||||
ActionListener<List<ChunkedInference>> listener
|
||||
) {
|
||||
if (model instanceof SageMakerModel == false) {
|
||||
listener.onFailure(createInvalidModelException(model));
|
||||
return;
|
||||
}
|
||||
try {
|
||||
var sageMakerModel = ((SageMakerModel) model).override(taskSettings);
|
||||
var batchedRequests = new EmbeddingRequestChunker<>(
|
||||
input,
|
||||
sageMakerModel.batchSize().orElse(DEFAULT_BATCH_SIZE),
|
||||
sageMakerModel.getConfigurations().getChunkingSettings()
|
||||
).batchRequestsWithListeners(listener);
|
||||
|
||||
var subscribableListener = SubscribableListener.newSucceeded(null);
|
||||
for (var request : batchedRequests) {
|
||||
subscribableListener = subscribableListener.andThen(
|
||||
threadPool.executor(UTILITY_THREAD_POOL_NAME),
|
||||
threadPool.getThreadContext(),
|
||||
(l, ignored) -> infer(
|
||||
sageMakerModel,
|
||||
query,
|
||||
null, // no return docs while chunking?
|
||||
null, // no topN while chunking?
|
||||
request.batch().inputs().get(),
|
||||
false, // we never stream when chunking
|
||||
null, // since we pass sageMakerModel as the model, we already overwrote the model with the task settings
|
||||
inputType,
|
||||
timeout,
|
||||
ActionListener.runAfter(request.listener(), () -> l.onResponse(null))
|
||||
)
|
||||
);
|
||||
}
|
||||
// if there were any errors trying to create the SubscribableListener chain, then forward that to the listener
|
||||
// otherwise, BatchRequestAndListener will handle forwarding errors from the infer method
|
||||
subscribableListener.addListener(
|
||||
ActionListener.noop().delegateResponse((l, e) -> listener.onFailure(internalFailure(model, e)))
|
||||
);
|
||||
} catch (Exception e) {
|
||||
listener.onFailure(internalFailure(model, e));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start(Model model, TimeValue timeout, ActionListener<Boolean> listener) {
|
||||
listener.onResponse(true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
|
||||
if (model instanceof SageMakerModel sageMakerModel) {
|
||||
return modelBuilder.updateModelWithEmbeddingDetails(sageMakerModel, embeddingSize);
|
||||
}
|
||||
|
||||
throw invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
return TransportVersions.ML_INFERENCE_SAGEMAKER;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws IOException {
|
||||
client.close();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.sagemaker.model;
|
||||
|
||||
import org.elasticsearch.common.CheckedSupplier;
|
||||
import org.elasticsearch.inference.SettingsConfiguration;
|
||||
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class SageMakerConfiguration implements CheckedSupplier<Map<String, SettingsConfiguration>, RuntimeException> {
|
||||
private final SageMakerSchemas schemas;
|
||||
|
||||
public SageMakerConfiguration(SageMakerSchemas schemas) {
|
||||
this.schemas = schemas;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, SettingsConfiguration> get() {
|
||||
return Stream.of(
|
||||
AwsSecretSettings.configuration(schemas.supportedTaskTypes()),
|
||||
SageMakerServiceSettings.configuration(schemas.supportedTaskTypes()),
|
||||
SageMakerTaskSettings.configuration(schemas.supportedTaskTypes())
|
||||
).flatMap(Function.identity()).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,146 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.sagemaker.model;
|
||||
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.inference.Model;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ModelSecrets;
|
||||
import org.elasticsearch.inference.ServiceSettings;
|
||||
import org.elasticsearch.inference.TaskSettings;
|
||||
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* This model represents all models in SageMaker. SageMaker maintains a base set of settings and configurations, and this model manages
|
||||
* those. Any settings that are required for a specific model are stored in the {@link SageMakerStoredServiceSchema} and
|
||||
* {@link SageMakerStoredTaskSchema}.
|
||||
* Design:
|
||||
* - Region is stored in ServiceSettings and is used to create the SageMaker client.
|
||||
* - RateLimiting is based on AWS Service Quota, metered by account and region. The SDK client handles rate limiting internally. In order to
|
||||
* rate limit appropriately, use the same AWS Access ID, Secret Key, and Region for each inference endpoint calling the same AWS account.
|
||||
* - Within Elastic, you cannot change the model behind the endpoint, because we require the embedding shape remain consistent between
|
||||
* invocations. So anything "model selection" related must remain consistent through the lifetime of the Elastic endpoint. SageMaker
|
||||
* allows model changes via TargetModel, InferenceComponentName, and TargetContainerHostname, but these will remain static in Elastic
|
||||
* within ServiceSettings.
|
||||
* - CustomAttributes, EnableExplanations, InferenceId, SessionId, and TargetVariant are all request-time fields that can be saved.
|
||||
* - SageMaker returns 4 headers, which Elastic will forward as-is: x-Amzn-Invoked-Production-Variant, X-Amzn-SageMaker-Custom-Attributes,
|
||||
* X-Amzn-SageMaker-New-Session-Id, X-Amzn-SageMaker-Closed-Session-Id
|
||||
*/
|
||||
public class SageMakerModel extends Model {
|
||||
private final SageMakerServiceSettings serviceSettings;
|
||||
private final SageMakerTaskSettings taskSettings;
|
||||
private final AwsSecretSettings awsSecretSettings;
|
||||
|
||||
SageMakerModel(
|
||||
ModelConfigurations configurations,
|
||||
ModelSecrets secrets,
|
||||
SageMakerServiceSettings serviceSettings,
|
||||
SageMakerTaskSettings taskSettings,
|
||||
AwsSecretSettings awsSecretSettings
|
||||
) {
|
||||
super(configurations, secrets);
|
||||
this.serviceSettings = serviceSettings;
|
||||
this.taskSettings = taskSettings;
|
||||
this.awsSecretSettings = awsSecretSettings;
|
||||
}
|
||||
|
||||
public Optional<AwsSecretSettings> awsSecretSettings() {
|
||||
return Optional.ofNullable(awsSecretSettings);
|
||||
}
|
||||
|
||||
public String region() {
|
||||
return serviceSettings.region();
|
||||
}
|
||||
|
||||
public String endpointName() {
|
||||
return serviceSettings.endpointName();
|
||||
}
|
||||
|
||||
public String api() {
|
||||
return serviceSettings.api();
|
||||
}
|
||||
|
||||
public Optional<String> customAttributes() {
|
||||
return Optional.ofNullable(taskSettings.customAttributes());
|
||||
}
|
||||
|
||||
public Optional<String> enableExplanations() {
|
||||
return Optional.ofNullable(taskSettings.enableExplanations());
|
||||
}
|
||||
|
||||
public Optional<String> inferenceComponentName() {
|
||||
return Optional.ofNullable(serviceSettings.inferenceComponentName());
|
||||
}
|
||||
|
||||
public Optional<String> inferenceIdForDataCapture() {
|
||||
return Optional.ofNullable(taskSettings.inferenceIdForDataCapture());
|
||||
}
|
||||
|
||||
public Optional<String> sessionId() {
|
||||
return Optional.ofNullable(taskSettings.sessionId());
|
||||
}
|
||||
|
||||
public Optional<String> targetContainerHostname() {
|
||||
return Optional.ofNullable(serviceSettings.targetContainerHostname());
|
||||
}
|
||||
|
||||
public Optional<String> targetModel() {
|
||||
return Optional.ofNullable(serviceSettings.targetModel());
|
||||
}
|
||||
|
||||
public Optional<String> targetVariant() {
|
||||
return Optional.ofNullable(taskSettings.targetVariant());
|
||||
}
|
||||
|
||||
public Optional<Integer> batchSize() {
|
||||
return Optional.ofNullable(serviceSettings.batchSize());
|
||||
}
|
||||
|
||||
public SageMakerModel override(Map<String, Object> taskSettingsOverride) {
|
||||
if (taskSettingsOverride == null || taskSettingsOverride.isEmpty()) {
|
||||
return this;
|
||||
}
|
||||
|
||||
return new SageMakerModel(
|
||||
getConfigurations(),
|
||||
getSecrets(),
|
||||
serviceSettings,
|
||||
taskSettings.updatedTaskSettings(taskSettingsOverride),
|
||||
awsSecretSettings
|
||||
);
|
||||
}
|
||||
|
||||
public static List<NamedWriteableRegistry.Entry> namedWriteables() {
|
||||
return List.of(
|
||||
new NamedWriteableRegistry.Entry(ServiceSettings.class, SageMakerServiceSettings.NAME, SageMakerServiceSettings::new),
|
||||
new NamedWriteableRegistry.Entry(TaskSettings.class, SageMakerTaskSettings.NAME, SageMakerTaskSettings::new)
|
||||
);
|
||||
}
|
||||
|
||||
public SageMakerStoredServiceSchema apiServiceSettings() {
|
||||
return serviceSettings.apiServiceSettings();
|
||||
}
|
||||
|
||||
public SageMakerStoredTaskSchema apiTaskSettings() {
|
||||
return taskSettings.apiTaskSettings();
|
||||
}
|
||||
|
||||
SageMakerServiceSettings serviceSettings() {
|
||||
return serviceSettings;
|
||||
}
|
||||
|
||||
SageMakerTaskSettings taskSettings() {
|
||||
return taskSettings;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,135 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.sagemaker.model;
|
||||
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ModelSecrets;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
|
||||
|
||||
public class SageMakerModelBuilder {
|
||||
|
||||
private final SageMakerSchemas schemas;
|
||||
|
||||
public SageMakerModelBuilder(SageMakerSchemas schemas) {
|
||||
this.schemas = schemas;
|
||||
}
|
||||
|
||||
public SageMakerModel fromRequest(String inferenceEntityId, TaskType taskType, String service, Map<String, Object> requestMap) {
|
||||
var validationException = new ValidationException();
|
||||
var serviceSettingsMap = removeFromMapOrThrowIfNull(requestMap, ModelConfigurations.SERVICE_SETTINGS);
|
||||
var awsSecretSettings = AwsSecretSettings.fromMap(serviceSettingsMap);
|
||||
var serviceSettings = SageMakerServiceSettings.fromMap(schemas, taskType, serviceSettingsMap);
|
||||
|
||||
var schema = schemas.schemaFor(taskType, serviceSettings.api());
|
||||
|
||||
var taskSettingsMap = removeFromMapOrDefaultEmpty(requestMap, ModelConfigurations.TASK_SETTINGS);
|
||||
var taskSettings = SageMakerTaskSettings.fromMap(
|
||||
taskSettingsMap,
|
||||
schema.apiTaskSettings(taskSettingsMap, validationException),
|
||||
validationException
|
||||
);
|
||||
|
||||
validationException.throwIfValidationErrorsExist();
|
||||
throwIfNotEmptyMap(serviceSettingsMap, service);
|
||||
throwIfNotEmptyMap(taskSettingsMap, service);
|
||||
|
||||
var modelConfigurations = new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings);
|
||||
return new SageMakerModel(
|
||||
modelConfigurations,
|
||||
new ModelSecrets(awsSecretSettings),
|
||||
serviceSettings,
|
||||
taskSettings,
|
||||
awsSecretSettings
|
||||
);
|
||||
}
|
||||
|
||||
public SageMakerModel fromStorage(
|
||||
String inferenceEntityId,
|
||||
TaskType taskType,
|
||||
String service,
|
||||
Map<String, Object> config,
|
||||
Map<String, Object> secrets
|
||||
) {
|
||||
var validationException = new ValidationException();
|
||||
var awsSecretSettings = secrets != null
|
||||
? AwsSecretSettings.fromMap(removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS))
|
||||
: null;
|
||||
|
||||
var serviceSettings = SageMakerServiceSettings.fromMap(
|
||||
schemas,
|
||||
taskType,
|
||||
removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS)
|
||||
);
|
||||
|
||||
var schema = schemas.schemaFor(taskType, serviceSettings.api());
|
||||
var taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
|
||||
|
||||
var taskSettings = SageMakerTaskSettings.fromMap(
|
||||
taskSettingsMap,
|
||||
schema.apiTaskSettings(taskSettingsMap, validationException),
|
||||
validationException
|
||||
);
|
||||
|
||||
validationException.throwIfValidationErrorsExist();
|
||||
throwIfNotEmptyMap(config, service);
|
||||
throwIfNotEmptyMap(taskSettingsMap, service);
|
||||
throwIfNotEmptyMap(secrets, service);
|
||||
|
||||
var modelConfigurations = new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings);
|
||||
return new SageMakerModel(
|
||||
modelConfigurations,
|
||||
new ModelSecrets(awsSecretSettings),
|
||||
serviceSettings,
|
||||
taskSettings,
|
||||
awsSecretSettings
|
||||
);
|
||||
}
|
||||
|
||||
public SageMakerModel updateModelWithEmbeddingDetails(SageMakerModel model, int embeddingSize) {
|
||||
var updatedApiServiceSettings = model.apiServiceSettings().updateModelWithEmbeddingDetails(embeddingSize);
|
||||
|
||||
if (updatedApiServiceSettings == model.apiServiceSettings()) {
|
||||
return model;
|
||||
}
|
||||
|
||||
var updatedServiceSettings = new SageMakerServiceSettings(
|
||||
model.serviceSettings().endpointName(),
|
||||
model.serviceSettings().region(),
|
||||
model.serviceSettings().api(),
|
||||
model.serviceSettings().targetModel(),
|
||||
model.serviceSettings().targetContainerHostname(),
|
||||
model.serviceSettings().inferenceComponentName(),
|
||||
model.serviceSettings().batchSize(),
|
||||
updatedApiServiceSettings
|
||||
);
|
||||
|
||||
var modelConfigurations = new ModelConfigurations(
|
||||
model.getInferenceEntityId(),
|
||||
model.getTaskType(),
|
||||
model.getConfigurations().getService(),
|
||||
updatedServiceSettings,
|
||||
model.taskSettings()
|
||||
);
|
||||
return new SageMakerModel(
|
||||
modelConfigurations,
|
||||
model.getSecrets(),
|
||||
updatedServiceSettings,
|
||||
model.taskSettings(),
|
||||
model.awsSecretSettings().orElse(null)
|
||||
);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,285 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.sagemaker.model;
|
||||
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.TransportVersions;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ServiceSettings;
|
||||
import org.elasticsearch.inference.SettingsConfiguration;
|
||||
import org.elasticsearch.inference.SimilarityMeasure;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
|
||||
import org.elasticsearch.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.EnumSet;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
|
||||
|
||||
/**
|
||||
* Maintains the settings for SageMaker that cannot be changed without impacting semantic search and AI assistants.
|
||||
* Model-specific settings are stored in {@link SageMakerStoredServiceSchema}.
|
||||
*/
|
||||
record SageMakerServiceSettings(
|
||||
String endpointName,
|
||||
String region,
|
||||
String api,
|
||||
@Nullable String targetModel,
|
||||
@Nullable String targetContainerHostname,
|
||||
@Nullable String inferenceComponentName,
|
||||
@Nullable Integer batchSize,
|
||||
SageMakerStoredServiceSchema apiServiceSettings
|
||||
) implements ServiceSettings {
|
||||
|
||||
static final String NAME = "sage_maker_service_settings";
|
||||
private static final String API = "api";
|
||||
private static final String ENDPOINT_NAME = "endpoint_name";
|
||||
private static final String REGION = "region";
|
||||
private static final String TARGET_MODEL = "target_model";
|
||||
private static final String TARGET_CONTAINER_HOSTNAME = "target_container_hostname";
|
||||
private static final String INFERENCE_COMPONENT_NAME = "inference_component_name";
|
||||
private static final String BATCH_SIZE = "batch_size";
|
||||
|
||||
SageMakerServiceSettings {
|
||||
Objects.requireNonNull(endpointName);
|
||||
Objects.requireNonNull(region);
|
||||
Objects.requireNonNull(api);
|
||||
Objects.requireNonNull(apiServiceSettings);
|
||||
}
|
||||
|
||||
SageMakerServiceSettings(StreamInput in) throws IOException {
|
||||
this(
|
||||
in.readString(),
|
||||
in.readString(),
|
||||
in.readString(),
|
||||
in.readOptionalString(),
|
||||
in.readOptionalString(),
|
||||
in.readOptionalString(),
|
||||
in.readOptionalInt(),
|
||||
in.readNamedWriteable(SageMakerStoredServiceSchema.class)
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String modelId() {
|
||||
return apiServiceSettings.modelId();
|
||||
}
|
||||
|
||||
@Override
|
||||
public SimilarityMeasure similarity() {
|
||||
return apiServiceSettings.similarity();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer dimensions() {
|
||||
return apiServiceSettings.dimensions();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Boolean dimensionsSetByUser() {
|
||||
return apiServiceSettings.dimensionsSetByUser();
|
||||
}
|
||||
|
||||
@Override
|
||||
public DenseVectorFieldMapper.ElementType elementType() {
|
||||
return apiServiceSettings.elementType();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
return TransportVersions.ML_INFERENCE_SAGEMAKER;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(endpointName());
|
||||
out.writeString(region());
|
||||
out.writeString(api());
|
||||
out.writeOptionalString(targetModel());
|
||||
out.writeOptionalString(targetContainerHostname());
|
||||
out.writeOptionalString(inferenceComponentName());
|
||||
out.writeOptionalInt(batchSize());
|
||||
out.writeNamedWriteable(apiServiceSettings);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ToXContentObject getFilteredXContentObject() {
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
|
||||
builder.field(ENDPOINT_NAME, endpointName());
|
||||
builder.field(REGION, region());
|
||||
builder.field(API, api());
|
||||
optionalField(TARGET_MODEL, targetModel(), builder);
|
||||
optionalField(TARGET_CONTAINER_HOSTNAME, targetContainerHostname(), builder);
|
||||
optionalField(INFERENCE_COMPONENT_NAME, inferenceComponentName(), builder);
|
||||
optionalField(BATCH_SIZE, batchSize(), builder);
|
||||
apiServiceSettings.toXContent(builder, params);
|
||||
|
||||
return builder.endObject();
|
||||
}
|
||||
|
||||
private static <T> void optionalField(String name, T value, XContentBuilder builder) throws IOException {
|
||||
if (value != null) {
|
||||
builder.field(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
static SageMakerServiceSettings fromMap(SageMakerSchemas schemas, TaskType taskType, Map<String, Object> serviceSettingsMap) {
|
||||
ValidationException validationException = new ValidationException();
|
||||
|
||||
var endpointName = extractRequiredString(
|
||||
serviceSettingsMap,
|
||||
ENDPOINT_NAME,
|
||||
ModelConfigurations.SERVICE_SETTINGS,
|
||||
validationException
|
||||
);
|
||||
var region = extractRequiredString(serviceSettingsMap, REGION, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
var api = extractRequiredString(serviceSettingsMap, API, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
var targetModel = extractOptionalString(
|
||||
serviceSettingsMap,
|
||||
TARGET_MODEL,
|
||||
ModelConfigurations.SERVICE_SETTINGS,
|
||||
validationException
|
||||
);
|
||||
var targetContainerHostname = extractOptionalString(
|
||||
serviceSettingsMap,
|
||||
TARGET_CONTAINER_HOSTNAME,
|
||||
ModelConfigurations.SERVICE_SETTINGS,
|
||||
validationException
|
||||
);
|
||||
var inferenceComponentName = extractOptionalString(
|
||||
serviceSettingsMap,
|
||||
INFERENCE_COMPONENT_NAME,
|
||||
ModelConfigurations.SERVICE_SETTINGS,
|
||||
validationException
|
||||
);
|
||||
var batchSize = extractOptionalPositiveInteger(
|
||||
serviceSettingsMap,
|
||||
BATCH_SIZE,
|
||||
ModelConfigurations.SERVICE_SETTINGS,
|
||||
validationException
|
||||
);
|
||||
|
||||
validationException.throwIfValidationErrorsExist();
|
||||
|
||||
var schema = schemas.schemaFor(taskType, api);
|
||||
var apiServiceSettings = schema.apiServiceSettings(serviceSettingsMap, validationException);
|
||||
|
||||
validationException.throwIfValidationErrorsExist();
|
||||
|
||||
return new SageMakerServiceSettings(
|
||||
endpointName,
|
||||
region,
|
||||
api,
|
||||
targetModel,
|
||||
targetContainerHostname,
|
||||
inferenceComponentName,
|
||||
batchSize,
|
||||
apiServiceSettings
|
||||
);
|
||||
}
|
||||
|
||||
static Stream<Map.Entry<String, SettingsConfiguration>> configuration(EnumSet<TaskType> supportedTaskTypes) {
|
||||
return Stream.of(
|
||||
Map.entry(
|
||||
API,
|
||||
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("The API format that your SageMaker Endpoint expects.")
|
||||
.setLabel("API")
|
||||
.setRequired(true)
|
||||
.setSensitive(false)
|
||||
.setUpdatable(false)
|
||||
.setType(SettingsConfigurationFieldType.STRING)
|
||||
.build()
|
||||
),
|
||||
Map.entry(
|
||||
ENDPOINT_NAME,
|
||||
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
|
||||
"The name specified when creating the SageMaker Endpoint."
|
||||
)
|
||||
.setLabel("Endpoint Name")
|
||||
.setRequired(true)
|
||||
.setSensitive(false)
|
||||
.setUpdatable(false)
|
||||
.setType(SettingsConfigurationFieldType.STRING)
|
||||
.build()
|
||||
),
|
||||
Map.entry(
|
||||
REGION,
|
||||
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
|
||||
"The AWS region that your model or ARN is deployed in."
|
||||
)
|
||||
.setLabel("Region")
|
||||
.setRequired(true)
|
||||
.setSensitive(false)
|
||||
.setUpdatable(false)
|
||||
.setType(SettingsConfigurationFieldType.STRING)
|
||||
.build()
|
||||
),
|
||||
Map.entry(
|
||||
TARGET_MODEL,
|
||||
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
|
||||
"The model to request when calling a SageMaker multi-model Endpoint."
|
||||
)
|
||||
.setLabel("Target Model")
|
||||
.setRequired(false)
|
||||
.setSensitive(false)
|
||||
.setUpdatable(false)
|
||||
.setType(SettingsConfigurationFieldType.STRING)
|
||||
.build()
|
||||
),
|
||||
Map.entry(
|
||||
TARGET_CONTAINER_HOSTNAME,
|
||||
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
|
||||
"The hostname of the container when calling a SageMaker multi-container Endpoint."
|
||||
)
|
||||
.setLabel("Target Container Hostname")
|
||||
.setRequired(false)
|
||||
.setSensitive(false)
|
||||
.setUpdatable(false)
|
||||
.setType(SettingsConfigurationFieldType.STRING)
|
||||
.build()
|
||||
),
|
||||
Map.entry(
|
||||
BATCH_SIZE,
|
||||
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
|
||||
"The maximum size a single chunk of input can be when chunking input for semantic text."
|
||||
)
|
||||
.setLabel("Batch Size")
|
||||
.setRequired(false)
|
||||
.setSensitive(false)
|
||||
.setUpdatable(false)
|
||||
.setType(SettingsConfigurationFieldType.INTEGER)
|
||||
.build()
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,236 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.sagemaker.model;
|
||||
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.TransportVersions;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.SettingsConfiguration;
|
||||
import org.elasticsearch.inference.TaskSettings;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.EnumSet;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
|
||||
|
||||
/**
|
||||
* Maintains mutable settings for SageMaker. Model-specific settings are stored in {@link SageMakerStoredTaskSchema}.
|
||||
*/
|
||||
record SageMakerTaskSettings(
|
||||
@Nullable String customAttributes,
|
||||
@Nullable String enableExplanations,
|
||||
@Nullable String inferenceIdForDataCapture,
|
||||
@Nullable String sessionId,
|
||||
@Nullable String targetVariant,
|
||||
SageMakerStoredTaskSchema apiTaskSettings
|
||||
) implements TaskSettings {
|
||||
|
||||
static final String NAME = "sage_maker_task_settings";
|
||||
private static final String CUSTOM_ATTRIBUTES = "custom_attributes";
|
||||
private static final String ENABLE_EXPLANATIONS = "enable_explanations";
|
||||
private static final String INFERENCE_ID = "inference_id";
|
||||
private static final String SESSION_ID = "session_id";
|
||||
private static final String TARGET_VARIANT = "target_variant";
|
||||
|
||||
SageMakerTaskSettings(StreamInput in) throws IOException {
|
||||
this(
|
||||
in.readOptionalString(),
|
||||
in.readOptionalString(),
|
||||
in.readOptionalString(),
|
||||
in.readOptionalString(),
|
||||
in.readOptionalString(),
|
||||
in.readNamedWriteable(SageMakerStoredTaskSchema.class)
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isEmpty() {
|
||||
return customAttributes == null
|
||||
&& enableExplanations == null
|
||||
&& inferenceIdForDataCapture == null
|
||||
&& sessionId == null
|
||||
&& targetVariant == null
|
||||
&& SageMakerStoredTaskSchema.NO_OP.equals(apiTaskSettings);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SageMakerTaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
|
||||
var validationException = new ValidationException();
|
||||
|
||||
var updateTaskSettings = fromMap(newSettings, apiTaskSettings.updatedTaskSettings(newSettings), validationException);
|
||||
|
||||
validationException.throwIfValidationErrorsExist();
|
||||
|
||||
var updatedExtraTaskSettings = updateTaskSettings.apiTaskSettings().equals(SageMakerStoredTaskSchema.NO_OP)
|
||||
? apiTaskSettings
|
||||
: updateTaskSettings.apiTaskSettings();
|
||||
|
||||
return new SageMakerTaskSettings(
|
||||
firstNotNullOrNull(updateTaskSettings.customAttributes(), customAttributes),
|
||||
firstNotNullOrNull(updateTaskSettings.enableExplanations(), enableExplanations),
|
||||
firstNotNullOrNull(updateTaskSettings.inferenceIdForDataCapture(), inferenceIdForDataCapture),
|
||||
firstNotNullOrNull(updateTaskSettings.sessionId(), sessionId),
|
||||
firstNotNullOrNull(updateTaskSettings.targetVariant(), targetVariant),
|
||||
updatedExtraTaskSettings
|
||||
);
|
||||
}
|
||||
|
||||
private static <T> T firstNotNullOrNull(T first, T second) {
|
||||
return first != null ? first : second;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
return TransportVersions.ML_INFERENCE_SAGEMAKER;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeOptionalString(customAttributes);
|
||||
out.writeOptionalString(enableExplanations);
|
||||
out.writeOptionalString(inferenceIdForDataCapture);
|
||||
out.writeOptionalString(sessionId);
|
||||
out.writeOptionalString(targetVariant);
|
||||
out.writeNamedWriteable(apiTaskSettings);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
optionalField(CUSTOM_ATTRIBUTES, customAttributes, builder);
|
||||
optionalField(ENABLE_EXPLANATIONS, enableExplanations, builder);
|
||||
optionalField(INFERENCE_ID, inferenceIdForDataCapture, builder);
|
||||
optionalField(SESSION_ID, sessionId, builder);
|
||||
optionalField(TARGET_VARIANT, targetVariant, builder);
|
||||
apiTaskSettings.toXContent(builder, params);
|
||||
|
||||
return builder.endObject();
|
||||
}
|
||||
|
||||
private static <T> void optionalField(String name, T value, XContentBuilder builder) throws IOException {
|
||||
if (value != null) {
|
||||
builder.field(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
public static SageMakerTaskSettings fromMap(
|
||||
Map<String, Object> taskSettingsMap,
|
||||
SageMakerStoredTaskSchema apiTaskSettings,
|
||||
ValidationException validationException
|
||||
) {
|
||||
var customAttributes = extractOptionalString(
|
||||
taskSettingsMap,
|
||||
CUSTOM_ATTRIBUTES,
|
||||
ModelConfigurations.TASK_SETTINGS,
|
||||
validationException
|
||||
);
|
||||
var enableExplanations = extractOptionalString(
|
||||
taskSettingsMap,
|
||||
ENABLE_EXPLANATIONS,
|
||||
ModelConfigurations.TASK_SETTINGS,
|
||||
validationException
|
||||
);
|
||||
var inferenceIdForDataCapture = extractOptionalString(
|
||||
taskSettingsMap,
|
||||
INFERENCE_ID,
|
||||
ModelConfigurations.TASK_SETTINGS,
|
||||
validationException
|
||||
);
|
||||
var sessionId = extractOptionalString(taskSettingsMap, SESSION_ID, ModelConfigurations.TASK_SETTINGS, validationException);
|
||||
var targetVariant = extractOptionalString(taskSettingsMap, TARGET_VARIANT, ModelConfigurations.TASK_SETTINGS, validationException);
|
||||
|
||||
return new SageMakerTaskSettings(
|
||||
customAttributes,
|
||||
enableExplanations,
|
||||
inferenceIdForDataCapture,
|
||||
sessionId,
|
||||
targetVariant,
|
||||
apiTaskSettings
|
||||
);
|
||||
}
|
||||
|
||||
static Stream<Map.Entry<String, SettingsConfiguration>> configuration(EnumSet<TaskType> supportedTaskTypes) {
|
||||
return Stream.of(
|
||||
Map.entry(
|
||||
CUSTOM_ATTRIBUTES,
|
||||
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
|
||||
"An opaque informational value forwarded as-is to the model within SageMaker."
|
||||
)
|
||||
.setLabel("Custom Attributes")
|
||||
.setRequired(false)
|
||||
.setSensitive(false)
|
||||
.setUpdatable(false)
|
||||
.setType(SettingsConfigurationFieldType.STRING)
|
||||
.build()
|
||||
),
|
||||
Map.entry(
|
||||
ENABLE_EXPLANATIONS,
|
||||
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
|
||||
"JMESPath expression overriding the ClarifyingExplainerConfig in the SageMaker Endpoint Configuration."
|
||||
)
|
||||
.setLabel("Enable Explanations")
|
||||
.setRequired(false)
|
||||
.setSensitive(false)
|
||||
.setUpdatable(false)
|
||||
.setType(SettingsConfigurationFieldType.STRING)
|
||||
.build()
|
||||
),
|
||||
Map.entry(
|
||||
INFERENCE_ID,
|
||||
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
|
||||
"Informational identifying for auditing requests within the SageMaker Endpoint."
|
||||
)
|
||||
.setLabel("Inference ID")
|
||||
.setRequired(false)
|
||||
.setSensitive(false)
|
||||
.setUpdatable(false)
|
||||
.setType(SettingsConfigurationFieldType.STRING)
|
||||
.build()
|
||||
),
|
||||
Map.entry(
|
||||
SESSION_ID,
|
||||
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
|
||||
"Creates or reuses an existing Session for SageMaker stateful models."
|
||||
)
|
||||
.setLabel("Session ID")
|
||||
.setRequired(false)
|
||||
.setSensitive(false)
|
||||
.setUpdatable(false)
|
||||
.setType(SettingsConfigurationFieldType.STRING)
|
||||
.build()
|
||||
),
|
||||
Map.entry(
|
||||
TARGET_VARIANT,
|
||||
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
|
||||
"The production variant when calling the SageMaker Endpoint"
|
||||
)
|
||||
.setLabel("Target Variant")
|
||||
.setRequired(false)
|
||||
.setSensitive(false)
|
||||
.setUpdatable(false)
|
||||
.setType(SettingsConfigurationFieldType.STRING)
|
||||
.build()
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,215 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
|
||||
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.InternalDependencyException;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.InternalFailureException;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointRequest;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.ModelErrorException;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.ModelNotReadyException;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.SageMakerRuntimeException;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.ServiceUnavailableException;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.ValidationErrorException;
|
||||
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||
import org.elasticsearch.core.Tuple;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.elasticsearch.core.Strings.format;
|
||||
|
||||
/**
|
||||
* All the logic that is required to call any SageMaker model is handled within this Schema class.
|
||||
* Any model-specific logic is handled within the associated {@link SageMakerSchemaPayload}.
|
||||
* This schema is specific for SageMaker's non-streaming API. For streaming, see {@link SageMakerStreamSchema}.
|
||||
*/
|
||||
public class SageMakerSchema {
|
||||
private static final String CUSTOM_ATTRIBUTES_HEADER = "X-elastic-sagemaker-custom-attributes";
|
||||
private static final String NEW_SESSION_HEADER = "X-elastic-sagemaker-new-session-id";
|
||||
private static final String CLOSED_SESSION_HEADER = "X-elastic-sagemaker-closed-session-id";
|
||||
|
||||
private static final String ACCESS_DENIED_CODE = "AccessDeniedException";
|
||||
private static final String INCOMPLETE_SIGNATURE = "IncompleteSignature";
|
||||
private static final String INVALID_ACTION = "InvalidAction";
|
||||
private static final String INVALID_CLIENT_TOKEN = "InvalidClientTokenId";
|
||||
private static final String NOT_AUTHORIZED = "NotAuthorized";
|
||||
private static final String OPT_IN_REQUIRED = "OptInRequired";
|
||||
private static final String REQUEST_EXPIRED = "RequestExpired";
|
||||
private static final String THROTTLING_EXCEPTION = "ThrottlingException";
|
||||
|
||||
private final SageMakerSchemaPayload schemaPayload;
|
||||
|
||||
public SageMakerSchema(SageMakerSchemaPayload schemaPayload) {
|
||||
this.schemaPayload = schemaPayload;
|
||||
}
|
||||
|
||||
public InvokeEndpointRequest request(SageMakerModel model, SageMakerInferenceRequest request) {
|
||||
try {
|
||||
return createRequest(model).accept(schemaPayload.accept(model))
|
||||
.contentType(schemaPayload.contentType(model))
|
||||
.body(schemaPayload.requestBytes(model, request))
|
||||
.build();
|
||||
} catch (ElasticsearchStatusException e) {
|
||||
throw e;
|
||||
} catch (Exception e) {
|
||||
throw new ElasticsearchStatusException(
|
||||
"Failed to create SageMaker request for [%s]",
|
||||
RestStatus.INTERNAL_SERVER_ERROR,
|
||||
e,
|
||||
model.getInferenceEntityId()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private InvokeEndpointRequest.Builder createRequest(SageMakerModel model) {
|
||||
var request = InvokeEndpointRequest.builder();
|
||||
request.endpointName(model.endpointName());
|
||||
model.customAttributes().ifPresent(request::customAttributes);
|
||||
model.enableExplanations().ifPresent(request::enableExplanations);
|
||||
model.inferenceComponentName().ifPresent(request::inferenceComponentName);
|
||||
model.inferenceIdForDataCapture().ifPresent(request::inferenceId);
|
||||
model.sessionId().ifPresent(request::sessionId);
|
||||
model.targetContainerHostname().ifPresent(request::targetContainerHostname);
|
||||
model.targetModel().ifPresent(request::targetModel);
|
||||
model.targetVariant().ifPresent(request::targetVariant);
|
||||
return request;
|
||||
}
|
||||
|
||||
public InferenceServiceResults response(SageMakerModel model, InvokeEndpointResponse response, ThreadContext threadContext)
|
||||
throws Exception {
|
||||
try {
|
||||
addHeaders(response, threadContext);
|
||||
return schemaPayload.responseBody(model, response);
|
||||
} catch (ElasticsearchStatusException e) {
|
||||
throw e;
|
||||
} catch (Exception e) {
|
||||
throw new ElasticsearchStatusException(
|
||||
"Failed to translate SageMaker response for [%s]",
|
||||
RestStatus.INTERNAL_SERVER_ERROR,
|
||||
e,
|
||||
model.getInferenceEntityId()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private void addHeaders(InvokeEndpointResponse response, ThreadContext threadContext) {
|
||||
if (response.customAttributes() != null) {
|
||||
threadContext.addResponseHeader(CUSTOM_ATTRIBUTES_HEADER, response.customAttributes());
|
||||
}
|
||||
if (response.newSessionId() != null) {
|
||||
threadContext.addResponseHeader(NEW_SESSION_HEADER, response.newSessionId());
|
||||
}
|
||||
if (response.closedSessionId() != null) {
|
||||
threadContext.addResponseHeader(CLOSED_SESSION_HEADER, response.closedSessionId());
|
||||
}
|
||||
}
|
||||
|
||||
public Exception error(SageMakerModel model, Exception e) {
|
||||
if (e instanceof ElasticsearchStatusException ee) {
|
||||
return ee;
|
||||
}
|
||||
var error = errorMessageAndStatus(model, e);
|
||||
return new ElasticsearchStatusException(error.v1(), error.v2(), e);
|
||||
}
|
||||
|
||||
/**
|
||||
* Protected because {@link SageMakerStreamSchema} will reuse this to create a Chat Completion error message.
|
||||
*/
|
||||
protected Tuple<String, RestStatus> errorMessageAndStatus(SageMakerModel model, Exception e) {
|
||||
String errorMessage = null;
|
||||
RestStatus restStatus = null;
|
||||
if (e instanceof InternalDependencyException) {
|
||||
errorMessage = format("Received an internal dependency error from SageMaker for [%s]", model.getInferenceEntityId());
|
||||
restStatus = RestStatus.INTERNAL_SERVER_ERROR;
|
||||
} else if (e instanceof InternalFailureException) {
|
||||
errorMessage = format("Received an internal failure from SageMaker for [%s]", model.getInferenceEntityId());
|
||||
restStatus = RestStatus.INTERNAL_SERVER_ERROR;
|
||||
} else if (e instanceof ModelErrorException) {
|
||||
errorMessage = format("Received a model error from SageMaker for [%s]", model.getInferenceEntityId());
|
||||
restStatus = RestStatus.FAILED_DEPENDENCY;
|
||||
} else if (e instanceof ModelNotReadyException) {
|
||||
errorMessage = format("Received a model not ready error from SageMaker for [%s]", model.getInferenceEntityId());
|
||||
restStatus = RestStatus.TOO_MANY_REQUESTS;
|
||||
} else if (e instanceof ServiceUnavailableException) {
|
||||
errorMessage = format("SageMaker is unavailable for [%s]", model.getInferenceEntityId());
|
||||
restStatus = RestStatus.SERVICE_UNAVAILABLE;
|
||||
} else if (e instanceof ValidationErrorException) {
|
||||
errorMessage = format("Received a validation error from SageMaker for [%s]", model.getInferenceEntityId());
|
||||
restStatus = RestStatus.BAD_REQUEST;
|
||||
}
|
||||
|
||||
// if we have a SageMakerRuntimeException that isn't one of the child exceptions, we can parse the error from AwsErrorDetails
|
||||
// https://docs.aws.amazon.com/sagemaker/latest/APIReference/CommonErrors.html
|
||||
if (errorMessage == null && e instanceof SageMakerRuntimeException re && re.awsErrorDetails() != null) {
|
||||
switch (re.awsErrorDetails().errorCode()) {
|
||||
case ACCESS_DENIED_CODE, NOT_AUTHORIZED -> {
|
||||
errorMessage = format(
|
||||
"Access and Secret key stored in [%s] do not have sufficient permissions.",
|
||||
model.getInferenceEntityId()
|
||||
);
|
||||
restStatus = RestStatus.BAD_REQUEST;
|
||||
}
|
||||
case INCOMPLETE_SIGNATURE -> {
|
||||
errorMessage = format("The request signature does not conform to AWS standards [%s]", model.getInferenceEntityId());
|
||||
restStatus = RestStatus.INTERNAL_SERVER_ERROR; // this shouldn't happen and isn't anything the user can do about it
|
||||
}
|
||||
case INVALID_ACTION -> {
|
||||
errorMessage = format("The requested action is not valid for [%s]", model.getInferenceEntityId());
|
||||
restStatus = RestStatus.BAD_REQUEST;
|
||||
}
|
||||
case INVALID_CLIENT_TOKEN -> {
|
||||
errorMessage = format("Access key stored in [%s] does not exist in AWS", model.getInferenceEntityId());
|
||||
restStatus = RestStatus.FORBIDDEN;
|
||||
}
|
||||
case OPT_IN_REQUIRED -> {
|
||||
errorMessage = format("Access key stored in [%s] needs a subscription for the service", model.getInferenceEntityId());
|
||||
restStatus = RestStatus.FORBIDDEN;
|
||||
}
|
||||
case REQUEST_EXPIRED -> {
|
||||
errorMessage = format(
|
||||
"The request reached SageMaker more than 15 minutes after the date stamp on the request for [%s]",
|
||||
model.getInferenceEntityId()
|
||||
);
|
||||
restStatus = RestStatus.BAD_REQUEST;
|
||||
}
|
||||
case THROTTLING_EXCEPTION -> {
|
||||
errorMessage = format("SageMaker denied the request for [%s] due to request throttling", model.getInferenceEntityId());
|
||||
restStatus = RestStatus.BAD_REQUEST;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (errorMessage == null) {
|
||||
errorMessage = format("Received an error from SageMaker for [%s]", model.getInferenceEntityId());
|
||||
restStatus = RestStatus.INTERNAL_SERVER_ERROR;
|
||||
}
|
||||
|
||||
return Tuple.tuple(errorMessage, restStatus);
|
||||
}
|
||||
|
||||
public SageMakerStoredServiceSchema apiServiceSettings(Map<String, Object> serviceSettings, ValidationException validationException) {
|
||||
return schemaPayload.apiServiceSettings(serviceSettings, validationException);
|
||||
}
|
||||
|
||||
public SageMakerStoredTaskSchema apiTaskSettings(Map<String, Object> taskSettings, ValidationException validationException) {
|
||||
return schemaPayload.apiTaskSettings(taskSettings, validationException);
|
||||
}
|
||||
|
||||
public Stream<NamedWriteableRegistry.Entry> namedWriteables() {
|
||||
return schemaPayload.namedWriteables();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
|
||||
|
||||
import software.amazon.awssdk.core.SdkBytes;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
|
||||
|
||||
import java.util.EnumSet;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public interface SageMakerSchemaPayload {
|
||||
|
||||
/**
|
||||
* The model API keyword that users will supply in the service settings when creating the request.
|
||||
* Automatically registered in {@link SageMakerSchemas}.
|
||||
*/
|
||||
String api();
|
||||
|
||||
/**
|
||||
* The supported TaskTypes for this model API.
|
||||
* Automatically registered in {@link SageMakerSchemas}.
|
||||
*/
|
||||
EnumSet<TaskType> supportedTasks();
|
||||
|
||||
/**
|
||||
* Implement this if the model requires extra ServiceSettings that can be saved to the model index.
|
||||
* This can be accessed via {@link SageMakerModel#apiServiceSettings()}.
|
||||
*/
|
||||
default SageMakerStoredServiceSchema apiServiceSettings(Map<String, Object> serviceSettings, ValidationException validationException) {
|
||||
return SageMakerStoredServiceSchema.NO_OP;
|
||||
}
|
||||
|
||||
/**
|
||||
* Implement this if the model requires extra TaskSettings that can be saved to the model index.
|
||||
* This can be accessed via {@link SageMakerModel#apiTaskSettings()}.
|
||||
*/
|
||||
default SageMakerStoredTaskSchema apiTaskSettings(Map<String, Object> taskSettings, ValidationException validationException) {
|
||||
return SageMakerStoredTaskSchema.NO_OP;
|
||||
}
|
||||
|
||||
/**
|
||||
* This must be thrown if {@link SageMakerModel#apiServiceSettings()} or {@link SageMakerModel#apiTaskSettings()} return the wrong
|
||||
* object types.
|
||||
*/
|
||||
default Exception createUnsupportedSchemaException(SageMakerModel model) {
|
||||
return new IllegalArgumentException(
|
||||
Strings.format(
|
||||
"Unsupported SageMaker settings for api [%s] and task type [%s]: [%s] and [%s]",
|
||||
model.api(),
|
||||
model.getTaskType(),
|
||||
model.apiServiceSettings().getWriteableName(),
|
||||
model.apiTaskSettings().getWriteableName()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Automatically register the required registry entries with {@link org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider}.
|
||||
*/
|
||||
default Stream<NamedWriteableRegistry.Entry> namedWriteables() {
|
||||
return Stream.of();
|
||||
}
|
||||
|
||||
/**
|
||||
* The MIME type of the response from SageMaker.
|
||||
*/
|
||||
String accept(SageMakerModel model);
|
||||
|
||||
/**
|
||||
* The MIME type of the request to SageMaker.
|
||||
*/
|
||||
String contentType(SageMakerModel model);
|
||||
|
||||
/**
|
||||
* Translate to the body of the request in the MIME type specified by {@link #contentType(SageMakerModel)}.
|
||||
*/
|
||||
SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest request) throws Exception;
|
||||
|
||||
/**
|
||||
* Translate from the body of the response in the MIME type specified by {@link #accept(SageMakerModel)}.
|
||||
*/
|
||||
InferenceServiceResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception;
|
||||
|
||||
}
|
|
@ -0,0 +1,137 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
|
||||
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiTextEmbeddingPayload;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.EnumSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.elasticsearch.core.Strings.format;
|
||||
|
||||
/**
|
||||
* The mapping and registry for all supported model API.
|
||||
*/
|
||||
public class SageMakerSchemas {
|
||||
private static final Map<TaskAndApi, SageMakerSchema> schemas;
|
||||
private static final Map<TaskAndApi, SageMakerStreamSchema> streamSchemas;
|
||||
private static final Map<String, Set<TaskType>> tasksByApi;
|
||||
private static final Map<String, Set<TaskType>> streamingTasksByApi;
|
||||
private static final Set<TaskType> supportedStreamingTasks;
|
||||
private static final EnumSet<TaskType> supportedTaskTypes;
|
||||
|
||||
static {
|
||||
/*
|
||||
* Add new model API to the register call.
|
||||
*/
|
||||
schemas = register(new OpenAiTextEmbeddingPayload());
|
||||
|
||||
streamSchemas = schemas.entrySet()
|
||||
.stream()
|
||||
.filter(e -> e.getValue() instanceof SageMakerStreamSchema)
|
||||
.collect(Collectors.toMap(Map.Entry::getKey, e -> (SageMakerStreamSchema) e.getValue()));
|
||||
|
||||
tasksByApi = schemas.keySet()
|
||||
.stream()
|
||||
.collect(Collectors.groupingBy(TaskAndApi::api, Collectors.mapping(TaskAndApi::taskType, Collectors.toSet())));
|
||||
streamingTasksByApi = streamSchemas.keySet()
|
||||
.stream()
|
||||
.collect(Collectors.groupingBy(TaskAndApi::api, Collectors.mapping(TaskAndApi::taskType, Collectors.toSet())));
|
||||
|
||||
supportedStreamingTasks = streamSchemas.keySet().stream().map(TaskAndApi::taskType).collect(Collectors.toSet());
|
||||
supportedTaskTypes = EnumSet.copyOf(schemas.keySet().stream().map(TaskAndApi::taskType).collect(Collectors.toSet()));
|
||||
}
|
||||
|
||||
private static Map<TaskAndApi, SageMakerSchema> register(SageMakerSchemaPayload... payloads) {
|
||||
return Arrays.stream(payloads).flatMap(payload -> payload.supportedTasks().stream().map(taskType -> {
|
||||
var key = new TaskAndApi(taskType, payload.api());
|
||||
SageMakerSchema value;
|
||||
if (payload instanceof SageMakerStreamSchemaPayload streamPayload) {
|
||||
value = new SageMakerStreamSchema(streamPayload);
|
||||
} else {
|
||||
value = new SageMakerSchema(payload);
|
||||
}
|
||||
return Map.entry(key, value);
|
||||
})).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
|
||||
}
|
||||
|
||||
/**
|
||||
* Automatically register the stored Schema writeables with {@link org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider}.
|
||||
*/
|
||||
public static List<NamedWriteableRegistry.Entry> namedWriteables() {
|
||||
return Stream.concat(
|
||||
Stream.of(
|
||||
new NamedWriteableRegistry.Entry(
|
||||
SageMakerStoredServiceSchema.class,
|
||||
SageMakerStoredServiceSchema.NO_OP.getWriteableName(),
|
||||
in -> SageMakerStoredServiceSchema.NO_OP
|
||||
),
|
||||
new NamedWriteableRegistry.Entry(
|
||||
SageMakerStoredTaskSchema.class,
|
||||
SageMakerStoredTaskSchema.NO_OP.getWriteableName(),
|
||||
in -> SageMakerStoredTaskSchema.NO_OP
|
||||
)
|
||||
),
|
||||
schemas.values().stream().flatMap(SageMakerSchema::namedWriteables)
|
||||
).toList();
|
||||
}
|
||||
|
||||
public SageMakerSchema schemaFor(SageMakerModel model) throws ElasticsearchStatusException {
|
||||
return schemaFor(model.getTaskType(), model.api());
|
||||
}
|
||||
|
||||
public SageMakerSchema schemaFor(TaskType taskType, String api) throws ElasticsearchStatusException {
|
||||
var schema = schemas.get(new TaskAndApi(taskType, api));
|
||||
if (schema == null) {
|
||||
throw new ElasticsearchStatusException(
|
||||
format(
|
||||
"Task [%s] is not compatible for service [sagemaker] and api [%s]. Supported tasks: [%s]",
|
||||
api,
|
||||
taskType.toString(),
|
||||
tasksByApi.getOrDefault(api, Set.of())
|
||||
),
|
||||
RestStatus.METHOD_NOT_ALLOWED
|
||||
);
|
||||
}
|
||||
return schema;
|
||||
}
|
||||
|
||||
public SageMakerStreamSchema streamSchemaFor(SageMakerModel model) throws ElasticsearchStatusException {
|
||||
var schema = streamSchemas.get(new TaskAndApi(model.getTaskType(), model.api()));
|
||||
if (schema == null) {
|
||||
throw new ElasticsearchStatusException(
|
||||
format(
|
||||
"Streaming is not allowed for service [sagemaker], api [%s], and task [%s]. Supported streaming tasks: [%s]",
|
||||
model.api(),
|
||||
model.getTaskType().toString(),
|
||||
streamingTasksByApi.getOrDefault(model.api(), Set.of())
|
||||
),
|
||||
RestStatus.METHOD_NOT_ALLOWED
|
||||
);
|
||||
}
|
||||
return schema;
|
||||
}
|
||||
|
||||
public EnumSet<TaskType> supportedTaskTypes() {
|
||||
return supportedTaskTypes;
|
||||
}
|
||||
|
||||
public Set<TaskType> supportedStreamingTasks() {
|
||||
return supportedStreamingTasks;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
|
||||
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.TransportVersions;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.inference.ServiceSettings;
|
||||
import org.elasticsearch.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
|
||||
/**
|
||||
* Contains any model-specific settings that are stored in SageMakerServiceSettings.
|
||||
*/
|
||||
public interface SageMakerStoredServiceSchema extends ServiceSettings {
|
||||
SageMakerStoredServiceSchema NO_OP = new SageMakerStoredServiceSchema() {
|
||||
|
||||
private static final String NAME = "noop_sagemaker_service_schema";
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
return TransportVersions.ML_INFERENCE_SAGEMAKER;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) {
|
||||
return builder;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* The model is part of the SageMaker Endpoint definition and is not declared in the service settings.
|
||||
*/
|
||||
@Override
|
||||
default String modelId() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
default ToXContentObject getFilteredXContentObject() {
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* These extra service settings serialize flatly alongside the overall SageMaker ServiceSettings.
|
||||
*/
|
||||
@Override
|
||||
default boolean isFragment() {
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* If this Schema supports Text Embeddings, then we need to implement this.
|
||||
* {@link org.elasticsearch.xpack.inference.services.validation.TextEmbeddingModelValidator} will set the dimensions if the user
|
||||
* does not do it, so we need to store the dimensions and flip the {@link #dimensionsSetByUser()} boolean.
|
||||
*/
|
||||
default SageMakerStoredServiceSchema updateModelWithEmbeddingDetails(Integer dimensions) {
|
||||
return this;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
|
||||
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.TransportVersions;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.inference.TaskSettings;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Contains any model-specific settings that are stored in SageMakerTaskSettings.
|
||||
*/
|
||||
public interface SageMakerStoredTaskSchema extends TaskSettings {
|
||||
SageMakerStoredTaskSchema NO_OP = new SageMakerStoredTaskSchema() {
|
||||
@Override
|
||||
public boolean isEmpty() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SageMakerStoredTaskSchema updatedTaskSettings(Map<String, Object> newSettings) {
|
||||
return this;
|
||||
}
|
||||
|
||||
private static final String NAME = "noop_sagemaker_task_schema";
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
return TransportVersions.ML_INFERENCE_SAGEMAKER;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) {}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) {
|
||||
return builder;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* These extra service settings serialize flatly alongside the overall SageMaker ServiceSettings.
|
||||
*/
|
||||
@Override
|
||||
default boolean isFragment() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
SageMakerStoredTaskSchema updatedTaskSettings(Map<String, Object> newSettings);
|
||||
}
|
|
@ -0,0 +1,152 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
|
||||
|
||||
import software.amazon.awssdk.core.SdkBytes;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamRequest;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponseHandler;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.ResponseStream;
|
||||
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.ExceptionsHelper;
|
||||
import org.elasticsearch.common.CheckedBiFunction;
|
||||
import org.elasticsearch.common.CheckedSupplier;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
|
||||
|
||||
import java.util.Locale;
|
||||
import java.util.concurrent.Flow;
|
||||
import java.util.function.BiFunction;
|
||||
|
||||
/**
|
||||
* All the logic that is required to call any SageMaker model is handled within this Schema class.
|
||||
* Any model-specific logic is handled within the associated {@link SageMakerStreamSchemaPayload}.
|
||||
* This schema is specific for SageMaker's streaming API. For non-streaming, see {@link SageMakerSchema}.
|
||||
*/
|
||||
public class SageMakerStreamSchema extends SageMakerSchema {
|
||||
|
||||
private final SageMakerStreamSchemaPayload payload;
|
||||
|
||||
public SageMakerStreamSchema(SageMakerStreamSchemaPayload payload) {
|
||||
super(payload);
|
||||
this.payload = payload;
|
||||
}
|
||||
|
||||
public InvokeEndpointWithResponseStreamRequest streamRequest(SageMakerModel model, SageMakerInferenceRequest request) {
|
||||
return streamRequest(model, () -> payload.requestBytes(model, request));
|
||||
}
|
||||
|
||||
private InvokeEndpointWithResponseStreamRequest streamRequest(SageMakerModel model, CheckedSupplier<SdkBytes, Exception> body) {
|
||||
try {
|
||||
return createStreamRequest(model).accept(payload.accept(model))
|
||||
.contentType(payload.contentType(model))
|
||||
.body(body.get())
|
||||
.build();
|
||||
} catch (ElasticsearchStatusException e) {
|
||||
throw e;
|
||||
} catch (Exception e) {
|
||||
throw new ElasticsearchStatusException(
|
||||
"Failed to create SageMaker request for [%s]",
|
||||
RestStatus.INTERNAL_SERVER_ERROR,
|
||||
e,
|
||||
model.getInferenceEntityId()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
public InferenceServiceResults streamResponse(SageMakerModel model, SageMakerClient.SageMakerStream response) {
|
||||
return streamResponse(model, response, payload::streamResponseBody, this::error);
|
||||
}
|
||||
|
||||
private InferenceServiceResults streamResponse(
|
||||
SageMakerModel model,
|
||||
SageMakerClient.SageMakerStream response,
|
||||
CheckedBiFunction<SageMakerModel, SdkBytes, InferenceServiceResults.Result, Exception> parseFunction,
|
||||
BiFunction<SageMakerModel, Exception, Exception> errorFunction
|
||||
) {
|
||||
return new StreamingChatCompletionResults(downstream -> {
|
||||
response.responseStream().subscribe(new Flow.Subscriber<>() {
|
||||
private volatile Flow.Subscription upstream;
|
||||
|
||||
@Override
|
||||
public void onSubscribe(Flow.Subscription subscription) {
|
||||
this.upstream = subscription;
|
||||
downstream.onSubscribe(subscription);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onNext(ResponseStream item) {
|
||||
if (item.sdkEventType() == ResponseStream.EventType.PAYLOAD_PART) {
|
||||
item.accept(InvokeEndpointWithResponseStreamResponseHandler.Visitor.builder().onPayloadPart(payloadPart -> {
|
||||
try {
|
||||
downstream.onNext(parseFunction.apply(model, payloadPart.bytes()));
|
||||
} catch (Exception e) {
|
||||
downstream.onError(errorFunction.apply(model, e));
|
||||
}
|
||||
}).build());
|
||||
} else {
|
||||
assert upstream != null : "upstream is unset";
|
||||
upstream.request(1);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable throwable) {
|
||||
if (throwable instanceof Exception e) {
|
||||
downstream.onError(errorFunction.apply(model, e));
|
||||
} else {
|
||||
ExceptionsHelper.maybeError(throwable).ifPresent(ExceptionsHelper::maybeDieOnAnotherThread);
|
||||
var e = new RuntimeException("Fatal while streaming SageMaker response for [" + model.getInferenceEntityId() + "]");
|
||||
downstream.onError(errorFunction.apply(model, e));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete() {
|
||||
downstream.onComplete();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
public InvokeEndpointWithResponseStreamRequest chatCompletionStreamRequest(SageMakerModel model, UnifiedCompletionRequest request) {
|
||||
return streamRequest(model, () -> payload.chatCompletionRequestBytes(model, request));
|
||||
}
|
||||
|
||||
public InferenceServiceResults chatCompletionStreamResponse(SageMakerModel model, SageMakerClient.SageMakerStream response) {
|
||||
return streamResponse(model, response, payload::chatCompletionResponseBody, this::chatCompletionError);
|
||||
}
|
||||
|
||||
public UnifiedChatCompletionException chatCompletionError(SageMakerModel model, Exception e) {
|
||||
if (e instanceof UnifiedChatCompletionException ucce) {
|
||||
return ucce;
|
||||
}
|
||||
|
||||
var error = errorMessageAndStatus(model, e);
|
||||
return new UnifiedChatCompletionException(error.v2(), error.v1(), "error", error.v2().name().toLowerCase(Locale.ROOT));
|
||||
}
|
||||
|
||||
private InvokeEndpointWithResponseStreamRequest.Builder createStreamRequest(SageMakerModel model) {
|
||||
var request = InvokeEndpointWithResponseStreamRequest.builder();
|
||||
request.endpointName(model.endpointName());
|
||||
model.customAttributes().ifPresent(request::customAttributes);
|
||||
model.inferenceComponentName().ifPresent(request::inferenceComponentName);
|
||||
model.inferenceIdForDataCapture().ifPresent(request::inferenceId);
|
||||
model.sessionId().ifPresent(request::sessionId);
|
||||
model.targetContainerHostname().ifPresent(request::targetContainerHostname);
|
||||
model.targetVariant().ifPresent(request::targetVariant);
|
||||
return request;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
|
||||
|
||||
import software.amazon.awssdk.core.SdkBytes;
|
||||
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
|
||||
|
||||
import java.util.EnumSet;
|
||||
|
||||
/**
|
||||
* Implemented for models that support streaming.
|
||||
* This is an extension of {@link SageMakerSchemaPayload} because Elastic expects Completion tasks to handle both streaming and
|
||||
* non-streaming, and all models currently support toggling streaming on/off.
|
||||
*/
|
||||
public interface SageMakerStreamSchemaPayload extends SageMakerSchemaPayload {
|
||||
/**
|
||||
* We currently only support streaming for Completion and Chat Completion, and if we are going to implement one then we should implement
|
||||
* the other, so this interface requires both streaming input and streaming unified input.
|
||||
* If we ever allowed streaming for more than just Completion, then we'd probably break up this class so that Unified Chat Completion
|
||||
* was its own interface.
|
||||
*/
|
||||
@Override
|
||||
default EnumSet<TaskType> supportedTasks() {
|
||||
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
|
||||
}
|
||||
|
||||
/**
|
||||
* This API would only be called for Completion task types. {@link #requestBytes(SageMakerModel, SageMakerInferenceRequest)} would
|
||||
* handle the request translation for both streaming and non-streaming.
|
||||
*/
|
||||
InferenceServiceResults.Result streamResponseBody(SageMakerModel model, SdkBytes response) throws Exception;
|
||||
|
||||
SdkBytes chatCompletionRequestBytes(SageMakerModel model, UnifiedCompletionRequest request) throws Exception;
|
||||
|
||||
InferenceServiceResults.Result chatCompletionResponseBody(SageMakerModel model, SdkBytes response) throws Exception;
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
|
||||
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
|
||||
record TaskAndApi(TaskType taskType, String api) {}
|
|
@ -0,0 +1,229 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.sagemaker.schema.openai;
|
||||
|
||||
import software.amazon.awssdk.core.SdkBytes;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
|
||||
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.TransportVersions;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.SimilarityMeasure;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.xcontent.XContent;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentParserConfiguration;
|
||||
import org.elasticsearch.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
|
||||
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiEmbeddingsResponseEntity;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemaPayload;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.EnumSet;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
|
||||
|
||||
public class OpenAiTextEmbeddingPayload implements SageMakerSchemaPayload {
|
||||
|
||||
private static final XContent jsonXContent = JsonXContent.jsonXContent;
|
||||
private static final String APPLICATION_JSON = jsonXContent.type().mediaTypeWithoutParameters();
|
||||
|
||||
@Override
|
||||
public String api() {
|
||||
return "openai";
|
||||
}
|
||||
|
||||
@Override
|
||||
public EnumSet<TaskType> supportedTasks() {
|
||||
return EnumSet.of(TaskType.TEXT_EMBEDDING);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SageMakerStoredServiceSchema apiServiceSettings(Map<String, Object> serviceSettings, ValidationException validationException) {
|
||||
return ApiServiceSettings.fromMap(serviceSettings, validationException);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SageMakerStoredTaskSchema apiTaskSettings(Map<String, Object> taskSettings, ValidationException validationException) {
|
||||
return ApiTaskSettings.fromMap(taskSettings, validationException);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Stream<NamedWriteableRegistry.Entry> namedWriteables() {
|
||||
return Stream.of(
|
||||
new NamedWriteableRegistry.Entry(SageMakerStoredServiceSchema.class, ApiServiceSettings.NAME, ApiServiceSettings::new),
|
||||
new NamedWriteableRegistry.Entry(SageMakerStoredTaskSchema.class, ApiTaskSettings.NAME, ApiTaskSettings::new)
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String accept(SageMakerModel model) {
|
||||
return APPLICATION_JSON;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String contentType(SageMakerModel model) {
|
||||
return APPLICATION_JSON;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest request) throws Exception {
|
||||
if (model.apiServiceSettings() instanceof ApiServiceSettings apiServiceSettings
|
||||
&& model.apiTaskSettings() instanceof ApiTaskSettings apiTaskSettings) {
|
||||
try (var builder = JsonXContent.contentBuilder()) {
|
||||
builder.startObject();
|
||||
if (request.query() != null) {
|
||||
builder.field("query", request.query());
|
||||
}
|
||||
if (request.input().size() == 1) {
|
||||
builder.field("input", request.input().get(0));
|
||||
} else {
|
||||
builder.field("input", request.input());
|
||||
}
|
||||
if (apiTaskSettings.user() != null) {
|
||||
builder.field("user", apiTaskSettings.user());
|
||||
}
|
||||
if (apiServiceSettings.dimensionsSetByUser() && apiServiceSettings.dimensions() != null) {
|
||||
builder.field("dimensions", apiServiceSettings.dimensions());
|
||||
}
|
||||
builder.endObject();
|
||||
return SdkBytes.fromUtf8String(Strings.toString(builder));
|
||||
}
|
||||
} else {
|
||||
throw createUnsupportedSchemaException(model);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public TextEmbeddingFloatResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception {
|
||||
try (var p = jsonXContent.createParser(XContentParserConfiguration.EMPTY, response.body().asInputStream())) {
|
||||
return OpenAiEmbeddingsResponseEntity.EmbeddingFloatResult.PARSER.apply(p, null).toTextEmbeddingFloatResults();
|
||||
}
|
||||
}
|
||||
|
||||
record ApiServiceSettings(@Nullable Integer dimensions, Boolean dimensionsSetByUser) implements SageMakerStoredServiceSchema {
|
||||
private static final String NAME = "sagemaker_openai_text_embeddings_service_settings";
|
||||
private static final String DIMENSIONS_FIELD = "dimensions";
|
||||
|
||||
ApiServiceSettings(StreamInput in) throws IOException {
|
||||
this(in.readOptionalInt(), in.readBoolean());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
return TransportVersions.ML_INFERENCE_SAGEMAKER;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeOptionalInt(dimensions);
|
||||
out.writeBoolean(dimensionsSetByUser);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
if (dimensions != null) {
|
||||
builder.field(DIMENSIONS_FIELD, dimensions);
|
||||
}
|
||||
return builder;
|
||||
}
|
||||
|
||||
static ApiServiceSettings fromMap(Map<String, Object> serviceSettings, ValidationException validationException) {
|
||||
var dimensions = extractOptionalPositiveInteger(
|
||||
serviceSettings,
|
||||
DIMENSIONS_FIELD,
|
||||
ModelConfigurations.SERVICE_SETTINGS,
|
||||
validationException
|
||||
);
|
||||
|
||||
return new ApiServiceSettings(dimensions, dimensions != null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SimilarityMeasure similarity() {
|
||||
return SimilarityMeasure.DOT_PRODUCT;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DenseVectorFieldMapper.ElementType elementType() {
|
||||
return DenseVectorFieldMapper.ElementType.FLOAT;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SageMakerStoredServiceSchema updateModelWithEmbeddingDetails(Integer dimensions) {
|
||||
return new ApiServiceSettings(dimensions, false);
|
||||
}
|
||||
}
|
||||
|
||||
record ApiTaskSettings(@Nullable String user) implements SageMakerStoredTaskSchema {
|
||||
private static final String NAME = "sagemaker_openai_text_embeddings_task_settings";
|
||||
private static final String USER_FIELD = "user";
|
||||
|
||||
ApiTaskSettings(StreamInput in) throws IOException {
|
||||
this(in.readOptionalString());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
return TransportVersions.ML_INFERENCE_SAGEMAKER;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeOptionalString(user);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
return user != null ? builder.field(USER_FIELD, user) : builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isEmpty() {
|
||||
return user == null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ApiTaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
|
||||
var validationException = new ValidationException();
|
||||
var newTaskSettings = fromMap(newSettings, validationException);
|
||||
validationException.throwIfValidationErrorsExist();
|
||||
|
||||
return new ApiTaskSettings(newTaskSettings.user() != null ? newTaskSettings.user() : user);
|
||||
}
|
||||
|
||||
static ApiTaskSettings fromMap(Map<String, Object> map, ValidationException exception) {
|
||||
var user = extractOptionalString(map, USER_FIELD, ModelConfigurations.TASK_SETTINGS, exception);
|
||||
return new ApiTaskSettings(user);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services;
|
||||
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.xcontent.ToXContent;
|
||||
import org.elasticsearch.xcontent.XContentParserConfiguration;
|
||||
import org.elasticsearch.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
|
||||
import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
/**
|
||||
* All Inference Setting classes derived from {@link org.elasticsearch.inference.ServiceSettings},
|
||||
* {@link org.elasticsearch.inference.TaskSettings}, and {@link org.elasticsearch.inference.SecretSettings}
|
||||
* must be able to read/write to an index via ToXContent as well as read/write between nodes via Writeable.
|
||||
*/
|
||||
public abstract class InferenceSettingsTestCase<T extends Writeable & ToXContent> extends AbstractBWCWireSerializationTestCase<T> {
|
||||
|
||||
/**
|
||||
* This method is final because {@link org.elasticsearch.inference.ModelConfigurations} settings must be registered in
|
||||
* {@link InferenceNamedWriteablesProvider}.
|
||||
*/
|
||||
@Override
|
||||
protected final NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||
return new NamedWriteableRegistry(InferenceNamedWriteablesProvider.getNamedWriteables());
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper implementation since most mutates can be handled by continuously randomizing until we have another test instance.
|
||||
*/
|
||||
@Override
|
||||
protected T mutateInstance(T instance) throws IOException {
|
||||
return randomValueOtherThan(instance, this::createTestInstance);
|
||||
}
|
||||
|
||||
/**
|
||||
* Change this for BWC once the implementation requires difference objects depending on the transport version.
|
||||
*/
|
||||
@Override
|
||||
protected T mutateInstanceForVersion(T instance, TransportVersion version) {
|
||||
return instance;
|
||||
}
|
||||
|
||||
/**
|
||||
* Verify that we can write to XContent and then read from XContent.
|
||||
* This simulates saving the model to the index and then reading the model from the index.
|
||||
*/
|
||||
public final void testXContentRoundTrip() throws IOException {
|
||||
var instance = createTestInstance();
|
||||
var instanceAsMap = toMap(instance);
|
||||
var roundTripInstance = fromMutableMap(new HashMap<>(instanceAsMap));
|
||||
assertThat(roundTripInstance, equalTo(instance));
|
||||
}
|
||||
|
||||
protected abstract T fromMutableMap(Map<String, Object> mutableMap);
|
||||
|
||||
public static Map<String, Object> toMap(ToXContent instance) throws IOException {
|
||||
try (var builder = JsonXContent.contentBuilder()) {
|
||||
if (instance.isFragment()) {
|
||||
builder.startObject();
|
||||
}
|
||||
instance.toXContent(builder, ToXContent.EMPTY_PARAMS);
|
||||
if (instance.isFragment()) {
|
||||
builder.endObject();
|
||||
}
|
||||
var taskSettingsBytes = Strings.toString(builder).getBytes(StandardCharsets.UTF_8);
|
||||
try (var parser = JsonXContent.jsonXContent.createParser(XContentParserConfiguration.EMPTY, taskSettingsBytes)) {
|
||||
return parser.map();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper, since most settings contain optional strings.
|
||||
*/
|
||||
protected static String randomOptionalString() {
|
||||
return randomBoolean() ? randomString() : null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper, since most settings contain strings.
|
||||
*/
|
||||
protected static String randomString() {
|
||||
return randomAlphaOfLength(randomIntBetween(4, 8));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,180 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.sagemaker;
|
||||
|
||||
import software.amazon.awssdk.core.async.SdkPublisher;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.SageMakerRuntimeAsyncClient;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointRequest;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamRequest;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponse;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponseHandler;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.ResponseStream;
|
||||
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.common.cache.CacheLoader;
|
||||
import org.elasticsearch.common.settings.SecureString;
|
||||
import org.elasticsearch.core.TimeValue;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.mockito.ArgumentMatchers;
|
||||
import org.reactivestreams.Subscriber;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.assertArg;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.spy;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class SageMakerClientTests extends ESTestCase {
|
||||
private static final InvokeEndpointRequest REQUEST = InvokeEndpointRequest.builder().build();
|
||||
private static final InvokeEndpointWithResponseStreamRequest STREAM_REQUEST = InvokeEndpointWithResponseStreamRequest.builder().build();
|
||||
private SageMakerRuntimeAsyncClient awsClient;
|
||||
private CacheLoader<SageMakerClient.RegionAndSecrets, SageMakerRuntimeAsyncClient> clientFactory;
|
||||
private SageMakerClient client;
|
||||
private ThreadPool threadPool;
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
super.setUp();
|
||||
threadPool = createThreadPool(inferenceUtilityPool());
|
||||
|
||||
awsClient = mock();
|
||||
clientFactory = spy(new CacheLoader<>() {
|
||||
public SageMakerRuntimeAsyncClient load(SageMakerClient.RegionAndSecrets key) {
|
||||
return awsClient;
|
||||
}
|
||||
});
|
||||
client = new SageMakerClient(clientFactory, threadPool);
|
||||
}
|
||||
|
||||
@After
|
||||
public void shutdown() throws IOException {
|
||||
terminate(threadPool);
|
||||
}
|
||||
|
||||
public void testInvoke() throws Exception {
|
||||
var expectedResponse = InvokeEndpointResponse.builder().build();
|
||||
when(awsClient.invokeEndpoint(any(InvokeEndpointRequest.class))).thenReturn(CompletableFuture.completedFuture(expectedResponse));
|
||||
|
||||
var listener = invoke(TimeValue.THIRTY_SECONDS);
|
||||
verify(clientFactory, times(1)).load(any());
|
||||
verify(listener, times(1)).onResponse(eq(expectedResponse));
|
||||
}
|
||||
|
||||
private static SageMakerClient.RegionAndSecrets regionAndSecrets() {
|
||||
return new SageMakerClient.RegionAndSecrets(
|
||||
"us-east-1",
|
||||
new AwsSecretSettings(new SecureString("access"), new SecureString("secrets"))
|
||||
);
|
||||
}
|
||||
|
||||
private ActionListener<InvokeEndpointResponse> invoke(TimeValue timeout) throws InterruptedException {
|
||||
var latch = new CountDownLatch(1);
|
||||
ActionListener<InvokeEndpointResponse> listener = spy(ActionListener.noop());
|
||||
client.invoke(regionAndSecrets(), REQUEST, timeout, ActionListener.runAfter(listener, latch::countDown));
|
||||
assertTrue("Timed out waiting for invoke call", latch.await(5, TimeUnit.SECONDS));
|
||||
return listener;
|
||||
}
|
||||
|
||||
public void testInvokeCache() throws Exception {
|
||||
when(awsClient.invokeEndpoint(any(InvokeEndpointRequest.class))).thenReturn(
|
||||
CompletableFuture.completedFuture(InvokeEndpointResponse.builder().build())
|
||||
);
|
||||
|
||||
invoke(TimeValue.THIRTY_SECONDS);
|
||||
invoke(TimeValue.THIRTY_SECONDS);
|
||||
verify(clientFactory, times(1)).load(any());
|
||||
}
|
||||
|
||||
public void testInvokeTimeout() throws Exception {
|
||||
when(awsClient.invokeEndpoint(any(InvokeEndpointRequest.class))).thenReturn(new CompletableFuture<>());
|
||||
|
||||
var listener = invoke(TimeValue.timeValueMillis(10));
|
||||
|
||||
verify(clientFactory, times(1)).load(any());
|
||||
verifyTimeout(listener);
|
||||
}
|
||||
|
||||
private static void verifyTimeout(ActionListener<?> listener) {
|
||||
verify(listener, times(1)).onFailure(assertArg(e -> assertThat(e.getMessage(), equalTo("Request timed out after [10ms]"))));
|
||||
}
|
||||
|
||||
public void testInvokeStream() throws Exception {
|
||||
SdkPublisher<ResponseStream> publisher = mockPublisher();
|
||||
|
||||
var listener = invokeStream(TimeValue.THIRTY_SECONDS);
|
||||
|
||||
verify(publisher, never()).subscribe(ArgumentMatchers.<Subscriber<ResponseStream>>any());
|
||||
verify(listener, times(1)).onResponse(assertArg(stream -> stream.responseStream().subscribe(mock())));
|
||||
verify(publisher, times(1)).subscribe(ArgumentMatchers.<Subscriber<ResponseStream>>any());
|
||||
}
|
||||
|
||||
private SdkPublisher<ResponseStream> mockPublisher() {
|
||||
SdkPublisher<ResponseStream> publisher = mock();
|
||||
doAnswer(ans -> {
|
||||
InvokeEndpointWithResponseStreamResponseHandler handler = ans.getArgument(1);
|
||||
handler.responseReceived(InvokeEndpointWithResponseStreamResponse.builder().build());
|
||||
handler.onEventStream(publisher);
|
||||
return CompletableFuture.completedFuture((Void) null);
|
||||
}).when(awsClient).invokeEndpointWithResponseStream(any(InvokeEndpointWithResponseStreamRequest.class), any());
|
||||
return publisher;
|
||||
}
|
||||
|
||||
private ActionListener<SageMakerClient.SageMakerStream> invokeStream(TimeValue timeout) throws Exception {
|
||||
var latch = new CountDownLatch(1);
|
||||
ActionListener<SageMakerClient.SageMakerStream> listener = spy(ActionListener.noop());
|
||||
client.invokeStream(regionAndSecrets(), STREAM_REQUEST, timeout, ActionListener.runAfter(listener, latch::countDown));
|
||||
assertTrue("Timed out waiting for invoke call", latch.await(5, TimeUnit.SECONDS));
|
||||
return listener;
|
||||
}
|
||||
|
||||
public void testInvokeStreamCache() throws Exception {
|
||||
mockPublisher();
|
||||
|
||||
invokeStream(TimeValue.THIRTY_SECONDS);
|
||||
invokeStream(TimeValue.THIRTY_SECONDS);
|
||||
|
||||
verify(clientFactory, times(1)).load(any());
|
||||
}
|
||||
|
||||
public void testInvokeStreamTimeout() throws Exception {
|
||||
when(awsClient.invokeEndpointWithResponseStream(any(InvokeEndpointWithResponseStreamRequest.class), any())).thenReturn(
|
||||
new CompletableFuture<>()
|
||||
);
|
||||
|
||||
var listener = invokeStream(TimeValue.timeValueMillis(10));
|
||||
|
||||
verify(clientFactory, times(1)).load(any());
|
||||
verifyTimeout(listener);
|
||||
}
|
||||
|
||||
public void testClose() throws Exception {
|
||||
// load cache
|
||||
mockPublisher();
|
||||
invokeStream(TimeValue.THIRTY_SECONDS);
|
||||
// clear cache
|
||||
client.close();
|
||||
verify(awsClient, times(1)).close();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,463 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.sagemaker;
|
||||
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponse;
|
||||
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.common.settings.SecureString;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.util.concurrent.EsExecutors;
|
||||
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||
import org.elasticsearch.inference.ChunkInferenceInput;
|
||||
import org.elasticsearch.inference.InputType;
|
||||
import org.elasticsearch.inference.Model;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
|
||||
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests;
|
||||
import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
|
||||
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchema;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStreamSchema;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.elasticsearch.action.ActionListener.assertOnce;
|
||||
import static org.elasticsearch.action.support.ActionTestUtils.assertNoFailureListener;
|
||||
import static org.elasticsearch.action.support.ActionTestUtils.assertNoSuccessListener;
|
||||
import static org.elasticsearch.core.TimeValue.THIRTY_SECONDS;
|
||||
import static org.elasticsearch.xpack.core.inference.action.UnifiedCompletionRequestTests.randomUnifiedCompletionRequest;
|
||||
import static org.hamcrest.Matchers.containsInAnyOrder;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.in;
|
||||
import static org.hamcrest.Matchers.isA;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.ArgumentMatchers.assertArg;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.only;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class SageMakerServiceTests extends ESTestCase {
|
||||
|
||||
private static final String QUERY = "query";
|
||||
private static final List<String> INPUT = List.of("input");
|
||||
private static final InputType INPUT_TYPE = InputType.UNSPECIFIED;
|
||||
|
||||
private SageMakerModelBuilder modelBuilder;
|
||||
private SageMakerClient client;
|
||||
private SageMakerSchemas schemas;
|
||||
private SageMakerService sageMakerService;
|
||||
|
||||
@Before
|
||||
public void init() {
|
||||
modelBuilder = mock();
|
||||
client = mock();
|
||||
schemas = mock();
|
||||
ThreadPool threadPool = mock();
|
||||
when(threadPool.executor(anyString())).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE);
|
||||
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
|
||||
sageMakerService = new SageMakerService(modelBuilder, client, schemas, threadPool, Map::of);
|
||||
}
|
||||
|
||||
public void testSupportedTaskTypes() {
|
||||
sageMakerService.supportedTaskTypes();
|
||||
verify(schemas, only()).supportedTaskTypes();
|
||||
}
|
||||
|
||||
public void testSupportedStreamingTasks() {
|
||||
sageMakerService.supportedStreamingTasks();
|
||||
verify(schemas, only()).supportedStreamingTasks();
|
||||
}
|
||||
|
||||
public void testParseRequestConfig() {
|
||||
sageMakerService.parseRequestConfig("modelId", TaskType.ANY, Map.of(), assertNoFailureListener(model -> {
|
||||
verify(modelBuilder, only()).fromRequest(eq("modelId"), eq(TaskType.ANY), eq(SageMakerService.NAME), eq(Map.of()));
|
||||
}));
|
||||
}
|
||||
|
||||
public void testParsePersistedConfigWithSecrets() {
|
||||
sageMakerService.parsePersistedConfigWithSecrets("modelId", TaskType.ANY, Map.of(), Map.of());
|
||||
verify(modelBuilder, only()).fromStorage(eq("modelId"), eq(TaskType.ANY), eq(SageMakerService.NAME), eq(Map.of()), eq(Map.of()));
|
||||
}
|
||||
|
||||
public void testParsePersistedConfig() {
|
||||
sageMakerService.parsePersistedConfig("modelId", TaskType.ANY, Map.of());
|
||||
verify(modelBuilder, only()).fromStorage(eq("modelId"), eq(TaskType.ANY), eq(SageMakerService.NAME), eq(Map.of()), eq(null));
|
||||
}
|
||||
|
||||
public void testInferWithWrongModel() {
|
||||
sageMakerService.infer(
|
||||
mockUnsupportedModel(),
|
||||
QUERY,
|
||||
false,
|
||||
null,
|
||||
INPUT,
|
||||
false,
|
||||
null,
|
||||
INPUT_TYPE,
|
||||
THIRTY_SECONDS,
|
||||
assertUnsupportedModel()
|
||||
);
|
||||
}
|
||||
|
||||
private static Model mockUnsupportedModel() {
|
||||
Model model = mock();
|
||||
ModelConfigurations modelConfigurations = mock();
|
||||
when(modelConfigurations.getService()).thenReturn("mockService");
|
||||
when(modelConfigurations.getInferenceEntityId()).thenReturn("mockInferenceId");
|
||||
when(model.getConfigurations()).thenReturn(modelConfigurations);
|
||||
return model;
|
||||
}
|
||||
|
||||
private static <T> ActionListener<T> assertUnsupportedModel() {
|
||||
return assertNoSuccessListener(e -> {
|
||||
assertThat(e, isA(ElasticsearchStatusException.class));
|
||||
assertThat(
|
||||
e.getMessage(),
|
||||
equalTo(
|
||||
"The internal model was invalid, please delete the service [mockService] "
|
||||
+ "with id [mockInferenceId] and add it again."
|
||||
)
|
||||
);
|
||||
assertThat(((ElasticsearchStatusException) e).status(), equalTo(RestStatus.INTERNAL_SERVER_ERROR));
|
||||
});
|
||||
}
|
||||
|
||||
public void testInfer() {
|
||||
var model = mockModel();
|
||||
|
||||
SageMakerSchema schema = mock();
|
||||
when(schemas.schemaFor(model)).thenReturn(schema);
|
||||
mockInvoke();
|
||||
|
||||
sageMakerService.infer(
|
||||
model,
|
||||
QUERY,
|
||||
null,
|
||||
null,
|
||||
INPUT,
|
||||
false,
|
||||
null,
|
||||
INPUT_TYPE,
|
||||
THIRTY_SECONDS,
|
||||
assertNoFailureListener(ignored -> {
|
||||
verify(schemas, only()).schemaFor(eq(model));
|
||||
verify(schema, times(1)).request(eq(model), assertRequest());
|
||||
verify(schema, times(1)).response(eq(model), any(), any());
|
||||
})
|
||||
);
|
||||
verify(client, only()).invoke(any(), any(), any(), any());
|
||||
verifyNoMoreInteractions(client, schemas, schema);
|
||||
}
|
||||
|
||||
private SageMakerModel mockModel() {
|
||||
SageMakerModel model = mock();
|
||||
when(model.override(null)).thenReturn(model);
|
||||
when(model.awsSecretSettings()).thenReturn(
|
||||
Optional.of(new AwsSecretSettings(new SecureString("test-accessKey"), new SecureString("test-secretKey")))
|
||||
);
|
||||
return model;
|
||||
}
|
||||
|
||||
private void mockInvoke() {
|
||||
doAnswer(ans -> {
|
||||
ActionListener<InvokeEndpointResponse> responseListener = ans.getArgument(3);
|
||||
responseListener.onResponse(InvokeEndpointResponse.builder().build());
|
||||
return null; // Void
|
||||
}).when(client).invoke(any(), any(), any(), any());
|
||||
}
|
||||
|
||||
private static SageMakerInferenceRequest assertRequest() {
|
||||
return assertArg(request -> {
|
||||
assertThat(request.query(), equalTo(QUERY));
|
||||
assertThat(request.input(), containsInAnyOrder(INPUT.toArray()));
|
||||
assertThat(request.inputType(), equalTo(InputType.UNSPECIFIED));
|
||||
assertNull(request.returnDocuments());
|
||||
assertNull(request.topN());
|
||||
});
|
||||
}
|
||||
|
||||
public void testInferStream() {
|
||||
SageMakerModel model = mockModel();
|
||||
|
||||
SageMakerStreamSchema schema = mock();
|
||||
when(schemas.streamSchemaFor(model)).thenReturn(schema);
|
||||
mockInvokeStream();
|
||||
|
||||
sageMakerService.infer(model, QUERY, null, null, INPUT, true, null, INPUT_TYPE, THIRTY_SECONDS, assertNoFailureListener(ignored -> {
|
||||
verify(schemas, only()).streamSchemaFor(eq(model));
|
||||
verify(schema, times(1)).streamRequest(eq(model), assertRequest());
|
||||
verify(schema, times(1)).streamResponse(eq(model), any());
|
||||
}));
|
||||
verify(client, only()).invokeStream(any(), any(), any(), any());
|
||||
verifyNoMoreInteractions(client, schemas, schema);
|
||||
}
|
||||
|
||||
private void mockInvokeStream() {
|
||||
doAnswer(ans -> {
|
||||
ActionListener<SageMakerClient.SageMakerStream> responseListener = ans.getArgument(3);
|
||||
responseListener.onResponse(
|
||||
new SageMakerClient.SageMakerStream(InvokeEndpointWithResponseStreamResponse.builder().build(), mock())
|
||||
);
|
||||
return null; // Void
|
||||
}).when(client).invokeStream(any(), any(), any(), any());
|
||||
}
|
||||
|
||||
public void testInferError() {
|
||||
SageMakerModel model = mockModel();
|
||||
|
||||
var expectedException = new IllegalArgumentException("hola");
|
||||
SageMakerSchema schema = mock();
|
||||
when(schemas.schemaFor(model)).thenReturn(schema);
|
||||
mockInvokeFailure(expectedException);
|
||||
|
||||
sageMakerService.infer(
|
||||
model,
|
||||
QUERY,
|
||||
null,
|
||||
null,
|
||||
INPUT,
|
||||
false,
|
||||
null,
|
||||
INPUT_TYPE,
|
||||
THIRTY_SECONDS,
|
||||
assertNoSuccessListener(ignored -> {
|
||||
verify(schemas, only()).schemaFor(eq(model));
|
||||
verify(schema, times(1)).request(eq(model), assertRequest());
|
||||
verify(schema, times(1)).error(eq(model), assertArg(e -> assertThat(e, equalTo(expectedException))));
|
||||
})
|
||||
);
|
||||
verify(client, only()).invoke(any(), any(), any(), any());
|
||||
verifyNoMoreInteractions(client, schemas, schema);
|
||||
}
|
||||
|
||||
private void mockInvokeFailure(Exception e) {
|
||||
doAnswer(ans -> {
|
||||
ActionListener<?> responseListener = ans.getArgument(3);
|
||||
responseListener.onFailure(e);
|
||||
return null; // Void
|
||||
}).when(client).invoke(any(), any(), any(), any());
|
||||
}
|
||||
|
||||
public void testInferException() {
|
||||
SageMakerModel model = mockModel();
|
||||
when(model.getInferenceEntityId()).thenReturn("some id");
|
||||
|
||||
SageMakerStreamSchema schema = mock();
|
||||
when(schemas.streamSchemaFor(model)).thenReturn(schema);
|
||||
doThrow(new IllegalArgumentException("wow, really?")).when(client).invokeStream(any(), any(), any(), any());
|
||||
|
||||
sageMakerService.infer(model, QUERY, null, null, INPUT, true, null, INPUT_TYPE, THIRTY_SECONDS, assertNoSuccessListener(e -> {
|
||||
verify(schemas, only()).streamSchemaFor(eq(model));
|
||||
verify(schema, times(1)).streamRequest(eq(model), assertRequest());
|
||||
assertThat(e, isA(ElasticsearchStatusException.class));
|
||||
assertThat(((ElasticsearchStatusException) e).status(), equalTo(RestStatus.INTERNAL_SERVER_ERROR));
|
||||
assertThat(e.getMessage(), equalTo("Failed to call SageMaker for inference id [some id]."));
|
||||
}));
|
||||
verify(client, only()).invokeStream(any(), any(), any(), any());
|
||||
verifyNoMoreInteractions(client, schemas, schema);
|
||||
}
|
||||
|
||||
public void testUnifiedInferWithWrongModel() {
|
||||
sageMakerService.unifiedCompletionInfer(
|
||||
mockUnsupportedModel(),
|
||||
randomUnifiedCompletionRequest(),
|
||||
THIRTY_SECONDS,
|
||||
assertUnsupportedModel()
|
||||
);
|
||||
}
|
||||
|
||||
public void testUnifiedInfer() {
|
||||
var model = mockModel();
|
||||
|
||||
SageMakerStreamSchema schema = mock();
|
||||
when(schemas.streamSchemaFor(model)).thenReturn(schema);
|
||||
mockInvokeStream();
|
||||
|
||||
sageMakerService.unifiedCompletionInfer(
|
||||
model,
|
||||
randomUnifiedCompletionRequest(),
|
||||
THIRTY_SECONDS,
|
||||
assertNoFailureListener(ignored -> {
|
||||
verify(schemas, only()).streamSchemaFor(eq(model));
|
||||
verify(schema, times(1)).chatCompletionStreamRequest(eq(model), any());
|
||||
verify(schema, times(1)).chatCompletionStreamResponse(eq(model), any());
|
||||
})
|
||||
);
|
||||
verify(client, only()).invokeStream(any(), any(), any(), any());
|
||||
verifyNoMoreInteractions(client, schemas, schema);
|
||||
}
|
||||
|
||||
public void testUnifiedInferError() {
|
||||
var model = mockModel();
|
||||
|
||||
var expectedException = new IllegalArgumentException("hola");
|
||||
SageMakerStreamSchema schema = mock();
|
||||
when(schemas.streamSchemaFor(model)).thenReturn(schema);
|
||||
mockInvokeStreamFailure(expectedException);
|
||||
|
||||
sageMakerService.unifiedCompletionInfer(
|
||||
model,
|
||||
randomUnifiedCompletionRequest(),
|
||||
THIRTY_SECONDS,
|
||||
assertNoSuccessListener(ignored -> {
|
||||
verify(schemas, only()).streamSchemaFor(eq(model));
|
||||
verify(schema, times(1)).chatCompletionStreamRequest(eq(model), any());
|
||||
verify(schema, times(1)).chatCompletionError(eq(model), assertArg(e -> assertThat(e, equalTo(expectedException))));
|
||||
})
|
||||
);
|
||||
verify(client, only()).invokeStream(any(), any(), any(), any());
|
||||
verifyNoMoreInteractions(client, schemas, schema);
|
||||
}
|
||||
|
||||
private void mockInvokeStreamFailure(Exception e) {
|
||||
doAnswer(ans -> {
|
||||
ActionListener<?> responseListener = ans.getArgument(3);
|
||||
responseListener.onFailure(e);
|
||||
return null; // Void
|
||||
}).when(client).invokeStream(any(), any(), any(), any());
|
||||
}
|
||||
|
||||
public void testUnifiedInferException() {
|
||||
SageMakerModel model = mockModel();
|
||||
when(model.getInferenceEntityId()).thenReturn("some id");
|
||||
|
||||
SageMakerStreamSchema schema = mock();
|
||||
when(schemas.streamSchemaFor(model)).thenReturn(schema);
|
||||
doThrow(new IllegalArgumentException("wow, rude")).when(client).invokeStream(any(), any(), any(), any());
|
||||
|
||||
sageMakerService.unifiedCompletionInfer(model, randomUnifiedCompletionRequest(), THIRTY_SECONDS, assertNoSuccessListener(e -> {
|
||||
verify(schemas, only()).streamSchemaFor(eq(model));
|
||||
verify(schema, times(1)).chatCompletionStreamRequest(eq(model), any());
|
||||
assertThat(e, isA(ElasticsearchStatusException.class));
|
||||
assertThat(((ElasticsearchStatusException) e).status(), equalTo(RestStatus.INTERNAL_SERVER_ERROR));
|
||||
assertThat(e.getMessage(), equalTo("Failed to call SageMaker for inference id [some id]."));
|
||||
}));
|
||||
verify(client, only()).invokeStream(any(), any(), any(), any());
|
||||
verifyNoMoreInteractions(client, schemas, schema);
|
||||
}
|
||||
|
||||
public void testChunkedInferWithWrongModel() {
|
||||
sageMakerService.chunkedInfer(
|
||||
mockUnsupportedModel(),
|
||||
QUERY,
|
||||
INPUT.stream().map(ChunkInferenceInput::new).toList(),
|
||||
null,
|
||||
INPUT_TYPE,
|
||||
THIRTY_SECONDS,
|
||||
assertUnsupportedModel()
|
||||
);
|
||||
}
|
||||
|
||||
public void testChunkedInfer() throws Exception {
|
||||
var model = mockModelForChunking();
|
||||
|
||||
SageMakerSchema schema = mock();
|
||||
when(schema.response(any(), any(), any())).thenReturn(TextEmbeddingFloatResultsTests.createRandomResults());
|
||||
when(schemas.schemaFor(model)).thenReturn(schema);
|
||||
mockInvoke();
|
||||
|
||||
var expectedInput = Set.of("first", "second");
|
||||
|
||||
sageMakerService.chunkedInfer(
|
||||
model,
|
||||
QUERY,
|
||||
expectedInput.stream().map(ChunkInferenceInput::new).toList(),
|
||||
null,
|
||||
INPUT_TYPE,
|
||||
THIRTY_SECONDS,
|
||||
assertOnce(assertNoFailureListener(chunkedInferences -> {
|
||||
verify(schemas, times(2)).schemaFor(eq(model));
|
||||
verify(schema, times(2)).request(eq(model), assertChunkRequest(expectedInput));
|
||||
verify(schema, times(2)).response(eq(model), any(), any());
|
||||
}))
|
||||
);
|
||||
verify(client, times(2)).invoke(any(), any(), any(), any());
|
||||
verifyNoMoreInteractions(client, schemas, schema);
|
||||
}
|
||||
|
||||
private SageMakerModel mockModelForChunking() {
|
||||
var model = mockModel();
|
||||
when(model.batchSize()).thenReturn(Optional.of(1));
|
||||
ModelConfigurations modelConfigurations = mock();
|
||||
when(modelConfigurations.getChunkingSettings()).thenReturn(new WordBoundaryChunkingSettings(1, 0));
|
||||
when(model.getConfigurations()).thenReturn(modelConfigurations);
|
||||
return model;
|
||||
}
|
||||
|
||||
private static SageMakerInferenceRequest assertChunkRequest(Set<String> expectedInput) {
|
||||
return assertArg(request -> {
|
||||
assertThat(request.query(), equalTo(QUERY));
|
||||
assertThat(request.input(), hasSize(1));
|
||||
assertThat(request.input().get(0), in(expectedInput));
|
||||
assertThat(request.inputType(), equalTo(InputType.UNSPECIFIED));
|
||||
assertNull(request.returnDocuments());
|
||||
assertNull(request.topN());
|
||||
assertFalse(request.stream());
|
||||
});
|
||||
}
|
||||
|
||||
public void testChunkedInferError() {
|
||||
var model = mockModelForChunking();
|
||||
|
||||
var expectedException = new IllegalArgumentException("hola");
|
||||
SageMakerSchema schema = mock();
|
||||
when(schema.error(any(), any())).thenReturn(expectedException);
|
||||
when(schemas.schemaFor(model)).thenReturn(schema);
|
||||
mockInvokeFailure(expectedException);
|
||||
|
||||
var expectedInput = Set.of("first", "second");
|
||||
var expectedOutput = Stream.of(expectedException, expectedException).map(ChunkedInferenceError::new).toArray();
|
||||
|
||||
sageMakerService.chunkedInfer(
|
||||
model,
|
||||
QUERY,
|
||||
expectedInput.stream().map(ChunkInferenceInput::new).toList(),
|
||||
null,
|
||||
INPUT_TYPE,
|
||||
THIRTY_SECONDS,
|
||||
assertOnce(assertNoFailureListener(chunkedInferences -> {
|
||||
verify(schemas, times(2)).schemaFor(eq(model));
|
||||
verify(schema, times(2)).request(eq(model), assertChunkRequest(expectedInput));
|
||||
verify(schema, times(2)).error(eq(model), assertArg(e -> assertThat(e, equalTo(expectedException))));
|
||||
assertThat(chunkedInferences, containsInAnyOrder(expectedOutput));
|
||||
}))
|
||||
);
|
||||
verify(client, times(2)).invoke(any(), any(), any(), any());
|
||||
verifyNoMoreInteractions(client, schemas, schema);
|
||||
}
|
||||
|
||||
public void testClose() throws IOException {
|
||||
sageMakerService.close();
|
||||
verify(client, only()).close();
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,315 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.sagemaker.model;
|
||||
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ModelSecrets;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.inference.UnparsedModel;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xcontent.ToXContent;
|
||||
import org.elasticsearch.xcontent.XContentParserConfiguration;
|
||||
import org.elasticsearch.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemasTests;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
|
||||
import static org.elasticsearch.inference.ModelConfigurations.USE_ID_FOR_INDEX;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class SageMakerModelBuilderTests extends ESTestCase {
|
||||
private static final String inferenceId = "inferenceId";
|
||||
private static final TaskType taskType = TaskType.ANY;
|
||||
private static final String service = "service";
|
||||
private SageMakerModelBuilder builder;
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
super.setUp();
|
||||
builder = new SageMakerModelBuilder(SageMakerSchemasTests.mockSchemas());
|
||||
}
|
||||
|
||||
public void testFromRequestWithRequiredFields() {
|
||||
var model = fromRequest("""
|
||||
{
|
||||
"service_settings": {
|
||||
"access_key": "test-access-key",
|
||||
"secret_key": "test-secret-key",
|
||||
"region": "us-east-1",
|
||||
"api": "test-api",
|
||||
"endpoint_name": "test-endpoint"
|
||||
}
|
||||
}
|
||||
""");
|
||||
|
||||
assertNotNull(model);
|
||||
assertTrue(model.awsSecretSettings().isPresent());
|
||||
assertThat(model.awsSecretSettings().get().accessKey().toString(), equalTo("test-access-key"));
|
||||
assertThat(model.awsSecretSettings().get().secretKey().toString(), equalTo("test-secret-key"));
|
||||
assertThat(model.region(), equalTo("us-east-1"));
|
||||
assertThat(model.api(), equalTo("test-api"));
|
||||
assertThat(model.endpointName(), equalTo("test-endpoint"));
|
||||
|
||||
assertTrue(model.customAttributes().isEmpty());
|
||||
assertTrue(model.enableExplanations().isEmpty());
|
||||
assertTrue(model.inferenceComponentName().isEmpty());
|
||||
assertTrue(model.inferenceIdForDataCapture().isEmpty());
|
||||
assertTrue(model.sessionId().isEmpty());
|
||||
assertTrue(model.targetContainerHostname().isEmpty());
|
||||
assertTrue(model.targetModel().isEmpty());
|
||||
assertTrue(model.targetVariant().isEmpty());
|
||||
assertTrue(model.batchSize().isEmpty());
|
||||
}
|
||||
|
||||
public void testFromRequestWithOptionalFields() {
|
||||
var model = fromRequest("""
|
||||
{
|
||||
"service_settings": {
|
||||
"access_key": "test-access-key",
|
||||
"secret_key": "test-secret-key",
|
||||
"region": "us-east-1",
|
||||
"api": "test-api",
|
||||
"endpoint_name": "test-endpoint",
|
||||
"target_model": "test-target",
|
||||
"target_container_hostname": "test-target-container",
|
||||
"inference_component_name": "test-inference-component",
|
||||
"batch_size": 1234
|
||||
},
|
||||
"task_settings": {
|
||||
"custom_attributes": "test-custom-attributes",
|
||||
"enable_explanations": "test-enable-explanations",
|
||||
"inference_id": "test-inference-id",
|
||||
"session_id": "test-session-id",
|
||||
"target_variant": "test-target-variant"
|
||||
}
|
||||
}
|
||||
""");
|
||||
|
||||
assertNotNull(model);
|
||||
assertTrue(model.awsSecretSettings().isPresent());
|
||||
assertThat(model.awsSecretSettings().get().accessKey().toString(), equalTo("test-access-key"));
|
||||
assertThat(model.awsSecretSettings().get().secretKey().toString(), equalTo("test-secret-key"));
|
||||
assertThat(model.region(), equalTo("us-east-1"));
|
||||
assertThat(model.api(), equalTo("test-api"));
|
||||
assertThat(model.endpointName(), equalTo("test-endpoint"));
|
||||
|
||||
assertPresent(model.customAttributes(), "test-custom-attributes");
|
||||
assertPresent(model.enableExplanations(), "test-enable-explanations");
|
||||
assertPresent(model.inferenceComponentName(), "test-inference-component");
|
||||
assertPresent(model.inferenceIdForDataCapture(), "test-inference-id");
|
||||
assertPresent(model.sessionId(), "test-session-id");
|
||||
assertPresent(model.targetContainerHostname(), "test-target-container");
|
||||
assertPresent(model.targetModel(), "test-target");
|
||||
assertPresent(model.targetVariant(), "test-target-variant");
|
||||
assertPresent(model.batchSize(), 1234);
|
||||
}
|
||||
|
||||
public void testFromRequestWithoutAccessKey() {
|
||||
testExceptionFromRequest("""
|
||||
{
|
||||
"service_settings": {
|
||||
"secret_key": "test-secret-key",
|
||||
"region": "us-east-1",
|
||||
"api": "test-api",
|
||||
"endpoint_name": "test-endpoint"
|
||||
}
|
||||
}
|
||||
""", ValidationException.class, "Validation Failed: 1: [secret_settings] does not contain the required setting [access_key];");
|
||||
}
|
||||
|
||||
public void testFromRequestWithoutSecretKey() {
|
||||
testExceptionFromRequest("""
|
||||
{
|
||||
"service_settings": {
|
||||
"access_key": "test-access-key",
|
||||
"region": "us-east-1",
|
||||
"api": "test-api",
|
||||
"endpoint_name": "test-endpoint"
|
||||
}
|
||||
}
|
||||
""", ValidationException.class, "Validation Failed: 1: [secret_settings] does not contain the required setting [secret_key];");
|
||||
}
|
||||
|
||||
public void testFromRequestWithoutRegion() {
|
||||
testExceptionFromRequest("""
|
||||
{
|
||||
"service_settings": {
|
||||
"access_key": "test-access-key",
|
||||
"secret_key": "test-secret-key",
|
||||
"api": "test-api",
|
||||
"endpoint_name": "test-endpoint"
|
||||
}
|
||||
}
|
||||
""", ValidationException.class, "Validation Failed: 1: [service_settings] does not contain the required setting [region];");
|
||||
}
|
||||
|
||||
public void testFromRequestWithoutApi() {
|
||||
testExceptionFromRequest("""
|
||||
{
|
||||
"service_settings": {
|
||||
"access_key": "test-access-key",
|
||||
"secret_key": "test-secret-key",
|
||||
"region": "us-east-1",
|
||||
"endpoint_name": "test-endpoint"
|
||||
}
|
||||
}
|
||||
""", ValidationException.class, "Validation Failed: 1: [service_settings] does not contain the required setting [api];");
|
||||
}
|
||||
|
||||
public void testFromRequestWithoutEndpointName() {
|
||||
testExceptionFromRequest(
|
||||
"""
|
||||
{
|
||||
"service_settings": {
|
||||
"access_key": "test-access-key",
|
||||
"secret_key": "test-secret-key",
|
||||
"region": "us-east-1",
|
||||
"api": "test-api"
|
||||
}
|
||||
}
|
||||
""",
|
||||
ValidationException.class,
|
||||
"Validation Failed: 1: [service_settings] does not contain the required setting [endpoint_name];"
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromRequestWithExtraServiceKeys() {
|
||||
testExceptionFromRequest(
|
||||
"""
|
||||
{
|
||||
"service_settings": {
|
||||
"access_key": "test-access-key",
|
||||
"secret_key": "test-secret-key",
|
||||
"region": "us-east-1",
|
||||
"api": "test-api",
|
||||
"endpoint_name": "test-endpoint",
|
||||
"hello": "there"
|
||||
}
|
||||
}
|
||||
""",
|
||||
ElasticsearchStatusException.class,
|
||||
"Model configuration contains settings [{hello=there}] unknown to the [service] service"
|
||||
);
|
||||
}
|
||||
|
||||
public void testFromRequestWithExtraTaskKeys() {
|
||||
testExceptionFromRequest(
|
||||
"""
|
||||
{
|
||||
"service_settings": {
|
||||
"access_key": "test-access-key",
|
||||
"secret_key": "test-secret-key",
|
||||
"region": "us-east-1",
|
||||
"api": "test-api",
|
||||
"endpoint_name": "test-endpoint"
|
||||
},
|
||||
"task_settings": {
|
||||
"hello": "there"
|
||||
}
|
||||
}
|
||||
""",
|
||||
ElasticsearchStatusException.class,
|
||||
"Model configuration contains settings [{hello=there}] unknown to the [service] service"
|
||||
);
|
||||
}
|
||||
|
||||
public void testRoundTrip() throws IOException {
|
||||
var expectedModel = fromRequest("""
|
||||
{
|
||||
"service_settings": {
|
||||
"access_key": "test-access-key",
|
||||
"secret_key": "test-secret-key",
|
||||
"region": "us-east-1",
|
||||
"api": "test-api",
|
||||
"endpoint_name": "test-endpoint",
|
||||
"target_model": "test-target",
|
||||
"target_container_hostname": "test-target-container",
|
||||
"inference_component_name": "test-inference-component",
|
||||
"batch_size": 1234
|
||||
},
|
||||
"task_settings": {
|
||||
"custom_attributes": "test-custom-attributes",
|
||||
"enable_explanations": "test-enable-explanations",
|
||||
"inference_id": "test-inference-id",
|
||||
"session_id": "test-session-id",
|
||||
"target_variant": "test-target-variant"
|
||||
}
|
||||
}
|
||||
""");
|
||||
|
||||
var unparsedModelWithSecrets = unparsedModel(expectedModel.getConfigurations(), expectedModel.getSecrets());
|
||||
var modelWithSecrets = builder.fromStorage(
|
||||
unparsedModelWithSecrets.inferenceEntityId(),
|
||||
unparsedModelWithSecrets.taskType(),
|
||||
unparsedModelWithSecrets.service(),
|
||||
unparsedModelWithSecrets.settings(),
|
||||
unparsedModelWithSecrets.secrets()
|
||||
);
|
||||
assertThat(modelWithSecrets, equalTo(expectedModel));
|
||||
assertNotNull(modelWithSecrets.getSecrets().getSecretSettings());
|
||||
|
||||
var unparsedModelWithoutSecrets = unparsedModel(expectedModel.getConfigurations(), null);
|
||||
var modelWithoutSecrets = builder.fromStorage(
|
||||
unparsedModelWithoutSecrets.inferenceEntityId(),
|
||||
unparsedModelWithoutSecrets.taskType(),
|
||||
unparsedModelWithoutSecrets.service(),
|
||||
unparsedModelWithoutSecrets.settings(),
|
||||
unparsedModelWithoutSecrets.secrets()
|
||||
);
|
||||
assertThat(modelWithoutSecrets.getConfigurations(), equalTo(expectedModel.getConfigurations()));
|
||||
assertNull(modelWithoutSecrets.getSecrets().getSecretSettings());
|
||||
}
|
||||
|
||||
private SageMakerModel fromRequest(String json) {
|
||||
return builder.fromRequest(inferenceId, taskType, service, map(json));
|
||||
}
|
||||
|
||||
private void testExceptionFromRequest(String json, Class<? extends Exception> exceptionClass, String message) {
|
||||
var exception = assertThrows(exceptionClass, () -> fromRequest(json));
|
||||
assertThat(exception.getMessage(), equalTo(message));
|
||||
}
|
||||
|
||||
private static <T> void assertPresent(Optional<T> optional, T expectedValue) {
|
||||
assertTrue(optional.isPresent());
|
||||
assertThat(optional.get(), equalTo(expectedValue));
|
||||
}
|
||||
|
||||
private static Map<String, Object> map(String json) {
|
||||
try (
|
||||
var parser = JsonXContent.jsonXContent.createParser(XContentParserConfiguration.EMPTY, json.getBytes(StandardCharsets.UTF_8))
|
||||
) {
|
||||
return parser.map();
|
||||
} catch (IOException e) {
|
||||
throw new AssertionError(e);
|
||||
}
|
||||
}
|
||||
|
||||
private static UnparsedModel unparsedModel(ModelConfigurations modelConfigurations, ModelSecrets modelSecrets) throws IOException {
|
||||
var modelConfigMap = new ModelRegistry.ModelConfigMap(
|
||||
toJsonMap(modelConfigurations),
|
||||
modelSecrets != null ? toJsonMap(modelSecrets) : null
|
||||
);
|
||||
|
||||
return ModelRegistry.unparsedModelFromMap(modelConfigMap);
|
||||
}
|
||||
|
||||
private static Map<String, Object> toJsonMap(ToXContent toXContent) throws IOException {
|
||||
try (var builder = JsonXContent.contentBuilder()) {
|
||||
toXContent.toXContent(builder, new ToXContent.MapParams(Map.of(USE_ID_FOR_INDEX, "true")));
|
||||
return map(Strings.toString(builder));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.sagemaker.model;
|
||||
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemasTests;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
public class SageMakerServiceSettingsTests extends InferenceSettingsTestCase<SageMakerServiceSettings> {
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<SageMakerServiceSettings> instanceReader() {
|
||||
return SageMakerServiceSettings::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected SageMakerServiceSettings createTestInstance() {
|
||||
return new SageMakerServiceSettings(
|
||||
randomString(),
|
||||
randomString(),
|
||||
randomString(),
|
||||
randomOptionalString(),
|
||||
randomOptionalString(),
|
||||
randomOptionalString(),
|
||||
randomBoolean() ? randomIntBetween(1, 1000) : null,
|
||||
SageMakerStoredServiceSchema.NO_OP
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected SageMakerServiceSettings fromMutableMap(Map<String, Object> mutableMap) {
|
||||
return SageMakerServiceSettings.fromMap(SageMakerSchemasTests.mockSchemas(), randomFrom(TaskType.values()), mutableMap);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.sagemaker.model;
|
||||
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
public class SageMakerTaskSettingsTests extends InferenceSettingsTestCase<SageMakerTaskSettings> {
|
||||
@Override
|
||||
protected Writeable.Reader<SageMakerTaskSettings> instanceReader() {
|
||||
return SageMakerTaskSettings::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected SageMakerTaskSettings createTestInstance() {
|
||||
return new SageMakerTaskSettings(
|
||||
randomOptionalString(),
|
||||
randomOptionalString(),
|
||||
randomOptionalString(),
|
||||
randomOptionalString(),
|
||||
randomOptionalString(),
|
||||
SageMakerStoredTaskSchema.NO_OP
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected SageMakerTaskSettings fromMutableMap(Map<String, Object> mutableMap) {
|
||||
var validationException = new ValidationException();
|
||||
var taskSettings = SageMakerTaskSettings.fromMap(mutableMap, SageMakerStoredTaskSchema.NO_OP, validationException);
|
||||
validationException.throwIfValidationErrorsExist();
|
||||
return taskSettings;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,156 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
|
||||
|
||||
import software.amazon.awssdk.core.SdkBytes;
|
||||
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
|
||||
import org.elasticsearch.inference.InputType;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Set;
|
||||
import java.util.function.Predicate;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase.toMap;
|
||||
import static org.hamcrest.Matchers.containsInAnyOrder;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.startsWith;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public abstract class SageMakerSchemaPayloadTestCase<T extends SageMakerSchemaPayload> extends ESTestCase {
|
||||
protected T payload;
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
super.setUp();
|
||||
payload = payload();
|
||||
}
|
||||
|
||||
protected abstract T payload();
|
||||
|
||||
protected abstract String expectedApi();
|
||||
|
||||
protected abstract Set<TaskType> expectedSupportedTaskTypes();
|
||||
|
||||
protected abstract SageMakerStoredServiceSchema randomApiServiceSettings();
|
||||
|
||||
protected abstract SageMakerStoredTaskSchema randomApiTaskSettings();
|
||||
|
||||
public final void testApi() {
|
||||
assertThat(payload.api(), equalTo(expectedApi()));
|
||||
}
|
||||
|
||||
public final void testSupportedTaskTypes() {
|
||||
assertThat(payload.supportedTasks(), containsInAnyOrder(expectedSupportedTaskTypes().toArray()));
|
||||
}
|
||||
|
||||
public final void testApiServiceSettings() throws IOException {
|
||||
var validationException = new ValidationException();
|
||||
var expectedApiServiceSettings = randomApiServiceSettings();
|
||||
var actualApiServiceSettings = payload.apiServiceSettings(toMap(expectedApiServiceSettings), validationException);
|
||||
assertThat(actualApiServiceSettings, equalTo(expectedApiServiceSettings));
|
||||
validationException.throwIfValidationErrorsExist();
|
||||
}
|
||||
|
||||
public final void testApiTaskSettings() throws IOException {
|
||||
var validationException = new ValidationException();
|
||||
var expectedApiTaskSettings = randomApiTaskSettings();
|
||||
var actualApiTaskSettings = payload.apiTaskSettings(toMap(expectedApiTaskSettings), validationException);
|
||||
assertThat(actualApiTaskSettings, equalTo(expectedApiTaskSettings));
|
||||
validationException.throwIfValidationErrorsExist();
|
||||
}
|
||||
|
||||
public void testNamedWriteables() {
|
||||
var namedWriteables = payload.namedWriteables().map(entry -> entry.name).toList();
|
||||
|
||||
var filteredNames = Set.of(
|
||||
SageMakerStoredServiceSchema.NO_OP.getWriteableName(),
|
||||
SageMakerStoredTaskSchema.NO_OP.getWriteableName()
|
||||
);
|
||||
var expectedWriteables = Stream.of(randomApiServiceSettings(), randomApiTaskSettings())
|
||||
.map(VersionedNamedWriteable::getWriteableName)
|
||||
.filter(Predicate.not(filteredNames::contains))
|
||||
.toArray();
|
||||
assertThat(namedWriteables, containsInAnyOrder(expectedWriteables));
|
||||
}
|
||||
|
||||
public final void testWithUnknownApiServiceSettings() {
|
||||
SageMakerModel model = mock();
|
||||
when(model.apiServiceSettings()).thenReturn(mock());
|
||||
when(model.apiTaskSettings()).thenReturn(randomApiTaskSettings());
|
||||
when(model.api()).thenReturn("serviceApi");
|
||||
when(model.getTaskType()).thenReturn(TaskType.ANY);
|
||||
|
||||
var e = assertThrows(IllegalArgumentException.class, () -> payload.requestBytes(model, randomRequest()));
|
||||
|
||||
assertThat(e.getMessage(), startsWith("Unsupported SageMaker settings for api [serviceApi] and task type [any]:"));
|
||||
}
|
||||
|
||||
public final void testWithUnknownApiTaskSettings() {
|
||||
SageMakerModel model = mock();
|
||||
when(model.apiServiceSettings()).thenReturn(randomApiServiceSettings());
|
||||
when(model.apiTaskSettings()).thenReturn(mock());
|
||||
when(model.api()).thenReturn("taskApi");
|
||||
when(model.getTaskType()).thenReturn(TaskType.ANY);
|
||||
|
||||
var e = assertThrows(IllegalArgumentException.class, () -> payload.requestBytes(model, randomRequest()));
|
||||
|
||||
assertThat(e.getMessage(), startsWith("Unsupported SageMaker settings for api [taskApi] and task type [any]:"));
|
||||
}
|
||||
|
||||
public final void testUpdate() throws IOException {
|
||||
var taskSettings = randomApiTaskSettings();
|
||||
if (taskSettings != SageMakerStoredTaskSchema.NO_OP) {
|
||||
var otherTaskSettings = randomValueOtherThan(taskSettings, this::randomApiTaskSettings);
|
||||
|
||||
var updatedSettings = toMap(taskSettings.updatedTaskSettings(toMap(otherTaskSettings)));
|
||||
|
||||
var initialSettings = toMap(taskSettings);
|
||||
var newSettings = toMap(otherTaskSettings);
|
||||
|
||||
newSettings.forEach((key, value) -> {
|
||||
assertThat("Value should have been updated for key " + key, value, equalTo(updatedSettings.remove(key)));
|
||||
});
|
||||
initialSettings.forEach((key, value) -> {
|
||||
if (updatedSettings.containsKey(key)) {
|
||||
assertThat("Value should not have been updated for key " + key, value, equalTo(updatedSettings.remove(key)));
|
||||
}
|
||||
});
|
||||
assertTrue("Map should be empty now that we verified all updated keys and all initial keys", updatedSettings.isEmpty());
|
||||
}
|
||||
if (payload instanceof SageMakerStoredTaskSchema taskSchema) {
|
||||
var otherTaskSettings = randomValueOtherThan(randomApiTaskSettings(), this::randomApiTaskSettings);
|
||||
var otherTaskSettingsAsMap = toMap(otherTaskSettings);
|
||||
|
||||
taskSchema.updatedTaskSettings(otherTaskSettingsAsMap);
|
||||
}
|
||||
}
|
||||
|
||||
protected static SageMakerInferenceRequest randomRequest() {
|
||||
return new SageMakerInferenceRequest(
|
||||
randomBoolean() ? randomAlphaOfLengthBetween(4, 8) : null,
|
||||
randomOptionalBoolean(),
|
||||
randomBoolean() ? randomInt() : null,
|
||||
randomList(randomIntBetween(2, 4), () -> randomAlphaOfLengthBetween(2, 4)),
|
||||
randomBoolean(),
|
||||
randomFrom(InputType.values())
|
||||
);
|
||||
}
|
||||
|
||||
protected static void assertSdkBytes(SdkBytes sdkBytes, String expectedValue) {
|
||||
assertThat(sdkBytes.asUtf8String(), equalTo(expectedValue));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,119 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
|
||||
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiTextEmbeddingPayload;
|
||||
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static org.hamcrest.Matchers.containsInAnyOrder;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyMap;
|
||||
import static org.mockito.ArgumentMatchers.anyString;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class SageMakerSchemasTests extends ESTestCase {
|
||||
public static SageMakerSchemas mockSchemas() {
|
||||
SageMakerSchemas schemas = mock();
|
||||
var schema = mockSchema();
|
||||
when(schemas.schemaFor(any(TaskType.class), anyString())).thenReturn(schema);
|
||||
return schemas;
|
||||
}
|
||||
|
||||
public static SageMakerSchema mockSchema() {
|
||||
SageMakerSchema schema = mock();
|
||||
when(schema.apiServiceSettings(anyMap(), any())).thenReturn(SageMakerStoredServiceSchema.NO_OP);
|
||||
when(schema.apiTaskSettings(anyMap(), any())).thenReturn(SageMakerStoredTaskSchema.NO_OP);
|
||||
return schema;
|
||||
}
|
||||
|
||||
private static final SageMakerSchemas schemas = new SageMakerSchemas();
|
||||
|
||||
public void testSupportedTaskTypes() {
|
||||
assertThat(schemas.supportedTaskTypes(), containsInAnyOrder(TaskType.TEXT_EMBEDDING));
|
||||
}
|
||||
|
||||
public void testSupportedStreamingTasks() {
|
||||
assertThat(schemas.supportedStreamingTasks(), empty());
|
||||
}
|
||||
|
||||
public void testSchemaFor() {
|
||||
var payloads = Stream.of(new OpenAiTextEmbeddingPayload());
|
||||
payloads.forEach(payload -> {
|
||||
payload.supportedTasks().forEach(taskType -> {
|
||||
var model = mockModel(taskType, payload.api());
|
||||
assertNotNull(schemas.schemaFor(model));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
public void testStreamSchemaFor() {
|
||||
var payloads = Stream.<SageMakerStreamSchemaPayload>of(/* For when we add support for streaming payloads */);
|
||||
payloads.forEach(payload -> {
|
||||
payload.supportedTasks().forEach(taskType -> {
|
||||
var model = mockModel(taskType, payload.api());
|
||||
assertNotNull(schemas.streamSchemaFor(model));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
private SageMakerModel mockModel(TaskType taskType, String api) {
|
||||
SageMakerModel model = mock();
|
||||
when(model.getTaskType()).thenReturn(taskType);
|
||||
when(model.api()).thenReturn(api);
|
||||
return model;
|
||||
}
|
||||
|
||||
public void testMissingTaskTypeThrowsException() {
|
||||
var knownPayload = new OpenAiTextEmbeddingPayload();
|
||||
var unknownTaskType = TaskType.COMPLETION;
|
||||
var knownModel = mockModel(unknownTaskType, knownPayload.api());
|
||||
assertThrows(
|
||||
"Task [completion] is not compatible for service [sagemaker] and api [openai]. Supported tasks: [text_embedding]",
|
||||
ElasticsearchStatusException.class,
|
||||
() -> schemas.schemaFor(knownModel)
|
||||
);
|
||||
}
|
||||
|
||||
public void testMissingSchemaThrowsException() {
|
||||
var unknownModel = mockModel(TaskType.ANY, "blah");
|
||||
assertThrows(
|
||||
"Task [any] is not compatible for service [sagemaker] and api [blah]. Supported tasks: []",
|
||||
ElasticsearchStatusException.class,
|
||||
() -> schemas.schemaFor(unknownModel)
|
||||
);
|
||||
}
|
||||
|
||||
public void testMissingStreamSchemaThrowsException() {
|
||||
var unknownModel = mockModel(TaskType.ANY, "blah");
|
||||
assertThrows(
|
||||
"Streaming is not allowed for service [sagemaker], api [blah], and task [any]. Supported streaming tasks: []",
|
||||
ElasticsearchStatusException.class,
|
||||
() -> schemas.streamSchemaFor(unknownModel)
|
||||
);
|
||||
}
|
||||
|
||||
public void testNamedWriteables() {
|
||||
var namedWriteables = Stream.of(new OpenAiTextEmbeddingPayload().namedWriteables());
|
||||
|
||||
var expectedNamedWriteables = Stream.concat(
|
||||
namedWriteables.flatMap(names -> names.map(entry -> entry.name)),
|
||||
Stream.of(SageMakerStoredServiceSchema.NO_OP.getWriteableName(), SageMakerStoredTaskSchema.NO_OP.getWriteableName())
|
||||
).distinct().toArray();
|
||||
|
||||
var actualRegisteredNames = SageMakerSchemas.namedWriteables().stream().map(entry -> entry.name).toList();
|
||||
|
||||
assertThat(actualRegisteredNames, containsInAnyOrder(expectedNamedWriteables));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,155 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.sagemaker.schema.openai;
|
||||
|
||||
import software.amazon.awssdk.core.SdkBytes;
|
||||
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
|
||||
|
||||
import org.elasticsearch.inference.InputType;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemaPayloadTestCase;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema;
|
||||
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema;
|
||||
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class OpenAiTextEmbeddingPayloadTests extends SageMakerSchemaPayloadTestCase<OpenAiTextEmbeddingPayload> {
|
||||
@Override
|
||||
protected OpenAiTextEmbeddingPayload payload() {
|
||||
return new OpenAiTextEmbeddingPayload();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String expectedApi() {
|
||||
return "openai";
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Set<TaskType> expectedSupportedTaskTypes() {
|
||||
return Set.of(TaskType.TEXT_EMBEDDING);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected SageMakerStoredServiceSchema randomApiServiceSettings() {
|
||||
return SageMakerOpenAiServiceSettingsTests.randomApiServiceSettings();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected SageMakerStoredTaskSchema randomApiTaskSettings() {
|
||||
return SageMakerOpenAiTaskSettingsTests.randomApiTaskSettings();
|
||||
}
|
||||
|
||||
public void testAccept() {
|
||||
assertThat(payload.accept(mock()), equalTo("application/json"));
|
||||
}
|
||||
|
||||
public void testContentType() {
|
||||
assertThat(payload.contentType(mock()), equalTo("application/json"));
|
||||
}
|
||||
|
||||
public void testRequestWithSingleInput() throws Exception {
|
||||
SageMakerModel model = mock();
|
||||
when(model.apiServiceSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiServiceSettings(null, false));
|
||||
when(model.apiTaskSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiTaskSettings((String) null));
|
||||
var request = new SageMakerInferenceRequest(null, null, null, List.of("hello"), randomBoolean(), randomFrom(InputType.values()));
|
||||
|
||||
var sdkByes = payload.requestBytes(model, request);
|
||||
assertSdkBytes(sdkByes, """
|
||||
{"input":"hello"}""");
|
||||
}
|
||||
|
||||
public void testRequestWithArrayInput() throws Exception {
|
||||
SageMakerModel model = mock();
|
||||
when(model.apiServiceSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiServiceSettings(null, false));
|
||||
when(model.apiTaskSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiTaskSettings((String) null));
|
||||
var request = new SageMakerInferenceRequest(
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
List.of("hello", "there"),
|
||||
randomBoolean(),
|
||||
randomFrom(InputType.values())
|
||||
);
|
||||
|
||||
var sdkByes = payload.requestBytes(model, request);
|
||||
assertSdkBytes(sdkByes, """
|
||||
{"input":["hello","there"]}""");
|
||||
}
|
||||
|
||||
public void testRequestWithDimensionsNotSetByUserIgnoreDimensions() throws Exception {
|
||||
SageMakerModel model = mock();
|
||||
when(model.apiServiceSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiServiceSettings(123, false));
|
||||
when(model.apiTaskSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiTaskSettings((String) null));
|
||||
var request = new SageMakerInferenceRequest(
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
List.of("hello", "there"),
|
||||
randomBoolean(),
|
||||
randomFrom(InputType.values())
|
||||
);
|
||||
|
||||
var sdkByes = payload.requestBytes(model, request);
|
||||
assertSdkBytes(sdkByes, """
|
||||
{"input":["hello","there"]}""");
|
||||
}
|
||||
|
||||
public void testRequestWithOptionals() throws Exception {
|
||||
SageMakerModel model = mock();
|
||||
when(model.apiServiceSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiServiceSettings(1234, true));
|
||||
when(model.apiTaskSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiTaskSettings("user"));
|
||||
var request = new SageMakerInferenceRequest("query", null, null, List.of("hello"), randomBoolean(), randomFrom(InputType.values()));
|
||||
|
||||
var sdkByes = payload.requestBytes(model, request);
|
||||
assertSdkBytes(sdkByes, """
|
||||
{"query":"query","input":"hello","user":"user","dimensions":1234}""");
|
||||
}
|
||||
|
||||
public void testResponse() throws Exception {
|
||||
String responseJson = """
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": 0,
|
||||
"embedding": [
|
||||
0.014539449,
|
||||
-0.015288644
|
||||
]
|
||||
}
|
||||
],
|
||||
"model": "text-embedding-ada-002-v2",
|
||||
"usage": {
|
||||
"prompt_tokens": 8,
|
||||
"total_tokens": 8
|
||||
}
|
||||
}
|
||||
""";
|
||||
var invokeEndpointResponse = InvokeEndpointResponse.builder()
|
||||
.body(SdkBytes.fromString(responseJson, StandardCharsets.UTF_8))
|
||||
.build();
|
||||
|
||||
var textEmbeddingFloatResults = payload.responseBody(mock(), invokeEndpointResponse);
|
||||
|
||||
assertThat(
|
||||
textEmbeddingFloatResults.embeddings(),
|
||||
is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))
|
||||
);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.sagemaker.schema.openai;
|
||||
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
import static org.hamcrest.Matchers.sameInstance;
|
||||
|
||||
public class SageMakerOpenAiServiceSettingsTests extends InferenceSettingsTestCase<OpenAiTextEmbeddingPayload.ApiServiceSettings> {
|
||||
@Override
|
||||
protected OpenAiTextEmbeddingPayload.ApiServiceSettings fromMutableMap(Map<String, Object> mutableMap) {
|
||||
var validationException = new ValidationException();
|
||||
var settings = OpenAiTextEmbeddingPayload.ApiServiceSettings.fromMap(mutableMap, validationException);
|
||||
validationException.throwIfValidationErrorsExist();
|
||||
return settings;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<OpenAiTextEmbeddingPayload.ApiServiceSettings> instanceReader() {
|
||||
return OpenAiTextEmbeddingPayload.ApiServiceSettings::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected OpenAiTextEmbeddingPayload.ApiServiceSettings createTestInstance() {
|
||||
return randomApiServiceSettings();
|
||||
}
|
||||
|
||||
static OpenAiTextEmbeddingPayload.ApiServiceSettings randomApiServiceSettings() {
|
||||
var dimensions = randomBoolean() ? randomIntBetween(1, 100) : null;
|
||||
return new OpenAiTextEmbeddingPayload.ApiServiceSettings(dimensions, dimensions != null);
|
||||
}
|
||||
|
||||
public void testDimensionsSetByUser() {
|
||||
var expectedDimensions = randomIntBetween(1, 100);
|
||||
var dimensionlessSettings = new OpenAiTextEmbeddingPayload.ApiServiceSettings(null, false);
|
||||
var updatedSettings = dimensionlessSettings.updateModelWithEmbeddingDetails(expectedDimensions);
|
||||
assertThat(updatedSettings, not(sameInstance(dimensionlessSettings)));
|
||||
assertThat(updatedSettings.dimensions(), equalTo(expectedDimensions));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.sagemaker.schema.openai;
|
||||
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
public class SageMakerOpenAiTaskSettingsTests extends InferenceSettingsTestCase<OpenAiTextEmbeddingPayload.ApiTaskSettings> {
|
||||
@Override
|
||||
protected OpenAiTextEmbeddingPayload.ApiTaskSettings fromMutableMap(Map<String, Object> mutableMap) {
|
||||
var validationException = new ValidationException();
|
||||
var settings = OpenAiTextEmbeddingPayload.ApiTaskSettings.fromMap(mutableMap, validationException);
|
||||
validationException.throwIfValidationErrorsExist();
|
||||
return settings;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<OpenAiTextEmbeddingPayload.ApiTaskSettings> instanceReader() {
|
||||
return OpenAiTextEmbeddingPayload.ApiTaskSettings::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected OpenAiTextEmbeddingPayload.ApiTaskSettings createTestInstance() {
|
||||
return randomApiTaskSettings();
|
||||
}
|
||||
|
||||
static OpenAiTextEmbeddingPayload.ApiTaskSettings randomApiTaskSettings() {
|
||||
return new OpenAiTextEmbeddingPayload.ApiTaskSettings(randomOptionalString());
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue