[ML] Bedrock Cohere Task Settings Support (#126493)

Add support for Cohere Task Settings and Truncate, through
the Amazon Bedrock provider integration.

Task Settings can now be passed bother during Inference endpoint
creation and Inference POST requests.

Close #126156
This commit is contained in:
Pat Whelan 2025-04-09 15:34:05 -04:00 committed by GitHub
parent 6d86b202ea
commit 6c6500ec3b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 468 additions and 102 deletions

View file

@ -0,0 +1,6 @@
pr: 126493
summary: Bedrock Cohere Task Settings Support
area: Machine Learning
type: enhancement
issues:
- 126156

View file

@ -157,6 +157,7 @@ public class TransportVersions {
public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14); public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14);
public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15); public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15);
public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16); public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16);
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01); public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02); public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
@ -215,6 +216,7 @@ public class TransportVersions {
public static final TransportVersion ADD_PROJECT_ID_TO_DSL_ERROR_INFO = def(9_046_0_00); public static final TransportVersion ADD_PROJECT_ID_TO_DSL_ERROR_INFO = def(9_046_0_00);
public static final TransportVersion SEMANTIC_TEXT_CHUNKING_CONFIG = def(9_047_00_0); public static final TransportVersion SEMANTIC_TEXT_CHUNKING_CONFIG = def(9_047_00_0);
public static final TransportVersion REPO_ANALYSIS_COPY_BLOB = def(9_048_00_0); public static final TransportVersion REPO_ANALYSIS_COPY_BLOB = def(9_048_00_0);
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS = def(9_049_00_0);
/* /*
* STOP! READ THIS FIRST! No, really, * STOP! READ THIS FIRST! No, really,

View file

@ -41,6 +41,7 @@ import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.Alib
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionTaskSettings; import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionTaskSettings; import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionServiceSettings;
@ -173,8 +174,13 @@ public class InferenceNamedWriteablesProvider {
AmazonBedrockEmbeddingsServiceSettings::new AmazonBedrockEmbeddingsServiceSettings::new
) )
); );
namedWriteables.add(
// no task settings for Amazon Bedrock Embeddings new NamedWriteableRegistry.Entry(
TaskSettings.class,
AmazonBedrockEmbeddingsTaskSettings.NAME,
AmazonBedrockEmbeddingsTaskSettings::new
)
);
namedWriteables.add( namedWriteables.add(
new NamedWriteableRegistry.Entry( new NamedWriteableRegistry.Entry(

View file

@ -19,6 +19,8 @@ public class AmazonBedrockConstants {
public static final String TOP_K_FIELD = "top_k"; public static final String TOP_K_FIELD = "top_k";
public static final String MAX_NEW_TOKENS_FIELD = "max_new_tokens"; public static final String MAX_NEW_TOKENS_FIELD = "max_new_tokens";
public static final String TRUNCATE_FIELD = "truncate";
public static final Double MIN_TEMPERATURE_TOP_P_TOP_K_VALUE = 0.0; public static final Double MIN_TEMPERATURE_TOP_P_TOP_K_VALUE = 0.0;
public static final Double MAX_TEMPERATURE_TOP_P_TOP_K_VALUE = 1.0; public static final Double MAX_TEMPERATURE_TOP_P_TOP_K_VALUE = 1.0;

View file

@ -303,6 +303,7 @@ public class AmazonBedrockService extends SenderService {
context context
); );
checkProviderForTask(TaskType.TEXT_EMBEDDING, model.provider()); checkProviderForTask(TaskType.TEXT_EMBEDDING, model.provider());
checkTaskSettingsForTextEmbeddingModel(model);
return model; return model;
} }
case COMPLETION -> { case COMPLETION -> {
@ -368,6 +369,17 @@ public class AmazonBedrockService extends SenderService {
} }
} }
private static void checkTaskSettingsForTextEmbeddingModel(AmazonBedrockEmbeddingsModel model) {
if (model.provider() != AmazonBedrockProvider.COHERE && model.getTaskSettings().cohereTruncation() != null) {
throw new ElasticsearchStatusException(
"The [{}] task type for provider [{}] does not allow [truncate] field",
RestStatus.BAD_REQUEST,
TaskType.TEXT_EMBEDDING,
model.provider()
);
}
}
private static void checkChatCompletionProviderForTopKParameter(AmazonBedrockChatCompletionModel model) { private static void checkChatCompletionProviderForTopKParameter(AmazonBedrockChatCompletionModel model) {
var taskSettings = model.getTaskSettings(); var taskSettings = model.getTaskSettings();
if (taskSettings.topK() != null) { if (taskSettings.topK() != null) {

View file

@ -7,14 +7,11 @@
package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings; package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.Model; import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings; import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
@ -28,10 +25,8 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel {
public static AmazonBedrockEmbeddingsModel of(AmazonBedrockEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings) { public static AmazonBedrockEmbeddingsModel of(AmazonBedrockEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings) {
if (taskSettings != null && taskSettings.isEmpty() == false) { if (taskSettings != null && taskSettings.isEmpty() == false) {
// no task settings allowed var updatedTaskSettings = embeddingsModel.getTaskSettings().updatedTaskSettings(taskSettings);
var validationException = new ValidationException(); return new AmazonBedrockEmbeddingsModel(embeddingsModel, updatedTaskSettings);
validationException.addValidationError("Amazon Bedrock embeddings model cannot have task settings");
throw validationException;
} }
return embeddingsModel; return embeddingsModel;
@ -52,7 +47,7 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel {
taskType, taskType,
service, service,
AmazonBedrockEmbeddingsServiceSettings.fromMap(serviceSettings, context), AmazonBedrockEmbeddingsServiceSettings.fromMap(serviceSettings, context),
new EmptyTaskSettings(), AmazonBedrockEmbeddingsTaskSettings.fromMap(taskSettings),
chunkingSettings, chunkingSettings,
AwsSecretSettings.fromMap(secretSettings) AwsSecretSettings.fromMap(secretSettings)
); );
@ -63,12 +58,12 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel {
TaskType taskType, TaskType taskType,
String service, String service,
AmazonBedrockEmbeddingsServiceSettings serviceSettings, AmazonBedrockEmbeddingsServiceSettings serviceSettings,
TaskSettings taskSettings, AmazonBedrockEmbeddingsTaskSettings taskSettings,
ChunkingSettings chunkingSettings, ChunkingSettings chunkingSettings,
AwsSecretSettings secrets AwsSecretSettings secrets
) { ) {
super( super(
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings(), chunkingSettings), new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
new ModelSecrets(secrets) new ModelSecrets(secrets)
); );
} }
@ -77,6 +72,10 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel {
super(model, serviceSettings); super(model, serviceSettings);
} }
public AmazonBedrockEmbeddingsModel(Model model, AmazonBedrockEmbeddingsTaskSettings taskSettings) {
super(model, taskSettings);
}
@Override @Override
public ExecutableAction accept(AmazonBedrockActionVisitor creator, Map<String, Object> taskSettings) { public ExecutableAction accept(AmazonBedrockActionVisitor creator, Map<String, Object> taskSettings) {
return creator.create(this, taskSettings); return creator.create(this, taskSettings);
@ -86,4 +85,9 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel {
public AmazonBedrockEmbeddingsServiceSettings getServiceSettings() { public AmazonBedrockEmbeddingsServiceSettings getServiceSettings() {
return (AmazonBedrockEmbeddingsServiceSettings) super.getServiceSettings(); return (AmazonBedrockEmbeddingsServiceSettings) super.getServiceSettings();
} }
@Override
public AmazonBedrockEmbeddingsTaskSettings getTaskSettings() {
return (AmazonBedrockEmbeddingsTaskSettings) super.getTaskSettings();
}
} }

View file

@ -0,0 +1,98 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TRUNCATE_FIELD;
public record AmazonBedrockEmbeddingsTaskSettings(@Nullable CohereTruncation cohereTruncation) implements TaskSettings {
public static final AmazonBedrockEmbeddingsTaskSettings EMPTY = new AmazonBedrockEmbeddingsTaskSettings((CohereTruncation) null);
public static final String NAME = "amazon_bedrock_embeddings_task_settings";
public static AmazonBedrockEmbeddingsTaskSettings fromMap(Map<String, Object> map) {
if (map == null || map.isEmpty()) {
return EMPTY;
}
ValidationException validationException = new ValidationException();
var cohereTruncation = extractOptionalEnum(
map,
TRUNCATE_FIELD,
ModelConfigurations.TASK_SETTINGS,
CohereTruncation::fromString,
CohereTruncation.ALL,
validationException
);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
return new AmazonBedrockEmbeddingsTaskSettings(cohereTruncation);
}
public AmazonBedrockEmbeddingsTaskSettings(StreamInput in) throws IOException {
this(in.readOptionalEnum(CohereTruncation.class));
}
@Override
public boolean isEmpty() {
return cohereTruncation() == null;
}
@Override
public AmazonBedrockEmbeddingsTaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
var newTaskSettings = fromMap(new HashMap<>(newSettings));
return new AmazonBedrockEmbeddingsTaskSettings(firstNonNullOrNull(newTaskSettings.cohereTruncation(), cohereTruncation()));
}
private static <T> T firstNonNullOrNull(T first, T second) {
return first != null ? first : second;
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.AMAZON_BEDROCK_TASK_SETTINGS;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalEnum(cohereTruncation());
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (cohereTruncation != null) {
builder.field(TRUNCATE_FIELD, cohereTruncation);
}
return builder.endObject();
}
}

View file

@ -11,6 +11,7 @@ import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.InputType;
import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettings;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
@ -18,7 +19,11 @@ import java.util.Objects;
import static org.elasticsearch.inference.InputType.invalidInputTypeMessage; import static org.elasticsearch.inference.InputType.invalidInputTypeMessage;
public record AmazonBedrockCohereEmbeddingsRequestEntity(List<String> input, @Nullable InputType inputType) implements ToXContentObject { public record AmazonBedrockCohereEmbeddingsRequestEntity(
List<String> input,
@Nullable InputType inputType,
AmazonBedrockEmbeddingsTaskSettings taskSettings
) implements ToXContentObject {
private static final String TEXTS_FIELD = "texts"; private static final String TEXTS_FIELD = "texts";
private static final String INPUT_TYPE_FIELD = "input_type"; private static final String INPUT_TYPE_FIELD = "input_type";
@ -26,9 +31,11 @@ public record AmazonBedrockCohereEmbeddingsRequestEntity(List<String> input, @Nu
private static final String SEARCH_QUERY = "search_query"; private static final String SEARCH_QUERY = "search_query";
private static final String CLUSTERING = "clustering"; private static final String CLUSTERING = "clustering";
private static final String CLASSIFICATION = "classification"; private static final String CLASSIFICATION = "classification";
private static final String TRUNCATE = "truncate";
public AmazonBedrockCohereEmbeddingsRequestEntity { public AmazonBedrockCohereEmbeddingsRequestEntity {
Objects.requireNonNull(input); Objects.requireNonNull(input);
Objects.requireNonNull(taskSettings);
} }
@Override @Override
@ -43,6 +50,10 @@ public record AmazonBedrockCohereEmbeddingsRequestEntity(List<String> input, @Nu
builder.field(INPUT_TYPE_FIELD, SEARCH_DOCUMENT); builder.field(INPUT_TYPE_FIELD, SEARCH_DOCUMENT);
} }
if (taskSettings.cohereTruncation() != null) {
builder.field(TRUNCATE, taskSettings.cohereTruncation().name());
}
builder.endObject(); builder.endObject();
return builder; return builder;
} }

View file

@ -39,7 +39,7 @@ public final class AmazonBedrockEmbeddingsEntityFactory {
return new AmazonBedrockTitanEmbeddingsRequestEntity(truncatedInput.get(0)); return new AmazonBedrockTitanEmbeddingsRequestEntity(truncatedInput.get(0));
} }
case COHERE -> { case COHERE -> {
return new AmazonBedrockCohereEmbeddingsRequestEntity(truncatedInput, inputType); return new AmazonBedrockCohereEmbeddingsRequestEntity(truncatedInput, inputType, model.getTaskSettings());
} }
default -> { default -> {
return null; return null;

View file

@ -76,6 +76,9 @@ public class AmazonBedrockEmbeddingsRequest extends AmazonBedrockRequest {
@Override @Override
public Request truncate() { public Request truncate() {
if (provider == AmazonBedrockProvider.COHERE) {
return this; // Cohere has its own truncation logic
}
var truncatedInput = truncator.truncate(truncationResult.input()); var truncatedInput = truncator.truncate(truncationResult.input());
return new AmazonBedrockEmbeddingsRequest(truncator, truncatedInput, embeddingsModel, requestEntity, timeout); return new AmazonBedrockEmbeddingsRequest(truncator, truncatedInput, embeddingsModel, requestEntity, timeout);
} }

View file

@ -7,6 +7,7 @@
package org.elasticsearch.xpack.inference.services.cohere; package org.elasticsearch.xpack.inference.services.cohere;
import java.util.EnumSet;
import java.util.Locale; import java.util.Locale;
/** /**
@ -31,6 +32,8 @@ public enum CohereTruncation {
*/ */
END; END;
public static final EnumSet<CohereTruncation> ALL = EnumSet.allOf(CohereTruncation.class);
@Override @Override
public String toString() { public String toString() {
return name().toLowerCase(Locale.ROOT); return name().toLowerCase(Locale.ROOT);

View file

@ -20,7 +20,6 @@ import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
import java.io.IOException; import java.io.IOException;
import java.util.EnumSet;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
@ -63,7 +62,7 @@ public class CohereEmbeddingsTaskSettings implements TaskSettings {
TRUNCATE, TRUNCATE,
ModelConfigurations.TASK_SETTINGS, ModelConfigurations.TASK_SETTINGS,
CohereTruncation::fromString, CohereTruncation::fromString,
EnumSet.allOf(CohereTruncation.class), CohereTruncation.ALL,
validationException validationException
); );

View file

@ -12,6 +12,7 @@ import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeExcept
import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionTestUtils;
import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesArray;
@ -51,6 +52,8 @@ import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.Amazo
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModel; import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModelTests;
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettingsTests;
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
import org.hamcrest.CoreMatchers; import org.hamcrest.CoreMatchers;
import org.hamcrest.Matchers; import org.hamcrest.Matchers;
@ -105,7 +108,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
public void testParseRequestConfig_CreatesAnAmazonBedrockModel() throws IOException { public void testParseRequestConfig_CreatesAnAmazonBedrockModel() throws IOException {
try (var service = createAmazonBedrockService()) { try (var service = createAmazonBedrockService()) {
ActionListener<Model> modelVerificationListener = ActionListener.wrap(model -> { ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> {
assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class));
var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings();
@ -115,7 +118,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var secretSettings = (AwsSecretSettings) model.getSecretSettings(); var secretSettings = (AwsSecretSettings) model.getSecretSettings();
assertThat(secretSettings.accessKey().toString(), is("access")); assertThat(secretSettings.accessKey().toString(), is("access"));
assertThat(secretSettings.secretKey().toString(), is("secret")); assertThat(secretSettings.secretKey().toString(), is("secret"));
}, exception -> fail("Unexpected exception: " + exception)); });
service.parseRequestConfig( service.parseRequestConfig(
"id", "id",
@ -130,15 +133,62 @@ public class AmazonBedrockServiceTests extends ESTestCase {
} }
} }
public void testParseRequestConfig_CreatesACohereModel() throws IOException {
try (var service = createAmazonBedrockService()) {
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> {
assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class));
var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings();
assertThat(settings.region(), is("region"));
assertThat(settings.modelId(), is("model"));
assertThat(settings.provider(), is(AmazonBedrockProvider.COHERE));
var secretSettings = (AwsSecretSettings) model.getSecretSettings();
assertThat(secretSettings.accessKey().toString(), is("access"));
assertThat(secretSettings.secretKey().toString(), is("secret"));
});
service.parseRequestConfig(
"id",
TaskType.TEXT_EMBEDDING,
getRequestConfigMap(
createEmbeddingsRequestSettingsMap("region", "model", "cohere", null, null, null, null),
AmazonBedrockEmbeddingsTaskSettingsTests.mutableMap("truncate", CohereTruncation.START),
getAmazonBedrockSecretSettingsMap("access", "secret")
),
modelVerificationListener
);
}
}
public void testParseRequestConfig_CohereSettingsWithNoCohereModel() throws IOException {
try (var service = createAmazonBedrockService()) {
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> {
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
assertThat(
exception.getMessage(),
is("The [text_embedding] task type for provider [amazontitan] does not allow [truncate] field")
);
});
service.parseRequestConfig(
"id",
TaskType.TEXT_EMBEDDING,
getRequestConfigMap(
createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, null, null, null),
AmazonBedrockEmbeddingsTaskSettingsTests.mutableMap("truncate", CohereTruncation.START),
getAmazonBedrockSecretSettingsMap("access", "secret")
),
modelVerificationListener
);
}
}
public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException { public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException {
try (var service = createAmazonBedrockService()) { try (var service = createAmazonBedrockService()) {
ActionListener<Model> modelVerificationListener = ActionListener.wrap( ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> {
model -> fail("Expected exception, but got model: " + model), assertThat(exception, instanceOf(ElasticsearchStatusException.class));
exception -> { assertThat(exception.getMessage(), is("The [amazonbedrock] service does not support task type [sparse_embedding]"));
assertThat(exception, instanceOf(ElasticsearchStatusException.class)); });
assertThat(exception.getMessage(), is("The [amazonbedrock] service does not support task type [sparse_embedding]"));
}
);
service.parseRequestConfig( service.parseRequestConfig(
"id", "id",
@ -247,13 +297,10 @@ public class AmazonBedrockServiceTests extends ESTestCase {
public void testCreateModel_ForEmbeddingsTask_InvalidProvider() throws IOException { public void testCreateModel_ForEmbeddingsTask_InvalidProvider() throws IOException {
try (var service = createAmazonBedrockService()) { try (var service = createAmazonBedrockService()) {
ActionListener<Model> modelVerificationListener = ActionListener.wrap( ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> {
model -> fail("Expected exception, but got model: " + model), assertThat(exception, instanceOf(ElasticsearchStatusException.class));
exception -> { assertThat(exception.getMessage(), is("The [text_embedding] task type for provider [anthropic] is not available"));
assertThat(exception, instanceOf(ElasticsearchStatusException.class)); });
assertThat(exception.getMessage(), is("The [text_embedding] task type for provider [anthropic] is not available"));
}
);
service.parseRequestConfig( service.parseRequestConfig(
"id", "id",
@ -270,13 +317,10 @@ public class AmazonBedrockServiceTests extends ESTestCase {
public void testCreateModel_TopKParameter_NotAvailable() throws IOException { public void testCreateModel_TopKParameter_NotAvailable() throws IOException {
try (var service = createAmazonBedrockService()) { try (var service = createAmazonBedrockService()) {
ActionListener<Model> modelVerificationListener = ActionListener.wrap( ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> {
model -> fail("Expected exception, but got model: " + model), assertThat(exception, instanceOf(ElasticsearchStatusException.class));
exception -> { assertThat(exception.getMessage(), is("The [top_k] task parameter is not available for provider [amazontitan]"));
assertThat(exception, instanceOf(ElasticsearchStatusException.class)); });
assertThat(exception.getMessage(), is("The [top_k] task parameter is not available for provider [amazontitan]"));
}
);
service.parseRequestConfig( service.parseRequestConfig(
"id", "id",
@ -301,16 +345,13 @@ public class AmazonBedrockServiceTests extends ESTestCase {
config.put("extra_key", "value"); config.put("extra_key", "value");
ActionListener<Model> modelVerificationListener = ActionListener.wrap( ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> {
model -> fail("Expected exception, but got model: " + model), assertThat(exception, instanceOf(ElasticsearchStatusException.class));
exception -> { assertThat(
assertThat(exception, instanceOf(ElasticsearchStatusException.class)); exception.getMessage(),
assertThat( is("Model configuration contains settings [{extra_key=value}] unknown to the [amazonbedrock] service")
exception.getMessage(), );
is("Model configuration contains settings [{extra_key=value}] unknown to the [amazonbedrock] service") });
);
}
);
service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener);
} }
@ -323,9 +364,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var config = getRequestConfigMap(serviceSettings, Map.of(), getAmazonBedrockSecretSettingsMap("access", "secret")); var config = getRequestConfigMap(serviceSettings, Map.of(), getAmazonBedrockSecretSettingsMap("access", "secret"));
ActionListener<Model> modelVerificationListener = ActionListener.<Model>wrap((model) -> { ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(e -> {
fail("Expected exception, but got model: " + model);
}, e -> {
assertThat(e, instanceOf(ElasticsearchStatusException.class)); assertThat(e, instanceOf(ElasticsearchStatusException.class));
assertThat( assertThat(
e.getMessage(), e.getMessage(),
@ -347,9 +386,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var config = getRequestConfigMap(settingsMap, taskSettingsMap, secretSettingsMap); var config = getRequestConfigMap(settingsMap, taskSettingsMap, secretSettingsMap);
ActionListener<Model> modelVerificationListener = ActionListener.<Model>wrap((model) -> { ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(e -> {
fail("Expected exception, but got model: " + model);
}, e -> {
assertThat(e, instanceOf(ElasticsearchStatusException.class)); assertThat(e, instanceOf(ElasticsearchStatusException.class));
assertThat( assertThat(
e.getMessage(), e.getMessage(),
@ -371,9 +408,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var config = getRequestConfigMap(settingsMap, taskSettingsMap, secretSettingsMap); var config = getRequestConfigMap(settingsMap, taskSettingsMap, secretSettingsMap);
ActionListener<Model> modelVerificationListener = ActionListener.<Model>wrap((model) -> { ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(e -> {
fail("Expected exception, but got model: " + model);
}, e -> {
assertThat(e, instanceOf(ElasticsearchStatusException.class)); assertThat(e, instanceOf(ElasticsearchStatusException.class));
assertThat( assertThat(
e.getMessage(), e.getMessage(),
@ -387,7 +422,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
public void testParseRequestConfig_MovesModel() throws IOException { public void testParseRequestConfig_MovesModel() throws IOException {
try (var service = createAmazonBedrockService()) { try (var service = createAmazonBedrockService()) {
ActionListener<Model> modelVerificationListener = ActionListener.wrap(model -> { ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> {
assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class));
var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings();
@ -397,7 +432,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var secretSettings = (AwsSecretSettings) model.getSecretSettings(); var secretSettings = (AwsSecretSettings) model.getSecretSettings();
assertThat(secretSettings.accessKey().toString(), is("access")); assertThat(secretSettings.accessKey().toString(), is("access"));
assertThat(secretSettings.secretKey().toString(), is("secret")); assertThat(secretSettings.secretKey().toString(), is("secret"));
}, exception -> fail("Unexpected exception: " + exception)); });
service.parseRequestConfig( service.parseRequestConfig(
"id", "id",
@ -414,7 +449,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
public void testParseRequestConfig_CreatesAnAmazonBedrockEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { public void testParseRequestConfig_CreatesAnAmazonBedrockEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
try (var service = createAmazonBedrockService()) { try (var service = createAmazonBedrockService()) {
ActionListener<Model> modelVerificationListener = ActionListener.wrap(model -> { ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> {
assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class));
var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings();
@ -425,7 +460,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
assertThat(secretSettings.accessKey().toString(), is("access")); assertThat(secretSettings.accessKey().toString(), is("access"));
assertThat(secretSettings.secretKey().toString(), is("secret")); assertThat(secretSettings.secretKey().toString(), is("secret"));
assertThat(model.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); assertThat(model.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
}, exception -> fail("Unexpected exception: " + exception)); });
service.parseRequestConfig( service.parseRequestConfig(
"id", "id",
@ -443,7 +478,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
public void testParseRequestConfig_CreatesAnAmazonBedrockEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { public void testParseRequestConfig_CreatesAnAmazonBedrockEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException {
try (var service = createAmazonBedrockService()) { try (var service = createAmazonBedrockService()) {
ActionListener<Model> modelVerificationListener = ActionListener.wrap(model -> { ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> {
assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class));
var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings();
@ -454,7 +489,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
assertThat(secretSettings.accessKey().toString(), is("access")); assertThat(secretSettings.accessKey().toString(), is("access"));
assertThat(secretSettings.secretKey().toString(), is("secret")); assertThat(secretSettings.secretKey().toString(), is("secret"));
assertThat(model.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); assertThat(model.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
}, exception -> fail("Unexpected exception: " + exception)); });
service.parseRequestConfig( service.parseRequestConfig(
"id", "id",
@ -471,13 +506,10 @@ public class AmazonBedrockServiceTests extends ESTestCase {
public void testCreateModel_ForEmbeddingsTask_DimensionsIsNotAllowed() throws IOException { public void testCreateModel_ForEmbeddingsTask_DimensionsIsNotAllowed() throws IOException {
try (var service = createAmazonBedrockService()) { try (var service = createAmazonBedrockService()) {
ActionListener<Model> modelVerificationListener = ActionListener.wrap( ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> {
model -> fail("Expected exception, but got model: " + model), assertThat(exception, instanceOf(ValidationException.class));
exception -> { assertThat(exception.getMessage(), containsString("[service_settings] does not allow the setting [dimensions]"));
assertThat(exception, instanceOf(ValidationException.class)); });
assertThat(exception.getMessage(), containsString("[service_settings] does not allow the setting [dimensions]"));
}
);
service.parseRequestConfig( service.parseRequestConfig(
"id", "id",
@ -497,7 +529,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap); var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
var model = service.parsePersistedConfigWithSecrets( var model = service.parsePersistedConfigWithSecrets(
"id", "id",
@ -525,7 +557,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var persistedConfig = getPersistedConfigMap( var persistedConfig = getPersistedConfigMap(
settingsMap, settingsMap,
new HashMap<String, Object>(Map.of()), new HashMap<>(Map.of()),
createRandomChunkingSettingsMap(), createRandomChunkingSettingsMap(),
secretSettingsMap secretSettingsMap
); );
@ -607,7 +639,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap); var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
persistedConfig.config().put("extra_key", "value"); persistedConfig.config().put("extra_key", "value");
var model = service.parsePersistedConfigWithSecrets( var model = service.parsePersistedConfigWithSecrets(
@ -635,7 +667,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
secretSettingsMap.put("extra_key", "value"); secretSettingsMap.put("extra_key", "value");
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap); var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
var model = service.parsePersistedConfigWithSecrets( var model = service.parsePersistedConfigWithSecrets(
"id", "id",
@ -661,7 +693,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap); var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
persistedConfig.secrets().put("extra_key", "value"); persistedConfig.secrets().put("extra_key", "value");
var model = service.parsePersistedConfigWithSecrets( var model = service.parsePersistedConfigWithSecrets(
@ -689,7 +721,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
settingsMap.put("extra_key", "value"); settingsMap.put("extra_key", "value");
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap); var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
var model = service.parsePersistedConfigWithSecrets( var model = service.parsePersistedConfigWithSecrets(
"id", "id",
@ -769,7 +801,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var persistedConfig = getPersistedConfigMap( var persistedConfig = getPersistedConfigMap(
settingsMap, settingsMap,
new HashMap<String, Object>(Map.of()), new HashMap<>(Map.of()),
createRandomChunkingSettingsMap(), createRandomChunkingSettingsMap(),
secretSettingsMap secretSettingsMap
); );
@ -792,7 +824,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap); var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
@ -836,7 +868,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap); var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
var thrownException = expectThrows( var thrownException = expectThrows(
ElasticsearchStatusException.class, ElasticsearchStatusException.class,
@ -855,7 +887,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap); var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
persistedConfig.config().put("extra_key", "value"); persistedConfig.config().put("extra_key", "value");
var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
@ -876,7 +908,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
settingsMap.put("extra_key", "value"); settingsMap.put("extra_key", "value");
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap); var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
persistedConfig.config().put("extra_key", "value"); persistedConfig.config().put("extra_key", "value");
var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());

View file

@ -11,6 +11,7 @@ import org.elasticsearch.inference.InputType;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.AmazonBedrockJsonBuilder; import org.elasticsearch.xpack.inference.services.amazonbedrock.request.AmazonBedrockJsonBuilder;
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.embeddings.AmazonBedrockCohereEmbeddingsRequestEntity; import org.elasticsearch.xpack.inference.services.amazonbedrock.request.embeddings.AmazonBedrockCohereEmbeddingsRequestEntity;
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
@ -19,23 +20,46 @@ import static org.hamcrest.Matchers.is;
public class AmazonBedrockCohereEmbeddingsRequestEntityTests extends ESTestCase { public class AmazonBedrockCohereEmbeddingsRequestEntityTests extends ESTestCase {
public void testRequestEntity_GeneratesExpectedJsonBody() throws IOException { public void testRequestEntity_GeneratesExpectedJsonBody() throws IOException {
var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(List.of("test input"), InputType.CLASSIFICATION); var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(
List.of("test input"),
InputType.CLASSIFICATION,
AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings()
);
var builder = new AmazonBedrockJsonBuilder(entity); var builder = new AmazonBedrockJsonBuilder(entity);
var result = builder.getStringContent(); var result = builder.getStringContent();
assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"classification\"}")); assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"classification\"}"));
} }
public void testRequestEntity_GeneratesExpectedJsonBody_WithInternalInputType() throws IOException { public void testRequestEntity_GeneratesExpectedJsonBody_WithInternalInputType() throws IOException {
var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(List.of("test input"), InputType.INTERNAL_SEARCH); var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(
List.of("test input"),
InputType.INTERNAL_SEARCH,
AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings()
);
var builder = new AmazonBedrockJsonBuilder(entity); var builder = new AmazonBedrockJsonBuilder(entity);
var result = builder.getStringContent(); var result = builder.getStringContent();
assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"search_query\"}")); assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"search_query\"}"));
} }
public void testRequestEntity_GeneratesExpectedJsonBody_WithoutInputType() throws IOException { public void testRequestEntity_GeneratesExpectedJsonBody_WithoutInputType() throws IOException {
var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(List.of("test input"), null); var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(
List.of("test input"),
null,
AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings()
);
var builder = new AmazonBedrockJsonBuilder(entity); var builder = new AmazonBedrockJsonBuilder(entity);
var result = builder.getStringContent(); var result = builder.getStringContent();
assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"search_document\"}")); assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"search_document\"}"));
} }
public void testRequestEntity_GeneratesExpectedJsonBody_WithCohereTruncation() throws IOException {
var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(
List.of("test input"),
null,
new AmazonBedrockEmbeddingsTaskSettings(CohereTruncation.START)
);
var builder = new AmazonBedrockJsonBuilder(entity);
var result = builder.getStringContent();
assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"search_document\",\"truncate\":\"START\"}"));
}
} }

View file

@ -7,12 +7,9 @@
package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings; package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
@ -20,19 +17,23 @@ import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import java.util.Map; import java.io.IOException;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.not;
public class AmazonBedrockEmbeddingsModelTests extends ESTestCase { public class AmazonBedrockEmbeddingsModelTests extends ESTestCase {
public void testCreateModel_withTaskSettings_shouldFail() { public void testCreateModel_withTaskSettingsOverride() throws IOException {
var baseModel = createModel("id", "region", "model", AmazonBedrockProvider.AMAZONTITAN, "accesskey", "secretkey"); var baseTaskSettings = AmazonBedrockEmbeddingsTaskSettingsTests.randomTaskSettings();
var thrownException = assertThrows( var baseModel = createModel("id", "region", "model", AmazonBedrockProvider.AMAZONTITAN, "accesskey", "secretkey", baseTaskSettings);
ValidationException.class,
() -> AmazonBedrockEmbeddingsModel.of(baseModel, Map.of("testkey", "testvalue")) var overrideTaskSettings = AmazonBedrockEmbeddingsTaskSettingsTests.mutateTaskSettings(baseTaskSettings);
); var overrideTaskSettingsMap = AmazonBedrockEmbeddingsTaskSettingsTests.toMap(overrideTaskSettings);
assertThat(thrownException.getMessage(), containsString("Amazon Bedrock embeddings model cannot have task settings"));
var overriddenModel = AmazonBedrockEmbeddingsModel.of(baseModel, overrideTaskSettingsMap);
assertThat(overriddenModel.getTaskSettings(), equalTo(overrideTaskSettings));
assertThat(overriddenModel.getTaskSettings(), not(equalTo(baseTaskSettings)));
} }
// model creation only - no tests to define, but we want to have the public createModel // model creation only - no tests to define, but we want to have the public createModel
@ -46,7 +47,15 @@ public class AmazonBedrockEmbeddingsModelTests extends ESTestCase {
String accessKey, String accessKey,
String secretKey String secretKey
) { ) {
return createModel(inferenceId, region, model, provider, null, false, null, null, new RateLimitSettings(240), accessKey, secretKey); return createModel(
inferenceId,
region,
model,
provider,
accessKey,
secretKey,
AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings()
);
} }
public static AmazonBedrockEmbeddingsModel createModel( public static AmazonBedrockEmbeddingsModel createModel(
@ -56,9 +65,22 @@ public class AmazonBedrockEmbeddingsModelTests extends ESTestCase {
AmazonBedrockProvider provider, AmazonBedrockProvider provider,
String accessKey, String accessKey,
String secretKey, String secretKey,
InputType inputType AmazonBedrockEmbeddingsTaskSettings taskSettings
) { ) {
return createModel(inferenceId, region, model, provider, null, false, null, null, new RateLimitSettings(240), accessKey, secretKey); return createModel(
inferenceId,
region,
model,
provider,
null,
false,
null,
null,
new RateLimitSettings(240),
accessKey,
secretKey,
taskSettings
);
} }
public static AmazonBedrockEmbeddingsModel createModel( public static AmazonBedrockEmbeddingsModel createModel(
@ -114,7 +136,7 @@ public class AmazonBedrockEmbeddingsModelTests extends ESTestCase {
similarity, similarity,
rateLimitSettings rateLimitSettings
), ),
new EmptyTaskSettings(), AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings(),
chunkingSettings, chunkingSettings,
new AwsSecretSettings(new SecureString(accessKey), new SecureString(secretKey)) new AwsSecretSettings(new SecureString(accessKey), new SecureString(secretKey))
); );
@ -132,6 +154,36 @@ public class AmazonBedrockEmbeddingsModelTests extends ESTestCase {
RateLimitSettings rateLimitSettings, RateLimitSettings rateLimitSettings,
String accessKey, String accessKey,
String secretKey String secretKey
) {
return createModel(
inferenceId,
region,
model,
provider,
dimensions,
dimensionsSetByUser,
maxTokens,
similarity,
rateLimitSettings,
accessKey,
secretKey,
AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings()
);
}
public static AmazonBedrockEmbeddingsModel createModel(
String inferenceId,
String region,
String model,
AmazonBedrockProvider provider,
@Nullable Integer dimensions,
boolean dimensionsSetByUser,
@Nullable Integer maxTokens,
@Nullable SimilarityMeasure similarity,
RateLimitSettings rateLimitSettings,
String accessKey,
String secretKey,
AmazonBedrockEmbeddingsTaskSettings taskSettings
) { ) {
return new AmazonBedrockEmbeddingsModel( return new AmazonBedrockEmbeddingsModel(
inferenceId, inferenceId,
@ -147,7 +199,7 @@ public class AmazonBedrockEmbeddingsModelTests extends ESTestCase {
similarity, similarity,
rateLimitSettings rateLimitSettings
), ),
new EmptyTaskSettings(), taskSettings,
null, null,
new AwsSecretSettings(new SecureString(accessKey), new SecureString(secretKey)) new AwsSecretSettings(new SecureString(accessKey), new SecureString(secretKey))
); );

View file

@ -0,0 +1,112 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.xpack.inference.InputTypeTests.randomWithoutUnspecified;
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TRUNCATE_FIELD;
import static org.hamcrest.Matchers.equalTo;
public class AmazonBedrockEmbeddingsTaskSettingsTests extends AbstractBWCWireSerializationTestCase<AmazonBedrockEmbeddingsTaskSettings> {
public static AmazonBedrockEmbeddingsTaskSettings emptyTaskSettings() {
return AmazonBedrockEmbeddingsTaskSettings.EMPTY;
}
public static AmazonBedrockEmbeddingsTaskSettings randomTaskSettings() {
var inputType = randomBoolean() ? randomWithoutUnspecified() : null;
var truncation = randomBoolean() ? randomFrom(CohereTruncation.values()) : null;
return new AmazonBedrockEmbeddingsTaskSettings(truncation);
}
public static AmazonBedrockEmbeddingsTaskSettings mutateTaskSettings(AmazonBedrockEmbeddingsTaskSettings instance) {
return randomValueOtherThanMany(
v -> Objects.equals(instance, v) || (instance.cohereTruncation() != null && v.cohereTruncation() == null),
AmazonBedrockEmbeddingsTaskSettingsTests::randomTaskSettings
);
}
@Override
protected AmazonBedrockEmbeddingsTaskSettings mutateInstanceForVersion(
AmazonBedrockEmbeddingsTaskSettings instance,
TransportVersion version
) {
return instance;
}
@Override
protected Writeable.Reader<AmazonBedrockEmbeddingsTaskSettings> instanceReader() {
return AmazonBedrockEmbeddingsTaskSettings::new;
}
@Override
protected AmazonBedrockEmbeddingsTaskSettings createTestInstance() {
return randomTaskSettings();
}
@Override
protected AmazonBedrockEmbeddingsTaskSettings mutateInstance(AmazonBedrockEmbeddingsTaskSettings instance) throws IOException {
return mutateTaskSettings(instance);
}
public void testEmpty() {
assertTrue(emptyTaskSettings().isEmpty());
assertTrue(AmazonBedrockEmbeddingsTaskSettings.fromMap(null).isEmpty());
assertTrue(AmazonBedrockEmbeddingsTaskSettings.fromMap(Map.of()).isEmpty());
}
public static Map<String, Object> mutableMap(String key, Enum<?> value) {
return new HashMap<>(Map.of(key, value.toString()));
}
public void testValidCohereTruncations() {
for (var expectedCohereTruncation : CohereTruncation.ALL) {
var map = mutableMap(TRUNCATE_FIELD, expectedCohereTruncation);
var taskSettings = AmazonBedrockEmbeddingsTaskSettings.fromMap(map);
assertFalse(taskSettings.isEmpty());
assertThat(taskSettings.cohereTruncation(), equalTo(expectedCohereTruncation));
}
}
public void testGarbageCohereTruncations() {
var map = new HashMap<String, Object>(Map.of(TRUNCATE_FIELD, "oiuesoirtuoawoeirha"));
assertThrows(ValidationException.class, () -> AmazonBedrockEmbeddingsTaskSettings.fromMap(map));
}
public void testXContent() throws IOException {
var taskSettings = randomTaskSettings();
var taskSettingsAsMap = toMap(taskSettings);
var roundTripTaskSettings = AmazonBedrockEmbeddingsTaskSettings.fromMap(new HashMap<>(taskSettingsAsMap));
assertThat(roundTripTaskSettings, equalTo(taskSettings));
}
public static Map<String, Object> toMap(AmazonBedrockEmbeddingsTaskSettings taskSettings) throws IOException {
try (var builder = JsonXContent.contentBuilder()) {
taskSettings.toXContent(builder, ToXContent.EMPTY_PARAMS);
var taskSettingsBytes = Strings.toString(builder).getBytes(StandardCharsets.UTF_8);
try (var parser = JsonXContent.jsonXContent.createParser(XContentParserConfiguration.EMPTY, taskSettingsBytes)) {
return parser.map();
}
}
}
}