mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 09:28:55 -04:00
[ML] Integrate SageMaker with OpenAI Embeddings (#126856)
Integrating with SageMaker. Current design: - SageMaker accepts any byte payload, which can be text, csv, or json. `api` represents the structure of the payload that we will send, for example `openai`, `elastic`, `common`, probably `cohere` or `huggingface` as well. - `api` implementations are extensions of `SageMakerSchemaPayload`, which supports: - "extra" service and task settings specific to the payload structure, so `cohere` would require `embedding_type` and `openai` would require `dimensions` in the `service_settings` - conversion logic from model, service settings, task settings, and input to `SdkBytes` - conversion logic from responding `SdkBytes` to `InferenceServiceResults` - Everything else is tunneling, there are a number of base `service_settings` and `task_settings` that are independent of the api format that we will store and set - We let the SDK do the bulk of the work in terms of connection details, rate limiting, retries, etc.
This commit is contained in:
parent
1b35cceacf
commit
245f5eebce
39 changed files with 4304 additions and 139 deletions
5
docs/changelog/126856.yaml
Normal file
5
docs/changelog/126856.yaml
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
pr: 126856
|
||||||
|
summary: "[ML] Integrate SageMaker with OpenAI Embeddings"
|
||||||
|
area: Machine Learning
|
||||||
|
type: enhancement
|
||||||
|
issues: []
|
|
@ -4912,6 +4912,11 @@
|
||||||
<sha256 value="c83dd82a9d82ff8c7d2eb1bdb2ae9f9505b312dad9a6bf0b80bc0136653a3a24" origin="Generated by Gradle"/>
|
<sha256 value="c83dd82a9d82ff8c7d2eb1bdb2ae9f9505b312dad9a6bf0b80bc0136653a3a24" origin="Generated by Gradle"/>
|
||||||
</artifact>
|
</artifact>
|
||||||
</component>
|
</component>
|
||||||
|
<component group="software.amazon.awssdk" name="sagemakerruntime" version="2.30.38">
|
||||||
|
<artifact name="sagemakerruntime-2.30.38.jar">
|
||||||
|
<sha256 value="b26ee73fa06d047eab9a174e49627972e646c0bbe909f479c18dbff193b561f5" origin="Generated by Gradle"/>
|
||||||
|
</artifact>
|
||||||
|
</component>
|
||||||
<component group="software.amazon.awssdk" name="sdk-core" version="2.30.38">
|
<component group="software.amazon.awssdk" name="sdk-core" version="2.30.38">
|
||||||
<artifact name="sdk-core-2.30.38.jar">
|
<artifact name="sdk-core-2.30.38.jar">
|
||||||
<sha256 value="556463b8c353408d93feab74719d141fcfda7fd3d7b7d1ad3a8a548b7cc2982d" origin="Generated by Gradle"/>
|
<sha256 value="556463b8c353408d93feab74719d141fcfda7fd3d7b7d1ad3a8a548b7cc2982d" origin="Generated by Gradle"/>
|
||||||
|
|
|
@ -162,6 +162,7 @@ public class TransportVersions {
|
||||||
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17);
|
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17);
|
||||||
public static final TransportVersion BATCHED_QUERY_PHASE_VERSION_BACKPORT_8_X = def(8_841_0_19);
|
public static final TransportVersion BATCHED_QUERY_PHASE_VERSION_BACKPORT_8_X = def(8_841_0_19);
|
||||||
public static final TransportVersion SEARCH_INCREMENTAL_TOP_DOCS_NULL_BACKPORT_8_19 = def(8_841_0_20);
|
public static final TransportVersion SEARCH_INCREMENTAL_TOP_DOCS_NULL_BACKPORT_8_19 = def(8_841_0_20);
|
||||||
|
public static final TransportVersion ML_INFERENCE_SAGEMAKER_8_19 = def(8_841_0_21);
|
||||||
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
|
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
|
||||||
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
|
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
|
||||||
public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED = def(9_001_0_00);
|
public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED = def(9_001_0_00);
|
||||||
|
@ -232,6 +233,7 @@ public class TransportVersions {
|
||||||
public static final TransportVersion PROJECT_METADATA_SETTINGS = def(9_066_00_0);
|
public static final TransportVersion PROJECT_METADATA_SETTINGS = def(9_066_00_0);
|
||||||
public static final TransportVersion AGGREGATE_METRIC_DOUBLE_BLOCK = def(9_067_00_0);
|
public static final TransportVersion AGGREGATE_METRIC_DOUBLE_BLOCK = def(9_067_00_0);
|
||||||
public static final TransportVersion PINNED_RETRIEVER = def(9_068_0_00);
|
public static final TransportVersion PINNED_RETRIEVER = def(9_068_0_00);
|
||||||
|
public static final TransportVersion ML_INFERENCE_SAGEMAKER = def(9_069_0_00);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* STOP! READ THIS FIRST! No, really,
|
* STOP! READ THIS FIRST! No, really,
|
||||||
|
|
|
@ -53,6 +53,12 @@ public class ValidationException extends IllegalArgumentException {
|
||||||
return validationErrors;
|
return validationErrors;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public final void throwIfValidationErrorsExist() {
|
||||||
|
if (validationErrors().isEmpty() == false) {
|
||||||
|
throw this;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public final String getMessage() {
|
public final String getMessage() {
|
||||||
StringBuilder sb = new StringBuilder();
|
StringBuilder sb = new StringBuilder();
|
||||||
|
|
|
@ -62,6 +62,7 @@ dependencies {
|
||||||
|
|
||||||
/* AWS SDK v2 */
|
/* AWS SDK v2 */
|
||||||
implementation("software.amazon.awssdk:bedrockruntime:${versions.awsv2sdk}")
|
implementation("software.amazon.awssdk:bedrockruntime:${versions.awsv2sdk}")
|
||||||
|
implementation("software.amazon.awssdk:sagemakerruntime:${versions.awsv2sdk}")
|
||||||
api "software.amazon.awssdk:protocol-core:${versions.awsv2sdk}"
|
api "software.amazon.awssdk:protocol-core:${versions.awsv2sdk}"
|
||||||
api "software.amazon.awssdk:aws-json-protocol:${versions.awsv2sdk}"
|
api "software.amazon.awssdk:aws-json-protocol:${versions.awsv2sdk}"
|
||||||
api "software.amazon.awssdk:third-party-jackson-core:${versions.awsv2sdk}"
|
api "software.amazon.awssdk:third-party-jackson-core:${versions.awsv2sdk}"
|
||||||
|
@ -142,6 +143,7 @@ tasks.named("dependencyLicenses").configure {
|
||||||
mapping from: /json-utils.*/, to: 'aws-sdk-2'
|
mapping from: /json-utils.*/, to: 'aws-sdk-2'
|
||||||
mapping from: /endpoints-spi.*/, to: 'aws-sdk-2'
|
mapping from: /endpoints-spi.*/, to: 'aws-sdk-2'
|
||||||
mapping from: /bedrockruntime.*/, to: 'aws-sdk-2'
|
mapping from: /bedrockruntime.*/, to: 'aws-sdk-2'
|
||||||
|
mapping from: /sagemakerruntime.*/, to: 'aws-sdk-2'
|
||||||
mapping from: /netty-nio-client/, to: 'aws-sdk-2'
|
mapping from: /netty-nio-client/, to: 'aws-sdk-2'
|
||||||
/* Cannot use REGEX to match netty-* because netty-nio-client is an AWS package */
|
/* Cannot use REGEX to match netty-* because netty-nio-client is an AWS package */
|
||||||
mapping from: /netty-buffer/, to: 'netty'
|
mapping from: /netty-buffer/, to: 'netty'
|
||||||
|
|
|
@ -18,163 +18,161 @@ import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated;
|
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated;
|
||||||
|
import static org.hamcrest.Matchers.containsInAnyOrder;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
|
||||||
public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
|
||||||
public void testGetServicesWithoutTaskType() throws IOException {
|
public void testGetServicesWithoutTaskType() throws IOException {
|
||||||
List<Object> services = getAllServices();
|
List<Object> services = getAllServices();
|
||||||
assertThat(services.size(), equalTo(21));
|
assertThat(services.size(), equalTo(22));
|
||||||
|
|
||||||
String[] providers = new String[services.size()];
|
var providers = providers(services);
|
||||||
for (int i = 0; i < services.size(); i++) {
|
|
||||||
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
|
|
||||||
providers[i] = (String) serviceConfig.get("service");
|
|
||||||
}
|
|
||||||
|
|
||||||
assertArrayEquals(
|
assertThat(
|
||||||
List.of(
|
providers,
|
||||||
"alibabacloud-ai-search",
|
containsInAnyOrder(
|
||||||
"amazonbedrock",
|
List.of(
|
||||||
"anthropic",
|
"alibabacloud-ai-search",
|
||||||
"azureaistudio",
|
"amazonbedrock",
|
||||||
"azureopenai",
|
"anthropic",
|
||||||
"cohere",
|
"azureaistudio",
|
||||||
"deepseek",
|
"azureopenai",
|
||||||
"elastic",
|
"cohere",
|
||||||
"elasticsearch",
|
"deepseek",
|
||||||
"googleaistudio",
|
"elastic",
|
||||||
"googlevertexai",
|
"elasticsearch",
|
||||||
"hugging_face",
|
"googleaistudio",
|
||||||
"jinaai",
|
"googlevertexai",
|
||||||
"mistral",
|
"hugging_face",
|
||||||
"openai",
|
"jinaai",
|
||||||
"streaming_completion_test_service",
|
"mistral",
|
||||||
"test_reranking_service",
|
"openai",
|
||||||
"test_service",
|
"streaming_completion_test_service",
|
||||||
"text_embedding_test_service",
|
"test_reranking_service",
|
||||||
"voyageai",
|
"test_service",
|
||||||
"watsonxai"
|
"text_embedding_test_service",
|
||||||
).toArray(),
|
"voyageai",
|
||||||
providers
|
"watsonxai",
|
||||||
|
"sagemaker"
|
||||||
|
).toArray()
|
||||||
|
)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
|
private Iterable<String> providers(List<Object> services) {
|
||||||
|
return services.stream().map(service -> {
|
||||||
|
var serviceConfig = (Map<String, Object>) service;
|
||||||
|
return (String) serviceConfig.get("service");
|
||||||
|
}).toList();
|
||||||
|
}
|
||||||
|
|
||||||
public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
|
public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
|
||||||
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
|
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
|
||||||
assertThat(services.size(), equalTo(15));
|
assertThat(services.size(), equalTo(16));
|
||||||
|
|
||||||
String[] providers = new String[services.size()];
|
var providers = providers(services);
|
||||||
for (int i = 0; i < services.size(); i++) {
|
|
||||||
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
|
|
||||||
providers[i] = (String) serviceConfig.get("service");
|
|
||||||
}
|
|
||||||
|
|
||||||
assertArrayEquals(
|
assertThat(
|
||||||
List.of(
|
providers,
|
||||||
"alibabacloud-ai-search",
|
containsInAnyOrder(
|
||||||
"amazonbedrock",
|
List.of(
|
||||||
"azureaistudio",
|
"alibabacloud-ai-search",
|
||||||
"azureopenai",
|
"amazonbedrock",
|
||||||
"cohere",
|
"azureaistudio",
|
||||||
"elasticsearch",
|
"azureopenai",
|
||||||
"googleaistudio",
|
"cohere",
|
||||||
"googlevertexai",
|
"elasticsearch",
|
||||||
"hugging_face",
|
"googleaistudio",
|
||||||
"jinaai",
|
"googlevertexai",
|
||||||
"mistral",
|
"hugging_face",
|
||||||
"openai",
|
"jinaai",
|
||||||
"text_embedding_test_service",
|
"mistral",
|
||||||
"voyageai",
|
"openai",
|
||||||
"watsonxai"
|
"text_embedding_test_service",
|
||||||
).toArray(),
|
"voyageai",
|
||||||
providers
|
"watsonxai",
|
||||||
|
"sagemaker"
|
||||||
|
).toArray()
|
||||||
|
)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
|
||||||
public void testGetServicesWithRerankTaskType() throws IOException {
|
public void testGetServicesWithRerankTaskType() throws IOException {
|
||||||
List<Object> services = getServices(TaskType.RERANK);
|
List<Object> services = getServices(TaskType.RERANK);
|
||||||
assertThat(services.size(), equalTo(7));
|
assertThat(services.size(), equalTo(7));
|
||||||
|
|
||||||
String[] providers = new String[services.size()];
|
var providers = providers(services);
|
||||||
for (int i = 0; i < services.size(); i++) {
|
|
||||||
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
|
|
||||||
providers[i] = (String) serviceConfig.get("service");
|
|
||||||
}
|
|
||||||
|
|
||||||
assertArrayEquals(
|
assertThat(
|
||||||
List.of("alibabacloud-ai-search", "cohere", "elasticsearch", "googlevertexai", "jinaai", "test_reranking_service", "voyageai")
|
providers,
|
||||||
.toArray(),
|
containsInAnyOrder(
|
||||||
providers
|
List.of(
|
||||||
|
"alibabacloud-ai-search",
|
||||||
|
"cohere",
|
||||||
|
"elasticsearch",
|
||||||
|
"googlevertexai",
|
||||||
|
"jinaai",
|
||||||
|
"test_reranking_service",
|
||||||
|
"voyageai"
|
||||||
|
).toArray()
|
||||||
|
)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
|
||||||
public void testGetServicesWithCompletionTaskType() throws IOException {
|
public void testGetServicesWithCompletionTaskType() throws IOException {
|
||||||
List<Object> services = getServices(TaskType.COMPLETION);
|
List<Object> services = getServices(TaskType.COMPLETION);
|
||||||
assertThat(services.size(), equalTo(10));
|
assertThat(services.size(), equalTo(10));
|
||||||
|
|
||||||
String[] providers = new String[services.size()];
|
var providers = providers(services);
|
||||||
for (int i = 0; i < services.size(); i++) {
|
|
||||||
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
|
|
||||||
providers[i] = (String) serviceConfig.get("service");
|
|
||||||
}
|
|
||||||
|
|
||||||
assertArrayEquals(
|
assertThat(
|
||||||
List.of(
|
providers,
|
||||||
"alibabacloud-ai-search",
|
containsInAnyOrder(
|
||||||
"amazonbedrock",
|
List.of(
|
||||||
"anthropic",
|
"alibabacloud-ai-search",
|
||||||
"azureaistudio",
|
"amazonbedrock",
|
||||||
"azureopenai",
|
"anthropic",
|
||||||
"cohere",
|
"azureaistudio",
|
||||||
"deepseek",
|
"azureopenai",
|
||||||
"googleaistudio",
|
"cohere",
|
||||||
"openai",
|
"deepseek",
|
||||||
"streaming_completion_test_service"
|
"googleaistudio",
|
||||||
).toArray(),
|
"openai",
|
||||||
providers
|
"streaming_completion_test_service"
|
||||||
|
).toArray()
|
||||||
|
)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
|
||||||
public void testGetServicesWithChatCompletionTaskType() throws IOException {
|
public void testGetServicesWithChatCompletionTaskType() throws IOException {
|
||||||
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
|
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
|
||||||
assertThat(services.size(), equalTo(4));
|
assertThat(services.size(), equalTo(4));
|
||||||
|
|
||||||
String[] providers = new String[services.size()];
|
var providers = providers(services);
|
||||||
for (int i = 0; i < services.size(); i++) {
|
|
||||||
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
|
|
||||||
providers[i] = (String) serviceConfig.get("service");
|
|
||||||
}
|
|
||||||
|
|
||||||
assertArrayEquals(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray(), providers);
|
assertThat(providers, containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray()));
|
||||||
}
|
}
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
|
||||||
public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
|
public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
|
||||||
List<Object> services = getServices(TaskType.SPARSE_EMBEDDING);
|
List<Object> services = getServices(TaskType.SPARSE_EMBEDDING);
|
||||||
assertThat(services.size(), equalTo(6));
|
assertThat(services.size(), equalTo(6));
|
||||||
|
|
||||||
String[] providers = new String[services.size()];
|
var providers = providers(services);
|
||||||
for (int i = 0; i < services.size(); i++) {
|
|
||||||
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
|
|
||||||
providers[i] = (String) serviceConfig.get("service");
|
|
||||||
}
|
|
||||||
|
|
||||||
assertArrayEquals(
|
assertThat(
|
||||||
List.of(
|
providers,
|
||||||
"alibabacloud-ai-search",
|
containsInAnyOrder(
|
||||||
"elastic",
|
List.of(
|
||||||
"elasticsearch",
|
"alibabacloud-ai-search",
|
||||||
"hugging_face",
|
"elastic",
|
||||||
"streaming_completion_test_service",
|
"elasticsearch",
|
||||||
"test_service"
|
"hugging_face",
|
||||||
).toArray(),
|
"streaming_completion_test_service",
|
||||||
providers
|
"test_service"
|
||||||
|
).toArray()
|
||||||
|
)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -36,6 +36,7 @@ module org.elasticsearch.inference {
|
||||||
requires org.elasticsearch.logging;
|
requires org.elasticsearch.logging;
|
||||||
requires org.elasticsearch.sslconfig;
|
requires org.elasticsearch.sslconfig;
|
||||||
requires org.apache.commons.text;
|
requires org.apache.commons.text;
|
||||||
|
requires software.amazon.awssdk.services.sagemakerruntime;
|
||||||
|
|
||||||
exports org.elasticsearch.xpack.inference.action;
|
exports org.elasticsearch.xpack.inference.action;
|
||||||
exports org.elasticsearch.xpack.inference.registry;
|
exports org.elasticsearch.xpack.inference.registry;
|
||||||
|
|
|
@ -92,6 +92,8 @@ import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCo
|
||||||
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
|
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings;
|
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings;
|
import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
|
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIServiceSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
|
import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings;
|
||||||
|
@ -157,6 +159,8 @@ public class InferenceNamedWriteablesProvider {
|
||||||
|
|
||||||
namedWriteables.addAll(StreamingTaskManager.namedWriteables());
|
namedWriteables.addAll(StreamingTaskManager.namedWriteables());
|
||||||
namedWriteables.addAll(DeepSeekChatCompletionModel.namedWriteables());
|
namedWriteables.addAll(DeepSeekChatCompletionModel.namedWriteables());
|
||||||
|
namedWriteables.addAll(SageMakerModel.namedWriteables());
|
||||||
|
namedWriteables.addAll(SageMakerSchemas.namedWriteables());
|
||||||
|
|
||||||
return namedWriteables;
|
return namedWriteables;
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@ import org.elasticsearch.common.settings.IndexScopedSettings;
|
||||||
import org.elasticsearch.common.settings.Setting;
|
import org.elasticsearch.common.settings.Setting;
|
||||||
import org.elasticsearch.common.settings.Settings;
|
import org.elasticsearch.common.settings.Settings;
|
||||||
import org.elasticsearch.common.settings.SettingsFilter;
|
import org.elasticsearch.common.settings.SettingsFilter;
|
||||||
|
import org.elasticsearch.common.util.LazyInitializable;
|
||||||
import org.elasticsearch.core.IOUtils;
|
import org.elasticsearch.core.IOUtils;
|
||||||
import org.elasticsearch.core.TimeValue;
|
import org.elasticsearch.core.TimeValue;
|
||||||
import org.elasticsearch.features.NodeFeature;
|
import org.elasticsearch.features.NodeFeature;
|
||||||
|
@ -132,6 +133,11 @@ import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService;
|
||||||
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService;
|
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService;
|
||||||
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
|
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
|
||||||
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
|
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerService;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerConfiguration;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
|
||||||
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIService;
|
import org.elasticsearch.xpack.inference.services.voyageai.VoyageAIService;
|
||||||
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
|
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
|
||||||
|
|
||||||
|
@ -294,6 +300,8 @@ public class InferencePlugin extends Plugin
|
||||||
services.threadPool()
|
services.threadPool()
|
||||||
);
|
);
|
||||||
|
|
||||||
|
var sageMakerSchemas = new SageMakerSchemas();
|
||||||
|
var sageMakerConfigurations = new LazyInitializable<>(new SageMakerConfiguration(sageMakerSchemas));
|
||||||
inferenceServices.add(
|
inferenceServices.add(
|
||||||
() -> List.of(
|
() -> List.of(
|
||||||
context -> new ElasticInferenceService(
|
context -> new ElasticInferenceService(
|
||||||
|
@ -302,6 +310,16 @@ public class InferencePlugin extends Plugin
|
||||||
inferenceServiceSettings,
|
inferenceServiceSettings,
|
||||||
modelRegistry.get(),
|
modelRegistry.get(),
|
||||||
authorizationHandler
|
authorizationHandler
|
||||||
|
),
|
||||||
|
context -> new SageMakerService(
|
||||||
|
new SageMakerModelBuilder(sageMakerSchemas),
|
||||||
|
new SageMakerClient(
|
||||||
|
new SageMakerClient.Factory(new HttpSettings(settings, services.clusterService())),
|
||||||
|
services.threadPool()
|
||||||
|
),
|
||||||
|
sageMakerSchemas,
|
||||||
|
services.threadPool(),
|
||||||
|
sageMakerConfigurations::getOrCompute
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
|
@ -23,11 +23,12 @@ import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
|
||||||
import org.elasticsearch.xcontent.XContentBuilder;
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.EnumSet;
|
import java.util.EnumSet;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredSecureString;
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredSecureString;
|
||||||
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.ACCESS_KEY_FIELD;
|
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.ACCESS_KEY_FIELD;
|
||||||
|
@ -134,33 +135,39 @@ public class AwsSecretSettings implements SecretSettings {
|
||||||
}
|
}
|
||||||
|
|
||||||
private static final LazyInitializable<Map<String, SettingsConfiguration>, RuntimeException> configuration =
|
private static final LazyInitializable<Map<String, SettingsConfiguration>, RuntimeException> configuration =
|
||||||
new LazyInitializable<>(() -> {
|
new LazyInitializable<>(
|
||||||
var configurationMap = new HashMap<String, SettingsConfiguration>();
|
() -> configuration(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)).collect(
|
||||||
configurationMap.put(
|
Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)
|
||||||
ACCESS_KEY_FIELD,
|
)
|
||||||
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)).setDescription(
|
);
|
||||||
"A valid AWS access key that has permissions to use Amazon Bedrock."
|
}
|
||||||
)
|
|
||||||
.setLabel("Access Key")
|
public static Stream<Map.Entry<String, SettingsConfiguration>> configuration(EnumSet<TaskType> supportedTaskTypes) {
|
||||||
.setRequired(true)
|
return Stream.of(
|
||||||
.setSensitive(true)
|
Map.entry(
|
||||||
.setUpdatable(true)
|
ACCESS_KEY_FIELD,
|
||||||
.setType(SettingsConfigurationFieldType.STRING)
|
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
|
||||||
.build()
|
"A valid AWS access key that has permissions to use Amazon Bedrock."
|
||||||
);
|
)
|
||||||
configurationMap.put(
|
.setLabel("Access Key")
|
||||||
SECRET_KEY_FIELD,
|
.setRequired(true)
|
||||||
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)).setDescription(
|
.setSensitive(true)
|
||||||
"A valid AWS secret key that is paired with the access_key."
|
.setUpdatable(true)
|
||||||
)
|
.setType(SettingsConfigurationFieldType.STRING)
|
||||||
.setLabel("Secret Key")
|
.build()
|
||||||
.setRequired(true)
|
),
|
||||||
.setSensitive(true)
|
Map.entry(
|
||||||
.setUpdatable(true)
|
SECRET_KEY_FIELD,
|
||||||
.setType(SettingsConfigurationFieldType.STRING)
|
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
|
||||||
.build()
|
"A valid AWS secret key that is paired with the access_key."
|
||||||
);
|
)
|
||||||
return Collections.unmodifiableMap(configurationMap);
|
.setLabel("Secret Key")
|
||||||
});
|
.setRequired(true)
|
||||||
|
.setSensitive(true)
|
||||||
|
.setUpdatable(true)
|
||||||
|
.setType(SettingsConfigurationFieldType.STRING)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.common.unit.ByteSizeUnit;
|
||||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||||
import org.elasticsearch.core.TimeValue;
|
import org.elasticsearch.core.TimeValue;
|
||||||
|
|
||||||
|
import java.time.Duration;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
|
@ -55,6 +56,10 @@ public class HttpSettings {
|
||||||
return connectionTimeout;
|
return connectionTimeout;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Duration connectionTimeoutDuration() {
|
||||||
|
return Duration.ofMillis(connectionTimeout);
|
||||||
|
}
|
||||||
|
|
||||||
private void setMaxResponseSize(ByteSizeValue maxResponseSize) {
|
private void setMaxResponseSize(ByteSizeValue maxResponseSize) {
|
||||||
this.maxResponseSize = maxResponseSize;
|
this.maxResponseSize = maxResponseSize;
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,251 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.sagemaker;
|
||||||
|
|
||||||
|
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
|
||||||
|
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
|
||||||
|
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
|
||||||
|
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
|
||||||
|
import software.amazon.awssdk.profiles.ProfileFile;
|
||||||
|
import software.amazon.awssdk.regions.Region;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.SageMakerRuntimeAsyncClient;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointRequest;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamRequest;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponse;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponseHandler;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.ResponseStream;
|
||||||
|
|
||||||
|
import org.apache.logging.log4j.LogManager;
|
||||||
|
import org.apache.logging.log4j.Logger;
|
||||||
|
import org.elasticsearch.ElasticsearchStatusException;
|
||||||
|
import org.elasticsearch.ExceptionsHelper;
|
||||||
|
import org.elasticsearch.SpecialPermission;
|
||||||
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
import org.elasticsearch.action.support.ContextPreservingActionListener;
|
||||||
|
import org.elasticsearch.action.support.ListenerTimeouts;
|
||||||
|
import org.elasticsearch.common.cache.Cache;
|
||||||
|
import org.elasticsearch.common.cache.CacheBuilder;
|
||||||
|
import org.elasticsearch.common.cache.CacheLoader;
|
||||||
|
import org.elasticsearch.common.util.concurrent.FutureUtils;
|
||||||
|
import org.elasticsearch.core.TimeValue;
|
||||||
|
import org.elasticsearch.core.Tuple;
|
||||||
|
import org.elasticsearch.rest.RestStatus;
|
||||||
|
import org.elasticsearch.threadpool.ThreadPool;
|
||||||
|
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.HttpSettings;
|
||||||
|
import org.reactivestreams.FlowAdapters;
|
||||||
|
|
||||||
|
import java.io.Closeable;
|
||||||
|
import java.security.AccessController;
|
||||||
|
import java.security.PrivilegedExceptionAction;
|
||||||
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
import java.util.concurrent.CompletionException;
|
||||||
|
import java.util.concurrent.ExecutionException;
|
||||||
|
import java.util.concurrent.Flow;
|
||||||
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;
|
||||||
|
|
||||||
|
public class SageMakerClient implements Closeable {
|
||||||
|
private static final Logger log = LogManager.getLogger(SageMakerClient.class);
|
||||||
|
private final Cache<RegionAndSecrets, SageMakerRuntimeAsyncClient> existingClients = CacheBuilder.<
|
||||||
|
RegionAndSecrets,
|
||||||
|
SageMakerRuntimeAsyncClient>builder()
|
||||||
|
.removalListener(removal -> removal.getValue().close())
|
||||||
|
.setExpireAfterAccess(TimeValue.timeValueMinutes(15))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
private final CacheLoader<RegionAndSecrets, SageMakerRuntimeAsyncClient> clientFactory;
|
||||||
|
private final ThreadPool threadPool;
|
||||||
|
|
||||||
|
public SageMakerClient(CacheLoader<RegionAndSecrets, SageMakerRuntimeAsyncClient> clientFactory, ThreadPool threadPool) {
|
||||||
|
this.clientFactory = clientFactory;
|
||||||
|
this.threadPool = threadPool;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void invoke(
|
||||||
|
RegionAndSecrets regionAndSecrets,
|
||||||
|
InvokeEndpointRequest request,
|
||||||
|
TimeValue timeout,
|
||||||
|
ActionListener<InvokeEndpointResponse> listener
|
||||||
|
) {
|
||||||
|
SageMakerRuntimeAsyncClient asyncClient;
|
||||||
|
try {
|
||||||
|
asyncClient = existingClients.computeIfAbsent(regionAndSecrets, clientFactory);
|
||||||
|
} catch (ExecutionException e) {
|
||||||
|
listener.onFailure(clientFailure(regionAndSecrets, e));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
var contextPreservingListener = new ContextPreservingActionListener<>(
|
||||||
|
threadPool.getThreadContext().newRestorableContext(false),
|
||||||
|
listener
|
||||||
|
);
|
||||||
|
|
||||||
|
var awsFuture = asyncClient.invokeEndpoint(request);
|
||||||
|
var timeoutListener = ListenerTimeouts.wrapWithTimeout(
|
||||||
|
threadPool,
|
||||||
|
timeout,
|
||||||
|
threadPool.executor(UTILITY_THREAD_POOL_NAME),
|
||||||
|
contextPreservingListener,
|
||||||
|
ignored -> {
|
||||||
|
FutureUtils.cancel(awsFuture);
|
||||||
|
contextPreservingListener.onFailure(
|
||||||
|
new ElasticsearchStatusException("Request timed out after [{}]", RestStatus.REQUEST_TIMEOUT, timeout)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
awsFuture.thenAcceptAsync(timeoutListener::onResponse, threadPool.executor(UTILITY_THREAD_POOL_NAME))
|
||||||
|
.exceptionallyAsync(t -> failAndMaybeThrowError(t, timeoutListener), threadPool.executor(UTILITY_THREAD_POOL_NAME));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Exception clientFailure(RegionAndSecrets regionAndSecrets, Exception cause) {
|
||||||
|
return new ElasticsearchStatusException(
|
||||||
|
"failed to create SageMakerRuntime client for region [{}]",
|
||||||
|
RestStatus.INTERNAL_SERVER_ERROR,
|
||||||
|
cause,
|
||||||
|
regionAndSecrets.region()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Void failAndMaybeThrowError(Throwable t, ActionListener<?> listener) {
|
||||||
|
if (t instanceof CompletionException ce) {
|
||||||
|
t = ce.getCause();
|
||||||
|
}
|
||||||
|
if (t instanceof Exception e) {
|
||||||
|
listener.onFailure(e);
|
||||||
|
} else {
|
||||||
|
ExceptionsHelper.maybeError(t).ifPresent(ExceptionsHelper::maybeDieOnAnotherThread);
|
||||||
|
log.atWarn().withThrowable(t).log("Unknown failure calling SageMaker.");
|
||||||
|
listener.onFailure(new RuntimeException("Unknown failure calling SageMaker.", t));
|
||||||
|
}
|
||||||
|
return null; // Void
|
||||||
|
}
|
||||||
|
|
||||||
|
public void invokeStream(
|
||||||
|
RegionAndSecrets regionAndSecrets,
|
||||||
|
InvokeEndpointWithResponseStreamRequest request,
|
||||||
|
TimeValue timeout,
|
||||||
|
ActionListener<SageMakerStream> listener
|
||||||
|
) {
|
||||||
|
SageMakerRuntimeAsyncClient asyncClient;
|
||||||
|
try {
|
||||||
|
asyncClient = existingClients.computeIfAbsent(regionAndSecrets, clientFactory);
|
||||||
|
} catch (ExecutionException e) {
|
||||||
|
listener.onFailure(clientFailure(regionAndSecrets, e));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
var contextPreservingListener = new ContextPreservingActionListener<>(
|
||||||
|
threadPool.getThreadContext().newRestorableContext(false),
|
||||||
|
listener
|
||||||
|
);
|
||||||
|
|
||||||
|
var responseStreamProcessor = new SageMakerStreamingResponseProcessor();
|
||||||
|
var cancelAwsRequestListener = new AtomicReference<CompletableFuture<?>>();
|
||||||
|
var timeoutListener = ListenerTimeouts.wrapWithTimeout(
|
||||||
|
threadPool,
|
||||||
|
timeout,
|
||||||
|
threadPool.executor(UTILITY_THREAD_POOL_NAME),
|
||||||
|
contextPreservingListener,
|
||||||
|
ignored -> {
|
||||||
|
FutureUtils.cancel(cancelAwsRequestListener.get());
|
||||||
|
contextPreservingListener.onFailure(
|
||||||
|
new ElasticsearchStatusException("Request timed out after [{}]", RestStatus.REQUEST_TIMEOUT, timeout)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
// To stay consistent with HTTP providers, we cancel the TimeoutListener onResponse because we are measuring the time it takes to
|
||||||
|
// start receiving bytes.
|
||||||
|
var responseStreamListener = InvokeEndpointWithResponseStreamResponseHandler.builder()
|
||||||
|
.onResponse(response -> timeoutListener.onResponse(new SageMakerStream(response, responseStreamProcessor)))
|
||||||
|
.onEventStream(publisher -> responseStreamProcessor.setPublisher(FlowAdapters.toFlowPublisher(publisher)))
|
||||||
|
.build();
|
||||||
|
var awsFuture = asyncClient.invokeEndpointWithResponseStream(request, responseStreamListener);
|
||||||
|
cancelAwsRequestListener.set(awsFuture);
|
||||||
|
awsFuture.exceptionallyAsync(t -> failAndMaybeThrowError(t, timeoutListener), threadPool.executor(UTILITY_THREAD_POOL_NAME));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
existingClients.invalidateAll(); // will close each cached client
|
||||||
|
}
|
||||||
|
|
||||||
|
public record RegionAndSecrets(String region, AwsSecretSettings secretSettings) {}
|
||||||
|
|
||||||
|
public static class Factory implements CacheLoader<RegionAndSecrets, SageMakerRuntimeAsyncClient> {
|
||||||
|
private final HttpSettings httpSettings;
|
||||||
|
|
||||||
|
public Factory(HttpSettings httpSettings) {
|
||||||
|
this.httpSettings = httpSettings;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SageMakerRuntimeAsyncClient load(RegionAndSecrets key) throws Exception {
|
||||||
|
SpecialPermission.check();
|
||||||
|
// TODO migrate to entitlements
|
||||||
|
return AccessController.doPrivileged((PrivilegedExceptionAction<SageMakerRuntimeAsyncClient>) () -> {
|
||||||
|
var credentials = AwsBasicCredentials.create(
|
||||||
|
key.secretSettings().accessKey().toString(),
|
||||||
|
key.secretSettings().secretKey().toString()
|
||||||
|
);
|
||||||
|
var credentialsProvider = StaticCredentialsProvider.create(credentials);
|
||||||
|
var clientConfig = NettyNioAsyncHttpClient.builder().connectionTimeout(httpSettings.connectionTimeoutDuration());
|
||||||
|
var override = ClientOverrideConfiguration.builder()
|
||||||
|
// disable profileFile, user credentials will always come from the configured Model Secrets
|
||||||
|
.defaultProfileFileSupplier(ProfileFile.aggregator()::build)
|
||||||
|
.defaultProfileFile(ProfileFile.aggregator().build())
|
||||||
|
.retryPolicy(retryPolicy -> retryPolicy.numRetries(3))
|
||||||
|
.retryStrategy(retryStrategy -> retryStrategy.maxAttempts(3))
|
||||||
|
.build();
|
||||||
|
return SageMakerRuntimeAsyncClient.builder()
|
||||||
|
.credentialsProvider(credentialsProvider)
|
||||||
|
.region(Region.of(key.region()))
|
||||||
|
.httpClientBuilder(clientConfig)
|
||||||
|
.overrideConfiguration(override)
|
||||||
|
.build();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class SageMakerStreamingResponseProcessor implements Flow.Publisher<ResponseStream> {
|
||||||
|
private static final Logger log = LogManager.getLogger(SageMakerStreamingResponseProcessor.class);
|
||||||
|
private final AtomicReference<Tuple<Flow.Publisher<ResponseStream>, Flow.Subscriber<? super ResponseStream>>> holder =
|
||||||
|
new AtomicReference<>(null);
|
||||||
|
private final AtomicBoolean subscribeCalledOnce = new AtomicBoolean(false);
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void subscribe(Flow.Subscriber<? super ResponseStream> subscriber) {
|
||||||
|
if (subscribeCalledOnce.compareAndSet(false, true) == false) {
|
||||||
|
subscriber.onError(new IllegalStateException("Subscriber already set."));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (holder.compareAndSet(null, Tuple.tuple(null, subscriber)) == false) {
|
||||||
|
log.debug("Subscriber connecting to publisher.");
|
||||||
|
var publisher = holder.getAndSet(null).v1();
|
||||||
|
publisher.subscribe(subscriber);
|
||||||
|
} else {
|
||||||
|
log.debug("Subscriber waiting for connection.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void setPublisher(Flow.Publisher<ResponseStream> publisher) {
|
||||||
|
if (holder.compareAndSet(null, Tuple.tuple(publisher, null)) == false) {
|
||||||
|
log.debug("Publisher connecting to subscriber.");
|
||||||
|
var subscriber = holder.getAndSet(null).v2();
|
||||||
|
publisher.subscribe(subscriber);
|
||||||
|
} else {
|
||||||
|
log.debug("Publisher waiting for connection.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public record SageMakerStream(InvokeEndpointWithResponseStreamResponse response, Flow.Publisher<ResponseStream> responseStream) {}
|
||||||
|
}
|
|
@ -0,0 +1,28 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.sagemaker;
|
||||||
|
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
|
import org.elasticsearch.inference.InputType;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
public record SageMakerInferenceRequest(
|
||||||
|
@Nullable String query,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer topN,
|
||||||
|
List<String> input,
|
||||||
|
boolean stream,
|
||||||
|
InputType inputType
|
||||||
|
) {
|
||||||
|
public SageMakerInferenceRequest {
|
||||||
|
Objects.requireNonNull(input);
|
||||||
|
Objects.requireNonNull(inputType);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,305 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.sagemaker;
|
||||||
|
|
||||||
|
import org.elasticsearch.ElasticsearchStatusException;
|
||||||
|
import org.elasticsearch.TransportVersion;
|
||||||
|
import org.elasticsearch.TransportVersions;
|
||||||
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
import org.elasticsearch.action.support.SubscribableListener;
|
||||||
|
import org.elasticsearch.common.CheckedSupplier;
|
||||||
|
import org.elasticsearch.common.util.LazyInitializable;
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
|
import org.elasticsearch.core.TimeValue;
|
||||||
|
import org.elasticsearch.inference.ChunkInferenceInput;
|
||||||
|
import org.elasticsearch.inference.ChunkedInference;
|
||||||
|
import org.elasticsearch.inference.InferenceService;
|
||||||
|
import org.elasticsearch.inference.InferenceServiceConfiguration;
|
||||||
|
import org.elasticsearch.inference.InferenceServiceResults;
|
||||||
|
import org.elasticsearch.inference.InputType;
|
||||||
|
import org.elasticsearch.inference.Model;
|
||||||
|
import org.elasticsearch.inference.SettingsConfiguration;
|
||||||
|
import org.elasticsearch.inference.TaskType;
|
||||||
|
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
||||||
|
import org.elasticsearch.rest.RestStatus;
|
||||||
|
import org.elasticsearch.threadpool.ThreadPool;
|
||||||
|
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.EnumSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
import static org.elasticsearch.core.Strings.format;
|
||||||
|
import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;
|
||||||
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
|
||||||
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails;
|
||||||
|
|
||||||
|
public class SageMakerService implements InferenceService {
|
||||||
|
public static final String NAME = "sagemaker";
|
||||||
|
private static final int DEFAULT_BATCH_SIZE = 256;
|
||||||
|
private final SageMakerModelBuilder modelBuilder;
|
||||||
|
private final SageMakerClient client;
|
||||||
|
private final SageMakerSchemas schemas;
|
||||||
|
private final ThreadPool threadPool;
|
||||||
|
private final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration;
|
||||||
|
|
||||||
|
public SageMakerService(
|
||||||
|
SageMakerModelBuilder modelBuilder,
|
||||||
|
SageMakerClient client,
|
||||||
|
SageMakerSchemas schemas,
|
||||||
|
ThreadPool threadPool,
|
||||||
|
CheckedSupplier<Map<String, SettingsConfiguration>, RuntimeException> configurationMap
|
||||||
|
) {
|
||||||
|
this.modelBuilder = modelBuilder;
|
||||||
|
this.client = client;
|
||||||
|
this.schemas = schemas;
|
||||||
|
this.threadPool = threadPool;
|
||||||
|
this.configuration = new LazyInitializable<>(
|
||||||
|
() -> new InferenceServiceConfiguration.Builder().setService(NAME)
|
||||||
|
.setName("Amazon SageMaker")
|
||||||
|
.setTaskTypes(supportedTaskTypes())
|
||||||
|
.setConfigurations(configurationMap.get())
|
||||||
|
.build()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String name() {
|
||||||
|
return NAME;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void parseRequestConfig(
|
||||||
|
String modelId,
|
||||||
|
TaskType taskType,
|
||||||
|
Map<String, Object> config,
|
||||||
|
ActionListener<Model> parsedModelListener
|
||||||
|
) {
|
||||||
|
ActionListener.completeWith(parsedModelListener, () -> modelBuilder.fromRequest(modelId, taskType, NAME, config));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Model parsePersistedConfigWithSecrets(
|
||||||
|
String modelId,
|
||||||
|
TaskType taskType,
|
||||||
|
Map<String, Object> config,
|
||||||
|
Map<String, Object> secrets
|
||||||
|
) {
|
||||||
|
return modelBuilder.fromStorage(modelId, taskType, NAME, config, secrets);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config) {
|
||||||
|
return modelBuilder.fromStorage(modelId, taskType, NAME, config, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public InferenceServiceConfiguration getConfiguration() {
|
||||||
|
return configuration.getOrCompute();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public EnumSet<TaskType> supportedTaskTypes() {
|
||||||
|
return schemas.supportedTaskTypes();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Set<TaskType> supportedStreamingTasks() {
|
||||||
|
return schemas.supportedStreamingTasks();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void infer(
|
||||||
|
Model model,
|
||||||
|
@Nullable String query,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer topN,
|
||||||
|
List<String> input,
|
||||||
|
boolean stream,
|
||||||
|
Map<String, Object> taskSettings,
|
||||||
|
InputType inputType,
|
||||||
|
TimeValue timeout,
|
||||||
|
ActionListener<InferenceServiceResults> listener
|
||||||
|
) {
|
||||||
|
if (model instanceof SageMakerModel == false) {
|
||||||
|
listener.onFailure(createInvalidModelException(model));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
var inferenceRequest = new SageMakerInferenceRequest(query, returnDocuments, topN, input, stream, inputType);
|
||||||
|
|
||||||
|
try {
|
||||||
|
var sageMakerModel = ((SageMakerModel) model).override(taskSettings);
|
||||||
|
var regionAndSecrets = regionAndSecrets(sageMakerModel);
|
||||||
|
|
||||||
|
if (stream) {
|
||||||
|
var schema = schemas.streamSchemaFor(sageMakerModel);
|
||||||
|
var request = schema.streamRequest(sageMakerModel, inferenceRequest);
|
||||||
|
client.invokeStream(
|
||||||
|
regionAndSecrets,
|
||||||
|
request,
|
||||||
|
timeout,
|
||||||
|
ActionListener.wrap(
|
||||||
|
response -> listener.onResponse(schema.streamResponse(sageMakerModel, response)),
|
||||||
|
e -> listener.onFailure(schema.error(sageMakerModel, e))
|
||||||
|
)
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
var schema = schemas.schemaFor(sageMakerModel);
|
||||||
|
var request = schema.request(sageMakerModel, inferenceRequest);
|
||||||
|
client.invoke(
|
||||||
|
regionAndSecrets,
|
||||||
|
request,
|
||||||
|
timeout,
|
||||||
|
ActionListener.wrap(
|
||||||
|
response -> listener.onResponse(schema.response(sageMakerModel, response, threadPool.getThreadContext())),
|
||||||
|
e -> listener.onFailure(schema.error(sageMakerModel, e))
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
listener.onFailure(internalFailure(model, e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private SageMakerClient.RegionAndSecrets regionAndSecrets(SageMakerModel model) {
|
||||||
|
var secrets = model.awsSecretSettings();
|
||||||
|
if (secrets.isEmpty()) {
|
||||||
|
assert false : "Cannot invoke a model without secrets";
|
||||||
|
throw new ElasticsearchStatusException(
|
||||||
|
format("Attempting to infer using a model without API keys, inference id [%s]", model.getInferenceEntityId()),
|
||||||
|
RestStatus.INTERNAL_SERVER_ERROR
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return new SageMakerClient.RegionAndSecrets(model.region(), secrets.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
private static ElasticsearchStatusException internalFailure(Model model, Exception cause) {
|
||||||
|
if (cause instanceof ElasticsearchStatusException ese) {
|
||||||
|
return ese;
|
||||||
|
} else {
|
||||||
|
return new ElasticsearchStatusException(
|
||||||
|
"Failed to call SageMaker for inference id [{}].",
|
||||||
|
RestStatus.INTERNAL_SERVER_ERROR,
|
||||||
|
cause,
|
||||||
|
model.getInferenceEntityId()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void unifiedCompletionInfer(
|
||||||
|
Model model,
|
||||||
|
UnifiedCompletionRequest request,
|
||||||
|
TimeValue timeout,
|
||||||
|
ActionListener<InferenceServiceResults> listener
|
||||||
|
) {
|
||||||
|
if (model instanceof SageMakerModel == false) {
|
||||||
|
listener.onFailure(createInvalidModelException(model));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
var sageMakerModel = (SageMakerModel) model;
|
||||||
|
var regionAndSecrets = regionAndSecrets(sageMakerModel);
|
||||||
|
var schema = schemas.streamSchemaFor(sageMakerModel);
|
||||||
|
var sagemakerRequest = schema.chatCompletionStreamRequest(sageMakerModel, request);
|
||||||
|
client.invokeStream(
|
||||||
|
regionAndSecrets,
|
||||||
|
sagemakerRequest,
|
||||||
|
timeout,
|
||||||
|
ActionListener.wrap(
|
||||||
|
response -> listener.onResponse(schema.chatCompletionStreamResponse(sageMakerModel, response)),
|
||||||
|
e -> listener.onFailure(schema.chatCompletionError(sageMakerModel, e))
|
||||||
|
)
|
||||||
|
);
|
||||||
|
} catch (Exception e) {
|
||||||
|
listener.onFailure(internalFailure(model, e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void chunkedInfer(
|
||||||
|
Model model,
|
||||||
|
String query,
|
||||||
|
List<ChunkInferenceInput> input,
|
||||||
|
Map<String, Object> taskSettings,
|
||||||
|
InputType inputType,
|
||||||
|
TimeValue timeout,
|
||||||
|
ActionListener<List<ChunkedInference>> listener
|
||||||
|
) {
|
||||||
|
if (model instanceof SageMakerModel == false) {
|
||||||
|
listener.onFailure(createInvalidModelException(model));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
var sageMakerModel = ((SageMakerModel) model).override(taskSettings);
|
||||||
|
var batchedRequests = new EmbeddingRequestChunker<>(
|
||||||
|
input,
|
||||||
|
sageMakerModel.batchSize().orElse(DEFAULT_BATCH_SIZE),
|
||||||
|
sageMakerModel.getConfigurations().getChunkingSettings()
|
||||||
|
).batchRequestsWithListeners(listener);
|
||||||
|
|
||||||
|
var subscribableListener = SubscribableListener.newSucceeded(null);
|
||||||
|
for (var request : batchedRequests) {
|
||||||
|
subscribableListener = subscribableListener.andThen(
|
||||||
|
threadPool.executor(UTILITY_THREAD_POOL_NAME),
|
||||||
|
threadPool.getThreadContext(),
|
||||||
|
(l, ignored) -> infer(
|
||||||
|
sageMakerModel,
|
||||||
|
query,
|
||||||
|
null, // no return docs while chunking?
|
||||||
|
null, // no topN while chunking?
|
||||||
|
request.batch().inputs().get(),
|
||||||
|
false, // we never stream when chunking
|
||||||
|
null, // since we pass sageMakerModel as the model, we already overwrote the model with the task settings
|
||||||
|
inputType,
|
||||||
|
timeout,
|
||||||
|
ActionListener.runAfter(request.listener(), () -> l.onResponse(null))
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
// if there were any errors trying to create the SubscribableListener chain, then forward that to the listener
|
||||||
|
// otherwise, BatchRequestAndListener will handle forwarding errors from the infer method
|
||||||
|
subscribableListener.addListener(
|
||||||
|
ActionListener.noop().delegateResponse((l, e) -> listener.onFailure(internalFailure(model, e)))
|
||||||
|
);
|
||||||
|
} catch (Exception e) {
|
||||||
|
listener.onFailure(internalFailure(model, e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void start(Model model, TimeValue timeout, ActionListener<Boolean> listener) {
|
||||||
|
listener.onResponse(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
|
||||||
|
if (model instanceof SageMakerModel sageMakerModel) {
|
||||||
|
return modelBuilder.updateModelWithEmbeddingDetails(sageMakerModel, embeddingSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
throw invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TransportVersion getMinimalSupportedVersion() {
|
||||||
|
return TransportVersions.ML_INFERENCE_SAGEMAKER;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() throws IOException {
|
||||||
|
client.close();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,35 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.sagemaker.model;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.CheckedSupplier;
|
||||||
|
import org.elasticsearch.inference.SettingsConfiguration;
|
||||||
|
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.function.Function;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
public class SageMakerConfiguration implements CheckedSupplier<Map<String, SettingsConfiguration>, RuntimeException> {
|
||||||
|
private final SageMakerSchemas schemas;
|
||||||
|
|
||||||
|
public SageMakerConfiguration(SageMakerSchemas schemas) {
|
||||||
|
this.schemas = schemas;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String, SettingsConfiguration> get() {
|
||||||
|
return Stream.of(
|
||||||
|
AwsSecretSettings.configuration(schemas.supportedTaskTypes()),
|
||||||
|
SageMakerServiceSettings.configuration(schemas.supportedTaskTypes()),
|
||||||
|
SageMakerTaskSettings.configuration(schemas.supportedTaskTypes())
|
||||||
|
).flatMap(Function.identity()).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,146 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.sagemaker.model;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||||
|
import org.elasticsearch.inference.Model;
|
||||||
|
import org.elasticsearch.inference.ModelConfigurations;
|
||||||
|
import org.elasticsearch.inference.ModelSecrets;
|
||||||
|
import org.elasticsearch.inference.ServiceSettings;
|
||||||
|
import org.elasticsearch.inference.TaskSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This model represents all models in SageMaker. SageMaker maintains a base set of settings and configurations, and this model manages
|
||||||
|
* those. Any settings that are required for a specific model are stored in the {@link SageMakerStoredServiceSchema} and
|
||||||
|
* {@link SageMakerStoredTaskSchema}.
|
||||||
|
* Design:
|
||||||
|
* - Region is stored in ServiceSettings and is used to create the SageMaker client.
|
||||||
|
* - RateLimiting is based on AWS Service Quota, metered by account and region. The SDK client handles rate limiting internally. In order to
|
||||||
|
* rate limit appropriately, use the same AWS Access ID, Secret Key, and Region for each inference endpoint calling the same AWS account.
|
||||||
|
* - Within Elastic, you cannot change the model behind the endpoint, because we require the embedding shape remain consistent between
|
||||||
|
* invocations. So anything "model selection" related must remain consistent through the lifetime of the Elastic endpoint. SageMaker
|
||||||
|
* allows model changes via TargetModel, InferenceComponentName, and TargetContainerHostname, but these will remain static in Elastic
|
||||||
|
* within ServiceSettings.
|
||||||
|
* - CustomAttributes, EnableExplanations, InferenceId, SessionId, and TargetVariant are all request-time fields that can be saved.
|
||||||
|
* - SageMaker returns 4 headers, which Elastic will forward as-is: x-Amzn-Invoked-Production-Variant, X-Amzn-SageMaker-Custom-Attributes,
|
||||||
|
* X-Amzn-SageMaker-New-Session-Id, X-Amzn-SageMaker-Closed-Session-Id
|
||||||
|
*/
|
||||||
|
public class SageMakerModel extends Model {
|
||||||
|
private final SageMakerServiceSettings serviceSettings;
|
||||||
|
private final SageMakerTaskSettings taskSettings;
|
||||||
|
private final AwsSecretSettings awsSecretSettings;
|
||||||
|
|
||||||
|
SageMakerModel(
|
||||||
|
ModelConfigurations configurations,
|
||||||
|
ModelSecrets secrets,
|
||||||
|
SageMakerServiceSettings serviceSettings,
|
||||||
|
SageMakerTaskSettings taskSettings,
|
||||||
|
AwsSecretSettings awsSecretSettings
|
||||||
|
) {
|
||||||
|
super(configurations, secrets);
|
||||||
|
this.serviceSettings = serviceSettings;
|
||||||
|
this.taskSettings = taskSettings;
|
||||||
|
this.awsSecretSettings = awsSecretSettings;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Optional<AwsSecretSettings> awsSecretSettings() {
|
||||||
|
return Optional.ofNullable(awsSecretSettings);
|
||||||
|
}
|
||||||
|
|
||||||
|
public String region() {
|
||||||
|
return serviceSettings.region();
|
||||||
|
}
|
||||||
|
|
||||||
|
public String endpointName() {
|
||||||
|
return serviceSettings.endpointName();
|
||||||
|
}
|
||||||
|
|
||||||
|
public String api() {
|
||||||
|
return serviceSettings.api();
|
||||||
|
}
|
||||||
|
|
||||||
|
public Optional<String> customAttributes() {
|
||||||
|
return Optional.ofNullable(taskSettings.customAttributes());
|
||||||
|
}
|
||||||
|
|
||||||
|
public Optional<String> enableExplanations() {
|
||||||
|
return Optional.ofNullable(taskSettings.enableExplanations());
|
||||||
|
}
|
||||||
|
|
||||||
|
public Optional<String> inferenceComponentName() {
|
||||||
|
return Optional.ofNullable(serviceSettings.inferenceComponentName());
|
||||||
|
}
|
||||||
|
|
||||||
|
public Optional<String> inferenceIdForDataCapture() {
|
||||||
|
return Optional.ofNullable(taskSettings.inferenceIdForDataCapture());
|
||||||
|
}
|
||||||
|
|
||||||
|
public Optional<String> sessionId() {
|
||||||
|
return Optional.ofNullable(taskSettings.sessionId());
|
||||||
|
}
|
||||||
|
|
||||||
|
public Optional<String> targetContainerHostname() {
|
||||||
|
return Optional.ofNullable(serviceSettings.targetContainerHostname());
|
||||||
|
}
|
||||||
|
|
||||||
|
public Optional<String> targetModel() {
|
||||||
|
return Optional.ofNullable(serviceSettings.targetModel());
|
||||||
|
}
|
||||||
|
|
||||||
|
public Optional<String> targetVariant() {
|
||||||
|
return Optional.ofNullable(taskSettings.targetVariant());
|
||||||
|
}
|
||||||
|
|
||||||
|
public Optional<Integer> batchSize() {
|
||||||
|
return Optional.ofNullable(serviceSettings.batchSize());
|
||||||
|
}
|
||||||
|
|
||||||
|
public SageMakerModel override(Map<String, Object> taskSettingsOverride) {
|
||||||
|
if (taskSettingsOverride == null || taskSettingsOverride.isEmpty()) {
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
return new SageMakerModel(
|
||||||
|
getConfigurations(),
|
||||||
|
getSecrets(),
|
||||||
|
serviceSettings,
|
||||||
|
taskSettings.updatedTaskSettings(taskSettingsOverride),
|
||||||
|
awsSecretSettings
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static List<NamedWriteableRegistry.Entry> namedWriteables() {
|
||||||
|
return List.of(
|
||||||
|
new NamedWriteableRegistry.Entry(ServiceSettings.class, SageMakerServiceSettings.NAME, SageMakerServiceSettings::new),
|
||||||
|
new NamedWriteableRegistry.Entry(TaskSettings.class, SageMakerTaskSettings.NAME, SageMakerTaskSettings::new)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public SageMakerStoredServiceSchema apiServiceSettings() {
|
||||||
|
return serviceSettings.apiServiceSettings();
|
||||||
|
}
|
||||||
|
|
||||||
|
public SageMakerStoredTaskSchema apiTaskSettings() {
|
||||||
|
return taskSettings.apiTaskSettings();
|
||||||
|
}
|
||||||
|
|
||||||
|
SageMakerServiceSettings serviceSettings() {
|
||||||
|
return serviceSettings;
|
||||||
|
}
|
||||||
|
|
||||||
|
SageMakerTaskSettings taskSettings() {
|
||||||
|
return taskSettings;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,135 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.sagemaker.model;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.ValidationException;
|
||||||
|
import org.elasticsearch.inference.ModelConfigurations;
|
||||||
|
import org.elasticsearch.inference.ModelSecrets;
|
||||||
|
import org.elasticsearch.inference.TaskType;
|
||||||
|
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
|
||||||
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
|
||||||
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
|
||||||
|
|
||||||
|
public class SageMakerModelBuilder {
|
||||||
|
|
||||||
|
private final SageMakerSchemas schemas;
|
||||||
|
|
||||||
|
public SageMakerModelBuilder(SageMakerSchemas schemas) {
|
||||||
|
this.schemas = schemas;
|
||||||
|
}
|
||||||
|
|
||||||
|
public SageMakerModel fromRequest(String inferenceEntityId, TaskType taskType, String service, Map<String, Object> requestMap) {
|
||||||
|
var validationException = new ValidationException();
|
||||||
|
var serviceSettingsMap = removeFromMapOrThrowIfNull(requestMap, ModelConfigurations.SERVICE_SETTINGS);
|
||||||
|
var awsSecretSettings = AwsSecretSettings.fromMap(serviceSettingsMap);
|
||||||
|
var serviceSettings = SageMakerServiceSettings.fromMap(schemas, taskType, serviceSettingsMap);
|
||||||
|
|
||||||
|
var schema = schemas.schemaFor(taskType, serviceSettings.api());
|
||||||
|
|
||||||
|
var taskSettingsMap = removeFromMapOrDefaultEmpty(requestMap, ModelConfigurations.TASK_SETTINGS);
|
||||||
|
var taskSettings = SageMakerTaskSettings.fromMap(
|
||||||
|
taskSettingsMap,
|
||||||
|
schema.apiTaskSettings(taskSettingsMap, validationException),
|
||||||
|
validationException
|
||||||
|
);
|
||||||
|
|
||||||
|
validationException.throwIfValidationErrorsExist();
|
||||||
|
throwIfNotEmptyMap(serviceSettingsMap, service);
|
||||||
|
throwIfNotEmptyMap(taskSettingsMap, service);
|
||||||
|
|
||||||
|
var modelConfigurations = new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings);
|
||||||
|
return new SageMakerModel(
|
||||||
|
modelConfigurations,
|
||||||
|
new ModelSecrets(awsSecretSettings),
|
||||||
|
serviceSettings,
|
||||||
|
taskSettings,
|
||||||
|
awsSecretSettings
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public SageMakerModel fromStorage(
|
||||||
|
String inferenceEntityId,
|
||||||
|
TaskType taskType,
|
||||||
|
String service,
|
||||||
|
Map<String, Object> config,
|
||||||
|
Map<String, Object> secrets
|
||||||
|
) {
|
||||||
|
var validationException = new ValidationException();
|
||||||
|
var awsSecretSettings = secrets != null
|
||||||
|
? AwsSecretSettings.fromMap(removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS))
|
||||||
|
: null;
|
||||||
|
|
||||||
|
var serviceSettings = SageMakerServiceSettings.fromMap(
|
||||||
|
schemas,
|
||||||
|
taskType,
|
||||||
|
removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS)
|
||||||
|
);
|
||||||
|
|
||||||
|
var schema = schemas.schemaFor(taskType, serviceSettings.api());
|
||||||
|
var taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
|
||||||
|
|
||||||
|
var taskSettings = SageMakerTaskSettings.fromMap(
|
||||||
|
taskSettingsMap,
|
||||||
|
schema.apiTaskSettings(taskSettingsMap, validationException),
|
||||||
|
validationException
|
||||||
|
);
|
||||||
|
|
||||||
|
validationException.throwIfValidationErrorsExist();
|
||||||
|
throwIfNotEmptyMap(config, service);
|
||||||
|
throwIfNotEmptyMap(taskSettingsMap, service);
|
||||||
|
throwIfNotEmptyMap(secrets, service);
|
||||||
|
|
||||||
|
var modelConfigurations = new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings);
|
||||||
|
return new SageMakerModel(
|
||||||
|
modelConfigurations,
|
||||||
|
new ModelSecrets(awsSecretSettings),
|
||||||
|
serviceSettings,
|
||||||
|
taskSettings,
|
||||||
|
awsSecretSettings
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public SageMakerModel updateModelWithEmbeddingDetails(SageMakerModel model, int embeddingSize) {
|
||||||
|
var updatedApiServiceSettings = model.apiServiceSettings().updateModelWithEmbeddingDetails(embeddingSize);
|
||||||
|
|
||||||
|
if (updatedApiServiceSettings == model.apiServiceSettings()) {
|
||||||
|
return model;
|
||||||
|
}
|
||||||
|
|
||||||
|
var updatedServiceSettings = new SageMakerServiceSettings(
|
||||||
|
model.serviceSettings().endpointName(),
|
||||||
|
model.serviceSettings().region(),
|
||||||
|
model.serviceSettings().api(),
|
||||||
|
model.serviceSettings().targetModel(),
|
||||||
|
model.serviceSettings().targetContainerHostname(),
|
||||||
|
model.serviceSettings().inferenceComponentName(),
|
||||||
|
model.serviceSettings().batchSize(),
|
||||||
|
updatedApiServiceSettings
|
||||||
|
);
|
||||||
|
|
||||||
|
var modelConfigurations = new ModelConfigurations(
|
||||||
|
model.getInferenceEntityId(),
|
||||||
|
model.getTaskType(),
|
||||||
|
model.getConfigurations().getService(),
|
||||||
|
updatedServiceSettings,
|
||||||
|
model.taskSettings()
|
||||||
|
);
|
||||||
|
return new SageMakerModel(
|
||||||
|
modelConfigurations,
|
||||||
|
model.getSecrets(),
|
||||||
|
updatedServiceSettings,
|
||||||
|
model.taskSettings(),
|
||||||
|
model.awsSecretSettings().orElse(null)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,285 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.sagemaker.model;
|
||||||
|
|
||||||
|
import org.elasticsearch.TransportVersion;
|
||||||
|
import org.elasticsearch.TransportVersions;
|
||||||
|
import org.elasticsearch.common.ValidationException;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamInput;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
|
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
|
||||||
|
import org.elasticsearch.inference.ModelConfigurations;
|
||||||
|
import org.elasticsearch.inference.ServiceSettings;
|
||||||
|
import org.elasticsearch.inference.SettingsConfiguration;
|
||||||
|
import org.elasticsearch.inference.SimilarityMeasure;
|
||||||
|
import org.elasticsearch.inference.TaskType;
|
||||||
|
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
|
||||||
|
import org.elasticsearch.xcontent.ToXContentObject;
|
||||||
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.EnumSet;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
|
||||||
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
|
||||||
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maintains the settings for SageMaker that cannot be changed without impacting semantic search and AI assistants.
|
||||||
|
* Model-specific settings are stored in {@link SageMakerStoredServiceSchema}.
|
||||||
|
*/
|
||||||
|
record SageMakerServiceSettings(
|
||||||
|
String endpointName,
|
||||||
|
String region,
|
||||||
|
String api,
|
||||||
|
@Nullable String targetModel,
|
||||||
|
@Nullable String targetContainerHostname,
|
||||||
|
@Nullable String inferenceComponentName,
|
||||||
|
@Nullable Integer batchSize,
|
||||||
|
SageMakerStoredServiceSchema apiServiceSettings
|
||||||
|
) implements ServiceSettings {
|
||||||
|
|
||||||
|
static final String NAME = "sage_maker_service_settings";
|
||||||
|
private static final String API = "api";
|
||||||
|
private static final String ENDPOINT_NAME = "endpoint_name";
|
||||||
|
private static final String REGION = "region";
|
||||||
|
private static final String TARGET_MODEL = "target_model";
|
||||||
|
private static final String TARGET_CONTAINER_HOSTNAME = "target_container_hostname";
|
||||||
|
private static final String INFERENCE_COMPONENT_NAME = "inference_component_name";
|
||||||
|
private static final String BATCH_SIZE = "batch_size";
|
||||||
|
|
||||||
|
SageMakerServiceSettings {
|
||||||
|
Objects.requireNonNull(endpointName);
|
||||||
|
Objects.requireNonNull(region);
|
||||||
|
Objects.requireNonNull(api);
|
||||||
|
Objects.requireNonNull(apiServiceSettings);
|
||||||
|
}
|
||||||
|
|
||||||
|
SageMakerServiceSettings(StreamInput in) throws IOException {
|
||||||
|
this(
|
||||||
|
in.readString(),
|
||||||
|
in.readString(),
|
||||||
|
in.readString(),
|
||||||
|
in.readOptionalString(),
|
||||||
|
in.readOptionalString(),
|
||||||
|
in.readOptionalString(),
|
||||||
|
in.readOptionalInt(),
|
||||||
|
in.readNamedWriteable(SageMakerStoredServiceSchema.class)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String modelId() {
|
||||||
|
return apiServiceSettings.modelId();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SimilarityMeasure similarity() {
|
||||||
|
return apiServiceSettings.similarity();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Integer dimensions() {
|
||||||
|
return apiServiceSettings.dimensions();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Boolean dimensionsSetByUser() {
|
||||||
|
return apiServiceSettings.dimensionsSetByUser();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DenseVectorFieldMapper.ElementType elementType() {
|
||||||
|
return apiServiceSettings.elementType();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getWriteableName() {
|
||||||
|
return NAME;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TransportVersion getMinimalSupportedVersion() {
|
||||||
|
return TransportVersions.ML_INFERENCE_SAGEMAKER;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void writeTo(StreamOutput out) throws IOException {
|
||||||
|
out.writeString(endpointName());
|
||||||
|
out.writeString(region());
|
||||||
|
out.writeString(api());
|
||||||
|
out.writeOptionalString(targetModel());
|
||||||
|
out.writeOptionalString(targetContainerHostname());
|
||||||
|
out.writeOptionalString(inferenceComponentName());
|
||||||
|
out.writeOptionalInt(batchSize());
|
||||||
|
out.writeNamedWriteable(apiServiceSettings);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ToXContentObject getFilteredXContentObject() {
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
builder.startObject();
|
||||||
|
|
||||||
|
builder.field(ENDPOINT_NAME, endpointName());
|
||||||
|
builder.field(REGION, region());
|
||||||
|
builder.field(API, api());
|
||||||
|
optionalField(TARGET_MODEL, targetModel(), builder);
|
||||||
|
optionalField(TARGET_CONTAINER_HOSTNAME, targetContainerHostname(), builder);
|
||||||
|
optionalField(INFERENCE_COMPONENT_NAME, inferenceComponentName(), builder);
|
||||||
|
optionalField(BATCH_SIZE, batchSize(), builder);
|
||||||
|
apiServiceSettings.toXContent(builder, params);
|
||||||
|
|
||||||
|
return builder.endObject();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static <T> void optionalField(String name, T value, XContentBuilder builder) throws IOException {
|
||||||
|
if (value != null) {
|
||||||
|
builder.field(name, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static SageMakerServiceSettings fromMap(SageMakerSchemas schemas, TaskType taskType, Map<String, Object> serviceSettingsMap) {
|
||||||
|
ValidationException validationException = new ValidationException();
|
||||||
|
|
||||||
|
var endpointName = extractRequiredString(
|
||||||
|
serviceSettingsMap,
|
||||||
|
ENDPOINT_NAME,
|
||||||
|
ModelConfigurations.SERVICE_SETTINGS,
|
||||||
|
validationException
|
||||||
|
);
|
||||||
|
var region = extractRequiredString(serviceSettingsMap, REGION, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||||
|
var api = extractRequiredString(serviceSettingsMap, API, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||||
|
var targetModel = extractOptionalString(
|
||||||
|
serviceSettingsMap,
|
||||||
|
TARGET_MODEL,
|
||||||
|
ModelConfigurations.SERVICE_SETTINGS,
|
||||||
|
validationException
|
||||||
|
);
|
||||||
|
var targetContainerHostname = extractOptionalString(
|
||||||
|
serviceSettingsMap,
|
||||||
|
TARGET_CONTAINER_HOSTNAME,
|
||||||
|
ModelConfigurations.SERVICE_SETTINGS,
|
||||||
|
validationException
|
||||||
|
);
|
||||||
|
var inferenceComponentName = extractOptionalString(
|
||||||
|
serviceSettingsMap,
|
||||||
|
INFERENCE_COMPONENT_NAME,
|
||||||
|
ModelConfigurations.SERVICE_SETTINGS,
|
||||||
|
validationException
|
||||||
|
);
|
||||||
|
var batchSize = extractOptionalPositiveInteger(
|
||||||
|
serviceSettingsMap,
|
||||||
|
BATCH_SIZE,
|
||||||
|
ModelConfigurations.SERVICE_SETTINGS,
|
||||||
|
validationException
|
||||||
|
);
|
||||||
|
|
||||||
|
validationException.throwIfValidationErrorsExist();
|
||||||
|
|
||||||
|
var schema = schemas.schemaFor(taskType, api);
|
||||||
|
var apiServiceSettings = schema.apiServiceSettings(serviceSettingsMap, validationException);
|
||||||
|
|
||||||
|
validationException.throwIfValidationErrorsExist();
|
||||||
|
|
||||||
|
return new SageMakerServiceSettings(
|
||||||
|
endpointName,
|
||||||
|
region,
|
||||||
|
api,
|
||||||
|
targetModel,
|
||||||
|
targetContainerHostname,
|
||||||
|
inferenceComponentName,
|
||||||
|
batchSize,
|
||||||
|
apiServiceSettings
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Stream<Map.Entry<String, SettingsConfiguration>> configuration(EnumSet<TaskType> supportedTaskTypes) {
|
||||||
|
return Stream.of(
|
||||||
|
Map.entry(
|
||||||
|
API,
|
||||||
|
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("The API format that your SageMaker Endpoint expects.")
|
||||||
|
.setLabel("API")
|
||||||
|
.setRequired(true)
|
||||||
|
.setSensitive(false)
|
||||||
|
.setUpdatable(false)
|
||||||
|
.setType(SettingsConfigurationFieldType.STRING)
|
||||||
|
.build()
|
||||||
|
),
|
||||||
|
Map.entry(
|
||||||
|
ENDPOINT_NAME,
|
||||||
|
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
|
||||||
|
"The name specified when creating the SageMaker Endpoint."
|
||||||
|
)
|
||||||
|
.setLabel("Endpoint Name")
|
||||||
|
.setRequired(true)
|
||||||
|
.setSensitive(false)
|
||||||
|
.setUpdatable(false)
|
||||||
|
.setType(SettingsConfigurationFieldType.STRING)
|
||||||
|
.build()
|
||||||
|
),
|
||||||
|
Map.entry(
|
||||||
|
REGION,
|
||||||
|
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
|
||||||
|
"The AWS region that your model or ARN is deployed in."
|
||||||
|
)
|
||||||
|
.setLabel("Region")
|
||||||
|
.setRequired(true)
|
||||||
|
.setSensitive(false)
|
||||||
|
.setUpdatable(false)
|
||||||
|
.setType(SettingsConfigurationFieldType.STRING)
|
||||||
|
.build()
|
||||||
|
),
|
||||||
|
Map.entry(
|
||||||
|
TARGET_MODEL,
|
||||||
|
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
|
||||||
|
"The model to request when calling a SageMaker multi-model Endpoint."
|
||||||
|
)
|
||||||
|
.setLabel("Target Model")
|
||||||
|
.setRequired(false)
|
||||||
|
.setSensitive(false)
|
||||||
|
.setUpdatable(false)
|
||||||
|
.setType(SettingsConfigurationFieldType.STRING)
|
||||||
|
.build()
|
||||||
|
),
|
||||||
|
Map.entry(
|
||||||
|
TARGET_CONTAINER_HOSTNAME,
|
||||||
|
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
|
||||||
|
"The hostname of the container when calling a SageMaker multi-container Endpoint."
|
||||||
|
)
|
||||||
|
.setLabel("Target Container Hostname")
|
||||||
|
.setRequired(false)
|
||||||
|
.setSensitive(false)
|
||||||
|
.setUpdatable(false)
|
||||||
|
.setType(SettingsConfigurationFieldType.STRING)
|
||||||
|
.build()
|
||||||
|
),
|
||||||
|
Map.entry(
|
||||||
|
BATCH_SIZE,
|
||||||
|
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
|
||||||
|
"The maximum size a single chunk of input can be when chunking input for semantic text."
|
||||||
|
)
|
||||||
|
.setLabel("Batch Size")
|
||||||
|
.setRequired(false)
|
||||||
|
.setSensitive(false)
|
||||||
|
.setUpdatable(false)
|
||||||
|
.setType(SettingsConfigurationFieldType.INTEGER)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,236 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.sagemaker.model;
|
||||||
|
|
||||||
|
import org.elasticsearch.TransportVersion;
|
||||||
|
import org.elasticsearch.TransportVersions;
|
||||||
|
import org.elasticsearch.common.ValidationException;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamInput;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
|
import org.elasticsearch.inference.ModelConfigurations;
|
||||||
|
import org.elasticsearch.inference.SettingsConfiguration;
|
||||||
|
import org.elasticsearch.inference.TaskSettings;
|
||||||
|
import org.elasticsearch.inference.TaskType;
|
||||||
|
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
|
||||||
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.EnumSet;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maintains mutable settings for SageMaker. Model-specific settings are stored in {@link SageMakerStoredTaskSchema}.
|
||||||
|
*/
|
||||||
|
record SageMakerTaskSettings(
|
||||||
|
@Nullable String customAttributes,
|
||||||
|
@Nullable String enableExplanations,
|
||||||
|
@Nullable String inferenceIdForDataCapture,
|
||||||
|
@Nullable String sessionId,
|
||||||
|
@Nullable String targetVariant,
|
||||||
|
SageMakerStoredTaskSchema apiTaskSettings
|
||||||
|
) implements TaskSettings {
|
||||||
|
|
||||||
|
static final String NAME = "sage_maker_task_settings";
|
||||||
|
private static final String CUSTOM_ATTRIBUTES = "custom_attributes";
|
||||||
|
private static final String ENABLE_EXPLANATIONS = "enable_explanations";
|
||||||
|
private static final String INFERENCE_ID = "inference_id";
|
||||||
|
private static final String SESSION_ID = "session_id";
|
||||||
|
private static final String TARGET_VARIANT = "target_variant";
|
||||||
|
|
||||||
|
SageMakerTaskSettings(StreamInput in) throws IOException {
|
||||||
|
this(
|
||||||
|
in.readOptionalString(),
|
||||||
|
in.readOptionalString(),
|
||||||
|
in.readOptionalString(),
|
||||||
|
in.readOptionalString(),
|
||||||
|
in.readOptionalString(),
|
||||||
|
in.readNamedWriteable(SageMakerStoredTaskSchema.class)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isEmpty() {
|
||||||
|
return customAttributes == null
|
||||||
|
&& enableExplanations == null
|
||||||
|
&& inferenceIdForDataCapture == null
|
||||||
|
&& sessionId == null
|
||||||
|
&& targetVariant == null
|
||||||
|
&& SageMakerStoredTaskSchema.NO_OP.equals(apiTaskSettings);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SageMakerTaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
|
||||||
|
var validationException = new ValidationException();
|
||||||
|
|
||||||
|
var updateTaskSettings = fromMap(newSettings, apiTaskSettings.updatedTaskSettings(newSettings), validationException);
|
||||||
|
|
||||||
|
validationException.throwIfValidationErrorsExist();
|
||||||
|
|
||||||
|
var updatedExtraTaskSettings = updateTaskSettings.apiTaskSettings().equals(SageMakerStoredTaskSchema.NO_OP)
|
||||||
|
? apiTaskSettings
|
||||||
|
: updateTaskSettings.apiTaskSettings();
|
||||||
|
|
||||||
|
return new SageMakerTaskSettings(
|
||||||
|
firstNotNullOrNull(updateTaskSettings.customAttributes(), customAttributes),
|
||||||
|
firstNotNullOrNull(updateTaskSettings.enableExplanations(), enableExplanations),
|
||||||
|
firstNotNullOrNull(updateTaskSettings.inferenceIdForDataCapture(), inferenceIdForDataCapture),
|
||||||
|
firstNotNullOrNull(updateTaskSettings.sessionId(), sessionId),
|
||||||
|
firstNotNullOrNull(updateTaskSettings.targetVariant(), targetVariant),
|
||||||
|
updatedExtraTaskSettings
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static <T> T firstNotNullOrNull(T first, T second) {
|
||||||
|
return first != null ? first : second;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getWriteableName() {
|
||||||
|
return NAME;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TransportVersion getMinimalSupportedVersion() {
|
||||||
|
return TransportVersions.ML_INFERENCE_SAGEMAKER;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void writeTo(StreamOutput out) throws IOException {
|
||||||
|
out.writeOptionalString(customAttributes);
|
||||||
|
out.writeOptionalString(enableExplanations);
|
||||||
|
out.writeOptionalString(inferenceIdForDataCapture);
|
||||||
|
out.writeOptionalString(sessionId);
|
||||||
|
out.writeOptionalString(targetVariant);
|
||||||
|
out.writeNamedWriteable(apiTaskSettings);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
builder.startObject();
|
||||||
|
optionalField(CUSTOM_ATTRIBUTES, customAttributes, builder);
|
||||||
|
optionalField(ENABLE_EXPLANATIONS, enableExplanations, builder);
|
||||||
|
optionalField(INFERENCE_ID, inferenceIdForDataCapture, builder);
|
||||||
|
optionalField(SESSION_ID, sessionId, builder);
|
||||||
|
optionalField(TARGET_VARIANT, targetVariant, builder);
|
||||||
|
apiTaskSettings.toXContent(builder, params);
|
||||||
|
|
||||||
|
return builder.endObject();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static <T> void optionalField(String name, T value, XContentBuilder builder) throws IOException {
|
||||||
|
if (value != null) {
|
||||||
|
builder.field(name, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static SageMakerTaskSettings fromMap(
|
||||||
|
Map<String, Object> taskSettingsMap,
|
||||||
|
SageMakerStoredTaskSchema apiTaskSettings,
|
||||||
|
ValidationException validationException
|
||||||
|
) {
|
||||||
|
var customAttributes = extractOptionalString(
|
||||||
|
taskSettingsMap,
|
||||||
|
CUSTOM_ATTRIBUTES,
|
||||||
|
ModelConfigurations.TASK_SETTINGS,
|
||||||
|
validationException
|
||||||
|
);
|
||||||
|
var enableExplanations = extractOptionalString(
|
||||||
|
taskSettingsMap,
|
||||||
|
ENABLE_EXPLANATIONS,
|
||||||
|
ModelConfigurations.TASK_SETTINGS,
|
||||||
|
validationException
|
||||||
|
);
|
||||||
|
var inferenceIdForDataCapture = extractOptionalString(
|
||||||
|
taskSettingsMap,
|
||||||
|
INFERENCE_ID,
|
||||||
|
ModelConfigurations.TASK_SETTINGS,
|
||||||
|
validationException
|
||||||
|
);
|
||||||
|
var sessionId = extractOptionalString(taskSettingsMap, SESSION_ID, ModelConfigurations.TASK_SETTINGS, validationException);
|
||||||
|
var targetVariant = extractOptionalString(taskSettingsMap, TARGET_VARIANT, ModelConfigurations.TASK_SETTINGS, validationException);
|
||||||
|
|
||||||
|
return new SageMakerTaskSettings(
|
||||||
|
customAttributes,
|
||||||
|
enableExplanations,
|
||||||
|
inferenceIdForDataCapture,
|
||||||
|
sessionId,
|
||||||
|
targetVariant,
|
||||||
|
apiTaskSettings
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Stream<Map.Entry<String, SettingsConfiguration>> configuration(EnumSet<TaskType> supportedTaskTypes) {
|
||||||
|
return Stream.of(
|
||||||
|
Map.entry(
|
||||||
|
CUSTOM_ATTRIBUTES,
|
||||||
|
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
|
||||||
|
"An opaque informational value forwarded as-is to the model within SageMaker."
|
||||||
|
)
|
||||||
|
.setLabel("Custom Attributes")
|
||||||
|
.setRequired(false)
|
||||||
|
.setSensitive(false)
|
||||||
|
.setUpdatable(false)
|
||||||
|
.setType(SettingsConfigurationFieldType.STRING)
|
||||||
|
.build()
|
||||||
|
),
|
||||||
|
Map.entry(
|
||||||
|
ENABLE_EXPLANATIONS,
|
||||||
|
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
|
||||||
|
"JMESPath expression overriding the ClarifyingExplainerConfig in the SageMaker Endpoint Configuration."
|
||||||
|
)
|
||||||
|
.setLabel("Enable Explanations")
|
||||||
|
.setRequired(false)
|
||||||
|
.setSensitive(false)
|
||||||
|
.setUpdatable(false)
|
||||||
|
.setType(SettingsConfigurationFieldType.STRING)
|
||||||
|
.build()
|
||||||
|
),
|
||||||
|
Map.entry(
|
||||||
|
INFERENCE_ID,
|
||||||
|
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
|
||||||
|
"Informational identifying for auditing requests within the SageMaker Endpoint."
|
||||||
|
)
|
||||||
|
.setLabel("Inference ID")
|
||||||
|
.setRequired(false)
|
||||||
|
.setSensitive(false)
|
||||||
|
.setUpdatable(false)
|
||||||
|
.setType(SettingsConfigurationFieldType.STRING)
|
||||||
|
.build()
|
||||||
|
),
|
||||||
|
Map.entry(
|
||||||
|
SESSION_ID,
|
||||||
|
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
|
||||||
|
"Creates or reuses an existing Session for SageMaker stateful models."
|
||||||
|
)
|
||||||
|
.setLabel("Session ID")
|
||||||
|
.setRequired(false)
|
||||||
|
.setSensitive(false)
|
||||||
|
.setUpdatable(false)
|
||||||
|
.setType(SettingsConfigurationFieldType.STRING)
|
||||||
|
.build()
|
||||||
|
),
|
||||||
|
Map.entry(
|
||||||
|
TARGET_VARIANT,
|
||||||
|
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription(
|
||||||
|
"The production variant when calling the SageMaker Endpoint"
|
||||||
|
)
|
||||||
|
.setLabel("Target Variant")
|
||||||
|
.setRequired(false)
|
||||||
|
.setSensitive(false)
|
||||||
|
.setUpdatable(false)
|
||||||
|
.setType(SettingsConfigurationFieldType.STRING)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,215 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
|
||||||
|
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.InternalDependencyException;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.InternalFailureException;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointRequest;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.ModelErrorException;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.ModelNotReadyException;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.SageMakerRuntimeException;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.ServiceUnavailableException;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.ValidationErrorException;
|
||||||
|
|
||||||
|
import org.elasticsearch.ElasticsearchStatusException;
|
||||||
|
import org.elasticsearch.common.ValidationException;
|
||||||
|
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||||
|
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||||
|
import org.elasticsearch.core.Tuple;
|
||||||
|
import org.elasticsearch.inference.InferenceServiceResults;
|
||||||
|
import org.elasticsearch.rest.RestStatus;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
import static org.elasticsearch.core.Strings.format;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* All the logic that is required to call any SageMaker model is handled within this Schema class.
|
||||||
|
* Any model-specific logic is handled within the associated {@link SageMakerSchemaPayload}.
|
||||||
|
* This schema is specific for SageMaker's non-streaming API. For streaming, see {@link SageMakerStreamSchema}.
|
||||||
|
*/
|
||||||
|
public class SageMakerSchema {
|
||||||
|
private static final String CUSTOM_ATTRIBUTES_HEADER = "X-elastic-sagemaker-custom-attributes";
|
||||||
|
private static final String NEW_SESSION_HEADER = "X-elastic-sagemaker-new-session-id";
|
||||||
|
private static final String CLOSED_SESSION_HEADER = "X-elastic-sagemaker-closed-session-id";
|
||||||
|
|
||||||
|
private static final String ACCESS_DENIED_CODE = "AccessDeniedException";
|
||||||
|
private static final String INCOMPLETE_SIGNATURE = "IncompleteSignature";
|
||||||
|
private static final String INVALID_ACTION = "InvalidAction";
|
||||||
|
private static final String INVALID_CLIENT_TOKEN = "InvalidClientTokenId";
|
||||||
|
private static final String NOT_AUTHORIZED = "NotAuthorized";
|
||||||
|
private static final String OPT_IN_REQUIRED = "OptInRequired";
|
||||||
|
private static final String REQUEST_EXPIRED = "RequestExpired";
|
||||||
|
private static final String THROTTLING_EXCEPTION = "ThrottlingException";
|
||||||
|
|
||||||
|
private final SageMakerSchemaPayload schemaPayload;
|
||||||
|
|
||||||
|
public SageMakerSchema(SageMakerSchemaPayload schemaPayload) {
|
||||||
|
this.schemaPayload = schemaPayload;
|
||||||
|
}
|
||||||
|
|
||||||
|
public InvokeEndpointRequest request(SageMakerModel model, SageMakerInferenceRequest request) {
|
||||||
|
try {
|
||||||
|
return createRequest(model).accept(schemaPayload.accept(model))
|
||||||
|
.contentType(schemaPayload.contentType(model))
|
||||||
|
.body(schemaPayload.requestBytes(model, request))
|
||||||
|
.build();
|
||||||
|
} catch (ElasticsearchStatusException e) {
|
||||||
|
throw e;
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new ElasticsearchStatusException(
|
||||||
|
"Failed to create SageMaker request for [%s]",
|
||||||
|
RestStatus.INTERNAL_SERVER_ERROR,
|
||||||
|
e,
|
||||||
|
model.getInferenceEntityId()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private InvokeEndpointRequest.Builder createRequest(SageMakerModel model) {
|
||||||
|
var request = InvokeEndpointRequest.builder();
|
||||||
|
request.endpointName(model.endpointName());
|
||||||
|
model.customAttributes().ifPresent(request::customAttributes);
|
||||||
|
model.enableExplanations().ifPresent(request::enableExplanations);
|
||||||
|
model.inferenceComponentName().ifPresent(request::inferenceComponentName);
|
||||||
|
model.inferenceIdForDataCapture().ifPresent(request::inferenceId);
|
||||||
|
model.sessionId().ifPresent(request::sessionId);
|
||||||
|
model.targetContainerHostname().ifPresent(request::targetContainerHostname);
|
||||||
|
model.targetModel().ifPresent(request::targetModel);
|
||||||
|
model.targetVariant().ifPresent(request::targetVariant);
|
||||||
|
return request;
|
||||||
|
}
|
||||||
|
|
||||||
|
public InferenceServiceResults response(SageMakerModel model, InvokeEndpointResponse response, ThreadContext threadContext)
|
||||||
|
throws Exception {
|
||||||
|
try {
|
||||||
|
addHeaders(response, threadContext);
|
||||||
|
return schemaPayload.responseBody(model, response);
|
||||||
|
} catch (ElasticsearchStatusException e) {
|
||||||
|
throw e;
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new ElasticsearchStatusException(
|
||||||
|
"Failed to translate SageMaker response for [%s]",
|
||||||
|
RestStatus.INTERNAL_SERVER_ERROR,
|
||||||
|
e,
|
||||||
|
model.getInferenceEntityId()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addHeaders(InvokeEndpointResponse response, ThreadContext threadContext) {
|
||||||
|
if (response.customAttributes() != null) {
|
||||||
|
threadContext.addResponseHeader(CUSTOM_ATTRIBUTES_HEADER, response.customAttributes());
|
||||||
|
}
|
||||||
|
if (response.newSessionId() != null) {
|
||||||
|
threadContext.addResponseHeader(NEW_SESSION_HEADER, response.newSessionId());
|
||||||
|
}
|
||||||
|
if (response.closedSessionId() != null) {
|
||||||
|
threadContext.addResponseHeader(CLOSED_SESSION_HEADER, response.closedSessionId());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public Exception error(SageMakerModel model, Exception e) {
|
||||||
|
if (e instanceof ElasticsearchStatusException ee) {
|
||||||
|
return ee;
|
||||||
|
}
|
||||||
|
var error = errorMessageAndStatus(model, e);
|
||||||
|
return new ElasticsearchStatusException(error.v1(), error.v2(), e);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Protected because {@link SageMakerStreamSchema} will reuse this to create a Chat Completion error message.
|
||||||
|
*/
|
||||||
|
protected Tuple<String, RestStatus> errorMessageAndStatus(SageMakerModel model, Exception e) {
|
||||||
|
String errorMessage = null;
|
||||||
|
RestStatus restStatus = null;
|
||||||
|
if (e instanceof InternalDependencyException) {
|
||||||
|
errorMessage = format("Received an internal dependency error from SageMaker for [%s]", model.getInferenceEntityId());
|
||||||
|
restStatus = RestStatus.INTERNAL_SERVER_ERROR;
|
||||||
|
} else if (e instanceof InternalFailureException) {
|
||||||
|
errorMessage = format("Received an internal failure from SageMaker for [%s]", model.getInferenceEntityId());
|
||||||
|
restStatus = RestStatus.INTERNAL_SERVER_ERROR;
|
||||||
|
} else if (e instanceof ModelErrorException) {
|
||||||
|
errorMessage = format("Received a model error from SageMaker for [%s]", model.getInferenceEntityId());
|
||||||
|
restStatus = RestStatus.FAILED_DEPENDENCY;
|
||||||
|
} else if (e instanceof ModelNotReadyException) {
|
||||||
|
errorMessage = format("Received a model not ready error from SageMaker for [%s]", model.getInferenceEntityId());
|
||||||
|
restStatus = RestStatus.TOO_MANY_REQUESTS;
|
||||||
|
} else if (e instanceof ServiceUnavailableException) {
|
||||||
|
errorMessage = format("SageMaker is unavailable for [%s]", model.getInferenceEntityId());
|
||||||
|
restStatus = RestStatus.SERVICE_UNAVAILABLE;
|
||||||
|
} else if (e instanceof ValidationErrorException) {
|
||||||
|
errorMessage = format("Received a validation error from SageMaker for [%s]", model.getInferenceEntityId());
|
||||||
|
restStatus = RestStatus.BAD_REQUEST;
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we have a SageMakerRuntimeException that isn't one of the child exceptions, we can parse the error from AwsErrorDetails
|
||||||
|
// https://docs.aws.amazon.com/sagemaker/latest/APIReference/CommonErrors.html
|
||||||
|
if (errorMessage == null && e instanceof SageMakerRuntimeException re && re.awsErrorDetails() != null) {
|
||||||
|
switch (re.awsErrorDetails().errorCode()) {
|
||||||
|
case ACCESS_DENIED_CODE, NOT_AUTHORIZED -> {
|
||||||
|
errorMessage = format(
|
||||||
|
"Access and Secret key stored in [%s] do not have sufficient permissions.",
|
||||||
|
model.getInferenceEntityId()
|
||||||
|
);
|
||||||
|
restStatus = RestStatus.BAD_REQUEST;
|
||||||
|
}
|
||||||
|
case INCOMPLETE_SIGNATURE -> {
|
||||||
|
errorMessage = format("The request signature does not conform to AWS standards [%s]", model.getInferenceEntityId());
|
||||||
|
restStatus = RestStatus.INTERNAL_SERVER_ERROR; // this shouldn't happen and isn't anything the user can do about it
|
||||||
|
}
|
||||||
|
case INVALID_ACTION -> {
|
||||||
|
errorMessage = format("The requested action is not valid for [%s]", model.getInferenceEntityId());
|
||||||
|
restStatus = RestStatus.BAD_REQUEST;
|
||||||
|
}
|
||||||
|
case INVALID_CLIENT_TOKEN -> {
|
||||||
|
errorMessage = format("Access key stored in [%s] does not exist in AWS", model.getInferenceEntityId());
|
||||||
|
restStatus = RestStatus.FORBIDDEN;
|
||||||
|
}
|
||||||
|
case OPT_IN_REQUIRED -> {
|
||||||
|
errorMessage = format("Access key stored in [%s] needs a subscription for the service", model.getInferenceEntityId());
|
||||||
|
restStatus = RestStatus.FORBIDDEN;
|
||||||
|
}
|
||||||
|
case REQUEST_EXPIRED -> {
|
||||||
|
errorMessage = format(
|
||||||
|
"The request reached SageMaker more than 15 minutes after the date stamp on the request for [%s]",
|
||||||
|
model.getInferenceEntityId()
|
||||||
|
);
|
||||||
|
restStatus = RestStatus.BAD_REQUEST;
|
||||||
|
}
|
||||||
|
case THROTTLING_EXCEPTION -> {
|
||||||
|
errorMessage = format("SageMaker denied the request for [%s] due to request throttling", model.getInferenceEntityId());
|
||||||
|
restStatus = RestStatus.BAD_REQUEST;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (errorMessage == null) {
|
||||||
|
errorMessage = format("Received an error from SageMaker for [%s]", model.getInferenceEntityId());
|
||||||
|
restStatus = RestStatus.INTERNAL_SERVER_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Tuple.tuple(errorMessage, restStatus);
|
||||||
|
}
|
||||||
|
|
||||||
|
public SageMakerStoredServiceSchema apiServiceSettings(Map<String, Object> serviceSettings, ValidationException validationException) {
|
||||||
|
return schemaPayload.apiServiceSettings(serviceSettings, validationException);
|
||||||
|
}
|
||||||
|
|
||||||
|
public SageMakerStoredTaskSchema apiTaskSettings(Map<String, Object> taskSettings, ValidationException validationException) {
|
||||||
|
return schemaPayload.apiTaskSettings(taskSettings, validationException);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Stream<NamedWriteableRegistry.Entry> namedWriteables() {
|
||||||
|
return schemaPayload.namedWriteables();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,98 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
|
||||||
|
|
||||||
|
import software.amazon.awssdk.core.SdkBytes;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.Strings;
|
||||||
|
import org.elasticsearch.common.ValidationException;
|
||||||
|
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||||
|
import org.elasticsearch.inference.InferenceServiceResults;
|
||||||
|
import org.elasticsearch.inference.TaskType;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
|
||||||
|
|
||||||
|
import java.util.EnumSet;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
public interface SageMakerSchemaPayload {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The model API keyword that users will supply in the service settings when creating the request.
|
||||||
|
* Automatically registered in {@link SageMakerSchemas}.
|
||||||
|
*/
|
||||||
|
String api();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The supported TaskTypes for this model API.
|
||||||
|
* Automatically registered in {@link SageMakerSchemas}.
|
||||||
|
*/
|
||||||
|
EnumSet<TaskType> supportedTasks();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Implement this if the model requires extra ServiceSettings that can be saved to the model index.
|
||||||
|
* This can be accessed via {@link SageMakerModel#apiServiceSettings()}.
|
||||||
|
*/
|
||||||
|
default SageMakerStoredServiceSchema apiServiceSettings(Map<String, Object> serviceSettings, ValidationException validationException) {
|
||||||
|
return SageMakerStoredServiceSchema.NO_OP;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Implement this if the model requires extra TaskSettings that can be saved to the model index.
|
||||||
|
* This can be accessed via {@link SageMakerModel#apiTaskSettings()}.
|
||||||
|
*/
|
||||||
|
default SageMakerStoredTaskSchema apiTaskSettings(Map<String, Object> taskSettings, ValidationException validationException) {
|
||||||
|
return SageMakerStoredTaskSchema.NO_OP;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This must be thrown if {@link SageMakerModel#apiServiceSettings()} or {@link SageMakerModel#apiTaskSettings()} return the wrong
|
||||||
|
* object types.
|
||||||
|
*/
|
||||||
|
default Exception createUnsupportedSchemaException(SageMakerModel model) {
|
||||||
|
return new IllegalArgumentException(
|
||||||
|
Strings.format(
|
||||||
|
"Unsupported SageMaker settings for api [%s] and task type [%s]: [%s] and [%s]",
|
||||||
|
model.api(),
|
||||||
|
model.getTaskType(),
|
||||||
|
model.apiServiceSettings().getWriteableName(),
|
||||||
|
model.apiTaskSettings().getWriteableName()
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Automatically register the required registry entries with {@link org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider}.
|
||||||
|
*/
|
||||||
|
default Stream<NamedWriteableRegistry.Entry> namedWriteables() {
|
||||||
|
return Stream.of();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The MIME type of the response from SageMaker.
|
||||||
|
*/
|
||||||
|
String accept(SageMakerModel model);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The MIME type of the request to SageMaker.
|
||||||
|
*/
|
||||||
|
String contentType(SageMakerModel model);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Translate to the body of the request in the MIME type specified by {@link #contentType(SageMakerModel)}.
|
||||||
|
*/
|
||||||
|
SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest request) throws Exception;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Translate from the body of the response in the MIME type specified by {@link #accept(SageMakerModel)}.
|
||||||
|
*/
|
||||||
|
InferenceServiceResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception;
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,137 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
|
||||||
|
|
||||||
|
import org.elasticsearch.ElasticsearchStatusException;
|
||||||
|
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||||
|
import org.elasticsearch.inference.TaskType;
|
||||||
|
import org.elasticsearch.rest.RestStatus;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiTextEmbeddingPayload;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.EnumSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
import static org.elasticsearch.core.Strings.format;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The mapping and registry for all supported model API.
|
||||||
|
*/
|
||||||
|
public class SageMakerSchemas {
|
||||||
|
private static final Map<TaskAndApi, SageMakerSchema> schemas;
|
||||||
|
private static final Map<TaskAndApi, SageMakerStreamSchema> streamSchemas;
|
||||||
|
private static final Map<String, Set<TaskType>> tasksByApi;
|
||||||
|
private static final Map<String, Set<TaskType>> streamingTasksByApi;
|
||||||
|
private static final Set<TaskType> supportedStreamingTasks;
|
||||||
|
private static final EnumSet<TaskType> supportedTaskTypes;
|
||||||
|
|
||||||
|
static {
|
||||||
|
/*
|
||||||
|
* Add new model API to the register call.
|
||||||
|
*/
|
||||||
|
schemas = register(new OpenAiTextEmbeddingPayload());
|
||||||
|
|
||||||
|
streamSchemas = schemas.entrySet()
|
||||||
|
.stream()
|
||||||
|
.filter(e -> e.getValue() instanceof SageMakerStreamSchema)
|
||||||
|
.collect(Collectors.toMap(Map.Entry::getKey, e -> (SageMakerStreamSchema) e.getValue()));
|
||||||
|
|
||||||
|
tasksByApi = schemas.keySet()
|
||||||
|
.stream()
|
||||||
|
.collect(Collectors.groupingBy(TaskAndApi::api, Collectors.mapping(TaskAndApi::taskType, Collectors.toSet())));
|
||||||
|
streamingTasksByApi = streamSchemas.keySet()
|
||||||
|
.stream()
|
||||||
|
.collect(Collectors.groupingBy(TaskAndApi::api, Collectors.mapping(TaskAndApi::taskType, Collectors.toSet())));
|
||||||
|
|
||||||
|
supportedStreamingTasks = streamSchemas.keySet().stream().map(TaskAndApi::taskType).collect(Collectors.toSet());
|
||||||
|
supportedTaskTypes = EnumSet.copyOf(schemas.keySet().stream().map(TaskAndApi::taskType).collect(Collectors.toSet()));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Map<TaskAndApi, SageMakerSchema> register(SageMakerSchemaPayload... payloads) {
|
||||||
|
return Arrays.stream(payloads).flatMap(payload -> payload.supportedTasks().stream().map(taskType -> {
|
||||||
|
var key = new TaskAndApi(taskType, payload.api());
|
||||||
|
SageMakerSchema value;
|
||||||
|
if (payload instanceof SageMakerStreamSchemaPayload streamPayload) {
|
||||||
|
value = new SageMakerStreamSchema(streamPayload);
|
||||||
|
} else {
|
||||||
|
value = new SageMakerSchema(payload);
|
||||||
|
}
|
||||||
|
return Map.entry(key, value);
|
||||||
|
})).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Automatically register the stored Schema writeables with {@link org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider}.
|
||||||
|
*/
|
||||||
|
public static List<NamedWriteableRegistry.Entry> namedWriteables() {
|
||||||
|
return Stream.concat(
|
||||||
|
Stream.of(
|
||||||
|
new NamedWriteableRegistry.Entry(
|
||||||
|
SageMakerStoredServiceSchema.class,
|
||||||
|
SageMakerStoredServiceSchema.NO_OP.getWriteableName(),
|
||||||
|
in -> SageMakerStoredServiceSchema.NO_OP
|
||||||
|
),
|
||||||
|
new NamedWriteableRegistry.Entry(
|
||||||
|
SageMakerStoredTaskSchema.class,
|
||||||
|
SageMakerStoredTaskSchema.NO_OP.getWriteableName(),
|
||||||
|
in -> SageMakerStoredTaskSchema.NO_OP
|
||||||
|
)
|
||||||
|
),
|
||||||
|
schemas.values().stream().flatMap(SageMakerSchema::namedWriteables)
|
||||||
|
).toList();
|
||||||
|
}
|
||||||
|
|
||||||
|
public SageMakerSchema schemaFor(SageMakerModel model) throws ElasticsearchStatusException {
|
||||||
|
return schemaFor(model.getTaskType(), model.api());
|
||||||
|
}
|
||||||
|
|
||||||
|
public SageMakerSchema schemaFor(TaskType taskType, String api) throws ElasticsearchStatusException {
|
||||||
|
var schema = schemas.get(new TaskAndApi(taskType, api));
|
||||||
|
if (schema == null) {
|
||||||
|
throw new ElasticsearchStatusException(
|
||||||
|
format(
|
||||||
|
"Task [%s] is not compatible for service [sagemaker] and api [%s]. Supported tasks: [%s]",
|
||||||
|
api,
|
||||||
|
taskType.toString(),
|
||||||
|
tasksByApi.getOrDefault(api, Set.of())
|
||||||
|
),
|
||||||
|
RestStatus.METHOD_NOT_ALLOWED
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return schema;
|
||||||
|
}
|
||||||
|
|
||||||
|
public SageMakerStreamSchema streamSchemaFor(SageMakerModel model) throws ElasticsearchStatusException {
|
||||||
|
var schema = streamSchemas.get(new TaskAndApi(model.getTaskType(), model.api()));
|
||||||
|
if (schema == null) {
|
||||||
|
throw new ElasticsearchStatusException(
|
||||||
|
format(
|
||||||
|
"Streaming is not allowed for service [sagemaker], api [%s], and task [%s]. Supported streaming tasks: [%s]",
|
||||||
|
model.api(),
|
||||||
|
model.getTaskType().toString(),
|
||||||
|
streamingTasksByApi.getOrDefault(model.api(), Set.of())
|
||||||
|
),
|
||||||
|
RestStatus.METHOD_NOT_ALLOWED
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return schema;
|
||||||
|
}
|
||||||
|
|
||||||
|
public EnumSet<TaskType> supportedTaskTypes() {
|
||||||
|
return supportedTaskTypes;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Set<TaskType> supportedStreamingTasks() {
|
||||||
|
return supportedStreamingTasks;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,75 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
|
||||||
|
|
||||||
|
import org.elasticsearch.TransportVersion;
|
||||||
|
import org.elasticsearch.TransportVersions;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
|
import org.elasticsearch.inference.ServiceSettings;
|
||||||
|
import org.elasticsearch.xcontent.ToXContentObject;
|
||||||
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Contains any model-specific settings that are stored in SageMakerServiceSettings.
|
||||||
|
*/
|
||||||
|
public interface SageMakerStoredServiceSchema extends ServiceSettings {
|
||||||
|
SageMakerStoredServiceSchema NO_OP = new SageMakerStoredServiceSchema() {
|
||||||
|
|
||||||
|
private static final String NAME = "noop_sagemaker_service_schema";
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getWriteableName() {
|
||||||
|
return NAME;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TransportVersion getMinimalSupportedVersion() {
|
||||||
|
return TransportVersions.ML_INFERENCE_SAGEMAKER;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void writeTo(StreamOutput out) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) {
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The model is part of the SageMaker Endpoint definition and is not declared in the service settings.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
default String modelId() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
default ToXContentObject getFilteredXContentObject() {
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* These extra service settings serialize flatly alongside the overall SageMaker ServiceSettings.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
default boolean isFragment() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* If this Schema supports Text Embeddings, then we need to implement this.
|
||||||
|
* {@link org.elasticsearch.xpack.inference.services.validation.TextEmbeddingModelValidator} will set the dimensions if the user
|
||||||
|
* does not do it, so we need to store the dimensions and flip the {@link #dimensionsSetByUser()} boolean.
|
||||||
|
*/
|
||||||
|
default SageMakerStoredServiceSchema updateModelWithEmbeddingDetails(Integer dimensions) {
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,64 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
|
||||||
|
|
||||||
|
import org.elasticsearch.TransportVersion;
|
||||||
|
import org.elasticsearch.TransportVersions;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
|
import org.elasticsearch.inference.TaskSettings;
|
||||||
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Contains any model-specific settings that are stored in SageMakerTaskSettings.
|
||||||
|
*/
|
||||||
|
public interface SageMakerStoredTaskSchema extends TaskSettings {
|
||||||
|
SageMakerStoredTaskSchema NO_OP = new SageMakerStoredTaskSchema() {
|
||||||
|
@Override
|
||||||
|
public boolean isEmpty() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SageMakerStoredTaskSchema updatedTaskSettings(Map<String, Object> newSettings) {
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final String NAME = "noop_sagemaker_task_schema";
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getWriteableName() {
|
||||||
|
return NAME;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TransportVersion getMinimalSupportedVersion() {
|
||||||
|
return TransportVersions.ML_INFERENCE_SAGEMAKER;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void writeTo(StreamOutput out) {}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) {
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* These extra service settings serialize flatly alongside the overall SageMaker ServiceSettings.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
default boolean isFragment() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
SageMakerStoredTaskSchema updatedTaskSettings(Map<String, Object> newSettings);
|
||||||
|
}
|
|
@ -0,0 +1,152 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
|
||||||
|
|
||||||
|
import software.amazon.awssdk.core.SdkBytes;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamRequest;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponseHandler;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.ResponseStream;
|
||||||
|
|
||||||
|
import org.elasticsearch.ElasticsearchStatusException;
|
||||||
|
import org.elasticsearch.ExceptionsHelper;
|
||||||
|
import org.elasticsearch.common.CheckedBiFunction;
|
||||||
|
import org.elasticsearch.common.CheckedSupplier;
|
||||||
|
import org.elasticsearch.inference.InferenceServiceResults;
|
||||||
|
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
||||||
|
import org.elasticsearch.rest.RestStatus;
|
||||||
|
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
|
||||||
|
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
|
||||||
|
|
||||||
|
import java.util.Locale;
|
||||||
|
import java.util.concurrent.Flow;
|
||||||
|
import java.util.function.BiFunction;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* All the logic that is required to call any SageMaker model is handled within this Schema class.
|
||||||
|
* Any model-specific logic is handled within the associated {@link SageMakerStreamSchemaPayload}.
|
||||||
|
* This schema is specific for SageMaker's streaming API. For non-streaming, see {@link SageMakerSchema}.
|
||||||
|
*/
|
||||||
|
public class SageMakerStreamSchema extends SageMakerSchema {
|
||||||
|
|
||||||
|
private final SageMakerStreamSchemaPayload payload;
|
||||||
|
|
||||||
|
public SageMakerStreamSchema(SageMakerStreamSchemaPayload payload) {
|
||||||
|
super(payload);
|
||||||
|
this.payload = payload;
|
||||||
|
}
|
||||||
|
|
||||||
|
public InvokeEndpointWithResponseStreamRequest streamRequest(SageMakerModel model, SageMakerInferenceRequest request) {
|
||||||
|
return streamRequest(model, () -> payload.requestBytes(model, request));
|
||||||
|
}
|
||||||
|
|
||||||
|
private InvokeEndpointWithResponseStreamRequest streamRequest(SageMakerModel model, CheckedSupplier<SdkBytes, Exception> body) {
|
||||||
|
try {
|
||||||
|
return createStreamRequest(model).accept(payload.accept(model))
|
||||||
|
.contentType(payload.contentType(model))
|
||||||
|
.body(body.get())
|
||||||
|
.build();
|
||||||
|
} catch (ElasticsearchStatusException e) {
|
||||||
|
throw e;
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new ElasticsearchStatusException(
|
||||||
|
"Failed to create SageMaker request for [%s]",
|
||||||
|
RestStatus.INTERNAL_SERVER_ERROR,
|
||||||
|
e,
|
||||||
|
model.getInferenceEntityId()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public InferenceServiceResults streamResponse(SageMakerModel model, SageMakerClient.SageMakerStream response) {
|
||||||
|
return streamResponse(model, response, payload::streamResponseBody, this::error);
|
||||||
|
}
|
||||||
|
|
||||||
|
private InferenceServiceResults streamResponse(
|
||||||
|
SageMakerModel model,
|
||||||
|
SageMakerClient.SageMakerStream response,
|
||||||
|
CheckedBiFunction<SageMakerModel, SdkBytes, InferenceServiceResults.Result, Exception> parseFunction,
|
||||||
|
BiFunction<SageMakerModel, Exception, Exception> errorFunction
|
||||||
|
) {
|
||||||
|
return new StreamingChatCompletionResults(downstream -> {
|
||||||
|
response.responseStream().subscribe(new Flow.Subscriber<>() {
|
||||||
|
private volatile Flow.Subscription upstream;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onSubscribe(Flow.Subscription subscription) {
|
||||||
|
this.upstream = subscription;
|
||||||
|
downstream.onSubscribe(subscription);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onNext(ResponseStream item) {
|
||||||
|
if (item.sdkEventType() == ResponseStream.EventType.PAYLOAD_PART) {
|
||||||
|
item.accept(InvokeEndpointWithResponseStreamResponseHandler.Visitor.builder().onPayloadPart(payloadPart -> {
|
||||||
|
try {
|
||||||
|
downstream.onNext(parseFunction.apply(model, payloadPart.bytes()));
|
||||||
|
} catch (Exception e) {
|
||||||
|
downstream.onError(errorFunction.apply(model, e));
|
||||||
|
}
|
||||||
|
}).build());
|
||||||
|
} else {
|
||||||
|
assert upstream != null : "upstream is unset";
|
||||||
|
upstream.request(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onError(Throwable throwable) {
|
||||||
|
if (throwable instanceof Exception e) {
|
||||||
|
downstream.onError(errorFunction.apply(model, e));
|
||||||
|
} else {
|
||||||
|
ExceptionsHelper.maybeError(throwable).ifPresent(ExceptionsHelper::maybeDieOnAnotherThread);
|
||||||
|
var e = new RuntimeException("Fatal while streaming SageMaker response for [" + model.getInferenceEntityId() + "]");
|
||||||
|
downstream.onError(errorFunction.apply(model, e));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onComplete() {
|
||||||
|
downstream.onComplete();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
public InvokeEndpointWithResponseStreamRequest chatCompletionStreamRequest(SageMakerModel model, UnifiedCompletionRequest request) {
|
||||||
|
return streamRequest(model, () -> payload.chatCompletionRequestBytes(model, request));
|
||||||
|
}
|
||||||
|
|
||||||
|
public InferenceServiceResults chatCompletionStreamResponse(SageMakerModel model, SageMakerClient.SageMakerStream response) {
|
||||||
|
return streamResponse(model, response, payload::chatCompletionResponseBody, this::chatCompletionError);
|
||||||
|
}
|
||||||
|
|
||||||
|
public UnifiedChatCompletionException chatCompletionError(SageMakerModel model, Exception e) {
|
||||||
|
if (e instanceof UnifiedChatCompletionException ucce) {
|
||||||
|
return ucce;
|
||||||
|
}
|
||||||
|
|
||||||
|
var error = errorMessageAndStatus(model, e);
|
||||||
|
return new UnifiedChatCompletionException(error.v2(), error.v1(), "error", error.v2().name().toLowerCase(Locale.ROOT));
|
||||||
|
}
|
||||||
|
|
||||||
|
private InvokeEndpointWithResponseStreamRequest.Builder createStreamRequest(SageMakerModel model) {
|
||||||
|
var request = InvokeEndpointWithResponseStreamRequest.builder();
|
||||||
|
request.endpointName(model.endpointName());
|
||||||
|
model.customAttributes().ifPresent(request::customAttributes);
|
||||||
|
model.inferenceComponentName().ifPresent(request::inferenceComponentName);
|
||||||
|
model.inferenceIdForDataCapture().ifPresent(request::inferenceId);
|
||||||
|
model.sessionId().ifPresent(request::sessionId);
|
||||||
|
model.targetContainerHostname().ifPresent(request::targetContainerHostname);
|
||||||
|
model.targetVariant().ifPresent(request::targetVariant);
|
||||||
|
return request;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,46 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
|
||||||
|
|
||||||
|
import software.amazon.awssdk.core.SdkBytes;
|
||||||
|
|
||||||
|
import org.elasticsearch.inference.InferenceServiceResults;
|
||||||
|
import org.elasticsearch.inference.TaskType;
|
||||||
|
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
|
||||||
|
|
||||||
|
import java.util.EnumSet;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Implemented for models that support streaming.
|
||||||
|
* This is an extension of {@link SageMakerSchemaPayload} because Elastic expects Completion tasks to handle both streaming and
|
||||||
|
* non-streaming, and all models currently support toggling streaming on/off.
|
||||||
|
*/
|
||||||
|
public interface SageMakerStreamSchemaPayload extends SageMakerSchemaPayload {
|
||||||
|
/**
|
||||||
|
* We currently only support streaming for Completion and Chat Completion, and if we are going to implement one then we should implement
|
||||||
|
* the other, so this interface requires both streaming input and streaming unified input.
|
||||||
|
* If we ever allowed streaming for more than just Completion, then we'd probably break up this class so that Unified Chat Completion
|
||||||
|
* was its own interface.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
default EnumSet<TaskType> supportedTasks() {
|
||||||
|
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This API would only be called for Completion task types. {@link #requestBytes(SageMakerModel, SageMakerInferenceRequest)} would
|
||||||
|
* handle the request translation for both streaming and non-streaming.
|
||||||
|
*/
|
||||||
|
InferenceServiceResults.Result streamResponseBody(SageMakerModel model, SdkBytes response) throws Exception;
|
||||||
|
|
||||||
|
SdkBytes chatCompletionRequestBytes(SageMakerModel model, UnifiedCompletionRequest request) throws Exception;
|
||||||
|
|
||||||
|
InferenceServiceResults.Result chatCompletionResponseBody(SageMakerModel model, SdkBytes response) throws Exception;
|
||||||
|
}
|
|
@ -0,0 +1,12 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
|
||||||
|
|
||||||
|
import org.elasticsearch.inference.TaskType;
|
||||||
|
|
||||||
|
record TaskAndApi(TaskType taskType, String api) {}
|
|
@ -0,0 +1,229 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.sagemaker.schema.openai;
|
||||||
|
|
||||||
|
import software.amazon.awssdk.core.SdkBytes;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
|
||||||
|
|
||||||
|
import org.elasticsearch.TransportVersion;
|
||||||
|
import org.elasticsearch.TransportVersions;
|
||||||
|
import org.elasticsearch.common.Strings;
|
||||||
|
import org.elasticsearch.common.ValidationException;
|
||||||
|
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamInput;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
|
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
|
||||||
|
import org.elasticsearch.inference.ModelConfigurations;
|
||||||
|
import org.elasticsearch.inference.SimilarityMeasure;
|
||||||
|
import org.elasticsearch.inference.TaskType;
|
||||||
|
import org.elasticsearch.xcontent.XContent;
|
||||||
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xcontent.XContentParserConfiguration;
|
||||||
|
import org.elasticsearch.xcontent.json.JsonXContent;
|
||||||
|
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
|
||||||
|
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiEmbeddingsResponseEntity;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemaPayload;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.EnumSet;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
|
||||||
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
|
||||||
|
|
||||||
|
public class OpenAiTextEmbeddingPayload implements SageMakerSchemaPayload {
|
||||||
|
|
||||||
|
private static final XContent jsonXContent = JsonXContent.jsonXContent;
|
||||||
|
private static final String APPLICATION_JSON = jsonXContent.type().mediaTypeWithoutParameters();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String api() {
|
||||||
|
return "openai";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public EnumSet<TaskType> supportedTasks() {
|
||||||
|
return EnumSet.of(TaskType.TEXT_EMBEDDING);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SageMakerStoredServiceSchema apiServiceSettings(Map<String, Object> serviceSettings, ValidationException validationException) {
|
||||||
|
return ApiServiceSettings.fromMap(serviceSettings, validationException);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SageMakerStoredTaskSchema apiTaskSettings(Map<String, Object> taskSettings, ValidationException validationException) {
|
||||||
|
return ApiTaskSettings.fromMap(taskSettings, validationException);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Stream<NamedWriteableRegistry.Entry> namedWriteables() {
|
||||||
|
return Stream.of(
|
||||||
|
new NamedWriteableRegistry.Entry(SageMakerStoredServiceSchema.class, ApiServiceSettings.NAME, ApiServiceSettings::new),
|
||||||
|
new NamedWriteableRegistry.Entry(SageMakerStoredTaskSchema.class, ApiTaskSettings.NAME, ApiTaskSettings::new)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String accept(SageMakerModel model) {
|
||||||
|
return APPLICATION_JSON;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String contentType(SageMakerModel model) {
|
||||||
|
return APPLICATION_JSON;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SdkBytes requestBytes(SageMakerModel model, SageMakerInferenceRequest request) throws Exception {
|
||||||
|
if (model.apiServiceSettings() instanceof ApiServiceSettings apiServiceSettings
|
||||||
|
&& model.apiTaskSettings() instanceof ApiTaskSettings apiTaskSettings) {
|
||||||
|
try (var builder = JsonXContent.contentBuilder()) {
|
||||||
|
builder.startObject();
|
||||||
|
if (request.query() != null) {
|
||||||
|
builder.field("query", request.query());
|
||||||
|
}
|
||||||
|
if (request.input().size() == 1) {
|
||||||
|
builder.field("input", request.input().get(0));
|
||||||
|
} else {
|
||||||
|
builder.field("input", request.input());
|
||||||
|
}
|
||||||
|
if (apiTaskSettings.user() != null) {
|
||||||
|
builder.field("user", apiTaskSettings.user());
|
||||||
|
}
|
||||||
|
if (apiServiceSettings.dimensionsSetByUser() && apiServiceSettings.dimensions() != null) {
|
||||||
|
builder.field("dimensions", apiServiceSettings.dimensions());
|
||||||
|
}
|
||||||
|
builder.endObject();
|
||||||
|
return SdkBytes.fromUtf8String(Strings.toString(builder));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw createUnsupportedSchemaException(model);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TextEmbeddingFloatResults responseBody(SageMakerModel model, InvokeEndpointResponse response) throws Exception {
|
||||||
|
try (var p = jsonXContent.createParser(XContentParserConfiguration.EMPTY, response.body().asInputStream())) {
|
||||||
|
return OpenAiEmbeddingsResponseEntity.EmbeddingFloatResult.PARSER.apply(p, null).toTextEmbeddingFloatResults();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
record ApiServiceSettings(@Nullable Integer dimensions, Boolean dimensionsSetByUser) implements SageMakerStoredServiceSchema {
|
||||||
|
private static final String NAME = "sagemaker_openai_text_embeddings_service_settings";
|
||||||
|
private static final String DIMENSIONS_FIELD = "dimensions";
|
||||||
|
|
||||||
|
ApiServiceSettings(StreamInput in) throws IOException {
|
||||||
|
this(in.readOptionalInt(), in.readBoolean());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getWriteableName() {
|
||||||
|
return NAME;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TransportVersion getMinimalSupportedVersion() {
|
||||||
|
return TransportVersions.ML_INFERENCE_SAGEMAKER;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void writeTo(StreamOutput out) throws IOException {
|
||||||
|
out.writeOptionalInt(dimensions);
|
||||||
|
out.writeBoolean(dimensionsSetByUser);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
if (dimensions != null) {
|
||||||
|
builder.field(DIMENSIONS_FIELD, dimensions);
|
||||||
|
}
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
static ApiServiceSettings fromMap(Map<String, Object> serviceSettings, ValidationException validationException) {
|
||||||
|
var dimensions = extractOptionalPositiveInteger(
|
||||||
|
serviceSettings,
|
||||||
|
DIMENSIONS_FIELD,
|
||||||
|
ModelConfigurations.SERVICE_SETTINGS,
|
||||||
|
validationException
|
||||||
|
);
|
||||||
|
|
||||||
|
return new ApiServiceSettings(dimensions, dimensions != null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SimilarityMeasure similarity() {
|
||||||
|
return SimilarityMeasure.DOT_PRODUCT;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DenseVectorFieldMapper.ElementType elementType() {
|
||||||
|
return DenseVectorFieldMapper.ElementType.FLOAT;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SageMakerStoredServiceSchema updateModelWithEmbeddingDetails(Integer dimensions) {
|
||||||
|
return new ApiServiceSettings(dimensions, false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
record ApiTaskSettings(@Nullable String user) implements SageMakerStoredTaskSchema {
|
||||||
|
private static final String NAME = "sagemaker_openai_text_embeddings_task_settings";
|
||||||
|
private static final String USER_FIELD = "user";
|
||||||
|
|
||||||
|
ApiTaskSettings(StreamInput in) throws IOException {
|
||||||
|
this(in.readOptionalString());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getWriteableName() {
|
||||||
|
return NAME;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TransportVersion getMinimalSupportedVersion() {
|
||||||
|
return TransportVersions.ML_INFERENCE_SAGEMAKER;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void writeTo(StreamOutput out) throws IOException {
|
||||||
|
out.writeOptionalString(user);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
return user != null ? builder.field(USER_FIELD, user) : builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isEmpty() {
|
||||||
|
return user == null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ApiTaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
|
||||||
|
var validationException = new ValidationException();
|
||||||
|
var newTaskSettings = fromMap(newSettings, validationException);
|
||||||
|
validationException.throwIfValidationErrorsExist();
|
||||||
|
|
||||||
|
return new ApiTaskSettings(newTaskSettings.user() != null ? newTaskSettings.user() : user);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ApiTaskSettings fromMap(Map<String, Object> map, ValidationException exception) {
|
||||||
|
var user = extractOptionalString(map, USER_FIELD, ModelConfigurations.TASK_SETTINGS, exception);
|
||||||
|
return new ApiTaskSettings(user);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,101 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services;
|
||||||
|
|
||||||
|
import org.elasticsearch.TransportVersion;
|
||||||
|
import org.elasticsearch.common.Strings;
|
||||||
|
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
||||||
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
|
import org.elasticsearch.xcontent.ToXContent;
|
||||||
|
import org.elasticsearch.xcontent.XContentParserConfiguration;
|
||||||
|
import org.elasticsearch.xcontent.json.JsonXContent;
|
||||||
|
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
|
||||||
|
import org.elasticsearch.xpack.inference.InferenceNamedWriteablesProvider;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* All Inference Setting classes derived from {@link org.elasticsearch.inference.ServiceSettings},
|
||||||
|
* {@link org.elasticsearch.inference.TaskSettings}, and {@link org.elasticsearch.inference.SecretSettings}
|
||||||
|
* must be able to read/write to an index via ToXContent as well as read/write between nodes via Writeable.
|
||||||
|
*/
|
||||||
|
public abstract class InferenceSettingsTestCase<T extends Writeable & ToXContent> extends AbstractBWCWireSerializationTestCase<T> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method is final because {@link org.elasticsearch.inference.ModelConfigurations} settings must be registered in
|
||||||
|
* {@link InferenceNamedWriteablesProvider}.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
protected final NamedWriteableRegistry getNamedWriteableRegistry() {
|
||||||
|
return new NamedWriteableRegistry(InferenceNamedWriteablesProvider.getNamedWriteables());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Helper implementation since most mutates can be handled by continuously randomizing until we have another test instance.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
protected T mutateInstance(T instance) throws IOException {
|
||||||
|
return randomValueOtherThan(instance, this::createTestInstance);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Change this for BWC once the implementation requires difference objects depending on the transport version.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
protected T mutateInstanceForVersion(T instance, TransportVersion version) {
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Verify that we can write to XContent and then read from XContent.
|
||||||
|
* This simulates saving the model to the index and then reading the model from the index.
|
||||||
|
*/
|
||||||
|
public final void testXContentRoundTrip() throws IOException {
|
||||||
|
var instance = createTestInstance();
|
||||||
|
var instanceAsMap = toMap(instance);
|
||||||
|
var roundTripInstance = fromMutableMap(new HashMap<>(instanceAsMap));
|
||||||
|
assertThat(roundTripInstance, equalTo(instance));
|
||||||
|
}
|
||||||
|
|
||||||
|
protected abstract T fromMutableMap(Map<String, Object> mutableMap);
|
||||||
|
|
||||||
|
public static Map<String, Object> toMap(ToXContent instance) throws IOException {
|
||||||
|
try (var builder = JsonXContent.contentBuilder()) {
|
||||||
|
if (instance.isFragment()) {
|
||||||
|
builder.startObject();
|
||||||
|
}
|
||||||
|
instance.toXContent(builder, ToXContent.EMPTY_PARAMS);
|
||||||
|
if (instance.isFragment()) {
|
||||||
|
builder.endObject();
|
||||||
|
}
|
||||||
|
var taskSettingsBytes = Strings.toString(builder).getBytes(StandardCharsets.UTF_8);
|
||||||
|
try (var parser = JsonXContent.jsonXContent.createParser(XContentParserConfiguration.EMPTY, taskSettingsBytes)) {
|
||||||
|
return parser.map();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Helper, since most settings contain optional strings.
|
||||||
|
*/
|
||||||
|
protected static String randomOptionalString() {
|
||||||
|
return randomBoolean() ? randomString() : null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Helper, since most settings contain strings.
|
||||||
|
*/
|
||||||
|
protected static String randomString() {
|
||||||
|
return randomAlphaOfLength(randomIntBetween(4, 8));
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,180 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.sagemaker;
|
||||||
|
|
||||||
|
import software.amazon.awssdk.core.async.SdkPublisher;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.SageMakerRuntimeAsyncClient;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointRequest;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamRequest;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponse;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponseHandler;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.ResponseStream;
|
||||||
|
|
||||||
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
import org.elasticsearch.common.cache.CacheLoader;
|
||||||
|
import org.elasticsearch.common.settings.SecureString;
|
||||||
|
import org.elasticsearch.core.TimeValue;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.threadpool.ThreadPool;
|
||||||
|
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
|
||||||
|
import org.junit.After;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.mockito.ArgumentMatchers;
|
||||||
|
import org.reactivestreams.Subscriber;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
import java.util.concurrent.CountDownLatch;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||||
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.ArgumentMatchers.assertArg;
|
||||||
|
import static org.mockito.ArgumentMatchers.eq;
|
||||||
|
import static org.mockito.Mockito.doAnswer;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.never;
|
||||||
|
import static org.mockito.Mockito.spy;
|
||||||
|
import static org.mockito.Mockito.times;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
public class SageMakerClientTests extends ESTestCase {
|
||||||
|
private static final InvokeEndpointRequest REQUEST = InvokeEndpointRequest.builder().build();
|
||||||
|
private static final InvokeEndpointWithResponseStreamRequest STREAM_REQUEST = InvokeEndpointWithResponseStreamRequest.builder().build();
|
||||||
|
private SageMakerRuntimeAsyncClient awsClient;
|
||||||
|
private CacheLoader<SageMakerClient.RegionAndSecrets, SageMakerRuntimeAsyncClient> clientFactory;
|
||||||
|
private SageMakerClient client;
|
||||||
|
private ThreadPool threadPool;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setUp() throws Exception {
|
||||||
|
super.setUp();
|
||||||
|
threadPool = createThreadPool(inferenceUtilityPool());
|
||||||
|
|
||||||
|
awsClient = mock();
|
||||||
|
clientFactory = spy(new CacheLoader<>() {
|
||||||
|
public SageMakerRuntimeAsyncClient load(SageMakerClient.RegionAndSecrets key) {
|
||||||
|
return awsClient;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
client = new SageMakerClient(clientFactory, threadPool);
|
||||||
|
}
|
||||||
|
|
||||||
|
@After
|
||||||
|
public void shutdown() throws IOException {
|
||||||
|
terminate(threadPool);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testInvoke() throws Exception {
|
||||||
|
var expectedResponse = InvokeEndpointResponse.builder().build();
|
||||||
|
when(awsClient.invokeEndpoint(any(InvokeEndpointRequest.class))).thenReturn(CompletableFuture.completedFuture(expectedResponse));
|
||||||
|
|
||||||
|
var listener = invoke(TimeValue.THIRTY_SECONDS);
|
||||||
|
verify(clientFactory, times(1)).load(any());
|
||||||
|
verify(listener, times(1)).onResponse(eq(expectedResponse));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static SageMakerClient.RegionAndSecrets regionAndSecrets() {
|
||||||
|
return new SageMakerClient.RegionAndSecrets(
|
||||||
|
"us-east-1",
|
||||||
|
new AwsSecretSettings(new SecureString("access"), new SecureString("secrets"))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private ActionListener<InvokeEndpointResponse> invoke(TimeValue timeout) throws InterruptedException {
|
||||||
|
var latch = new CountDownLatch(1);
|
||||||
|
ActionListener<InvokeEndpointResponse> listener = spy(ActionListener.noop());
|
||||||
|
client.invoke(regionAndSecrets(), REQUEST, timeout, ActionListener.runAfter(listener, latch::countDown));
|
||||||
|
assertTrue("Timed out waiting for invoke call", latch.await(5, TimeUnit.SECONDS));
|
||||||
|
return listener;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testInvokeCache() throws Exception {
|
||||||
|
when(awsClient.invokeEndpoint(any(InvokeEndpointRequest.class))).thenReturn(
|
||||||
|
CompletableFuture.completedFuture(InvokeEndpointResponse.builder().build())
|
||||||
|
);
|
||||||
|
|
||||||
|
invoke(TimeValue.THIRTY_SECONDS);
|
||||||
|
invoke(TimeValue.THIRTY_SECONDS);
|
||||||
|
verify(clientFactory, times(1)).load(any());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testInvokeTimeout() throws Exception {
|
||||||
|
when(awsClient.invokeEndpoint(any(InvokeEndpointRequest.class))).thenReturn(new CompletableFuture<>());
|
||||||
|
|
||||||
|
var listener = invoke(TimeValue.timeValueMillis(10));
|
||||||
|
|
||||||
|
verify(clientFactory, times(1)).load(any());
|
||||||
|
verifyTimeout(listener);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void verifyTimeout(ActionListener<?> listener) {
|
||||||
|
verify(listener, times(1)).onFailure(assertArg(e -> assertThat(e.getMessage(), equalTo("Request timed out after [10ms]"))));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testInvokeStream() throws Exception {
|
||||||
|
SdkPublisher<ResponseStream> publisher = mockPublisher();
|
||||||
|
|
||||||
|
var listener = invokeStream(TimeValue.THIRTY_SECONDS);
|
||||||
|
|
||||||
|
verify(publisher, never()).subscribe(ArgumentMatchers.<Subscriber<ResponseStream>>any());
|
||||||
|
verify(listener, times(1)).onResponse(assertArg(stream -> stream.responseStream().subscribe(mock())));
|
||||||
|
verify(publisher, times(1)).subscribe(ArgumentMatchers.<Subscriber<ResponseStream>>any());
|
||||||
|
}
|
||||||
|
|
||||||
|
private SdkPublisher<ResponseStream> mockPublisher() {
|
||||||
|
SdkPublisher<ResponseStream> publisher = mock();
|
||||||
|
doAnswer(ans -> {
|
||||||
|
InvokeEndpointWithResponseStreamResponseHandler handler = ans.getArgument(1);
|
||||||
|
handler.responseReceived(InvokeEndpointWithResponseStreamResponse.builder().build());
|
||||||
|
handler.onEventStream(publisher);
|
||||||
|
return CompletableFuture.completedFuture((Void) null);
|
||||||
|
}).when(awsClient).invokeEndpointWithResponseStream(any(InvokeEndpointWithResponseStreamRequest.class), any());
|
||||||
|
return publisher;
|
||||||
|
}
|
||||||
|
|
||||||
|
private ActionListener<SageMakerClient.SageMakerStream> invokeStream(TimeValue timeout) throws Exception {
|
||||||
|
var latch = new CountDownLatch(1);
|
||||||
|
ActionListener<SageMakerClient.SageMakerStream> listener = spy(ActionListener.noop());
|
||||||
|
client.invokeStream(regionAndSecrets(), STREAM_REQUEST, timeout, ActionListener.runAfter(listener, latch::countDown));
|
||||||
|
assertTrue("Timed out waiting for invoke call", latch.await(5, TimeUnit.SECONDS));
|
||||||
|
return listener;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testInvokeStreamCache() throws Exception {
|
||||||
|
mockPublisher();
|
||||||
|
|
||||||
|
invokeStream(TimeValue.THIRTY_SECONDS);
|
||||||
|
invokeStream(TimeValue.THIRTY_SECONDS);
|
||||||
|
|
||||||
|
verify(clientFactory, times(1)).load(any());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testInvokeStreamTimeout() throws Exception {
|
||||||
|
when(awsClient.invokeEndpointWithResponseStream(any(InvokeEndpointWithResponseStreamRequest.class), any())).thenReturn(
|
||||||
|
new CompletableFuture<>()
|
||||||
|
);
|
||||||
|
|
||||||
|
var listener = invokeStream(TimeValue.timeValueMillis(10));
|
||||||
|
|
||||||
|
verify(clientFactory, times(1)).load(any());
|
||||||
|
verifyTimeout(listener);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testClose() throws Exception {
|
||||||
|
// load cache
|
||||||
|
mockPublisher();
|
||||||
|
invokeStream(TimeValue.THIRTY_SECONDS);
|
||||||
|
// clear cache
|
||||||
|
client.close();
|
||||||
|
verify(awsClient, times(1)).close();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,463 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.sagemaker;
|
||||||
|
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponse;
|
||||||
|
|
||||||
|
import org.elasticsearch.ElasticsearchStatusException;
|
||||||
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
import org.elasticsearch.common.settings.SecureString;
|
||||||
|
import org.elasticsearch.common.settings.Settings;
|
||||||
|
import org.elasticsearch.common.util.concurrent.EsExecutors;
|
||||||
|
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||||
|
import org.elasticsearch.inference.ChunkInferenceInput;
|
||||||
|
import org.elasticsearch.inference.InputType;
|
||||||
|
import org.elasticsearch.inference.Model;
|
||||||
|
import org.elasticsearch.inference.ModelConfigurations;
|
||||||
|
import org.elasticsearch.inference.TaskType;
|
||||||
|
import org.elasticsearch.rest.RestStatus;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.threadpool.ThreadPool;
|
||||||
|
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
|
||||||
|
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests;
|
||||||
|
import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchema;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStreamSchema;
|
||||||
|
import org.junit.Before;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
import static org.elasticsearch.action.ActionListener.assertOnce;
|
||||||
|
import static org.elasticsearch.action.support.ActionTestUtils.assertNoFailureListener;
|
||||||
|
import static org.elasticsearch.action.support.ActionTestUtils.assertNoSuccessListener;
|
||||||
|
import static org.elasticsearch.core.TimeValue.THIRTY_SECONDS;
|
||||||
|
import static org.elasticsearch.xpack.core.inference.action.UnifiedCompletionRequestTests.randomUnifiedCompletionRequest;
|
||||||
|
import static org.hamcrest.Matchers.containsInAnyOrder;
|
||||||
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
import static org.hamcrest.Matchers.hasSize;
|
||||||
|
import static org.hamcrest.Matchers.in;
|
||||||
|
import static org.hamcrest.Matchers.isA;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.ArgumentMatchers.anyString;
|
||||||
|
import static org.mockito.ArgumentMatchers.assertArg;
|
||||||
|
import static org.mockito.ArgumentMatchers.eq;
|
||||||
|
import static org.mockito.Mockito.doAnswer;
|
||||||
|
import static org.mockito.Mockito.doThrow;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.only;
|
||||||
|
import static org.mockito.Mockito.times;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
public class SageMakerServiceTests extends ESTestCase {
|
||||||
|
|
||||||
|
private static final String QUERY = "query";
|
||||||
|
private static final List<String> INPUT = List.of("input");
|
||||||
|
private static final InputType INPUT_TYPE = InputType.UNSPECIFIED;
|
||||||
|
|
||||||
|
private SageMakerModelBuilder modelBuilder;
|
||||||
|
private SageMakerClient client;
|
||||||
|
private SageMakerSchemas schemas;
|
||||||
|
private SageMakerService sageMakerService;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void init() {
|
||||||
|
modelBuilder = mock();
|
||||||
|
client = mock();
|
||||||
|
schemas = mock();
|
||||||
|
ThreadPool threadPool = mock();
|
||||||
|
when(threadPool.executor(anyString())).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE);
|
||||||
|
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
|
||||||
|
sageMakerService = new SageMakerService(modelBuilder, client, schemas, threadPool, Map::of);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testSupportedTaskTypes() {
|
||||||
|
sageMakerService.supportedTaskTypes();
|
||||||
|
verify(schemas, only()).supportedTaskTypes();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testSupportedStreamingTasks() {
|
||||||
|
sageMakerService.supportedStreamingTasks();
|
||||||
|
verify(schemas, only()).supportedStreamingTasks();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testParseRequestConfig() {
|
||||||
|
sageMakerService.parseRequestConfig("modelId", TaskType.ANY, Map.of(), assertNoFailureListener(model -> {
|
||||||
|
verify(modelBuilder, only()).fromRequest(eq("modelId"), eq(TaskType.ANY), eq(SageMakerService.NAME), eq(Map.of()));
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testParsePersistedConfigWithSecrets() {
|
||||||
|
sageMakerService.parsePersistedConfigWithSecrets("modelId", TaskType.ANY, Map.of(), Map.of());
|
||||||
|
verify(modelBuilder, only()).fromStorage(eq("modelId"), eq(TaskType.ANY), eq(SageMakerService.NAME), eq(Map.of()), eq(Map.of()));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testParsePersistedConfig() {
|
||||||
|
sageMakerService.parsePersistedConfig("modelId", TaskType.ANY, Map.of());
|
||||||
|
verify(modelBuilder, only()).fromStorage(eq("modelId"), eq(TaskType.ANY), eq(SageMakerService.NAME), eq(Map.of()), eq(null));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testInferWithWrongModel() {
|
||||||
|
sageMakerService.infer(
|
||||||
|
mockUnsupportedModel(),
|
||||||
|
QUERY,
|
||||||
|
false,
|
||||||
|
null,
|
||||||
|
INPUT,
|
||||||
|
false,
|
||||||
|
null,
|
||||||
|
INPUT_TYPE,
|
||||||
|
THIRTY_SECONDS,
|
||||||
|
assertUnsupportedModel()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Model mockUnsupportedModel() {
|
||||||
|
Model model = mock();
|
||||||
|
ModelConfigurations modelConfigurations = mock();
|
||||||
|
when(modelConfigurations.getService()).thenReturn("mockService");
|
||||||
|
when(modelConfigurations.getInferenceEntityId()).thenReturn("mockInferenceId");
|
||||||
|
when(model.getConfigurations()).thenReturn(modelConfigurations);
|
||||||
|
return model;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static <T> ActionListener<T> assertUnsupportedModel() {
|
||||||
|
return assertNoSuccessListener(e -> {
|
||||||
|
assertThat(e, isA(ElasticsearchStatusException.class));
|
||||||
|
assertThat(
|
||||||
|
e.getMessage(),
|
||||||
|
equalTo(
|
||||||
|
"The internal model was invalid, please delete the service [mockService] "
|
||||||
|
+ "with id [mockInferenceId] and add it again."
|
||||||
|
)
|
||||||
|
);
|
||||||
|
assertThat(((ElasticsearchStatusException) e).status(), equalTo(RestStatus.INTERNAL_SERVER_ERROR));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testInfer() {
|
||||||
|
var model = mockModel();
|
||||||
|
|
||||||
|
SageMakerSchema schema = mock();
|
||||||
|
when(schemas.schemaFor(model)).thenReturn(schema);
|
||||||
|
mockInvoke();
|
||||||
|
|
||||||
|
sageMakerService.infer(
|
||||||
|
model,
|
||||||
|
QUERY,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
INPUT,
|
||||||
|
false,
|
||||||
|
null,
|
||||||
|
INPUT_TYPE,
|
||||||
|
THIRTY_SECONDS,
|
||||||
|
assertNoFailureListener(ignored -> {
|
||||||
|
verify(schemas, only()).schemaFor(eq(model));
|
||||||
|
verify(schema, times(1)).request(eq(model), assertRequest());
|
||||||
|
verify(schema, times(1)).response(eq(model), any(), any());
|
||||||
|
})
|
||||||
|
);
|
||||||
|
verify(client, only()).invoke(any(), any(), any(), any());
|
||||||
|
verifyNoMoreInteractions(client, schemas, schema);
|
||||||
|
}
|
||||||
|
|
||||||
|
private SageMakerModel mockModel() {
|
||||||
|
SageMakerModel model = mock();
|
||||||
|
when(model.override(null)).thenReturn(model);
|
||||||
|
when(model.awsSecretSettings()).thenReturn(
|
||||||
|
Optional.of(new AwsSecretSettings(new SecureString("test-accessKey"), new SecureString("test-secretKey")))
|
||||||
|
);
|
||||||
|
return model;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void mockInvoke() {
|
||||||
|
doAnswer(ans -> {
|
||||||
|
ActionListener<InvokeEndpointResponse> responseListener = ans.getArgument(3);
|
||||||
|
responseListener.onResponse(InvokeEndpointResponse.builder().build());
|
||||||
|
return null; // Void
|
||||||
|
}).when(client).invoke(any(), any(), any(), any());
|
||||||
|
}
|
||||||
|
|
||||||
|
private static SageMakerInferenceRequest assertRequest() {
|
||||||
|
return assertArg(request -> {
|
||||||
|
assertThat(request.query(), equalTo(QUERY));
|
||||||
|
assertThat(request.input(), containsInAnyOrder(INPUT.toArray()));
|
||||||
|
assertThat(request.inputType(), equalTo(InputType.UNSPECIFIED));
|
||||||
|
assertNull(request.returnDocuments());
|
||||||
|
assertNull(request.topN());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testInferStream() {
|
||||||
|
SageMakerModel model = mockModel();
|
||||||
|
|
||||||
|
SageMakerStreamSchema schema = mock();
|
||||||
|
when(schemas.streamSchemaFor(model)).thenReturn(schema);
|
||||||
|
mockInvokeStream();
|
||||||
|
|
||||||
|
sageMakerService.infer(model, QUERY, null, null, INPUT, true, null, INPUT_TYPE, THIRTY_SECONDS, assertNoFailureListener(ignored -> {
|
||||||
|
verify(schemas, only()).streamSchemaFor(eq(model));
|
||||||
|
verify(schema, times(1)).streamRequest(eq(model), assertRequest());
|
||||||
|
verify(schema, times(1)).streamResponse(eq(model), any());
|
||||||
|
}));
|
||||||
|
verify(client, only()).invokeStream(any(), any(), any(), any());
|
||||||
|
verifyNoMoreInteractions(client, schemas, schema);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void mockInvokeStream() {
|
||||||
|
doAnswer(ans -> {
|
||||||
|
ActionListener<SageMakerClient.SageMakerStream> responseListener = ans.getArgument(3);
|
||||||
|
responseListener.onResponse(
|
||||||
|
new SageMakerClient.SageMakerStream(InvokeEndpointWithResponseStreamResponse.builder().build(), mock())
|
||||||
|
);
|
||||||
|
return null; // Void
|
||||||
|
}).when(client).invokeStream(any(), any(), any(), any());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testInferError() {
|
||||||
|
SageMakerModel model = mockModel();
|
||||||
|
|
||||||
|
var expectedException = new IllegalArgumentException("hola");
|
||||||
|
SageMakerSchema schema = mock();
|
||||||
|
when(schemas.schemaFor(model)).thenReturn(schema);
|
||||||
|
mockInvokeFailure(expectedException);
|
||||||
|
|
||||||
|
sageMakerService.infer(
|
||||||
|
model,
|
||||||
|
QUERY,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
INPUT,
|
||||||
|
false,
|
||||||
|
null,
|
||||||
|
INPUT_TYPE,
|
||||||
|
THIRTY_SECONDS,
|
||||||
|
assertNoSuccessListener(ignored -> {
|
||||||
|
verify(schemas, only()).schemaFor(eq(model));
|
||||||
|
verify(schema, times(1)).request(eq(model), assertRequest());
|
||||||
|
verify(schema, times(1)).error(eq(model), assertArg(e -> assertThat(e, equalTo(expectedException))));
|
||||||
|
})
|
||||||
|
);
|
||||||
|
verify(client, only()).invoke(any(), any(), any(), any());
|
||||||
|
verifyNoMoreInteractions(client, schemas, schema);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void mockInvokeFailure(Exception e) {
|
||||||
|
doAnswer(ans -> {
|
||||||
|
ActionListener<?> responseListener = ans.getArgument(3);
|
||||||
|
responseListener.onFailure(e);
|
||||||
|
return null; // Void
|
||||||
|
}).when(client).invoke(any(), any(), any(), any());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testInferException() {
|
||||||
|
SageMakerModel model = mockModel();
|
||||||
|
when(model.getInferenceEntityId()).thenReturn("some id");
|
||||||
|
|
||||||
|
SageMakerStreamSchema schema = mock();
|
||||||
|
when(schemas.streamSchemaFor(model)).thenReturn(schema);
|
||||||
|
doThrow(new IllegalArgumentException("wow, really?")).when(client).invokeStream(any(), any(), any(), any());
|
||||||
|
|
||||||
|
sageMakerService.infer(model, QUERY, null, null, INPUT, true, null, INPUT_TYPE, THIRTY_SECONDS, assertNoSuccessListener(e -> {
|
||||||
|
verify(schemas, only()).streamSchemaFor(eq(model));
|
||||||
|
verify(schema, times(1)).streamRequest(eq(model), assertRequest());
|
||||||
|
assertThat(e, isA(ElasticsearchStatusException.class));
|
||||||
|
assertThat(((ElasticsearchStatusException) e).status(), equalTo(RestStatus.INTERNAL_SERVER_ERROR));
|
||||||
|
assertThat(e.getMessage(), equalTo("Failed to call SageMaker for inference id [some id]."));
|
||||||
|
}));
|
||||||
|
verify(client, only()).invokeStream(any(), any(), any(), any());
|
||||||
|
verifyNoMoreInteractions(client, schemas, schema);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testUnifiedInferWithWrongModel() {
|
||||||
|
sageMakerService.unifiedCompletionInfer(
|
||||||
|
mockUnsupportedModel(),
|
||||||
|
randomUnifiedCompletionRequest(),
|
||||||
|
THIRTY_SECONDS,
|
||||||
|
assertUnsupportedModel()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testUnifiedInfer() {
|
||||||
|
var model = mockModel();
|
||||||
|
|
||||||
|
SageMakerStreamSchema schema = mock();
|
||||||
|
when(schemas.streamSchemaFor(model)).thenReturn(schema);
|
||||||
|
mockInvokeStream();
|
||||||
|
|
||||||
|
sageMakerService.unifiedCompletionInfer(
|
||||||
|
model,
|
||||||
|
randomUnifiedCompletionRequest(),
|
||||||
|
THIRTY_SECONDS,
|
||||||
|
assertNoFailureListener(ignored -> {
|
||||||
|
verify(schemas, only()).streamSchemaFor(eq(model));
|
||||||
|
verify(schema, times(1)).chatCompletionStreamRequest(eq(model), any());
|
||||||
|
verify(schema, times(1)).chatCompletionStreamResponse(eq(model), any());
|
||||||
|
})
|
||||||
|
);
|
||||||
|
verify(client, only()).invokeStream(any(), any(), any(), any());
|
||||||
|
verifyNoMoreInteractions(client, schemas, schema);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testUnifiedInferError() {
|
||||||
|
var model = mockModel();
|
||||||
|
|
||||||
|
var expectedException = new IllegalArgumentException("hola");
|
||||||
|
SageMakerStreamSchema schema = mock();
|
||||||
|
when(schemas.streamSchemaFor(model)).thenReturn(schema);
|
||||||
|
mockInvokeStreamFailure(expectedException);
|
||||||
|
|
||||||
|
sageMakerService.unifiedCompletionInfer(
|
||||||
|
model,
|
||||||
|
randomUnifiedCompletionRequest(),
|
||||||
|
THIRTY_SECONDS,
|
||||||
|
assertNoSuccessListener(ignored -> {
|
||||||
|
verify(schemas, only()).streamSchemaFor(eq(model));
|
||||||
|
verify(schema, times(1)).chatCompletionStreamRequest(eq(model), any());
|
||||||
|
verify(schema, times(1)).chatCompletionError(eq(model), assertArg(e -> assertThat(e, equalTo(expectedException))));
|
||||||
|
})
|
||||||
|
);
|
||||||
|
verify(client, only()).invokeStream(any(), any(), any(), any());
|
||||||
|
verifyNoMoreInteractions(client, schemas, schema);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void mockInvokeStreamFailure(Exception e) {
|
||||||
|
doAnswer(ans -> {
|
||||||
|
ActionListener<?> responseListener = ans.getArgument(3);
|
||||||
|
responseListener.onFailure(e);
|
||||||
|
return null; // Void
|
||||||
|
}).when(client).invokeStream(any(), any(), any(), any());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testUnifiedInferException() {
|
||||||
|
SageMakerModel model = mockModel();
|
||||||
|
when(model.getInferenceEntityId()).thenReturn("some id");
|
||||||
|
|
||||||
|
SageMakerStreamSchema schema = mock();
|
||||||
|
when(schemas.streamSchemaFor(model)).thenReturn(schema);
|
||||||
|
doThrow(new IllegalArgumentException("wow, rude")).when(client).invokeStream(any(), any(), any(), any());
|
||||||
|
|
||||||
|
sageMakerService.unifiedCompletionInfer(model, randomUnifiedCompletionRequest(), THIRTY_SECONDS, assertNoSuccessListener(e -> {
|
||||||
|
verify(schemas, only()).streamSchemaFor(eq(model));
|
||||||
|
verify(schema, times(1)).chatCompletionStreamRequest(eq(model), any());
|
||||||
|
assertThat(e, isA(ElasticsearchStatusException.class));
|
||||||
|
assertThat(((ElasticsearchStatusException) e).status(), equalTo(RestStatus.INTERNAL_SERVER_ERROR));
|
||||||
|
assertThat(e.getMessage(), equalTo("Failed to call SageMaker for inference id [some id]."));
|
||||||
|
}));
|
||||||
|
verify(client, only()).invokeStream(any(), any(), any(), any());
|
||||||
|
verifyNoMoreInteractions(client, schemas, schema);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testChunkedInferWithWrongModel() {
|
||||||
|
sageMakerService.chunkedInfer(
|
||||||
|
mockUnsupportedModel(),
|
||||||
|
QUERY,
|
||||||
|
INPUT.stream().map(ChunkInferenceInput::new).toList(),
|
||||||
|
null,
|
||||||
|
INPUT_TYPE,
|
||||||
|
THIRTY_SECONDS,
|
||||||
|
assertUnsupportedModel()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testChunkedInfer() throws Exception {
|
||||||
|
var model = mockModelForChunking();
|
||||||
|
|
||||||
|
SageMakerSchema schema = mock();
|
||||||
|
when(schema.response(any(), any(), any())).thenReturn(TextEmbeddingFloatResultsTests.createRandomResults());
|
||||||
|
when(schemas.schemaFor(model)).thenReturn(schema);
|
||||||
|
mockInvoke();
|
||||||
|
|
||||||
|
var expectedInput = Set.of("first", "second");
|
||||||
|
|
||||||
|
sageMakerService.chunkedInfer(
|
||||||
|
model,
|
||||||
|
QUERY,
|
||||||
|
expectedInput.stream().map(ChunkInferenceInput::new).toList(),
|
||||||
|
null,
|
||||||
|
INPUT_TYPE,
|
||||||
|
THIRTY_SECONDS,
|
||||||
|
assertOnce(assertNoFailureListener(chunkedInferences -> {
|
||||||
|
verify(schemas, times(2)).schemaFor(eq(model));
|
||||||
|
verify(schema, times(2)).request(eq(model), assertChunkRequest(expectedInput));
|
||||||
|
verify(schema, times(2)).response(eq(model), any(), any());
|
||||||
|
}))
|
||||||
|
);
|
||||||
|
verify(client, times(2)).invoke(any(), any(), any(), any());
|
||||||
|
verifyNoMoreInteractions(client, schemas, schema);
|
||||||
|
}
|
||||||
|
|
||||||
|
private SageMakerModel mockModelForChunking() {
|
||||||
|
var model = mockModel();
|
||||||
|
when(model.batchSize()).thenReturn(Optional.of(1));
|
||||||
|
ModelConfigurations modelConfigurations = mock();
|
||||||
|
when(modelConfigurations.getChunkingSettings()).thenReturn(new WordBoundaryChunkingSettings(1, 0));
|
||||||
|
when(model.getConfigurations()).thenReturn(modelConfigurations);
|
||||||
|
return model;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static SageMakerInferenceRequest assertChunkRequest(Set<String> expectedInput) {
|
||||||
|
return assertArg(request -> {
|
||||||
|
assertThat(request.query(), equalTo(QUERY));
|
||||||
|
assertThat(request.input(), hasSize(1));
|
||||||
|
assertThat(request.input().get(0), in(expectedInput));
|
||||||
|
assertThat(request.inputType(), equalTo(InputType.UNSPECIFIED));
|
||||||
|
assertNull(request.returnDocuments());
|
||||||
|
assertNull(request.topN());
|
||||||
|
assertFalse(request.stream());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testChunkedInferError() {
|
||||||
|
var model = mockModelForChunking();
|
||||||
|
|
||||||
|
var expectedException = new IllegalArgumentException("hola");
|
||||||
|
SageMakerSchema schema = mock();
|
||||||
|
when(schema.error(any(), any())).thenReturn(expectedException);
|
||||||
|
when(schemas.schemaFor(model)).thenReturn(schema);
|
||||||
|
mockInvokeFailure(expectedException);
|
||||||
|
|
||||||
|
var expectedInput = Set.of("first", "second");
|
||||||
|
var expectedOutput = Stream.of(expectedException, expectedException).map(ChunkedInferenceError::new).toArray();
|
||||||
|
|
||||||
|
sageMakerService.chunkedInfer(
|
||||||
|
model,
|
||||||
|
QUERY,
|
||||||
|
expectedInput.stream().map(ChunkInferenceInput::new).toList(),
|
||||||
|
null,
|
||||||
|
INPUT_TYPE,
|
||||||
|
THIRTY_SECONDS,
|
||||||
|
assertOnce(assertNoFailureListener(chunkedInferences -> {
|
||||||
|
verify(schemas, times(2)).schemaFor(eq(model));
|
||||||
|
verify(schema, times(2)).request(eq(model), assertChunkRequest(expectedInput));
|
||||||
|
verify(schema, times(2)).error(eq(model), assertArg(e -> assertThat(e, equalTo(expectedException))));
|
||||||
|
assertThat(chunkedInferences, containsInAnyOrder(expectedOutput));
|
||||||
|
}))
|
||||||
|
);
|
||||||
|
verify(client, times(2)).invoke(any(), any(), any(), any());
|
||||||
|
verifyNoMoreInteractions(client, schemas, schema);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testClose() throws IOException {
|
||||||
|
sageMakerService.close();
|
||||||
|
verify(client, only()).close();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,315 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.sagemaker.model;
|
||||||
|
|
||||||
|
import org.elasticsearch.ElasticsearchStatusException;
|
||||||
|
import org.elasticsearch.common.Strings;
|
||||||
|
import org.elasticsearch.common.ValidationException;
|
||||||
|
import org.elasticsearch.inference.ModelConfigurations;
|
||||||
|
import org.elasticsearch.inference.ModelSecrets;
|
||||||
|
import org.elasticsearch.inference.TaskType;
|
||||||
|
import org.elasticsearch.inference.UnparsedModel;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.xcontent.ToXContent;
|
||||||
|
import org.elasticsearch.xcontent.XContentParserConfiguration;
|
||||||
|
import org.elasticsearch.xcontent.json.JsonXContent;
|
||||||
|
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemasTests;
|
||||||
|
import org.junit.Before;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
import static org.elasticsearch.inference.ModelConfigurations.USE_ID_FOR_INDEX;
|
||||||
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
|
||||||
|
public class SageMakerModelBuilderTests extends ESTestCase {
|
||||||
|
private static final String inferenceId = "inferenceId";
|
||||||
|
private static final TaskType taskType = TaskType.ANY;
|
||||||
|
private static final String service = "service";
|
||||||
|
private SageMakerModelBuilder builder;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setUp() throws Exception {
|
||||||
|
super.setUp();
|
||||||
|
builder = new SageMakerModelBuilder(SageMakerSchemasTests.mockSchemas());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testFromRequestWithRequiredFields() {
|
||||||
|
var model = fromRequest("""
|
||||||
|
{
|
||||||
|
"service_settings": {
|
||||||
|
"access_key": "test-access-key",
|
||||||
|
"secret_key": "test-secret-key",
|
||||||
|
"region": "us-east-1",
|
||||||
|
"api": "test-api",
|
||||||
|
"endpoint_name": "test-endpoint"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""");
|
||||||
|
|
||||||
|
assertNotNull(model);
|
||||||
|
assertTrue(model.awsSecretSettings().isPresent());
|
||||||
|
assertThat(model.awsSecretSettings().get().accessKey().toString(), equalTo("test-access-key"));
|
||||||
|
assertThat(model.awsSecretSettings().get().secretKey().toString(), equalTo("test-secret-key"));
|
||||||
|
assertThat(model.region(), equalTo("us-east-1"));
|
||||||
|
assertThat(model.api(), equalTo("test-api"));
|
||||||
|
assertThat(model.endpointName(), equalTo("test-endpoint"));
|
||||||
|
|
||||||
|
assertTrue(model.customAttributes().isEmpty());
|
||||||
|
assertTrue(model.enableExplanations().isEmpty());
|
||||||
|
assertTrue(model.inferenceComponentName().isEmpty());
|
||||||
|
assertTrue(model.inferenceIdForDataCapture().isEmpty());
|
||||||
|
assertTrue(model.sessionId().isEmpty());
|
||||||
|
assertTrue(model.targetContainerHostname().isEmpty());
|
||||||
|
assertTrue(model.targetModel().isEmpty());
|
||||||
|
assertTrue(model.targetVariant().isEmpty());
|
||||||
|
assertTrue(model.batchSize().isEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testFromRequestWithOptionalFields() {
|
||||||
|
var model = fromRequest("""
|
||||||
|
{
|
||||||
|
"service_settings": {
|
||||||
|
"access_key": "test-access-key",
|
||||||
|
"secret_key": "test-secret-key",
|
||||||
|
"region": "us-east-1",
|
||||||
|
"api": "test-api",
|
||||||
|
"endpoint_name": "test-endpoint",
|
||||||
|
"target_model": "test-target",
|
||||||
|
"target_container_hostname": "test-target-container",
|
||||||
|
"inference_component_name": "test-inference-component",
|
||||||
|
"batch_size": 1234
|
||||||
|
},
|
||||||
|
"task_settings": {
|
||||||
|
"custom_attributes": "test-custom-attributes",
|
||||||
|
"enable_explanations": "test-enable-explanations",
|
||||||
|
"inference_id": "test-inference-id",
|
||||||
|
"session_id": "test-session-id",
|
||||||
|
"target_variant": "test-target-variant"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""");
|
||||||
|
|
||||||
|
assertNotNull(model);
|
||||||
|
assertTrue(model.awsSecretSettings().isPresent());
|
||||||
|
assertThat(model.awsSecretSettings().get().accessKey().toString(), equalTo("test-access-key"));
|
||||||
|
assertThat(model.awsSecretSettings().get().secretKey().toString(), equalTo("test-secret-key"));
|
||||||
|
assertThat(model.region(), equalTo("us-east-1"));
|
||||||
|
assertThat(model.api(), equalTo("test-api"));
|
||||||
|
assertThat(model.endpointName(), equalTo("test-endpoint"));
|
||||||
|
|
||||||
|
assertPresent(model.customAttributes(), "test-custom-attributes");
|
||||||
|
assertPresent(model.enableExplanations(), "test-enable-explanations");
|
||||||
|
assertPresent(model.inferenceComponentName(), "test-inference-component");
|
||||||
|
assertPresent(model.inferenceIdForDataCapture(), "test-inference-id");
|
||||||
|
assertPresent(model.sessionId(), "test-session-id");
|
||||||
|
assertPresent(model.targetContainerHostname(), "test-target-container");
|
||||||
|
assertPresent(model.targetModel(), "test-target");
|
||||||
|
assertPresent(model.targetVariant(), "test-target-variant");
|
||||||
|
assertPresent(model.batchSize(), 1234);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testFromRequestWithoutAccessKey() {
|
||||||
|
testExceptionFromRequest("""
|
||||||
|
{
|
||||||
|
"service_settings": {
|
||||||
|
"secret_key": "test-secret-key",
|
||||||
|
"region": "us-east-1",
|
||||||
|
"api": "test-api",
|
||||||
|
"endpoint_name": "test-endpoint"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""", ValidationException.class, "Validation Failed: 1: [secret_settings] does not contain the required setting [access_key];");
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testFromRequestWithoutSecretKey() {
|
||||||
|
testExceptionFromRequest("""
|
||||||
|
{
|
||||||
|
"service_settings": {
|
||||||
|
"access_key": "test-access-key",
|
||||||
|
"region": "us-east-1",
|
||||||
|
"api": "test-api",
|
||||||
|
"endpoint_name": "test-endpoint"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""", ValidationException.class, "Validation Failed: 1: [secret_settings] does not contain the required setting [secret_key];");
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testFromRequestWithoutRegion() {
|
||||||
|
testExceptionFromRequest("""
|
||||||
|
{
|
||||||
|
"service_settings": {
|
||||||
|
"access_key": "test-access-key",
|
||||||
|
"secret_key": "test-secret-key",
|
||||||
|
"api": "test-api",
|
||||||
|
"endpoint_name": "test-endpoint"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""", ValidationException.class, "Validation Failed: 1: [service_settings] does not contain the required setting [region];");
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testFromRequestWithoutApi() {
|
||||||
|
testExceptionFromRequest("""
|
||||||
|
{
|
||||||
|
"service_settings": {
|
||||||
|
"access_key": "test-access-key",
|
||||||
|
"secret_key": "test-secret-key",
|
||||||
|
"region": "us-east-1",
|
||||||
|
"endpoint_name": "test-endpoint"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""", ValidationException.class, "Validation Failed: 1: [service_settings] does not contain the required setting [api];");
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testFromRequestWithoutEndpointName() {
|
||||||
|
testExceptionFromRequest(
|
||||||
|
"""
|
||||||
|
{
|
||||||
|
"service_settings": {
|
||||||
|
"access_key": "test-access-key",
|
||||||
|
"secret_key": "test-secret-key",
|
||||||
|
"region": "us-east-1",
|
||||||
|
"api": "test-api"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""",
|
||||||
|
ValidationException.class,
|
||||||
|
"Validation Failed: 1: [service_settings] does not contain the required setting [endpoint_name];"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testFromRequestWithExtraServiceKeys() {
|
||||||
|
testExceptionFromRequest(
|
||||||
|
"""
|
||||||
|
{
|
||||||
|
"service_settings": {
|
||||||
|
"access_key": "test-access-key",
|
||||||
|
"secret_key": "test-secret-key",
|
||||||
|
"region": "us-east-1",
|
||||||
|
"api": "test-api",
|
||||||
|
"endpoint_name": "test-endpoint",
|
||||||
|
"hello": "there"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""",
|
||||||
|
ElasticsearchStatusException.class,
|
||||||
|
"Model configuration contains settings [{hello=there}] unknown to the [service] service"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testFromRequestWithExtraTaskKeys() {
|
||||||
|
testExceptionFromRequest(
|
||||||
|
"""
|
||||||
|
{
|
||||||
|
"service_settings": {
|
||||||
|
"access_key": "test-access-key",
|
||||||
|
"secret_key": "test-secret-key",
|
||||||
|
"region": "us-east-1",
|
||||||
|
"api": "test-api",
|
||||||
|
"endpoint_name": "test-endpoint"
|
||||||
|
},
|
||||||
|
"task_settings": {
|
||||||
|
"hello": "there"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""",
|
||||||
|
ElasticsearchStatusException.class,
|
||||||
|
"Model configuration contains settings [{hello=there}] unknown to the [service] service"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testRoundTrip() throws IOException {
|
||||||
|
var expectedModel = fromRequest("""
|
||||||
|
{
|
||||||
|
"service_settings": {
|
||||||
|
"access_key": "test-access-key",
|
||||||
|
"secret_key": "test-secret-key",
|
||||||
|
"region": "us-east-1",
|
||||||
|
"api": "test-api",
|
||||||
|
"endpoint_name": "test-endpoint",
|
||||||
|
"target_model": "test-target",
|
||||||
|
"target_container_hostname": "test-target-container",
|
||||||
|
"inference_component_name": "test-inference-component",
|
||||||
|
"batch_size": 1234
|
||||||
|
},
|
||||||
|
"task_settings": {
|
||||||
|
"custom_attributes": "test-custom-attributes",
|
||||||
|
"enable_explanations": "test-enable-explanations",
|
||||||
|
"inference_id": "test-inference-id",
|
||||||
|
"session_id": "test-session-id",
|
||||||
|
"target_variant": "test-target-variant"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""");
|
||||||
|
|
||||||
|
var unparsedModelWithSecrets = unparsedModel(expectedModel.getConfigurations(), expectedModel.getSecrets());
|
||||||
|
var modelWithSecrets = builder.fromStorage(
|
||||||
|
unparsedModelWithSecrets.inferenceEntityId(),
|
||||||
|
unparsedModelWithSecrets.taskType(),
|
||||||
|
unparsedModelWithSecrets.service(),
|
||||||
|
unparsedModelWithSecrets.settings(),
|
||||||
|
unparsedModelWithSecrets.secrets()
|
||||||
|
);
|
||||||
|
assertThat(modelWithSecrets, equalTo(expectedModel));
|
||||||
|
assertNotNull(modelWithSecrets.getSecrets().getSecretSettings());
|
||||||
|
|
||||||
|
var unparsedModelWithoutSecrets = unparsedModel(expectedModel.getConfigurations(), null);
|
||||||
|
var modelWithoutSecrets = builder.fromStorage(
|
||||||
|
unparsedModelWithoutSecrets.inferenceEntityId(),
|
||||||
|
unparsedModelWithoutSecrets.taskType(),
|
||||||
|
unparsedModelWithoutSecrets.service(),
|
||||||
|
unparsedModelWithoutSecrets.settings(),
|
||||||
|
unparsedModelWithoutSecrets.secrets()
|
||||||
|
);
|
||||||
|
assertThat(modelWithoutSecrets.getConfigurations(), equalTo(expectedModel.getConfigurations()));
|
||||||
|
assertNull(modelWithoutSecrets.getSecrets().getSecretSettings());
|
||||||
|
}
|
||||||
|
|
||||||
|
private SageMakerModel fromRequest(String json) {
|
||||||
|
return builder.fromRequest(inferenceId, taskType, service, map(json));
|
||||||
|
}
|
||||||
|
|
||||||
|
private void testExceptionFromRequest(String json, Class<? extends Exception> exceptionClass, String message) {
|
||||||
|
var exception = assertThrows(exceptionClass, () -> fromRequest(json));
|
||||||
|
assertThat(exception.getMessage(), equalTo(message));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static <T> void assertPresent(Optional<T> optional, T expectedValue) {
|
||||||
|
assertTrue(optional.isPresent());
|
||||||
|
assertThat(optional.get(), equalTo(expectedValue));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Map<String, Object> map(String json) {
|
||||||
|
try (
|
||||||
|
var parser = JsonXContent.jsonXContent.createParser(XContentParserConfiguration.EMPTY, json.getBytes(StandardCharsets.UTF_8))
|
||||||
|
) {
|
||||||
|
return parser.map();
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new AssertionError(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static UnparsedModel unparsedModel(ModelConfigurations modelConfigurations, ModelSecrets modelSecrets) throws IOException {
|
||||||
|
var modelConfigMap = new ModelRegistry.ModelConfigMap(
|
||||||
|
toJsonMap(modelConfigurations),
|
||||||
|
modelSecrets != null ? toJsonMap(modelSecrets) : null
|
||||||
|
);
|
||||||
|
|
||||||
|
return ModelRegistry.unparsedModelFromMap(modelConfigMap);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Map<String, Object> toJsonMap(ToXContent toXContent) throws IOException {
|
||||||
|
try (var builder = JsonXContent.contentBuilder()) {
|
||||||
|
toXContent.toXContent(builder, new ToXContent.MapParams(Map.of(USE_ID_FOR_INDEX, "true")));
|
||||||
|
return map(Strings.toString(builder));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,43 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.sagemaker.model;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
|
import org.elasticsearch.inference.TaskType;
|
||||||
|
import org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemasTests;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class SageMakerServiceSettingsTests extends InferenceSettingsTestCase<SageMakerServiceSettings> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Writeable.Reader<SageMakerServiceSettings> instanceReader() {
|
||||||
|
return SageMakerServiceSettings::new;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected SageMakerServiceSettings createTestInstance() {
|
||||||
|
return new SageMakerServiceSettings(
|
||||||
|
randomString(),
|
||||||
|
randomString(),
|
||||||
|
randomString(),
|
||||||
|
randomOptionalString(),
|
||||||
|
randomOptionalString(),
|
||||||
|
randomOptionalString(),
|
||||||
|
randomBoolean() ? randomIntBetween(1, 1000) : null,
|
||||||
|
SageMakerStoredServiceSchema.NO_OP
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected SageMakerServiceSettings fromMutableMap(Map<String, Object> mutableMap) {
|
||||||
|
return SageMakerServiceSettings.fromMap(SageMakerSchemasTests.mockSchemas(), randomFrom(TaskType.values()), mutableMap);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,42 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.sagemaker.model;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.ValidationException;
|
||||||
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
|
import org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class SageMakerTaskSettingsTests extends InferenceSettingsTestCase<SageMakerTaskSettings> {
|
||||||
|
@Override
|
||||||
|
protected Writeable.Reader<SageMakerTaskSettings> instanceReader() {
|
||||||
|
return SageMakerTaskSettings::new;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected SageMakerTaskSettings createTestInstance() {
|
||||||
|
return new SageMakerTaskSettings(
|
||||||
|
randomOptionalString(),
|
||||||
|
randomOptionalString(),
|
||||||
|
randomOptionalString(),
|
||||||
|
randomOptionalString(),
|
||||||
|
randomOptionalString(),
|
||||||
|
SageMakerStoredTaskSchema.NO_OP
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected SageMakerTaskSettings fromMutableMap(Map<String, Object> mutableMap) {
|
||||||
|
var validationException = new ValidationException();
|
||||||
|
var taskSettings = SageMakerTaskSettings.fromMap(mutableMap, SageMakerStoredTaskSchema.NO_OP, validationException);
|
||||||
|
validationException.throwIfValidationErrorsExist();
|
||||||
|
return taskSettings;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,156 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
|
||||||
|
|
||||||
|
import software.amazon.awssdk.core.SdkBytes;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.ValidationException;
|
||||||
|
import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
|
||||||
|
import org.elasticsearch.inference.InputType;
|
||||||
|
import org.elasticsearch.inference.TaskType;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
|
||||||
|
import org.junit.Before;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.function.Predicate;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase.toMap;
|
||||||
|
import static org.hamcrest.Matchers.containsInAnyOrder;
|
||||||
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
import static org.hamcrest.Matchers.startsWith;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
public abstract class SageMakerSchemaPayloadTestCase<T extends SageMakerSchemaPayload> extends ESTestCase {
|
||||||
|
protected T payload;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setUp() throws Exception {
|
||||||
|
super.setUp();
|
||||||
|
payload = payload();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected abstract T payload();
|
||||||
|
|
||||||
|
protected abstract String expectedApi();
|
||||||
|
|
||||||
|
protected abstract Set<TaskType> expectedSupportedTaskTypes();
|
||||||
|
|
||||||
|
protected abstract SageMakerStoredServiceSchema randomApiServiceSettings();
|
||||||
|
|
||||||
|
protected abstract SageMakerStoredTaskSchema randomApiTaskSettings();
|
||||||
|
|
||||||
|
public final void testApi() {
|
||||||
|
assertThat(payload.api(), equalTo(expectedApi()));
|
||||||
|
}
|
||||||
|
|
||||||
|
public final void testSupportedTaskTypes() {
|
||||||
|
assertThat(payload.supportedTasks(), containsInAnyOrder(expectedSupportedTaskTypes().toArray()));
|
||||||
|
}
|
||||||
|
|
||||||
|
public final void testApiServiceSettings() throws IOException {
|
||||||
|
var validationException = new ValidationException();
|
||||||
|
var expectedApiServiceSettings = randomApiServiceSettings();
|
||||||
|
var actualApiServiceSettings = payload.apiServiceSettings(toMap(expectedApiServiceSettings), validationException);
|
||||||
|
assertThat(actualApiServiceSettings, equalTo(expectedApiServiceSettings));
|
||||||
|
validationException.throwIfValidationErrorsExist();
|
||||||
|
}
|
||||||
|
|
||||||
|
public final void testApiTaskSettings() throws IOException {
|
||||||
|
var validationException = new ValidationException();
|
||||||
|
var expectedApiTaskSettings = randomApiTaskSettings();
|
||||||
|
var actualApiTaskSettings = payload.apiTaskSettings(toMap(expectedApiTaskSettings), validationException);
|
||||||
|
assertThat(actualApiTaskSettings, equalTo(expectedApiTaskSettings));
|
||||||
|
validationException.throwIfValidationErrorsExist();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testNamedWriteables() {
|
||||||
|
var namedWriteables = payload.namedWriteables().map(entry -> entry.name).toList();
|
||||||
|
|
||||||
|
var filteredNames = Set.of(
|
||||||
|
SageMakerStoredServiceSchema.NO_OP.getWriteableName(),
|
||||||
|
SageMakerStoredTaskSchema.NO_OP.getWriteableName()
|
||||||
|
);
|
||||||
|
var expectedWriteables = Stream.of(randomApiServiceSettings(), randomApiTaskSettings())
|
||||||
|
.map(VersionedNamedWriteable::getWriteableName)
|
||||||
|
.filter(Predicate.not(filteredNames::contains))
|
||||||
|
.toArray();
|
||||||
|
assertThat(namedWriteables, containsInAnyOrder(expectedWriteables));
|
||||||
|
}
|
||||||
|
|
||||||
|
public final void testWithUnknownApiServiceSettings() {
|
||||||
|
SageMakerModel model = mock();
|
||||||
|
when(model.apiServiceSettings()).thenReturn(mock());
|
||||||
|
when(model.apiTaskSettings()).thenReturn(randomApiTaskSettings());
|
||||||
|
when(model.api()).thenReturn("serviceApi");
|
||||||
|
when(model.getTaskType()).thenReturn(TaskType.ANY);
|
||||||
|
|
||||||
|
var e = assertThrows(IllegalArgumentException.class, () -> payload.requestBytes(model, randomRequest()));
|
||||||
|
|
||||||
|
assertThat(e.getMessage(), startsWith("Unsupported SageMaker settings for api [serviceApi] and task type [any]:"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public final void testWithUnknownApiTaskSettings() {
|
||||||
|
SageMakerModel model = mock();
|
||||||
|
when(model.apiServiceSettings()).thenReturn(randomApiServiceSettings());
|
||||||
|
when(model.apiTaskSettings()).thenReturn(mock());
|
||||||
|
when(model.api()).thenReturn("taskApi");
|
||||||
|
when(model.getTaskType()).thenReturn(TaskType.ANY);
|
||||||
|
|
||||||
|
var e = assertThrows(IllegalArgumentException.class, () -> payload.requestBytes(model, randomRequest()));
|
||||||
|
|
||||||
|
assertThat(e.getMessage(), startsWith("Unsupported SageMaker settings for api [taskApi] and task type [any]:"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public final void testUpdate() throws IOException {
|
||||||
|
var taskSettings = randomApiTaskSettings();
|
||||||
|
if (taskSettings != SageMakerStoredTaskSchema.NO_OP) {
|
||||||
|
var otherTaskSettings = randomValueOtherThan(taskSettings, this::randomApiTaskSettings);
|
||||||
|
|
||||||
|
var updatedSettings = toMap(taskSettings.updatedTaskSettings(toMap(otherTaskSettings)));
|
||||||
|
|
||||||
|
var initialSettings = toMap(taskSettings);
|
||||||
|
var newSettings = toMap(otherTaskSettings);
|
||||||
|
|
||||||
|
newSettings.forEach((key, value) -> {
|
||||||
|
assertThat("Value should have been updated for key " + key, value, equalTo(updatedSettings.remove(key)));
|
||||||
|
});
|
||||||
|
initialSettings.forEach((key, value) -> {
|
||||||
|
if (updatedSettings.containsKey(key)) {
|
||||||
|
assertThat("Value should not have been updated for key " + key, value, equalTo(updatedSettings.remove(key)));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
assertTrue("Map should be empty now that we verified all updated keys and all initial keys", updatedSettings.isEmpty());
|
||||||
|
}
|
||||||
|
if (payload instanceof SageMakerStoredTaskSchema taskSchema) {
|
||||||
|
var otherTaskSettings = randomValueOtherThan(randomApiTaskSettings(), this::randomApiTaskSettings);
|
||||||
|
var otherTaskSettingsAsMap = toMap(otherTaskSettings);
|
||||||
|
|
||||||
|
taskSchema.updatedTaskSettings(otherTaskSettingsAsMap);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected static SageMakerInferenceRequest randomRequest() {
|
||||||
|
return new SageMakerInferenceRequest(
|
||||||
|
randomBoolean() ? randomAlphaOfLengthBetween(4, 8) : null,
|
||||||
|
randomOptionalBoolean(),
|
||||||
|
randomBoolean() ? randomInt() : null,
|
||||||
|
randomList(randomIntBetween(2, 4), () -> randomAlphaOfLengthBetween(2, 4)),
|
||||||
|
randomBoolean(),
|
||||||
|
randomFrom(InputType.values())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected static void assertSdkBytes(SdkBytes sdkBytes, String expectedValue) {
|
||||||
|
assertThat(sdkBytes.asUtf8String(), equalTo(expectedValue));
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,119 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.sagemaker.schema;
|
||||||
|
|
||||||
|
import org.elasticsearch.ElasticsearchStatusException;
|
||||||
|
import org.elasticsearch.inference.TaskType;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiTextEmbeddingPayload;
|
||||||
|
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.containsInAnyOrder;
|
||||||
|
import static org.hamcrest.Matchers.empty;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.ArgumentMatchers.anyMap;
|
||||||
|
import static org.mockito.ArgumentMatchers.anyString;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
public class SageMakerSchemasTests extends ESTestCase {
|
||||||
|
public static SageMakerSchemas mockSchemas() {
|
||||||
|
SageMakerSchemas schemas = mock();
|
||||||
|
var schema = mockSchema();
|
||||||
|
when(schemas.schemaFor(any(TaskType.class), anyString())).thenReturn(schema);
|
||||||
|
return schemas;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static SageMakerSchema mockSchema() {
|
||||||
|
SageMakerSchema schema = mock();
|
||||||
|
when(schema.apiServiceSettings(anyMap(), any())).thenReturn(SageMakerStoredServiceSchema.NO_OP);
|
||||||
|
when(schema.apiTaskSettings(anyMap(), any())).thenReturn(SageMakerStoredTaskSchema.NO_OP);
|
||||||
|
return schema;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final SageMakerSchemas schemas = new SageMakerSchemas();
|
||||||
|
|
||||||
|
public void testSupportedTaskTypes() {
|
||||||
|
assertThat(schemas.supportedTaskTypes(), containsInAnyOrder(TaskType.TEXT_EMBEDDING));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testSupportedStreamingTasks() {
|
||||||
|
assertThat(schemas.supportedStreamingTasks(), empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testSchemaFor() {
|
||||||
|
var payloads = Stream.of(new OpenAiTextEmbeddingPayload());
|
||||||
|
payloads.forEach(payload -> {
|
||||||
|
payload.supportedTasks().forEach(taskType -> {
|
||||||
|
var model = mockModel(taskType, payload.api());
|
||||||
|
assertNotNull(schemas.schemaFor(model));
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testStreamSchemaFor() {
|
||||||
|
var payloads = Stream.<SageMakerStreamSchemaPayload>of(/* For when we add support for streaming payloads */);
|
||||||
|
payloads.forEach(payload -> {
|
||||||
|
payload.supportedTasks().forEach(taskType -> {
|
||||||
|
var model = mockModel(taskType, payload.api());
|
||||||
|
assertNotNull(schemas.streamSchemaFor(model));
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private SageMakerModel mockModel(TaskType taskType, String api) {
|
||||||
|
SageMakerModel model = mock();
|
||||||
|
when(model.getTaskType()).thenReturn(taskType);
|
||||||
|
when(model.api()).thenReturn(api);
|
||||||
|
return model;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMissingTaskTypeThrowsException() {
|
||||||
|
var knownPayload = new OpenAiTextEmbeddingPayload();
|
||||||
|
var unknownTaskType = TaskType.COMPLETION;
|
||||||
|
var knownModel = mockModel(unknownTaskType, knownPayload.api());
|
||||||
|
assertThrows(
|
||||||
|
"Task [completion] is not compatible for service [sagemaker] and api [openai]. Supported tasks: [text_embedding]",
|
||||||
|
ElasticsearchStatusException.class,
|
||||||
|
() -> schemas.schemaFor(knownModel)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMissingSchemaThrowsException() {
|
||||||
|
var unknownModel = mockModel(TaskType.ANY, "blah");
|
||||||
|
assertThrows(
|
||||||
|
"Task [any] is not compatible for service [sagemaker] and api [blah]. Supported tasks: []",
|
||||||
|
ElasticsearchStatusException.class,
|
||||||
|
() -> schemas.schemaFor(unknownModel)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMissingStreamSchemaThrowsException() {
|
||||||
|
var unknownModel = mockModel(TaskType.ANY, "blah");
|
||||||
|
assertThrows(
|
||||||
|
"Streaming is not allowed for service [sagemaker], api [blah], and task [any]. Supported streaming tasks: []",
|
||||||
|
ElasticsearchStatusException.class,
|
||||||
|
() -> schemas.streamSchemaFor(unknownModel)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testNamedWriteables() {
|
||||||
|
var namedWriteables = Stream.of(new OpenAiTextEmbeddingPayload().namedWriteables());
|
||||||
|
|
||||||
|
var expectedNamedWriteables = Stream.concat(
|
||||||
|
namedWriteables.flatMap(names -> names.map(entry -> entry.name)),
|
||||||
|
Stream.of(SageMakerStoredServiceSchema.NO_OP.getWriteableName(), SageMakerStoredTaskSchema.NO_OP.getWriteableName())
|
||||||
|
).distinct().toArray();
|
||||||
|
|
||||||
|
var actualRegisteredNames = SageMakerSchemas.namedWriteables().stream().map(entry -> entry.name).toList();
|
||||||
|
|
||||||
|
assertThat(actualRegisteredNames, containsInAnyOrder(expectedNamedWriteables));
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,155 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.sagemaker.schema.openai;
|
||||||
|
|
||||||
|
import software.amazon.awssdk.core.SdkBytes;
|
||||||
|
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
|
||||||
|
|
||||||
|
import org.elasticsearch.inference.InputType;
|
||||||
|
import org.elasticsearch.inference.TaskType;
|
||||||
|
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemaPayloadTestCase;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema;
|
||||||
|
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema;
|
||||||
|
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
import static org.hamcrest.Matchers.is;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
public class OpenAiTextEmbeddingPayloadTests extends SageMakerSchemaPayloadTestCase<OpenAiTextEmbeddingPayload> {
|
||||||
|
@Override
|
||||||
|
protected OpenAiTextEmbeddingPayload payload() {
|
||||||
|
return new OpenAiTextEmbeddingPayload();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected String expectedApi() {
|
||||||
|
return "openai";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Set<TaskType> expectedSupportedTaskTypes() {
|
||||||
|
return Set.of(TaskType.TEXT_EMBEDDING);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected SageMakerStoredServiceSchema randomApiServiceSettings() {
|
||||||
|
return SageMakerOpenAiServiceSettingsTests.randomApiServiceSettings();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected SageMakerStoredTaskSchema randomApiTaskSettings() {
|
||||||
|
return SageMakerOpenAiTaskSettingsTests.randomApiTaskSettings();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testAccept() {
|
||||||
|
assertThat(payload.accept(mock()), equalTo("application/json"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testContentType() {
|
||||||
|
assertThat(payload.contentType(mock()), equalTo("application/json"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testRequestWithSingleInput() throws Exception {
|
||||||
|
SageMakerModel model = mock();
|
||||||
|
when(model.apiServiceSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiServiceSettings(null, false));
|
||||||
|
when(model.apiTaskSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiTaskSettings((String) null));
|
||||||
|
var request = new SageMakerInferenceRequest(null, null, null, List.of("hello"), randomBoolean(), randomFrom(InputType.values()));
|
||||||
|
|
||||||
|
var sdkByes = payload.requestBytes(model, request);
|
||||||
|
assertSdkBytes(sdkByes, """
|
||||||
|
{"input":"hello"}""");
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testRequestWithArrayInput() throws Exception {
|
||||||
|
SageMakerModel model = mock();
|
||||||
|
when(model.apiServiceSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiServiceSettings(null, false));
|
||||||
|
when(model.apiTaskSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiTaskSettings((String) null));
|
||||||
|
var request = new SageMakerInferenceRequest(
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
List.of("hello", "there"),
|
||||||
|
randomBoolean(),
|
||||||
|
randomFrom(InputType.values())
|
||||||
|
);
|
||||||
|
|
||||||
|
var sdkByes = payload.requestBytes(model, request);
|
||||||
|
assertSdkBytes(sdkByes, """
|
||||||
|
{"input":["hello","there"]}""");
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testRequestWithDimensionsNotSetByUserIgnoreDimensions() throws Exception {
|
||||||
|
SageMakerModel model = mock();
|
||||||
|
when(model.apiServiceSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiServiceSettings(123, false));
|
||||||
|
when(model.apiTaskSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiTaskSettings((String) null));
|
||||||
|
var request = new SageMakerInferenceRequest(
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
List.of("hello", "there"),
|
||||||
|
randomBoolean(),
|
||||||
|
randomFrom(InputType.values())
|
||||||
|
);
|
||||||
|
|
||||||
|
var sdkByes = payload.requestBytes(model, request);
|
||||||
|
assertSdkBytes(sdkByes, """
|
||||||
|
{"input":["hello","there"]}""");
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testRequestWithOptionals() throws Exception {
|
||||||
|
SageMakerModel model = mock();
|
||||||
|
when(model.apiServiceSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiServiceSettings(1234, true));
|
||||||
|
when(model.apiTaskSettings()).thenReturn(new OpenAiTextEmbeddingPayload.ApiTaskSettings("user"));
|
||||||
|
var request = new SageMakerInferenceRequest("query", null, null, List.of("hello"), randomBoolean(), randomFrom(InputType.values()));
|
||||||
|
|
||||||
|
var sdkByes = payload.requestBytes(model, request);
|
||||||
|
assertSdkBytes(sdkByes, """
|
||||||
|
{"query":"query","input":"hello","user":"user","dimensions":1234}""");
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testResponse() throws Exception {
|
||||||
|
String responseJson = """
|
||||||
|
{
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": 0,
|
||||||
|
"embedding": [
|
||||||
|
0.014539449,
|
||||||
|
-0.015288644
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"model": "text-embedding-ada-002-v2",
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 8,
|
||||||
|
"total_tokens": 8
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""";
|
||||||
|
var invokeEndpointResponse = InvokeEndpointResponse.builder()
|
||||||
|
.body(SdkBytes.fromString(responseJson, StandardCharsets.UTF_8))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
var textEmbeddingFloatResults = payload.responseBody(mock(), invokeEndpointResponse);
|
||||||
|
|
||||||
|
assertThat(
|
||||||
|
textEmbeddingFloatResults.embeddings(),
|
||||||
|
is(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.014539449F, -0.015288644F })))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,51 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.sagemaker.schema.openai;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.ValidationException;
|
||||||
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
|
import org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
import static org.hamcrest.Matchers.not;
|
||||||
|
import static org.hamcrest.Matchers.sameInstance;
|
||||||
|
|
||||||
|
public class SageMakerOpenAiServiceSettingsTests extends InferenceSettingsTestCase<OpenAiTextEmbeddingPayload.ApiServiceSettings> {
|
||||||
|
@Override
|
||||||
|
protected OpenAiTextEmbeddingPayload.ApiServiceSettings fromMutableMap(Map<String, Object> mutableMap) {
|
||||||
|
var validationException = new ValidationException();
|
||||||
|
var settings = OpenAiTextEmbeddingPayload.ApiServiceSettings.fromMap(mutableMap, validationException);
|
||||||
|
validationException.throwIfValidationErrorsExist();
|
||||||
|
return settings;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Writeable.Reader<OpenAiTextEmbeddingPayload.ApiServiceSettings> instanceReader() {
|
||||||
|
return OpenAiTextEmbeddingPayload.ApiServiceSettings::new;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected OpenAiTextEmbeddingPayload.ApiServiceSettings createTestInstance() {
|
||||||
|
return randomApiServiceSettings();
|
||||||
|
}
|
||||||
|
|
||||||
|
static OpenAiTextEmbeddingPayload.ApiServiceSettings randomApiServiceSettings() {
|
||||||
|
var dimensions = randomBoolean() ? randomIntBetween(1, 100) : null;
|
||||||
|
return new OpenAiTextEmbeddingPayload.ApiServiceSettings(dimensions, dimensions != null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testDimensionsSetByUser() {
|
||||||
|
var expectedDimensions = randomIntBetween(1, 100);
|
||||||
|
var dimensionlessSettings = new OpenAiTextEmbeddingPayload.ApiServiceSettings(null, false);
|
||||||
|
var updatedSettings = dimensionlessSettings.updateModelWithEmbeddingDetails(expectedDimensions);
|
||||||
|
assertThat(updatedSettings, not(sameInstance(dimensionlessSettings)));
|
||||||
|
assertThat(updatedSettings.dimensions(), equalTo(expectedDimensions));
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,38 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.sagemaker.schema.openai;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.ValidationException;
|
||||||
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
|
import org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class SageMakerOpenAiTaskSettingsTests extends InferenceSettingsTestCase<OpenAiTextEmbeddingPayload.ApiTaskSettings> {
|
||||||
|
@Override
|
||||||
|
protected OpenAiTextEmbeddingPayload.ApiTaskSettings fromMutableMap(Map<String, Object> mutableMap) {
|
||||||
|
var validationException = new ValidationException();
|
||||||
|
var settings = OpenAiTextEmbeddingPayload.ApiTaskSettings.fromMap(mutableMap, validationException);
|
||||||
|
validationException.throwIfValidationErrorsExist();
|
||||||
|
return settings;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Writeable.Reader<OpenAiTextEmbeddingPayload.ApiTaskSettings> instanceReader() {
|
||||||
|
return OpenAiTextEmbeddingPayload.ApiTaskSettings::new;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected OpenAiTextEmbeddingPayload.ApiTaskSettings createTestInstance() {
|
||||||
|
return randomApiTaskSettings();
|
||||||
|
}
|
||||||
|
|
||||||
|
static OpenAiTextEmbeddingPayload.ApiTaskSettings randomApiTaskSettings() {
|
||||||
|
return new OpenAiTextEmbeddingPayload.ApiTaskSettings(randomOptionalString());
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue