diff --git a/docs/changelog/126493.yaml b/docs/changelog/126493.yaml new file mode 100644 index 000000000000..84a54b105882 --- /dev/null +++ b/docs/changelog/126493.yaml @@ -0,0 +1,6 @@ +pr: 126493 +summary: Bedrock Cohere Task Settings Support +area: Machine Learning +type: enhancement +issues: + - 126156 diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 2ad225417ce8..c730a0fe4cf0 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -157,6 +157,7 @@ public class TransportVersions { public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14); public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15); public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16); + public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00); public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01); public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02); @@ -215,6 +216,7 @@ public class TransportVersions { public static final TransportVersion ADD_PROJECT_ID_TO_DSL_ERROR_INFO = def(9_046_0_00); public static final TransportVersion SEMANTIC_TEXT_CHUNKING_CONFIG = def(9_047_00_0); public static final TransportVersion REPO_ANALYSIS_COPY_BLOB = def(9_048_00_0); + public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS = def(9_049_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index a0691fcfe517..63d9d8a3bd9d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -41,6 +41,7 @@ import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.Alib import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionTaskSettings; import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionTaskSettings; import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionServiceSettings; @@ -173,8 +174,13 @@ public class InferenceNamedWriteablesProvider { AmazonBedrockEmbeddingsServiceSettings::new ) ); - - // no task settings for Amazon Bedrock Embeddings + namedWriteables.add( + new NamedWriteableRegistry.Entry( + TaskSettings.class, + AmazonBedrockEmbeddingsTaskSettings.NAME, + AmazonBedrockEmbeddingsTaskSettings::new + ) + ); namedWriteables.add( new NamedWriteableRegistry.Entry( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockConstants.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockConstants.java index 1755dac2ac13..b9e3a237a3cc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockConstants.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockConstants.java @@ -19,6 +19,8 @@ public class AmazonBedrockConstants { public static final String TOP_K_FIELD = "top_k"; public static final String MAX_NEW_TOKENS_FIELD = "max_new_tokens"; + public static final String TRUNCATE_FIELD = "truncate"; + public static final Double MIN_TEMPERATURE_TOP_P_TOP_K_VALUE = 0.0; public static final Double MAX_TEMPERATURE_TOP_P_TOP_K_VALUE = 1.0; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 93e0033d88c6..b0b4b7eed1a7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -303,6 +303,7 @@ public class AmazonBedrockService extends SenderService { context ); checkProviderForTask(TaskType.TEXT_EMBEDDING, model.provider()); + checkTaskSettingsForTextEmbeddingModel(model); return model; } case COMPLETION -> { @@ -368,6 +369,17 @@ public class AmazonBedrockService extends SenderService { } } + private static void checkTaskSettingsForTextEmbeddingModel(AmazonBedrockEmbeddingsModel model) { + if (model.provider() != AmazonBedrockProvider.COHERE && model.getTaskSettings().cohereTruncation() != null) { + throw new ElasticsearchStatusException( + "The [{}] task type for provider [{}] does not allow [truncate] field", + RestStatus.BAD_REQUEST, + TaskType.TEXT_EMBEDDING, + model.provider() + ); + } + } + private static void checkChatCompletionProviderForTopKParameter(AmazonBedrockChatCompletionModel model) { var taskSettings = model.getTaskSettings(); if (taskSettings.topK() != null) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModel.java index 30703f584739..f7874304b457 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModel.java @@ -7,14 +7,11 @@ package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings; -import org.elasticsearch.common.ValidationException; import org.elasticsearch.inference.ChunkingSettings; -import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.ServiceSettings; -import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; @@ -28,10 +25,8 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel { public static AmazonBedrockEmbeddingsModel of(AmazonBedrockEmbeddingsModel embeddingsModel, Map taskSettings) { if (taskSettings != null && taskSettings.isEmpty() == false) { - // no task settings allowed - var validationException = new ValidationException(); - validationException.addValidationError("Amazon Bedrock embeddings model cannot have task settings"); - throw validationException; + var updatedTaskSettings = embeddingsModel.getTaskSettings().updatedTaskSettings(taskSettings); + return new AmazonBedrockEmbeddingsModel(embeddingsModel, updatedTaskSettings); } return embeddingsModel; @@ -52,7 +47,7 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel { taskType, service, AmazonBedrockEmbeddingsServiceSettings.fromMap(serviceSettings, context), - new EmptyTaskSettings(), + AmazonBedrockEmbeddingsTaskSettings.fromMap(taskSettings), chunkingSettings, AwsSecretSettings.fromMap(secretSettings) ); @@ -63,12 +58,12 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel { TaskType taskType, String service, AmazonBedrockEmbeddingsServiceSettings serviceSettings, - TaskSettings taskSettings, + AmazonBedrockEmbeddingsTaskSettings taskSettings, ChunkingSettings chunkingSettings, AwsSecretSettings secrets ) { super( - new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings(), chunkingSettings), + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings), new ModelSecrets(secrets) ); } @@ -77,6 +72,10 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel { super(model, serviceSettings); } + public AmazonBedrockEmbeddingsModel(Model model, AmazonBedrockEmbeddingsTaskSettings taskSettings) { + super(model, taskSettings); + } + @Override public ExecutableAction accept(AmazonBedrockActionVisitor creator, Map taskSettings) { return creator.create(this, taskSettings); @@ -86,4 +85,9 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel { public AmazonBedrockEmbeddingsServiceSettings getServiceSettings() { return (AmazonBedrockEmbeddingsServiceSettings) super.getServiceSettings(); } + + @Override + public AmazonBedrockEmbeddingsTaskSettings getTaskSettings() { + return (AmazonBedrockEmbeddingsTaskSettings) super.getTaskSettings(); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsTaskSettings.java new file mode 100644 index 000000000000..bb0a8a3348ad --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsTaskSettings.java @@ -0,0 +1,98 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TRUNCATE_FIELD; + +public record AmazonBedrockEmbeddingsTaskSettings(@Nullable CohereTruncation cohereTruncation) implements TaskSettings { + public static final AmazonBedrockEmbeddingsTaskSettings EMPTY = new AmazonBedrockEmbeddingsTaskSettings((CohereTruncation) null); + public static final String NAME = "amazon_bedrock_embeddings_task_settings"; + + public static AmazonBedrockEmbeddingsTaskSettings fromMap(Map map) { + if (map == null || map.isEmpty()) { + return EMPTY; + } + + ValidationException validationException = new ValidationException(); + + var cohereTruncation = extractOptionalEnum( + map, + TRUNCATE_FIELD, + ModelConfigurations.TASK_SETTINGS, + CohereTruncation::fromString, + CohereTruncation.ALL, + validationException + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AmazonBedrockEmbeddingsTaskSettings(cohereTruncation); + } + + public AmazonBedrockEmbeddingsTaskSettings(StreamInput in) throws IOException { + this(in.readOptionalEnum(CohereTruncation.class)); + } + + @Override + public boolean isEmpty() { + return cohereTruncation() == null; + } + + @Override + public AmazonBedrockEmbeddingsTaskSettings updatedTaskSettings(Map newSettings) { + var newTaskSettings = fromMap(new HashMap<>(newSettings)); + + return new AmazonBedrockEmbeddingsTaskSettings(firstNonNullOrNull(newTaskSettings.cohereTruncation(), cohereTruncation())); + } + + private static T firstNonNullOrNull(T first, T second) { + return first != null ? first : second; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.AMAZON_BEDROCK_TASK_SETTINGS; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalEnum(cohereTruncation()); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (cohereTruncation != null) { + builder.field(TRUNCATE_FIELD, cohereTruncation); + } + return builder.endObject(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/embeddings/AmazonBedrockCohereEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/embeddings/AmazonBedrockCohereEmbeddingsRequestEntity.java index d0a11f5ebc50..78ede2315464 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/embeddings/AmazonBedrockCohereEmbeddingsRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/embeddings/AmazonBedrockCohereEmbeddingsRequestEntity.java @@ -11,6 +11,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InputType; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettings; import java.io.IOException; import java.util.List; @@ -18,7 +19,11 @@ import java.util.Objects; import static org.elasticsearch.inference.InputType.invalidInputTypeMessage; -public record AmazonBedrockCohereEmbeddingsRequestEntity(List input, @Nullable InputType inputType) implements ToXContentObject { +public record AmazonBedrockCohereEmbeddingsRequestEntity( + List input, + @Nullable InputType inputType, + AmazonBedrockEmbeddingsTaskSettings taskSettings +) implements ToXContentObject { private static final String TEXTS_FIELD = "texts"; private static final String INPUT_TYPE_FIELD = "input_type"; @@ -26,9 +31,11 @@ public record AmazonBedrockCohereEmbeddingsRequestEntity(List input, @Nu private static final String SEARCH_QUERY = "search_query"; private static final String CLUSTERING = "clustering"; private static final String CLASSIFICATION = "classification"; + private static final String TRUNCATE = "truncate"; public AmazonBedrockCohereEmbeddingsRequestEntity { Objects.requireNonNull(input); + Objects.requireNonNull(taskSettings); } @Override @@ -43,6 +50,10 @@ public record AmazonBedrockCohereEmbeddingsRequestEntity(List input, @Nu builder.field(INPUT_TYPE_FIELD, SEARCH_DOCUMENT); } + if (taskSettings.cohereTruncation() != null) { + builder.field(TRUNCATE, taskSettings.cohereTruncation().name()); + } + builder.endObject(); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/embeddings/AmazonBedrockEmbeddingsEntityFactory.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/embeddings/AmazonBedrockEmbeddingsEntityFactory.java index 0bd1b191f050..860934961ddc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/embeddings/AmazonBedrockEmbeddingsEntityFactory.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/embeddings/AmazonBedrockEmbeddingsEntityFactory.java @@ -39,7 +39,7 @@ public final class AmazonBedrockEmbeddingsEntityFactory { return new AmazonBedrockTitanEmbeddingsRequestEntity(truncatedInput.get(0)); } case COHERE -> { - return new AmazonBedrockCohereEmbeddingsRequestEntity(truncatedInput, inputType); + return new AmazonBedrockCohereEmbeddingsRequestEntity(truncatedInput, inputType, model.getTaskSettings()); } default -> { return null; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/embeddings/AmazonBedrockEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/embeddings/AmazonBedrockEmbeddingsRequest.java index b8a676001995..006c2b681cc3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/embeddings/AmazonBedrockEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/embeddings/AmazonBedrockEmbeddingsRequest.java @@ -76,6 +76,9 @@ public class AmazonBedrockEmbeddingsRequest extends AmazonBedrockRequest { @Override public Request truncate() { + if (provider == AmazonBedrockProvider.COHERE) { + return this; // Cohere has its own truncation logic + } var truncatedInput = truncator.truncate(truncationResult.input()); return new AmazonBedrockEmbeddingsRequest(truncator, truncatedInput, embeddingsModel, requestEntity, timeout); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java index e7c9a0247bb1..16164b2d887f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.services.cohere; +import java.util.EnumSet; import java.util.Locale; /** @@ -31,6 +32,8 @@ public enum CohereTruncation { */ END; + public static final EnumSet ALL = EnumSet.allOf(CohereTruncation.class); + @Override public String toString() { return name().toLowerCase(Locale.ROOT); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java index 88bb50def78f..09d78708b688 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java @@ -20,7 +20,6 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; import java.io.IOException; -import java.util.EnumSet; import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -63,7 +62,7 @@ public class CohereEmbeddingsTaskSettings implements TaskSettings { TRUNCATE, ModelConfigurations.TASK_SETTINGS, CohereTruncation::fromString, - EnumSet.allOf(CohereTruncation.class), + CohereTruncation.ALL, validationException ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index 688bd3d4afc5..d34e8b3408fe 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -12,6 +12,7 @@ import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeExcept import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionTestUtils; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.bytes.BytesArray; @@ -51,6 +52,8 @@ import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.Amazo import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModel; import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettingsTests; +import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; import org.hamcrest.CoreMatchers; import org.hamcrest.Matchers; @@ -105,7 +108,7 @@ public class AmazonBedrockServiceTests extends ESTestCase { public void testParseRequestConfig_CreatesAnAmazonBedrockModel() throws IOException { try (var service = createAmazonBedrockService()) { - ActionListener modelVerificationListener = ActionListener.wrap(model -> { + ActionListener modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> { assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); @@ -115,7 +118,7 @@ public class AmazonBedrockServiceTests extends ESTestCase { var secretSettings = (AwsSecretSettings) model.getSecretSettings(); assertThat(secretSettings.accessKey().toString(), is("access")); assertThat(secretSettings.secretKey().toString(), is("secret")); - }, exception -> fail("Unexpected exception: " + exception)); + }); service.parseRequestConfig( "id", @@ -130,15 +133,62 @@ public class AmazonBedrockServiceTests extends ESTestCase { } } + public void testParseRequestConfig_CreatesACohereModel() throws IOException { + try (var service = createAmazonBedrockService()) { + ActionListener modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> { + assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); + + var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); + assertThat(settings.region(), is("region")); + assertThat(settings.modelId(), is("model")); + assertThat(settings.provider(), is(AmazonBedrockProvider.COHERE)); + var secretSettings = (AwsSecretSettings) model.getSecretSettings(); + assertThat(secretSettings.accessKey().toString(), is("access")); + assertThat(secretSettings.secretKey().toString(), is("secret")); + }); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + createEmbeddingsRequestSettingsMap("region", "model", "cohere", null, null, null, null), + AmazonBedrockEmbeddingsTaskSettingsTests.mutableMap("truncate", CohereTruncation.START), + getAmazonBedrockSecretSettingsMap("access", "secret") + ), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_CohereSettingsWithNoCohereModel() throws IOException { + try (var service = createAmazonBedrockService()) { + ActionListener modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat( + exception.getMessage(), + is("The [text_embedding] task type for provider [amazontitan] does not allow [truncate] field") + ); + }); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, null, null, null), + AmazonBedrockEmbeddingsTaskSettingsTests.mutableMap("truncate", CohereTruncation.START), + getAmazonBedrockSecretSettingsMap("access", "secret") + ), + modelVerificationListener + ); + } + } + public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException { try (var service = createAmazonBedrockService()) { - ActionListener modelVerificationListener = ActionListener.wrap( - model -> fail("Expected exception, but got model: " + model), - exception -> { - assertThat(exception, instanceOf(ElasticsearchStatusException.class)); - assertThat(exception.getMessage(), is("The [amazonbedrock] service does not support task type [sparse_embedding]")); - } - ); + ActionListener modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), is("The [amazonbedrock] service does not support task type [sparse_embedding]")); + }); service.parseRequestConfig( "id", @@ -247,13 +297,10 @@ public class AmazonBedrockServiceTests extends ESTestCase { public void testCreateModel_ForEmbeddingsTask_InvalidProvider() throws IOException { try (var service = createAmazonBedrockService()) { - ActionListener modelVerificationListener = ActionListener.wrap( - model -> fail("Expected exception, but got model: " + model), - exception -> { - assertThat(exception, instanceOf(ElasticsearchStatusException.class)); - assertThat(exception.getMessage(), is("The [text_embedding] task type for provider [anthropic] is not available")); - } - ); + ActionListener modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), is("The [text_embedding] task type for provider [anthropic] is not available")); + }); service.parseRequestConfig( "id", @@ -270,13 +317,10 @@ public class AmazonBedrockServiceTests extends ESTestCase { public void testCreateModel_TopKParameter_NotAvailable() throws IOException { try (var service = createAmazonBedrockService()) { - ActionListener modelVerificationListener = ActionListener.wrap( - model -> fail("Expected exception, but got model: " + model), - exception -> { - assertThat(exception, instanceOf(ElasticsearchStatusException.class)); - assertThat(exception.getMessage(), is("The [top_k] task parameter is not available for provider [amazontitan]")); - } - ); + ActionListener modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), is("The [top_k] task parameter is not available for provider [amazontitan]")); + }); service.parseRequestConfig( "id", @@ -301,16 +345,13 @@ public class AmazonBedrockServiceTests extends ESTestCase { config.put("extra_key", "value"); - ActionListener modelVerificationListener = ActionListener.wrap( - model -> fail("Expected exception, but got model: " + model), - exception -> { - assertThat(exception, instanceOf(ElasticsearchStatusException.class)); - assertThat( - exception.getMessage(), - is("Model configuration contains settings [{extra_key=value}] unknown to the [amazonbedrock] service") - ); - } - ); + ActionListener modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat( + exception.getMessage(), + is("Model configuration contains settings [{extra_key=value}] unknown to the [amazonbedrock] service") + ); + }); service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); } @@ -323,9 +364,7 @@ public class AmazonBedrockServiceTests extends ESTestCase { var config = getRequestConfigMap(serviceSettings, Map.of(), getAmazonBedrockSecretSettingsMap("access", "secret")); - ActionListener modelVerificationListener = ActionListener.wrap((model) -> { - fail("Expected exception, but got model: " + model); - }, e -> { + ActionListener modelVerificationListener = ActionTestUtils.assertNoSuccessListener(e -> { assertThat(e, instanceOf(ElasticsearchStatusException.class)); assertThat( e.getMessage(), @@ -347,9 +386,7 @@ public class AmazonBedrockServiceTests extends ESTestCase { var config = getRequestConfigMap(settingsMap, taskSettingsMap, secretSettingsMap); - ActionListener modelVerificationListener = ActionListener.wrap((model) -> { - fail("Expected exception, but got model: " + model); - }, e -> { + ActionListener modelVerificationListener = ActionTestUtils.assertNoSuccessListener(e -> { assertThat(e, instanceOf(ElasticsearchStatusException.class)); assertThat( e.getMessage(), @@ -371,9 +408,7 @@ public class AmazonBedrockServiceTests extends ESTestCase { var config = getRequestConfigMap(settingsMap, taskSettingsMap, secretSettingsMap); - ActionListener modelVerificationListener = ActionListener.wrap((model) -> { - fail("Expected exception, but got model: " + model); - }, e -> { + ActionListener modelVerificationListener = ActionTestUtils.assertNoSuccessListener(e -> { assertThat(e, instanceOf(ElasticsearchStatusException.class)); assertThat( e.getMessage(), @@ -387,7 +422,7 @@ public class AmazonBedrockServiceTests extends ESTestCase { public void testParseRequestConfig_MovesModel() throws IOException { try (var service = createAmazonBedrockService()) { - ActionListener modelVerificationListener = ActionListener.wrap(model -> { + ActionListener modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> { assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); @@ -397,7 +432,7 @@ public class AmazonBedrockServiceTests extends ESTestCase { var secretSettings = (AwsSecretSettings) model.getSecretSettings(); assertThat(secretSettings.accessKey().toString(), is("access")); assertThat(secretSettings.secretKey().toString(), is("secret")); - }, exception -> fail("Unexpected exception: " + exception)); + }); service.parseRequestConfig( "id", @@ -414,7 +449,7 @@ public class AmazonBedrockServiceTests extends ESTestCase { public void testParseRequestConfig_CreatesAnAmazonBedrockEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { try (var service = createAmazonBedrockService()) { - ActionListener modelVerificationListener = ActionListener.wrap(model -> { + ActionListener modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> { assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); @@ -425,7 +460,7 @@ public class AmazonBedrockServiceTests extends ESTestCase { assertThat(secretSettings.accessKey().toString(), is("access")); assertThat(secretSettings.secretKey().toString(), is("secret")); assertThat(model.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - }, exception -> fail("Unexpected exception: " + exception)); + }); service.parseRequestConfig( "id", @@ -443,7 +478,7 @@ public class AmazonBedrockServiceTests extends ESTestCase { public void testParseRequestConfig_CreatesAnAmazonBedrockEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { try (var service = createAmazonBedrockService()) { - ActionListener modelVerificationListener = ActionListener.wrap(model -> { + ActionListener modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> { assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); @@ -454,7 +489,7 @@ public class AmazonBedrockServiceTests extends ESTestCase { assertThat(secretSettings.accessKey().toString(), is("access")); assertThat(secretSettings.secretKey().toString(), is("secret")); assertThat(model.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - }, exception -> fail("Unexpected exception: " + exception)); + }); service.parseRequestConfig( "id", @@ -471,13 +506,10 @@ public class AmazonBedrockServiceTests extends ESTestCase { public void testCreateModel_ForEmbeddingsTask_DimensionsIsNotAllowed() throws IOException { try (var service = createAmazonBedrockService()) { - ActionListener modelVerificationListener = ActionListener.wrap( - model -> fail("Expected exception, but got model: " + model), - exception -> { - assertThat(exception, instanceOf(ValidationException.class)); - assertThat(exception.getMessage(), containsString("[service_settings] does not allow the setting [dimensions]")); - } - ); + ActionListener modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> { + assertThat(exception, instanceOf(ValidationException.class)); + assertThat(exception.getMessage(), containsString("[service_settings] does not allow the setting [dimensions]")); + }); service.parseRequestConfig( "id", @@ -497,7 +529,7 @@ public class AmazonBedrockServiceTests extends ESTestCase { var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); - var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap); var model = service.parsePersistedConfigWithSecrets( "id", @@ -525,7 +557,7 @@ public class AmazonBedrockServiceTests extends ESTestCase { var persistedConfig = getPersistedConfigMap( settingsMap, - new HashMap(Map.of()), + new HashMap<>(Map.of()), createRandomChunkingSettingsMap(), secretSettingsMap ); @@ -607,7 +639,7 @@ public class AmazonBedrockServiceTests extends ESTestCase { var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); - var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap); persistedConfig.config().put("extra_key", "value"); var model = service.parsePersistedConfigWithSecrets( @@ -635,7 +667,7 @@ public class AmazonBedrockServiceTests extends ESTestCase { var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); secretSettingsMap.put("extra_key", "value"); - var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap); var model = service.parsePersistedConfigWithSecrets( "id", @@ -661,7 +693,7 @@ public class AmazonBedrockServiceTests extends ESTestCase { var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); - var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap); persistedConfig.secrets().put("extra_key", "value"); var model = service.parsePersistedConfigWithSecrets( @@ -689,7 +721,7 @@ public class AmazonBedrockServiceTests extends ESTestCase { settingsMap.put("extra_key", "value"); var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); - var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap); var model = service.parsePersistedConfigWithSecrets( "id", @@ -769,7 +801,7 @@ public class AmazonBedrockServiceTests extends ESTestCase { var persistedConfig = getPersistedConfigMap( settingsMap, - new HashMap(Map.of()), + new HashMap<>(Map.of()), createRandomChunkingSettingsMap(), secretSettingsMap ); @@ -792,7 +824,7 @@ public class AmazonBedrockServiceTests extends ESTestCase { var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); - var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap); var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); @@ -836,7 +868,7 @@ public class AmazonBedrockServiceTests extends ESTestCase { var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); - var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap); var thrownException = expectThrows( ElasticsearchStatusException.class, @@ -855,7 +887,7 @@ public class AmazonBedrockServiceTests extends ESTestCase { var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); - var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap); persistedConfig.config().put("extra_key", "value"); var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); @@ -876,7 +908,7 @@ public class AmazonBedrockServiceTests extends ESTestCase { settingsMap.put("extra_key", "value"); var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); - var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap); persistedConfig.config().put("extra_key", "value"); var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockCohereEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockCohereEmbeddingsRequestEntityTests.java index 508105824684..a0db4f97c31d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockCohereEmbeddingsRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockCohereEmbeddingsRequestEntityTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.services.amazonbedrock.request.AmazonBedrockJsonBuilder; import org.elasticsearch.xpack.inference.services.amazonbedrock.request.embeddings.AmazonBedrockCohereEmbeddingsRequestEntity; +import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; import java.io.IOException; import java.util.List; @@ -19,23 +20,46 @@ import static org.hamcrest.Matchers.is; public class AmazonBedrockCohereEmbeddingsRequestEntityTests extends ESTestCase { public void testRequestEntity_GeneratesExpectedJsonBody() throws IOException { - var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(List.of("test input"), InputType.CLASSIFICATION); + var entity = new AmazonBedrockCohereEmbeddingsRequestEntity( + List.of("test input"), + InputType.CLASSIFICATION, + AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings() + ); var builder = new AmazonBedrockJsonBuilder(entity); var result = builder.getStringContent(); assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"classification\"}")); } public void testRequestEntity_GeneratesExpectedJsonBody_WithInternalInputType() throws IOException { - var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(List.of("test input"), InputType.INTERNAL_SEARCH); + var entity = new AmazonBedrockCohereEmbeddingsRequestEntity( + List.of("test input"), + InputType.INTERNAL_SEARCH, + AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings() + ); var builder = new AmazonBedrockJsonBuilder(entity); var result = builder.getStringContent(); assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"search_query\"}")); } public void testRequestEntity_GeneratesExpectedJsonBody_WithoutInputType() throws IOException { - var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(List.of("test input"), null); + var entity = new AmazonBedrockCohereEmbeddingsRequestEntity( + List.of("test input"), + null, + AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings() + ); var builder = new AmazonBedrockJsonBuilder(entity); var result = builder.getStringContent(); assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"search_document\"}")); } + + public void testRequestEntity_GeneratesExpectedJsonBody_WithCohereTruncation() throws IOException { + var entity = new AmazonBedrockCohereEmbeddingsRequestEntity( + List.of("test input"), + null, + new AmazonBedrockEmbeddingsTaskSettings(CohereTruncation.START) + ); + var builder = new AmazonBedrockJsonBuilder(entity); + var result = builder.getStringContent(); + assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"search_document\",\"truncate\":\"START\"}")); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModelTests.java index e9e31cf0ccca..be7796a4b7d8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModelTests.java @@ -7,12 +7,9 @@ package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings; -import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.ChunkingSettings; -import org.elasticsearch.inference.EmptyTaskSettings; -import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; @@ -20,19 +17,23 @@ import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings; import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import java.util.Map; +import java.io.IOException; -import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; public class AmazonBedrockEmbeddingsModelTests extends ESTestCase { - public void testCreateModel_withTaskSettings_shouldFail() { - var baseModel = createModel("id", "region", "model", AmazonBedrockProvider.AMAZONTITAN, "accesskey", "secretkey"); - var thrownException = assertThrows( - ValidationException.class, - () -> AmazonBedrockEmbeddingsModel.of(baseModel, Map.of("testkey", "testvalue")) - ); - assertThat(thrownException.getMessage(), containsString("Amazon Bedrock embeddings model cannot have task settings")); + public void testCreateModel_withTaskSettingsOverride() throws IOException { + var baseTaskSettings = AmazonBedrockEmbeddingsTaskSettingsTests.randomTaskSettings(); + var baseModel = createModel("id", "region", "model", AmazonBedrockProvider.AMAZONTITAN, "accesskey", "secretkey", baseTaskSettings); + + var overrideTaskSettings = AmazonBedrockEmbeddingsTaskSettingsTests.mutateTaskSettings(baseTaskSettings); + var overrideTaskSettingsMap = AmazonBedrockEmbeddingsTaskSettingsTests.toMap(overrideTaskSettings); + + var overriddenModel = AmazonBedrockEmbeddingsModel.of(baseModel, overrideTaskSettingsMap); + assertThat(overriddenModel.getTaskSettings(), equalTo(overrideTaskSettings)); + assertThat(overriddenModel.getTaskSettings(), not(equalTo(baseTaskSettings))); } // model creation only - no tests to define, but we want to have the public createModel @@ -46,7 +47,15 @@ public class AmazonBedrockEmbeddingsModelTests extends ESTestCase { String accessKey, String secretKey ) { - return createModel(inferenceId, region, model, provider, null, false, null, null, new RateLimitSettings(240), accessKey, secretKey); + return createModel( + inferenceId, + region, + model, + provider, + accessKey, + secretKey, + AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings() + ); } public static AmazonBedrockEmbeddingsModel createModel( @@ -56,9 +65,22 @@ public class AmazonBedrockEmbeddingsModelTests extends ESTestCase { AmazonBedrockProvider provider, String accessKey, String secretKey, - InputType inputType + AmazonBedrockEmbeddingsTaskSettings taskSettings ) { - return createModel(inferenceId, region, model, provider, null, false, null, null, new RateLimitSettings(240), accessKey, secretKey); + return createModel( + inferenceId, + region, + model, + provider, + null, + false, + null, + null, + new RateLimitSettings(240), + accessKey, + secretKey, + taskSettings + ); } public static AmazonBedrockEmbeddingsModel createModel( @@ -114,7 +136,7 @@ public class AmazonBedrockEmbeddingsModelTests extends ESTestCase { similarity, rateLimitSettings ), - new EmptyTaskSettings(), + AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings(), chunkingSettings, new AwsSecretSettings(new SecureString(accessKey), new SecureString(secretKey)) ); @@ -132,6 +154,36 @@ public class AmazonBedrockEmbeddingsModelTests extends ESTestCase { RateLimitSettings rateLimitSettings, String accessKey, String secretKey + ) { + return createModel( + inferenceId, + region, + model, + provider, + dimensions, + dimensionsSetByUser, + maxTokens, + similarity, + rateLimitSettings, + accessKey, + secretKey, + AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings() + ); + } + + public static AmazonBedrockEmbeddingsModel createModel( + String inferenceId, + String region, + String model, + AmazonBedrockProvider provider, + @Nullable Integer dimensions, + boolean dimensionsSetByUser, + @Nullable Integer maxTokens, + @Nullable SimilarityMeasure similarity, + RateLimitSettings rateLimitSettings, + String accessKey, + String secretKey, + AmazonBedrockEmbeddingsTaskSettings taskSettings ) { return new AmazonBedrockEmbeddingsModel( inferenceId, @@ -147,7 +199,7 @@ public class AmazonBedrockEmbeddingsModelTests extends ESTestCase { similarity, rateLimitSettings ), - new EmptyTaskSettings(), + taskSettings, null, new AwsSecretSettings(new SecureString(accessKey), new SecureString(secretKey)) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsTaskSettingsTests.java new file mode 100644 index 000000000000..3fc76743cc87 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsTaskSettingsTests.java @@ -0,0 +1,112 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.InputTypeTests.randomWithoutUnspecified; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TRUNCATE_FIELD; +import static org.hamcrest.Matchers.equalTo; + +public class AmazonBedrockEmbeddingsTaskSettingsTests extends AbstractBWCWireSerializationTestCase { + + public static AmazonBedrockEmbeddingsTaskSettings emptyTaskSettings() { + return AmazonBedrockEmbeddingsTaskSettings.EMPTY; + } + + public static AmazonBedrockEmbeddingsTaskSettings randomTaskSettings() { + var inputType = randomBoolean() ? randomWithoutUnspecified() : null; + var truncation = randomBoolean() ? randomFrom(CohereTruncation.values()) : null; + return new AmazonBedrockEmbeddingsTaskSettings(truncation); + } + + public static AmazonBedrockEmbeddingsTaskSettings mutateTaskSettings(AmazonBedrockEmbeddingsTaskSettings instance) { + return randomValueOtherThanMany( + v -> Objects.equals(instance, v) || (instance.cohereTruncation() != null && v.cohereTruncation() == null), + AmazonBedrockEmbeddingsTaskSettingsTests::randomTaskSettings + ); + } + + @Override + protected AmazonBedrockEmbeddingsTaskSettings mutateInstanceForVersion( + AmazonBedrockEmbeddingsTaskSettings instance, + TransportVersion version + ) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return AmazonBedrockEmbeddingsTaskSettings::new; + } + + @Override + protected AmazonBedrockEmbeddingsTaskSettings createTestInstance() { + return randomTaskSettings(); + } + + @Override + protected AmazonBedrockEmbeddingsTaskSettings mutateInstance(AmazonBedrockEmbeddingsTaskSettings instance) throws IOException { + return mutateTaskSettings(instance); + } + + public void testEmpty() { + assertTrue(emptyTaskSettings().isEmpty()); + assertTrue(AmazonBedrockEmbeddingsTaskSettings.fromMap(null).isEmpty()); + assertTrue(AmazonBedrockEmbeddingsTaskSettings.fromMap(Map.of()).isEmpty()); + } + + public static Map mutableMap(String key, Enum value) { + return new HashMap<>(Map.of(key, value.toString())); + } + + public void testValidCohereTruncations() { + for (var expectedCohereTruncation : CohereTruncation.ALL) { + var map = mutableMap(TRUNCATE_FIELD, expectedCohereTruncation); + var taskSettings = AmazonBedrockEmbeddingsTaskSettings.fromMap(map); + assertFalse(taskSettings.isEmpty()); + assertThat(taskSettings.cohereTruncation(), equalTo(expectedCohereTruncation)); + } + } + + public void testGarbageCohereTruncations() { + var map = new HashMap(Map.of(TRUNCATE_FIELD, "oiuesoirtuoawoeirha")); + assertThrows(ValidationException.class, () -> AmazonBedrockEmbeddingsTaskSettings.fromMap(map)); + } + + public void testXContent() throws IOException { + var taskSettings = randomTaskSettings(); + var taskSettingsAsMap = toMap(taskSettings); + var roundTripTaskSettings = AmazonBedrockEmbeddingsTaskSettings.fromMap(new HashMap<>(taskSettingsAsMap)); + assertThat(roundTripTaskSettings, equalTo(taskSettings)); + } + + public static Map toMap(AmazonBedrockEmbeddingsTaskSettings taskSettings) throws IOException { + try (var builder = JsonXContent.contentBuilder()) { + taskSettings.toXContent(builder, ToXContent.EMPTY_PARAMS); + var taskSettingsBytes = Strings.toString(builder).getBytes(StandardCharsets.UTF_8); + try (var parser = JsonXContent.jsonXContent.createParser(XContentParserConfiguration.EMPTY, taskSettingsBytes)) { + return parser.map(); + } + } + } +}