[ML] Bedrock Cohere Task Settings Support (#126493)

Add support for Cohere Task Settings and Truncate, through
the Amazon Bedrock provider integration.

Task Settings can now be passed bother during Inference endpoint
creation and Inference POST requests.

Close #126156
This commit is contained in:
Pat Whelan 2025-04-09 15:34:05 -04:00 committed by GitHub
parent 6d86b202ea
commit 6c6500ec3b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 468 additions and 102 deletions

View file

@ -0,0 +1,6 @@
pr: 126493
summary: Bedrock Cohere Task Settings Support
area: Machine Learning
type: enhancement
issues:
- 126156

View file

@ -157,6 +157,7 @@ public class TransportVersions {
public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14);
public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15);
public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16);
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
@ -215,6 +216,7 @@ public class TransportVersions {
public static final TransportVersion ADD_PROJECT_ID_TO_DSL_ERROR_INFO = def(9_046_0_00);
public static final TransportVersion SEMANTIC_TEXT_CHUNKING_CONFIG = def(9_047_00_0);
public static final TransportVersion REPO_ANALYSIS_COPY_BLOB = def(9_048_00_0);
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS = def(9_049_00_0);
/*
* STOP! READ THIS FIRST! No, really,

View file

@ -41,6 +41,7 @@ import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.Alib
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionServiceSettings;
@ -173,8 +174,13 @@ public class InferenceNamedWriteablesProvider {
AmazonBedrockEmbeddingsServiceSettings::new
)
);
// no task settings for Amazon Bedrock Embeddings
namedWriteables.add(
new NamedWriteableRegistry.Entry(
TaskSettings.class,
AmazonBedrockEmbeddingsTaskSettings.NAME,
AmazonBedrockEmbeddingsTaskSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(

View file

@ -19,6 +19,8 @@ public class AmazonBedrockConstants {
public static final String TOP_K_FIELD = "top_k";
public static final String MAX_NEW_TOKENS_FIELD = "max_new_tokens";
public static final String TRUNCATE_FIELD = "truncate";
public static final Double MIN_TEMPERATURE_TOP_P_TOP_K_VALUE = 0.0;
public static final Double MAX_TEMPERATURE_TOP_P_TOP_K_VALUE = 1.0;

View file

@ -303,6 +303,7 @@ public class AmazonBedrockService extends SenderService {
context
);
checkProviderForTask(TaskType.TEXT_EMBEDDING, model.provider());
checkTaskSettingsForTextEmbeddingModel(model);
return model;
}
case COMPLETION -> {
@ -368,6 +369,17 @@ public class AmazonBedrockService extends SenderService {
}
}
private static void checkTaskSettingsForTextEmbeddingModel(AmazonBedrockEmbeddingsModel model) {
if (model.provider() != AmazonBedrockProvider.COHERE && model.getTaskSettings().cohereTruncation() != null) {
throw new ElasticsearchStatusException(
"The [{}] task type for provider [{}] does not allow [truncate] field",
RestStatus.BAD_REQUEST,
TaskType.TEXT_EMBEDDING,
model.provider()
);
}
}
private static void checkChatCompletionProviderForTopKParameter(AmazonBedrockChatCompletionModel model) {
var taskSettings = model.getTaskSettings();
if (taskSettings.topK() != null) {

View file

@ -7,14 +7,11 @@
package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
@ -28,10 +25,8 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel {
public static AmazonBedrockEmbeddingsModel of(AmazonBedrockEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings) {
if (taskSettings != null && taskSettings.isEmpty() == false) {
// no task settings allowed
var validationException = new ValidationException();
validationException.addValidationError("Amazon Bedrock embeddings model cannot have task settings");
throw validationException;
var updatedTaskSettings = embeddingsModel.getTaskSettings().updatedTaskSettings(taskSettings);
return new AmazonBedrockEmbeddingsModel(embeddingsModel, updatedTaskSettings);
}
return embeddingsModel;
@ -52,7 +47,7 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel {
taskType,
service,
AmazonBedrockEmbeddingsServiceSettings.fromMap(serviceSettings, context),
new EmptyTaskSettings(),
AmazonBedrockEmbeddingsTaskSettings.fromMap(taskSettings),
chunkingSettings,
AwsSecretSettings.fromMap(secretSettings)
);
@ -63,12 +58,12 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel {
TaskType taskType,
String service,
AmazonBedrockEmbeddingsServiceSettings serviceSettings,
TaskSettings taskSettings,
AmazonBedrockEmbeddingsTaskSettings taskSettings,
ChunkingSettings chunkingSettings,
AwsSecretSettings secrets
) {
super(
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings(), chunkingSettings),
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
new ModelSecrets(secrets)
);
}
@ -77,6 +72,10 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel {
super(model, serviceSettings);
}
public AmazonBedrockEmbeddingsModel(Model model, AmazonBedrockEmbeddingsTaskSettings taskSettings) {
super(model, taskSettings);
}
@Override
public ExecutableAction accept(AmazonBedrockActionVisitor creator, Map<String, Object> taskSettings) {
return creator.create(this, taskSettings);
@ -86,4 +85,9 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel {
public AmazonBedrockEmbeddingsServiceSettings getServiceSettings() {
return (AmazonBedrockEmbeddingsServiceSettings) super.getServiceSettings();
}
@Override
public AmazonBedrockEmbeddingsTaskSettings getTaskSettings() {
return (AmazonBedrockEmbeddingsTaskSettings) super.getTaskSettings();
}
}

View file

@ -0,0 +1,98 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TRUNCATE_FIELD;
public record AmazonBedrockEmbeddingsTaskSettings(@Nullable CohereTruncation cohereTruncation) implements TaskSettings {
public static final AmazonBedrockEmbeddingsTaskSettings EMPTY = new AmazonBedrockEmbeddingsTaskSettings((CohereTruncation) null);
public static final String NAME = "amazon_bedrock_embeddings_task_settings";
public static AmazonBedrockEmbeddingsTaskSettings fromMap(Map<String, Object> map) {
if (map == null || map.isEmpty()) {
return EMPTY;
}
ValidationException validationException = new ValidationException();
var cohereTruncation = extractOptionalEnum(
map,
TRUNCATE_FIELD,
ModelConfigurations.TASK_SETTINGS,
CohereTruncation::fromString,
CohereTruncation.ALL,
validationException
);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
return new AmazonBedrockEmbeddingsTaskSettings(cohereTruncation);
}
public AmazonBedrockEmbeddingsTaskSettings(StreamInput in) throws IOException {
this(in.readOptionalEnum(CohereTruncation.class));
}
@Override
public boolean isEmpty() {
return cohereTruncation() == null;
}
@Override
public AmazonBedrockEmbeddingsTaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
var newTaskSettings = fromMap(new HashMap<>(newSettings));
return new AmazonBedrockEmbeddingsTaskSettings(firstNonNullOrNull(newTaskSettings.cohereTruncation(), cohereTruncation()));
}
private static <T> T firstNonNullOrNull(T first, T second) {
return first != null ? first : second;
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.AMAZON_BEDROCK_TASK_SETTINGS;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalEnum(cohereTruncation());
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (cohereTruncation != null) {
builder.field(TRUNCATE_FIELD, cohereTruncation);
}
return builder.endObject();
}
}

View file

@ -11,6 +11,7 @@ import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettings;
import java.io.IOException;
import java.util.List;
@ -18,7 +19,11 @@ import java.util.Objects;
import static org.elasticsearch.inference.InputType.invalidInputTypeMessage;
public record AmazonBedrockCohereEmbeddingsRequestEntity(List<String> input, @Nullable InputType inputType) implements ToXContentObject {
public record AmazonBedrockCohereEmbeddingsRequestEntity(
List<String> input,
@Nullable InputType inputType,
AmazonBedrockEmbeddingsTaskSettings taskSettings
) implements ToXContentObject {
private static final String TEXTS_FIELD = "texts";
private static final String INPUT_TYPE_FIELD = "input_type";
@ -26,9 +31,11 @@ public record AmazonBedrockCohereEmbeddingsRequestEntity(List<String> input, @Nu
private static final String SEARCH_QUERY = "search_query";
private static final String CLUSTERING = "clustering";
private static final String CLASSIFICATION = "classification";
private static final String TRUNCATE = "truncate";
public AmazonBedrockCohereEmbeddingsRequestEntity {
Objects.requireNonNull(input);
Objects.requireNonNull(taskSettings);
}
@Override
@ -43,6 +50,10 @@ public record AmazonBedrockCohereEmbeddingsRequestEntity(List<String> input, @Nu
builder.field(INPUT_TYPE_FIELD, SEARCH_DOCUMENT);
}
if (taskSettings.cohereTruncation() != null) {
builder.field(TRUNCATE, taskSettings.cohereTruncation().name());
}
builder.endObject();
return builder;
}

View file

@ -39,7 +39,7 @@ public final class AmazonBedrockEmbeddingsEntityFactory {
return new AmazonBedrockTitanEmbeddingsRequestEntity(truncatedInput.get(0));
}
case COHERE -> {
return new AmazonBedrockCohereEmbeddingsRequestEntity(truncatedInput, inputType);
return new AmazonBedrockCohereEmbeddingsRequestEntity(truncatedInput, inputType, model.getTaskSettings());
}
default -> {
return null;

View file

@ -76,6 +76,9 @@ public class AmazonBedrockEmbeddingsRequest extends AmazonBedrockRequest {
@Override
public Request truncate() {
if (provider == AmazonBedrockProvider.COHERE) {
return this; // Cohere has its own truncation logic
}
var truncatedInput = truncator.truncate(truncationResult.input());
return new AmazonBedrockEmbeddingsRequest(truncator, truncatedInput, embeddingsModel, requestEntity, timeout);
}

View file

@ -7,6 +7,7 @@
package org.elasticsearch.xpack.inference.services.cohere;
import java.util.EnumSet;
import java.util.Locale;
/**
@ -31,6 +32,8 @@ public enum CohereTruncation {
*/
END;
public static final EnumSet<CohereTruncation> ALL = EnumSet.allOf(CohereTruncation.class);
@Override
public String toString() {
return name().toLowerCase(Locale.ROOT);

View file

@ -20,7 +20,6 @@ import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
import java.io.IOException;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
@ -63,7 +62,7 @@ public class CohereEmbeddingsTaskSettings implements TaskSettings {
TRUNCATE,
ModelConfigurations.TASK_SETTINGS,
CohereTruncation::fromString,
EnumSet.allOf(CohereTruncation.class),
CohereTruncation.ALL,
validationException
);

View file

@ -12,6 +12,7 @@ import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeExcept
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionTestUtils;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.bytes.BytesArray;
@ -51,6 +52,8 @@ import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.Amazo
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModelTests;
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettingsTests;
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
import org.hamcrest.CoreMatchers;
import org.hamcrest.Matchers;
@ -105,7 +108,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
public void testParseRequestConfig_CreatesAnAmazonBedrockModel() throws IOException {
try (var service = createAmazonBedrockService()) {
ActionListener<Model> modelVerificationListener = ActionListener.wrap(model -> {
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> {
assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class));
var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings();
@ -115,7 +118,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var secretSettings = (AwsSecretSettings) model.getSecretSettings();
assertThat(secretSettings.accessKey().toString(), is("access"));
assertThat(secretSettings.secretKey().toString(), is("secret"));
}, exception -> fail("Unexpected exception: " + exception));
});
service.parseRequestConfig(
"id",
@ -130,15 +133,62 @@ public class AmazonBedrockServiceTests extends ESTestCase {
}
}
public void testParseRequestConfig_CreatesACohereModel() throws IOException {
try (var service = createAmazonBedrockService()) {
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> {
assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class));
var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings();
assertThat(settings.region(), is("region"));
assertThat(settings.modelId(), is("model"));
assertThat(settings.provider(), is(AmazonBedrockProvider.COHERE));
var secretSettings = (AwsSecretSettings) model.getSecretSettings();
assertThat(secretSettings.accessKey().toString(), is("access"));
assertThat(secretSettings.secretKey().toString(), is("secret"));
});
service.parseRequestConfig(
"id",
TaskType.TEXT_EMBEDDING,
getRequestConfigMap(
createEmbeddingsRequestSettingsMap("region", "model", "cohere", null, null, null, null),
AmazonBedrockEmbeddingsTaskSettingsTests.mutableMap("truncate", CohereTruncation.START),
getAmazonBedrockSecretSettingsMap("access", "secret")
),
modelVerificationListener
);
}
}
public void testParseRequestConfig_CohereSettingsWithNoCohereModel() throws IOException {
try (var service = createAmazonBedrockService()) {
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> {
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
assertThat(
exception.getMessage(),
is("The [text_embedding] task type for provider [amazontitan] does not allow [truncate] field")
);
});
service.parseRequestConfig(
"id",
TaskType.TEXT_EMBEDDING,
getRequestConfigMap(
createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, null, null, null),
AmazonBedrockEmbeddingsTaskSettingsTests.mutableMap("truncate", CohereTruncation.START),
getAmazonBedrockSecretSettingsMap("access", "secret")
),
modelVerificationListener
);
}
}
public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException {
try (var service = createAmazonBedrockService()) {
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
model -> fail("Expected exception, but got model: " + model),
exception -> {
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> {
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
assertThat(exception.getMessage(), is("The [amazonbedrock] service does not support task type [sparse_embedding]"));
}
);
});
service.parseRequestConfig(
"id",
@ -247,13 +297,10 @@ public class AmazonBedrockServiceTests extends ESTestCase {
public void testCreateModel_ForEmbeddingsTask_InvalidProvider() throws IOException {
try (var service = createAmazonBedrockService()) {
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
model -> fail("Expected exception, but got model: " + model),
exception -> {
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> {
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
assertThat(exception.getMessage(), is("The [text_embedding] task type for provider [anthropic] is not available"));
}
);
});
service.parseRequestConfig(
"id",
@ -270,13 +317,10 @@ public class AmazonBedrockServiceTests extends ESTestCase {
public void testCreateModel_TopKParameter_NotAvailable() throws IOException {
try (var service = createAmazonBedrockService()) {
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
model -> fail("Expected exception, but got model: " + model),
exception -> {
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> {
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
assertThat(exception.getMessage(), is("The [top_k] task parameter is not available for provider [amazontitan]"));
}
);
});
service.parseRequestConfig(
"id",
@ -301,16 +345,13 @@ public class AmazonBedrockServiceTests extends ESTestCase {
config.put("extra_key", "value");
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
model -> fail("Expected exception, but got model: " + model),
exception -> {
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> {
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
assertThat(
exception.getMessage(),
is("Model configuration contains settings [{extra_key=value}] unknown to the [amazonbedrock] service")
);
}
);
});
service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener);
}
@ -323,9 +364,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var config = getRequestConfigMap(serviceSettings, Map.of(), getAmazonBedrockSecretSettingsMap("access", "secret"));
ActionListener<Model> modelVerificationListener = ActionListener.<Model>wrap((model) -> {
fail("Expected exception, but got model: " + model);
}, e -> {
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(e -> {
assertThat(e, instanceOf(ElasticsearchStatusException.class));
assertThat(
e.getMessage(),
@ -347,9 +386,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var config = getRequestConfigMap(settingsMap, taskSettingsMap, secretSettingsMap);
ActionListener<Model> modelVerificationListener = ActionListener.<Model>wrap((model) -> {
fail("Expected exception, but got model: " + model);
}, e -> {
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(e -> {
assertThat(e, instanceOf(ElasticsearchStatusException.class));
assertThat(
e.getMessage(),
@ -371,9 +408,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var config = getRequestConfigMap(settingsMap, taskSettingsMap, secretSettingsMap);
ActionListener<Model> modelVerificationListener = ActionListener.<Model>wrap((model) -> {
fail("Expected exception, but got model: " + model);
}, e -> {
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(e -> {
assertThat(e, instanceOf(ElasticsearchStatusException.class));
assertThat(
e.getMessage(),
@ -387,7 +422,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
public void testParseRequestConfig_MovesModel() throws IOException {
try (var service = createAmazonBedrockService()) {
ActionListener<Model> modelVerificationListener = ActionListener.wrap(model -> {
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> {
assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class));
var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings();
@ -397,7 +432,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var secretSettings = (AwsSecretSettings) model.getSecretSettings();
assertThat(secretSettings.accessKey().toString(), is("access"));
assertThat(secretSettings.secretKey().toString(), is("secret"));
}, exception -> fail("Unexpected exception: " + exception));
});
service.parseRequestConfig(
"id",
@ -414,7 +449,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
public void testParseRequestConfig_CreatesAnAmazonBedrockEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
try (var service = createAmazonBedrockService()) {
ActionListener<Model> modelVerificationListener = ActionListener.wrap(model -> {
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> {
assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class));
var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings();
@ -425,7 +460,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
assertThat(secretSettings.accessKey().toString(), is("access"));
assertThat(secretSettings.secretKey().toString(), is("secret"));
assertThat(model.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
}, exception -> fail("Unexpected exception: " + exception));
});
service.parseRequestConfig(
"id",
@ -443,7 +478,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
public void testParseRequestConfig_CreatesAnAmazonBedrockEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException {
try (var service = createAmazonBedrockService()) {
ActionListener<Model> modelVerificationListener = ActionListener.wrap(model -> {
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> {
assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class));
var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings();
@ -454,7 +489,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
assertThat(secretSettings.accessKey().toString(), is("access"));
assertThat(secretSettings.secretKey().toString(), is("secret"));
assertThat(model.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
}, exception -> fail("Unexpected exception: " + exception));
});
service.parseRequestConfig(
"id",
@ -471,13 +506,10 @@ public class AmazonBedrockServiceTests extends ESTestCase {
public void testCreateModel_ForEmbeddingsTask_DimensionsIsNotAllowed() throws IOException {
try (var service = createAmazonBedrockService()) {
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
model -> fail("Expected exception, but got model: " + model),
exception -> {
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> {
assertThat(exception, instanceOf(ValidationException.class));
assertThat(exception.getMessage(), containsString("[service_settings] does not allow the setting [dimensions]"));
}
);
});
service.parseRequestConfig(
"id",
@ -497,7 +529,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap);
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
var model = service.parsePersistedConfigWithSecrets(
"id",
@ -525,7 +557,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var persistedConfig = getPersistedConfigMap(
settingsMap,
new HashMap<String, Object>(Map.of()),
new HashMap<>(Map.of()),
createRandomChunkingSettingsMap(),
secretSettingsMap
);
@ -607,7 +639,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap);
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
persistedConfig.config().put("extra_key", "value");
var model = service.parsePersistedConfigWithSecrets(
@ -635,7 +667,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
secretSettingsMap.put("extra_key", "value");
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap);
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
var model = service.parsePersistedConfigWithSecrets(
"id",
@ -661,7 +693,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap);
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
persistedConfig.secrets().put("extra_key", "value");
var model = service.parsePersistedConfigWithSecrets(
@ -689,7 +721,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
settingsMap.put("extra_key", "value");
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap);
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
var model = service.parsePersistedConfigWithSecrets(
"id",
@ -769,7 +801,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var persistedConfig = getPersistedConfigMap(
settingsMap,
new HashMap<String, Object>(Map.of()),
new HashMap<>(Map.of()),
createRandomChunkingSettingsMap(),
secretSettingsMap
);
@ -792,7 +824,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap);
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
@ -836,7 +868,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap);
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
var thrownException = expectThrows(
ElasticsearchStatusException.class,
@ -855,7 +887,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap);
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
persistedConfig.config().put("extra_key", "value");
var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
@ -876,7 +908,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
settingsMap.put("extra_key", "value");
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap);
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
persistedConfig.config().put("extra_key", "value");
var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());

View file

@ -11,6 +11,7 @@ import org.elasticsearch.inference.InputType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.AmazonBedrockJsonBuilder;
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.embeddings.AmazonBedrockCohereEmbeddingsRequestEntity;
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
import java.io.IOException;
import java.util.List;
@ -19,23 +20,46 @@ import static org.hamcrest.Matchers.is;
public class AmazonBedrockCohereEmbeddingsRequestEntityTests extends ESTestCase {
public void testRequestEntity_GeneratesExpectedJsonBody() throws IOException {
var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(List.of("test input"), InputType.CLASSIFICATION);
var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(
List.of("test input"),
InputType.CLASSIFICATION,
AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings()
);
var builder = new AmazonBedrockJsonBuilder(entity);
var result = builder.getStringContent();
assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"classification\"}"));
}
public void testRequestEntity_GeneratesExpectedJsonBody_WithInternalInputType() throws IOException {
var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(List.of("test input"), InputType.INTERNAL_SEARCH);
var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(
List.of("test input"),
InputType.INTERNAL_SEARCH,
AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings()
);
var builder = new AmazonBedrockJsonBuilder(entity);
var result = builder.getStringContent();
assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"search_query\"}"));
}
public void testRequestEntity_GeneratesExpectedJsonBody_WithoutInputType() throws IOException {
var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(List.of("test input"), null);
var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(
List.of("test input"),
null,
AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings()
);
var builder = new AmazonBedrockJsonBuilder(entity);
var result = builder.getStringContent();
assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"search_document\"}"));
}
public void testRequestEntity_GeneratesExpectedJsonBody_WithCohereTruncation() throws IOException {
var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(
List.of("test input"),
null,
new AmazonBedrockEmbeddingsTaskSettings(CohereTruncation.START)
);
var builder = new AmazonBedrockJsonBuilder(entity);
var result = builder.getStringContent();
assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"search_document\",\"truncate\":\"START\"}"));
}
}

View file

@ -7,12 +7,9 @@
package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
@ -20,19 +17,23 @@ import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import java.util.Map;
import java.io.IOException;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.not;
public class AmazonBedrockEmbeddingsModelTests extends ESTestCase {
public void testCreateModel_withTaskSettings_shouldFail() {
var baseModel = createModel("id", "region", "model", AmazonBedrockProvider.AMAZONTITAN, "accesskey", "secretkey");
var thrownException = assertThrows(
ValidationException.class,
() -> AmazonBedrockEmbeddingsModel.of(baseModel, Map.of("testkey", "testvalue"))
);
assertThat(thrownException.getMessage(), containsString("Amazon Bedrock embeddings model cannot have task settings"));
public void testCreateModel_withTaskSettingsOverride() throws IOException {
var baseTaskSettings = AmazonBedrockEmbeddingsTaskSettingsTests.randomTaskSettings();
var baseModel = createModel("id", "region", "model", AmazonBedrockProvider.AMAZONTITAN, "accesskey", "secretkey", baseTaskSettings);
var overrideTaskSettings = AmazonBedrockEmbeddingsTaskSettingsTests.mutateTaskSettings(baseTaskSettings);
var overrideTaskSettingsMap = AmazonBedrockEmbeddingsTaskSettingsTests.toMap(overrideTaskSettings);
var overriddenModel = AmazonBedrockEmbeddingsModel.of(baseModel, overrideTaskSettingsMap);
assertThat(overriddenModel.getTaskSettings(), equalTo(overrideTaskSettings));
assertThat(overriddenModel.getTaskSettings(), not(equalTo(baseTaskSettings)));
}
// model creation only - no tests to define, but we want to have the public createModel
@ -46,7 +47,15 @@ public class AmazonBedrockEmbeddingsModelTests extends ESTestCase {
String accessKey,
String secretKey
) {
return createModel(inferenceId, region, model, provider, null, false, null, null, new RateLimitSettings(240), accessKey, secretKey);
return createModel(
inferenceId,
region,
model,
provider,
accessKey,
secretKey,
AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings()
);
}
public static AmazonBedrockEmbeddingsModel createModel(
@ -56,9 +65,22 @@ public class AmazonBedrockEmbeddingsModelTests extends ESTestCase {
AmazonBedrockProvider provider,
String accessKey,
String secretKey,
InputType inputType
AmazonBedrockEmbeddingsTaskSettings taskSettings
) {
return createModel(inferenceId, region, model, provider, null, false, null, null, new RateLimitSettings(240), accessKey, secretKey);
return createModel(
inferenceId,
region,
model,
provider,
null,
false,
null,
null,
new RateLimitSettings(240),
accessKey,
secretKey,
taskSettings
);
}
public static AmazonBedrockEmbeddingsModel createModel(
@ -114,7 +136,7 @@ public class AmazonBedrockEmbeddingsModelTests extends ESTestCase {
similarity,
rateLimitSettings
),
new EmptyTaskSettings(),
AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings(),
chunkingSettings,
new AwsSecretSettings(new SecureString(accessKey), new SecureString(secretKey))
);
@ -132,6 +154,36 @@ public class AmazonBedrockEmbeddingsModelTests extends ESTestCase {
RateLimitSettings rateLimitSettings,
String accessKey,
String secretKey
) {
return createModel(
inferenceId,
region,
model,
provider,
dimensions,
dimensionsSetByUser,
maxTokens,
similarity,
rateLimitSettings,
accessKey,
secretKey,
AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings()
);
}
public static AmazonBedrockEmbeddingsModel createModel(
String inferenceId,
String region,
String model,
AmazonBedrockProvider provider,
@Nullable Integer dimensions,
boolean dimensionsSetByUser,
@Nullable Integer maxTokens,
@Nullable SimilarityMeasure similarity,
RateLimitSettings rateLimitSettings,
String accessKey,
String secretKey,
AmazonBedrockEmbeddingsTaskSettings taskSettings
) {
return new AmazonBedrockEmbeddingsModel(
inferenceId,
@ -147,7 +199,7 @@ public class AmazonBedrockEmbeddingsModelTests extends ESTestCase {
similarity,
rateLimitSettings
),
new EmptyTaskSettings(),
taskSettings,
null,
new AwsSecretSettings(new SecureString(accessKey), new SecureString(secretKey))
);

View file

@ -0,0 +1,112 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.xpack.inference.InputTypeTests.randomWithoutUnspecified;
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TRUNCATE_FIELD;
import static org.hamcrest.Matchers.equalTo;
public class AmazonBedrockEmbeddingsTaskSettingsTests extends AbstractBWCWireSerializationTestCase<AmazonBedrockEmbeddingsTaskSettings> {
public static AmazonBedrockEmbeddingsTaskSettings emptyTaskSettings() {
return AmazonBedrockEmbeddingsTaskSettings.EMPTY;
}
public static AmazonBedrockEmbeddingsTaskSettings randomTaskSettings() {
var inputType = randomBoolean() ? randomWithoutUnspecified() : null;
var truncation = randomBoolean() ? randomFrom(CohereTruncation.values()) : null;
return new AmazonBedrockEmbeddingsTaskSettings(truncation);
}
public static AmazonBedrockEmbeddingsTaskSettings mutateTaskSettings(AmazonBedrockEmbeddingsTaskSettings instance) {
return randomValueOtherThanMany(
v -> Objects.equals(instance, v) || (instance.cohereTruncation() != null && v.cohereTruncation() == null),
AmazonBedrockEmbeddingsTaskSettingsTests::randomTaskSettings
);
}
@Override
protected AmazonBedrockEmbeddingsTaskSettings mutateInstanceForVersion(
AmazonBedrockEmbeddingsTaskSettings instance,
TransportVersion version
) {
return instance;
}
@Override
protected Writeable.Reader<AmazonBedrockEmbeddingsTaskSettings> instanceReader() {
return AmazonBedrockEmbeddingsTaskSettings::new;
}
@Override
protected AmazonBedrockEmbeddingsTaskSettings createTestInstance() {
return randomTaskSettings();
}
@Override
protected AmazonBedrockEmbeddingsTaskSettings mutateInstance(AmazonBedrockEmbeddingsTaskSettings instance) throws IOException {
return mutateTaskSettings(instance);
}
public void testEmpty() {
assertTrue(emptyTaskSettings().isEmpty());
assertTrue(AmazonBedrockEmbeddingsTaskSettings.fromMap(null).isEmpty());
assertTrue(AmazonBedrockEmbeddingsTaskSettings.fromMap(Map.of()).isEmpty());
}
public static Map<String, Object> mutableMap(String key, Enum<?> value) {
return new HashMap<>(Map.of(key, value.toString()));
}
public void testValidCohereTruncations() {
for (var expectedCohereTruncation : CohereTruncation.ALL) {
var map = mutableMap(TRUNCATE_FIELD, expectedCohereTruncation);
var taskSettings = AmazonBedrockEmbeddingsTaskSettings.fromMap(map);
assertFalse(taskSettings.isEmpty());
assertThat(taskSettings.cohereTruncation(), equalTo(expectedCohereTruncation));
}
}
public void testGarbageCohereTruncations() {
var map = new HashMap<String, Object>(Map.of(TRUNCATE_FIELD, "oiuesoirtuoawoeirha"));
assertThrows(ValidationException.class, () -> AmazonBedrockEmbeddingsTaskSettings.fromMap(map));
}
public void testXContent() throws IOException {
var taskSettings = randomTaskSettings();
var taskSettingsAsMap = toMap(taskSettings);
var roundTripTaskSettings = AmazonBedrockEmbeddingsTaskSettings.fromMap(new HashMap<>(taskSettingsAsMap));
assertThat(roundTripTaskSettings, equalTo(taskSettings));
}
public static Map<String, Object> toMap(AmazonBedrockEmbeddingsTaskSettings taskSettings) throws IOException {
try (var builder = JsonXContent.contentBuilder()) {
taskSettings.toXContent(builder, ToXContent.EMPTY_PARAMS);
var taskSettingsBytes = Strings.toString(builder).getBytes(StandardCharsets.UTF_8);
try (var parser = JsonXContent.jsonXContent.createParser(XContentParserConfiguration.EMPTY, taskSettingsBytes)) {
return parser.map();
}
}
}
}