mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 09:28:55 -04:00
[ML] Bedrock Cohere Task Settings Support (#126493)
Add support for Cohere Task Settings and Truncate, through the Amazon Bedrock provider integration. Task Settings can now be passed bother during Inference endpoint creation and Inference POST requests. Close #126156
This commit is contained in:
parent
6d86b202ea
commit
6c6500ec3b
16 changed files with 468 additions and 102 deletions
6
docs/changelog/126493.yaml
Normal file
6
docs/changelog/126493.yaml
Normal file
|
@ -0,0 +1,6 @@
|
|||
pr: 126493
|
||||
summary: Bedrock Cohere Task Settings Support
|
||||
area: Machine Learning
|
||||
type: enhancement
|
||||
issues:
|
||||
- 126156
|
|
@ -157,6 +157,7 @@ public class TransportVersions {
|
|||
public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14);
|
||||
public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15);
|
||||
public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16);
|
||||
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17);
|
||||
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
|
||||
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
|
||||
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
|
||||
|
@ -215,6 +216,7 @@ public class TransportVersions {
|
|||
public static final TransportVersion ADD_PROJECT_ID_TO_DSL_ERROR_INFO = def(9_046_0_00);
|
||||
public static final TransportVersion SEMANTIC_TEXT_CHUNKING_CONFIG = def(9_047_00_0);
|
||||
public static final TransportVersion REPO_ANALYSIS_COPY_BLOB = def(9_048_00_0);
|
||||
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS = def(9_049_00_0);
|
||||
|
||||
/*
|
||||
* STOP! READ THIS FIRST! No, really,
|
||||
|
|
|
@ -41,6 +41,7 @@ import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.Alib
|
|||
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionTaskSettings;
|
||||
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettings;
|
||||
import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionTaskSettings;
|
||||
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionServiceSettings;
|
||||
|
@ -173,8 +174,13 @@ public class InferenceNamedWriteablesProvider {
|
|||
AmazonBedrockEmbeddingsServiceSettings::new
|
||||
)
|
||||
);
|
||||
|
||||
// no task settings for Amazon Bedrock Embeddings
|
||||
namedWriteables.add(
|
||||
new NamedWriteableRegistry.Entry(
|
||||
TaskSettings.class,
|
||||
AmazonBedrockEmbeddingsTaskSettings.NAME,
|
||||
AmazonBedrockEmbeddingsTaskSettings::new
|
||||
)
|
||||
);
|
||||
|
||||
namedWriteables.add(
|
||||
new NamedWriteableRegistry.Entry(
|
||||
|
|
|
@ -19,6 +19,8 @@ public class AmazonBedrockConstants {
|
|||
public static final String TOP_K_FIELD = "top_k";
|
||||
public static final String MAX_NEW_TOKENS_FIELD = "max_new_tokens";
|
||||
|
||||
public static final String TRUNCATE_FIELD = "truncate";
|
||||
|
||||
public static final Double MIN_TEMPERATURE_TOP_P_TOP_K_VALUE = 0.0;
|
||||
public static final Double MAX_TEMPERATURE_TOP_P_TOP_K_VALUE = 1.0;
|
||||
|
||||
|
|
|
@ -303,6 +303,7 @@ public class AmazonBedrockService extends SenderService {
|
|||
context
|
||||
);
|
||||
checkProviderForTask(TaskType.TEXT_EMBEDDING, model.provider());
|
||||
checkTaskSettingsForTextEmbeddingModel(model);
|
||||
return model;
|
||||
}
|
||||
case COMPLETION -> {
|
||||
|
@ -368,6 +369,17 @@ public class AmazonBedrockService extends SenderService {
|
|||
}
|
||||
}
|
||||
|
||||
private static void checkTaskSettingsForTextEmbeddingModel(AmazonBedrockEmbeddingsModel model) {
|
||||
if (model.provider() != AmazonBedrockProvider.COHERE && model.getTaskSettings().cohereTruncation() != null) {
|
||||
throw new ElasticsearchStatusException(
|
||||
"The [{}] task type for provider [{}] does not allow [truncate] field",
|
||||
RestStatus.BAD_REQUEST,
|
||||
TaskType.TEXT_EMBEDDING,
|
||||
model.provider()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private static void checkChatCompletionProviderForTopKParameter(AmazonBedrockChatCompletionModel model) {
|
||||
var taskSettings = model.getTaskSettings();
|
||||
if (taskSettings.topK() != null) {
|
||||
|
|
|
@ -7,14 +7,11 @@
|
|||
|
||||
package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;
|
||||
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.inference.ChunkingSettings;
|
||||
import org.elasticsearch.inference.EmptyTaskSettings;
|
||||
import org.elasticsearch.inference.Model;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ModelSecrets;
|
||||
import org.elasticsearch.inference.ServiceSettings;
|
||||
import org.elasticsearch.inference.TaskSettings;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
|
@ -28,10 +25,8 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel {
|
|||
|
||||
public static AmazonBedrockEmbeddingsModel of(AmazonBedrockEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings) {
|
||||
if (taskSettings != null && taskSettings.isEmpty() == false) {
|
||||
// no task settings allowed
|
||||
var validationException = new ValidationException();
|
||||
validationException.addValidationError("Amazon Bedrock embeddings model cannot have task settings");
|
||||
throw validationException;
|
||||
var updatedTaskSettings = embeddingsModel.getTaskSettings().updatedTaskSettings(taskSettings);
|
||||
return new AmazonBedrockEmbeddingsModel(embeddingsModel, updatedTaskSettings);
|
||||
}
|
||||
|
||||
return embeddingsModel;
|
||||
|
@ -52,7 +47,7 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel {
|
|||
taskType,
|
||||
service,
|
||||
AmazonBedrockEmbeddingsServiceSettings.fromMap(serviceSettings, context),
|
||||
new EmptyTaskSettings(),
|
||||
AmazonBedrockEmbeddingsTaskSettings.fromMap(taskSettings),
|
||||
chunkingSettings,
|
||||
AwsSecretSettings.fromMap(secretSettings)
|
||||
);
|
||||
|
@ -63,12 +58,12 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel {
|
|||
TaskType taskType,
|
||||
String service,
|
||||
AmazonBedrockEmbeddingsServiceSettings serviceSettings,
|
||||
TaskSettings taskSettings,
|
||||
AmazonBedrockEmbeddingsTaskSettings taskSettings,
|
||||
ChunkingSettings chunkingSettings,
|
||||
AwsSecretSettings secrets
|
||||
) {
|
||||
super(
|
||||
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings(), chunkingSettings),
|
||||
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
|
||||
new ModelSecrets(secrets)
|
||||
);
|
||||
}
|
||||
|
@ -77,6 +72,10 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel {
|
|||
super(model, serviceSettings);
|
||||
}
|
||||
|
||||
public AmazonBedrockEmbeddingsModel(Model model, AmazonBedrockEmbeddingsTaskSettings taskSettings) {
|
||||
super(model, taskSettings);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ExecutableAction accept(AmazonBedrockActionVisitor creator, Map<String, Object> taskSettings) {
|
||||
return creator.create(this, taskSettings);
|
||||
|
@ -86,4 +85,9 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel {
|
|||
public AmazonBedrockEmbeddingsServiceSettings getServiceSettings() {
|
||||
return (AmazonBedrockEmbeddingsServiceSettings) super.getServiceSettings();
|
||||
}
|
||||
|
||||
@Override
|
||||
public AmazonBedrockEmbeddingsTaskSettings getTaskSettings() {
|
||||
return (AmazonBedrockEmbeddingsTaskSettings) super.getTaskSettings();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;
|
||||
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.TransportVersions;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.TaskSettings;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
|
||||
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TRUNCATE_FIELD;
|
||||
|
||||
public record AmazonBedrockEmbeddingsTaskSettings(@Nullable CohereTruncation cohereTruncation) implements TaskSettings {
|
||||
public static final AmazonBedrockEmbeddingsTaskSettings EMPTY = new AmazonBedrockEmbeddingsTaskSettings((CohereTruncation) null);
|
||||
public static final String NAME = "amazon_bedrock_embeddings_task_settings";
|
||||
|
||||
public static AmazonBedrockEmbeddingsTaskSettings fromMap(Map<String, Object> map) {
|
||||
if (map == null || map.isEmpty()) {
|
||||
return EMPTY;
|
||||
}
|
||||
|
||||
ValidationException validationException = new ValidationException();
|
||||
|
||||
var cohereTruncation = extractOptionalEnum(
|
||||
map,
|
||||
TRUNCATE_FIELD,
|
||||
ModelConfigurations.TASK_SETTINGS,
|
||||
CohereTruncation::fromString,
|
||||
CohereTruncation.ALL,
|
||||
validationException
|
||||
);
|
||||
|
||||
if (validationException.validationErrors().isEmpty() == false) {
|
||||
throw validationException;
|
||||
}
|
||||
|
||||
return new AmazonBedrockEmbeddingsTaskSettings(cohereTruncation);
|
||||
}
|
||||
|
||||
public AmazonBedrockEmbeddingsTaskSettings(StreamInput in) throws IOException {
|
||||
this(in.readOptionalEnum(CohereTruncation.class));
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isEmpty() {
|
||||
return cohereTruncation() == null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AmazonBedrockEmbeddingsTaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
|
||||
var newTaskSettings = fromMap(new HashMap<>(newSettings));
|
||||
|
||||
return new AmazonBedrockEmbeddingsTaskSettings(firstNonNullOrNull(newTaskSettings.cohereTruncation(), cohereTruncation()));
|
||||
}
|
||||
|
||||
private static <T> T firstNonNullOrNull(T first, T second) {
|
||||
return first != null ? first : second;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
return TransportVersions.AMAZON_BEDROCK_TASK_SETTINGS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeOptionalEnum(cohereTruncation());
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (cohereTruncation != null) {
|
||||
builder.field(TRUNCATE_FIELD, cohereTruncation);
|
||||
}
|
||||
return builder.endObject();
|
||||
}
|
||||
}
|
|
@ -11,6 +11,7 @@ import org.elasticsearch.core.Nullable;
|
|||
import org.elasticsearch.inference.InputType;
|
||||
import org.elasticsearch.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettings;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
@ -18,7 +19,11 @@ import java.util.Objects;
|
|||
|
||||
import static org.elasticsearch.inference.InputType.invalidInputTypeMessage;
|
||||
|
||||
public record AmazonBedrockCohereEmbeddingsRequestEntity(List<String> input, @Nullable InputType inputType) implements ToXContentObject {
|
||||
public record AmazonBedrockCohereEmbeddingsRequestEntity(
|
||||
List<String> input,
|
||||
@Nullable InputType inputType,
|
||||
AmazonBedrockEmbeddingsTaskSettings taskSettings
|
||||
) implements ToXContentObject {
|
||||
|
||||
private static final String TEXTS_FIELD = "texts";
|
||||
private static final String INPUT_TYPE_FIELD = "input_type";
|
||||
|
@ -26,9 +31,11 @@ public record AmazonBedrockCohereEmbeddingsRequestEntity(List<String> input, @Nu
|
|||
private static final String SEARCH_QUERY = "search_query";
|
||||
private static final String CLUSTERING = "clustering";
|
||||
private static final String CLASSIFICATION = "classification";
|
||||
private static final String TRUNCATE = "truncate";
|
||||
|
||||
public AmazonBedrockCohereEmbeddingsRequestEntity {
|
||||
Objects.requireNonNull(input);
|
||||
Objects.requireNonNull(taskSettings);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -43,6 +50,10 @@ public record AmazonBedrockCohereEmbeddingsRequestEntity(List<String> input, @Nu
|
|||
builder.field(INPUT_TYPE_FIELD, SEARCH_DOCUMENT);
|
||||
}
|
||||
|
||||
if (taskSettings.cohereTruncation() != null) {
|
||||
builder.field(TRUNCATE, taskSettings.cohereTruncation().name());
|
||||
}
|
||||
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
|
|
@ -39,7 +39,7 @@ public final class AmazonBedrockEmbeddingsEntityFactory {
|
|||
return new AmazonBedrockTitanEmbeddingsRequestEntity(truncatedInput.get(0));
|
||||
}
|
||||
case COHERE -> {
|
||||
return new AmazonBedrockCohereEmbeddingsRequestEntity(truncatedInput, inputType);
|
||||
return new AmazonBedrockCohereEmbeddingsRequestEntity(truncatedInput, inputType, model.getTaskSettings());
|
||||
}
|
||||
default -> {
|
||||
return null;
|
||||
|
|
|
@ -76,6 +76,9 @@ public class AmazonBedrockEmbeddingsRequest extends AmazonBedrockRequest {
|
|||
|
||||
@Override
|
||||
public Request truncate() {
|
||||
if (provider == AmazonBedrockProvider.COHERE) {
|
||||
return this; // Cohere has its own truncation logic
|
||||
}
|
||||
var truncatedInput = truncator.truncate(truncationResult.input());
|
||||
return new AmazonBedrockEmbeddingsRequest(truncator, truncatedInput, embeddingsModel, requestEntity, timeout);
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
|
||||
package org.elasticsearch.xpack.inference.services.cohere;
|
||||
|
||||
import java.util.EnumSet;
|
||||
import java.util.Locale;
|
||||
|
||||
/**
|
||||
|
@ -31,6 +32,8 @@ public enum CohereTruncation {
|
|||
*/
|
||||
END;
|
||||
|
||||
public static final EnumSet<CohereTruncation> ALL = EnumSet.allOf(CohereTruncation.class);
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return name().toLowerCase(Locale.ROOT);
|
||||
|
|
|
@ -20,7 +20,6 @@ import org.elasticsearch.xcontent.XContentBuilder;
|
|||
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.EnumSet;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
@ -63,7 +62,7 @@ public class CohereEmbeddingsTaskSettings implements TaskSettings {
|
|||
TRUNCATE,
|
||||
ModelConfigurations.TASK_SETTINGS,
|
||||
CohereTruncation::fromString,
|
||||
EnumSet.allOf(CohereTruncation.class),
|
||||
CohereTruncation.ALL,
|
||||
validationException
|
||||
);
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeExcept
|
|||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.support.ActionTestUtils;
|
||||
import org.elasticsearch.action.support.PlainActionFuture;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.bytes.BytesArray;
|
||||
|
@ -51,6 +52,8 @@ import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.Amazo
|
|||
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModel;
|
||||
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModelTests;
|
||||
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettingsTests;
|
||||
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
|
||||
import org.hamcrest.CoreMatchers;
|
||||
import org.hamcrest.Matchers;
|
||||
|
@ -105,7 +108,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
|
||||
public void testParseRequestConfig_CreatesAnAmazonBedrockModel() throws IOException {
|
||||
try (var service = createAmazonBedrockService()) {
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(model -> {
|
||||
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> {
|
||||
assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class));
|
||||
|
||||
var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings();
|
||||
|
@ -115,7 +118,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
var secretSettings = (AwsSecretSettings) model.getSecretSettings();
|
||||
assertThat(secretSettings.accessKey().toString(), is("access"));
|
||||
assertThat(secretSettings.secretKey().toString(), is("secret"));
|
||||
}, exception -> fail("Unexpected exception: " + exception));
|
||||
});
|
||||
|
||||
service.parseRequestConfig(
|
||||
"id",
|
||||
|
@ -130,15 +133,62 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testParseRequestConfig_CreatesACohereModel() throws IOException {
|
||||
try (var service = createAmazonBedrockService()) {
|
||||
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> {
|
||||
assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class));
|
||||
|
||||
var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings();
|
||||
assertThat(settings.region(), is("region"));
|
||||
assertThat(settings.modelId(), is("model"));
|
||||
assertThat(settings.provider(), is(AmazonBedrockProvider.COHERE));
|
||||
var secretSettings = (AwsSecretSettings) model.getSecretSettings();
|
||||
assertThat(secretSettings.accessKey().toString(), is("access"));
|
||||
assertThat(secretSettings.secretKey().toString(), is("secret"));
|
||||
});
|
||||
|
||||
service.parseRequestConfig(
|
||||
"id",
|
||||
TaskType.TEXT_EMBEDDING,
|
||||
getRequestConfigMap(
|
||||
createEmbeddingsRequestSettingsMap("region", "model", "cohere", null, null, null, null),
|
||||
AmazonBedrockEmbeddingsTaskSettingsTests.mutableMap("truncate", CohereTruncation.START),
|
||||
getAmazonBedrockSecretSettingsMap("access", "secret")
|
||||
),
|
||||
modelVerificationListener
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
public void testParseRequestConfig_CohereSettingsWithNoCohereModel() throws IOException {
|
||||
try (var service = createAmazonBedrockService()) {
|
||||
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> {
|
||||
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
|
||||
assertThat(
|
||||
exception.getMessage(),
|
||||
is("The [text_embedding] task type for provider [amazontitan] does not allow [truncate] field")
|
||||
);
|
||||
});
|
||||
|
||||
service.parseRequestConfig(
|
||||
"id",
|
||||
TaskType.TEXT_EMBEDDING,
|
||||
getRequestConfigMap(
|
||||
createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, null, null, null),
|
||||
AmazonBedrockEmbeddingsTaskSettingsTests.mutableMap("truncate", CohereTruncation.START),
|
||||
getAmazonBedrockSecretSettingsMap("access", "secret")
|
||||
),
|
||||
modelVerificationListener
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException {
|
||||
try (var service = createAmazonBedrockService()) {
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
||||
model -> fail("Expected exception, but got model: " + model),
|
||||
exception -> {
|
||||
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> {
|
||||
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
|
||||
assertThat(exception.getMessage(), is("The [amazonbedrock] service does not support task type [sparse_embedding]"));
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
service.parseRequestConfig(
|
||||
"id",
|
||||
|
@ -247,13 +297,10 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
|
||||
public void testCreateModel_ForEmbeddingsTask_InvalidProvider() throws IOException {
|
||||
try (var service = createAmazonBedrockService()) {
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
||||
model -> fail("Expected exception, but got model: " + model),
|
||||
exception -> {
|
||||
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> {
|
||||
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
|
||||
assertThat(exception.getMessage(), is("The [text_embedding] task type for provider [anthropic] is not available"));
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
service.parseRequestConfig(
|
||||
"id",
|
||||
|
@ -270,13 +317,10 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
|
||||
public void testCreateModel_TopKParameter_NotAvailable() throws IOException {
|
||||
try (var service = createAmazonBedrockService()) {
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
||||
model -> fail("Expected exception, but got model: " + model),
|
||||
exception -> {
|
||||
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> {
|
||||
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
|
||||
assertThat(exception.getMessage(), is("The [top_k] task parameter is not available for provider [amazontitan]"));
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
service.parseRequestConfig(
|
||||
"id",
|
||||
|
@ -301,16 +345,13 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
|
||||
config.put("extra_key", "value");
|
||||
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
||||
model -> fail("Expected exception, but got model: " + model),
|
||||
exception -> {
|
||||
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> {
|
||||
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
|
||||
assertThat(
|
||||
exception.getMessage(),
|
||||
is("Model configuration contains settings [{extra_key=value}] unknown to the [amazonbedrock] service")
|
||||
);
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener);
|
||||
}
|
||||
|
@ -323,9 +364,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
|
||||
var config = getRequestConfigMap(serviceSettings, Map.of(), getAmazonBedrockSecretSettingsMap("access", "secret"));
|
||||
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.<Model>wrap((model) -> {
|
||||
fail("Expected exception, but got model: " + model);
|
||||
}, e -> {
|
||||
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(e -> {
|
||||
assertThat(e, instanceOf(ElasticsearchStatusException.class));
|
||||
assertThat(
|
||||
e.getMessage(),
|
||||
|
@ -347,9 +386,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
|
||||
var config = getRequestConfigMap(settingsMap, taskSettingsMap, secretSettingsMap);
|
||||
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.<Model>wrap((model) -> {
|
||||
fail("Expected exception, but got model: " + model);
|
||||
}, e -> {
|
||||
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(e -> {
|
||||
assertThat(e, instanceOf(ElasticsearchStatusException.class));
|
||||
assertThat(
|
||||
e.getMessage(),
|
||||
|
@ -371,9 +408,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
|
||||
var config = getRequestConfigMap(settingsMap, taskSettingsMap, secretSettingsMap);
|
||||
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.<Model>wrap((model) -> {
|
||||
fail("Expected exception, but got model: " + model);
|
||||
}, e -> {
|
||||
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(e -> {
|
||||
assertThat(e, instanceOf(ElasticsearchStatusException.class));
|
||||
assertThat(
|
||||
e.getMessage(),
|
||||
|
@ -387,7 +422,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
|
||||
public void testParseRequestConfig_MovesModel() throws IOException {
|
||||
try (var service = createAmazonBedrockService()) {
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(model -> {
|
||||
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> {
|
||||
assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class));
|
||||
|
||||
var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings();
|
||||
|
@ -397,7 +432,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
var secretSettings = (AwsSecretSettings) model.getSecretSettings();
|
||||
assertThat(secretSettings.accessKey().toString(), is("access"));
|
||||
assertThat(secretSettings.secretKey().toString(), is("secret"));
|
||||
}, exception -> fail("Unexpected exception: " + exception));
|
||||
});
|
||||
|
||||
service.parseRequestConfig(
|
||||
"id",
|
||||
|
@ -414,7 +449,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
|
||||
public void testParseRequestConfig_CreatesAnAmazonBedrockEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
|
||||
try (var service = createAmazonBedrockService()) {
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(model -> {
|
||||
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> {
|
||||
assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class));
|
||||
|
||||
var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings();
|
||||
|
@ -425,7 +460,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
assertThat(secretSettings.accessKey().toString(), is("access"));
|
||||
assertThat(secretSettings.secretKey().toString(), is("secret"));
|
||||
assertThat(model.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
|
||||
}, exception -> fail("Unexpected exception: " + exception));
|
||||
});
|
||||
|
||||
service.parseRequestConfig(
|
||||
"id",
|
||||
|
@ -443,7 +478,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
|
||||
public void testParseRequestConfig_CreatesAnAmazonBedrockEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException {
|
||||
try (var service = createAmazonBedrockService()) {
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(model -> {
|
||||
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> {
|
||||
assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class));
|
||||
|
||||
var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings();
|
||||
|
@ -454,7 +489,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
assertThat(secretSettings.accessKey().toString(), is("access"));
|
||||
assertThat(secretSettings.secretKey().toString(), is("secret"));
|
||||
assertThat(model.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
|
||||
}, exception -> fail("Unexpected exception: " + exception));
|
||||
});
|
||||
|
||||
service.parseRequestConfig(
|
||||
"id",
|
||||
|
@ -471,13 +506,10 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
|
||||
public void testCreateModel_ForEmbeddingsTask_DimensionsIsNotAllowed() throws IOException {
|
||||
try (var service = createAmazonBedrockService()) {
|
||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
||||
model -> fail("Expected exception, but got model: " + model),
|
||||
exception -> {
|
||||
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> {
|
||||
assertThat(exception, instanceOf(ValidationException.class));
|
||||
assertThat(exception.getMessage(), containsString("[service_settings] does not allow the setting [dimensions]"));
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
service.parseRequestConfig(
|
||||
"id",
|
||||
|
@ -497,7 +529,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
|
||||
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
|
||||
|
||||
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap);
|
||||
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
|
||||
|
||||
var model = service.parsePersistedConfigWithSecrets(
|
||||
"id",
|
||||
|
@ -525,7 +557,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
|
||||
var persistedConfig = getPersistedConfigMap(
|
||||
settingsMap,
|
||||
new HashMap<String, Object>(Map.of()),
|
||||
new HashMap<>(Map.of()),
|
||||
createRandomChunkingSettingsMap(),
|
||||
secretSettingsMap
|
||||
);
|
||||
|
@ -607,7 +639,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
|
||||
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
|
||||
|
||||
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap);
|
||||
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
|
||||
persistedConfig.config().put("extra_key", "value");
|
||||
|
||||
var model = service.parsePersistedConfigWithSecrets(
|
||||
|
@ -635,7 +667,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
|
||||
secretSettingsMap.put("extra_key", "value");
|
||||
|
||||
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap);
|
||||
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
|
||||
|
||||
var model = service.parsePersistedConfigWithSecrets(
|
||||
"id",
|
||||
|
@ -661,7 +693,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
|
||||
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
|
||||
|
||||
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap);
|
||||
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
|
||||
persistedConfig.secrets().put("extra_key", "value");
|
||||
|
||||
var model = service.parsePersistedConfigWithSecrets(
|
||||
|
@ -689,7 +721,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
settingsMap.put("extra_key", "value");
|
||||
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
|
||||
|
||||
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap);
|
||||
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
|
||||
|
||||
var model = service.parsePersistedConfigWithSecrets(
|
||||
"id",
|
||||
|
@ -769,7 +801,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
|
||||
var persistedConfig = getPersistedConfigMap(
|
||||
settingsMap,
|
||||
new HashMap<String, Object>(Map.of()),
|
||||
new HashMap<>(Map.of()),
|
||||
createRandomChunkingSettingsMap(),
|
||||
secretSettingsMap
|
||||
);
|
||||
|
@ -792,7 +824,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
|
||||
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
|
||||
|
||||
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap);
|
||||
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
|
||||
|
||||
var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
|
||||
|
||||
|
@ -836,7 +868,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
|
||||
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
|
||||
|
||||
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap);
|
||||
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
|
||||
|
||||
var thrownException = expectThrows(
|
||||
ElasticsearchStatusException.class,
|
||||
|
@ -855,7 +887,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
|
||||
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
|
||||
|
||||
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap);
|
||||
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
|
||||
persistedConfig.config().put("extra_key", "value");
|
||||
|
||||
var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
|
||||
|
@ -876,7 +908,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
settingsMap.put("extra_key", "value");
|
||||
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
|
||||
|
||||
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap);
|
||||
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
|
||||
persistedConfig.config().put("extra_key", "value");
|
||||
|
||||
var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
|
||||
|
|
|
@ -11,6 +11,7 @@ import org.elasticsearch.inference.InputType;
|
|||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.AmazonBedrockJsonBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.embeddings.AmazonBedrockCohereEmbeddingsRequestEntity;
|
||||
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
@ -19,23 +20,46 @@ import static org.hamcrest.Matchers.is;
|
|||
|
||||
public class AmazonBedrockCohereEmbeddingsRequestEntityTests extends ESTestCase {
|
||||
public void testRequestEntity_GeneratesExpectedJsonBody() throws IOException {
|
||||
var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(List.of("test input"), InputType.CLASSIFICATION);
|
||||
var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(
|
||||
List.of("test input"),
|
||||
InputType.CLASSIFICATION,
|
||||
AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings()
|
||||
);
|
||||
var builder = new AmazonBedrockJsonBuilder(entity);
|
||||
var result = builder.getStringContent();
|
||||
assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"classification\"}"));
|
||||
}
|
||||
|
||||
public void testRequestEntity_GeneratesExpectedJsonBody_WithInternalInputType() throws IOException {
|
||||
var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(List.of("test input"), InputType.INTERNAL_SEARCH);
|
||||
var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(
|
||||
List.of("test input"),
|
||||
InputType.INTERNAL_SEARCH,
|
||||
AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings()
|
||||
);
|
||||
var builder = new AmazonBedrockJsonBuilder(entity);
|
||||
var result = builder.getStringContent();
|
||||
assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"search_query\"}"));
|
||||
}
|
||||
|
||||
public void testRequestEntity_GeneratesExpectedJsonBody_WithoutInputType() throws IOException {
|
||||
var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(List.of("test input"), null);
|
||||
var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(
|
||||
List.of("test input"),
|
||||
null,
|
||||
AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings()
|
||||
);
|
||||
var builder = new AmazonBedrockJsonBuilder(entity);
|
||||
var result = builder.getStringContent();
|
||||
assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"search_document\"}"));
|
||||
}
|
||||
|
||||
public void testRequestEntity_GeneratesExpectedJsonBody_WithCohereTruncation() throws IOException {
|
||||
var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(
|
||||
List.of("test input"),
|
||||
null,
|
||||
new AmazonBedrockEmbeddingsTaskSettings(CohereTruncation.START)
|
||||
);
|
||||
var builder = new AmazonBedrockJsonBuilder(entity);
|
||||
var result = builder.getStringContent();
|
||||
assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"search_document\",\"truncate\":\"START\"}"));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,12 +7,9 @@
|
|||
|
||||
package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;
|
||||
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.settings.SecureString;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.ChunkingSettings;
|
||||
import org.elasticsearch.inference.EmptyTaskSettings;
|
||||
import org.elasticsearch.inference.InputType;
|
||||
import org.elasticsearch.inference.SimilarityMeasure;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
@ -20,19 +17,23 @@ import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
|
|||
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
import java.util.Map;
|
||||
import java.io.IOException;
|
||||
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
|
||||
public class AmazonBedrockEmbeddingsModelTests extends ESTestCase {
|
||||
|
||||
public void testCreateModel_withTaskSettings_shouldFail() {
|
||||
var baseModel = createModel("id", "region", "model", AmazonBedrockProvider.AMAZONTITAN, "accesskey", "secretkey");
|
||||
var thrownException = assertThrows(
|
||||
ValidationException.class,
|
||||
() -> AmazonBedrockEmbeddingsModel.of(baseModel, Map.of("testkey", "testvalue"))
|
||||
);
|
||||
assertThat(thrownException.getMessage(), containsString("Amazon Bedrock embeddings model cannot have task settings"));
|
||||
public void testCreateModel_withTaskSettingsOverride() throws IOException {
|
||||
var baseTaskSettings = AmazonBedrockEmbeddingsTaskSettingsTests.randomTaskSettings();
|
||||
var baseModel = createModel("id", "region", "model", AmazonBedrockProvider.AMAZONTITAN, "accesskey", "secretkey", baseTaskSettings);
|
||||
|
||||
var overrideTaskSettings = AmazonBedrockEmbeddingsTaskSettingsTests.mutateTaskSettings(baseTaskSettings);
|
||||
var overrideTaskSettingsMap = AmazonBedrockEmbeddingsTaskSettingsTests.toMap(overrideTaskSettings);
|
||||
|
||||
var overriddenModel = AmazonBedrockEmbeddingsModel.of(baseModel, overrideTaskSettingsMap);
|
||||
assertThat(overriddenModel.getTaskSettings(), equalTo(overrideTaskSettings));
|
||||
assertThat(overriddenModel.getTaskSettings(), not(equalTo(baseTaskSettings)));
|
||||
}
|
||||
|
||||
// model creation only - no tests to define, but we want to have the public createModel
|
||||
|
@ -46,7 +47,15 @@ public class AmazonBedrockEmbeddingsModelTests extends ESTestCase {
|
|||
String accessKey,
|
||||
String secretKey
|
||||
) {
|
||||
return createModel(inferenceId, region, model, provider, null, false, null, null, new RateLimitSettings(240), accessKey, secretKey);
|
||||
return createModel(
|
||||
inferenceId,
|
||||
region,
|
||||
model,
|
||||
provider,
|
||||
accessKey,
|
||||
secretKey,
|
||||
AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings()
|
||||
);
|
||||
}
|
||||
|
||||
public static AmazonBedrockEmbeddingsModel createModel(
|
||||
|
@ -56,9 +65,22 @@ public class AmazonBedrockEmbeddingsModelTests extends ESTestCase {
|
|||
AmazonBedrockProvider provider,
|
||||
String accessKey,
|
||||
String secretKey,
|
||||
InputType inputType
|
||||
AmazonBedrockEmbeddingsTaskSettings taskSettings
|
||||
) {
|
||||
return createModel(inferenceId, region, model, provider, null, false, null, null, new RateLimitSettings(240), accessKey, secretKey);
|
||||
return createModel(
|
||||
inferenceId,
|
||||
region,
|
||||
model,
|
||||
provider,
|
||||
null,
|
||||
false,
|
||||
null,
|
||||
null,
|
||||
new RateLimitSettings(240),
|
||||
accessKey,
|
||||
secretKey,
|
||||
taskSettings
|
||||
);
|
||||
}
|
||||
|
||||
public static AmazonBedrockEmbeddingsModel createModel(
|
||||
|
@ -114,7 +136,7 @@ public class AmazonBedrockEmbeddingsModelTests extends ESTestCase {
|
|||
similarity,
|
||||
rateLimitSettings
|
||||
),
|
||||
new EmptyTaskSettings(),
|
||||
AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings(),
|
||||
chunkingSettings,
|
||||
new AwsSecretSettings(new SecureString(accessKey), new SecureString(secretKey))
|
||||
);
|
||||
|
@ -132,6 +154,36 @@ public class AmazonBedrockEmbeddingsModelTests extends ESTestCase {
|
|||
RateLimitSettings rateLimitSettings,
|
||||
String accessKey,
|
||||
String secretKey
|
||||
) {
|
||||
return createModel(
|
||||
inferenceId,
|
||||
region,
|
||||
model,
|
||||
provider,
|
||||
dimensions,
|
||||
dimensionsSetByUser,
|
||||
maxTokens,
|
||||
similarity,
|
||||
rateLimitSettings,
|
||||
accessKey,
|
||||
secretKey,
|
||||
AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings()
|
||||
);
|
||||
}
|
||||
|
||||
public static AmazonBedrockEmbeddingsModel createModel(
|
||||
String inferenceId,
|
||||
String region,
|
||||
String model,
|
||||
AmazonBedrockProvider provider,
|
||||
@Nullable Integer dimensions,
|
||||
boolean dimensionsSetByUser,
|
||||
@Nullable Integer maxTokens,
|
||||
@Nullable SimilarityMeasure similarity,
|
||||
RateLimitSettings rateLimitSettings,
|
||||
String accessKey,
|
||||
String secretKey,
|
||||
AmazonBedrockEmbeddingsTaskSettings taskSettings
|
||||
) {
|
||||
return new AmazonBedrockEmbeddingsModel(
|
||||
inferenceId,
|
||||
|
@ -147,7 +199,7 @@ public class AmazonBedrockEmbeddingsModelTests extends ESTestCase {
|
|||
similarity,
|
||||
rateLimitSettings
|
||||
),
|
||||
new EmptyTaskSettings(),
|
||||
taskSettings,
|
||||
null,
|
||||
new AwsSecretSettings(new SecureString(accessKey), new SecureString(secretKey))
|
||||
);
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;
|
||||
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.xcontent.ToXContent;
|
||||
import org.elasticsearch.xcontent.XContentParserConfiguration;
|
||||
import org.elasticsearch.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
|
||||
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.InputTypeTests.randomWithoutUnspecified;
|
||||
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TRUNCATE_FIELD;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
|
||||
public class AmazonBedrockEmbeddingsTaskSettingsTests extends AbstractBWCWireSerializationTestCase<AmazonBedrockEmbeddingsTaskSettings> {
|
||||
|
||||
public static AmazonBedrockEmbeddingsTaskSettings emptyTaskSettings() {
|
||||
return AmazonBedrockEmbeddingsTaskSettings.EMPTY;
|
||||
}
|
||||
|
||||
public static AmazonBedrockEmbeddingsTaskSettings randomTaskSettings() {
|
||||
var inputType = randomBoolean() ? randomWithoutUnspecified() : null;
|
||||
var truncation = randomBoolean() ? randomFrom(CohereTruncation.values()) : null;
|
||||
return new AmazonBedrockEmbeddingsTaskSettings(truncation);
|
||||
}
|
||||
|
||||
public static AmazonBedrockEmbeddingsTaskSettings mutateTaskSettings(AmazonBedrockEmbeddingsTaskSettings instance) {
|
||||
return randomValueOtherThanMany(
|
||||
v -> Objects.equals(instance, v) || (instance.cohereTruncation() != null && v.cohereTruncation() == null),
|
||||
AmazonBedrockEmbeddingsTaskSettingsTests::randomTaskSettings
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AmazonBedrockEmbeddingsTaskSettings mutateInstanceForVersion(
|
||||
AmazonBedrockEmbeddingsTaskSettings instance,
|
||||
TransportVersion version
|
||||
) {
|
||||
return instance;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<AmazonBedrockEmbeddingsTaskSettings> instanceReader() {
|
||||
return AmazonBedrockEmbeddingsTaskSettings::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AmazonBedrockEmbeddingsTaskSettings createTestInstance() {
|
||||
return randomTaskSettings();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected AmazonBedrockEmbeddingsTaskSettings mutateInstance(AmazonBedrockEmbeddingsTaskSettings instance) throws IOException {
|
||||
return mutateTaskSettings(instance);
|
||||
}
|
||||
|
||||
public void testEmpty() {
|
||||
assertTrue(emptyTaskSettings().isEmpty());
|
||||
assertTrue(AmazonBedrockEmbeddingsTaskSettings.fromMap(null).isEmpty());
|
||||
assertTrue(AmazonBedrockEmbeddingsTaskSettings.fromMap(Map.of()).isEmpty());
|
||||
}
|
||||
|
||||
public static Map<String, Object> mutableMap(String key, Enum<?> value) {
|
||||
return new HashMap<>(Map.of(key, value.toString()));
|
||||
}
|
||||
|
||||
public void testValidCohereTruncations() {
|
||||
for (var expectedCohereTruncation : CohereTruncation.ALL) {
|
||||
var map = mutableMap(TRUNCATE_FIELD, expectedCohereTruncation);
|
||||
var taskSettings = AmazonBedrockEmbeddingsTaskSettings.fromMap(map);
|
||||
assertFalse(taskSettings.isEmpty());
|
||||
assertThat(taskSettings.cohereTruncation(), equalTo(expectedCohereTruncation));
|
||||
}
|
||||
}
|
||||
|
||||
public void testGarbageCohereTruncations() {
|
||||
var map = new HashMap<String, Object>(Map.of(TRUNCATE_FIELD, "oiuesoirtuoawoeirha"));
|
||||
assertThrows(ValidationException.class, () -> AmazonBedrockEmbeddingsTaskSettings.fromMap(map));
|
||||
}
|
||||
|
||||
public void testXContent() throws IOException {
|
||||
var taskSettings = randomTaskSettings();
|
||||
var taskSettingsAsMap = toMap(taskSettings);
|
||||
var roundTripTaskSettings = AmazonBedrockEmbeddingsTaskSettings.fromMap(new HashMap<>(taskSettingsAsMap));
|
||||
assertThat(roundTripTaskSettings, equalTo(taskSettings));
|
||||
}
|
||||
|
||||
public static Map<String, Object> toMap(AmazonBedrockEmbeddingsTaskSettings taskSettings) throws IOException {
|
||||
try (var builder = JsonXContent.contentBuilder()) {
|
||||
taskSettings.toXContent(builder, ToXContent.EMPTY_PARAMS);
|
||||
var taskSettingsBytes = Strings.toString(builder).getBytes(StandardCharsets.UTF_8);
|
||||
try (var parser = JsonXContent.jsonXContent.createParser(XContentParserConfiguration.EMPTY, taskSettingsBytes)) {
|
||||
return parser.map();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue