mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 09:28:55 -04:00
[ML] Bedrock Cohere Task Settings Support (#126493)
Add support for Cohere Task Settings and Truncate, through the Amazon Bedrock provider integration. Task Settings can now be passed bother during Inference endpoint creation and Inference POST requests. Close #126156
This commit is contained in:
parent
6d86b202ea
commit
6c6500ec3b
16 changed files with 468 additions and 102 deletions
6
docs/changelog/126493.yaml
Normal file
6
docs/changelog/126493.yaml
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
pr: 126493
|
||||||
|
summary: Bedrock Cohere Task Settings Support
|
||||||
|
area: Machine Learning
|
||||||
|
type: enhancement
|
||||||
|
issues:
|
||||||
|
- 126156
|
|
@ -157,6 +157,7 @@ public class TransportVersions {
|
||||||
public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14);
|
public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14);
|
||||||
public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15);
|
public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15);
|
||||||
public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16);
|
public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16);
|
||||||
|
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17);
|
||||||
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
|
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
|
||||||
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
|
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
|
||||||
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
|
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
|
||||||
|
@ -215,6 +216,7 @@ public class TransportVersions {
|
||||||
public static final TransportVersion ADD_PROJECT_ID_TO_DSL_ERROR_INFO = def(9_046_0_00);
|
public static final TransportVersion ADD_PROJECT_ID_TO_DSL_ERROR_INFO = def(9_046_0_00);
|
||||||
public static final TransportVersion SEMANTIC_TEXT_CHUNKING_CONFIG = def(9_047_00_0);
|
public static final TransportVersion SEMANTIC_TEXT_CHUNKING_CONFIG = def(9_047_00_0);
|
||||||
public static final TransportVersion REPO_ANALYSIS_COPY_BLOB = def(9_048_00_0);
|
public static final TransportVersion REPO_ANALYSIS_COPY_BLOB = def(9_048_00_0);
|
||||||
|
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS = def(9_049_00_0);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* STOP! READ THIS FIRST! No, really,
|
* STOP! READ THIS FIRST! No, really,
|
||||||
|
|
|
@ -41,6 +41,7 @@ import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.Alib
|
||||||
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionServiceSettings;
|
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionServiceSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionTaskSettings;
|
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionTaskSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings;
|
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionServiceSettings;
|
import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionServiceSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionTaskSettings;
|
import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionTaskSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionServiceSettings;
|
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionServiceSettings;
|
||||||
|
@ -173,8 +174,13 @@ public class InferenceNamedWriteablesProvider {
|
||||||
AmazonBedrockEmbeddingsServiceSettings::new
|
AmazonBedrockEmbeddingsServiceSettings::new
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
namedWriteables.add(
|
||||||
// no task settings for Amazon Bedrock Embeddings
|
new NamedWriteableRegistry.Entry(
|
||||||
|
TaskSettings.class,
|
||||||
|
AmazonBedrockEmbeddingsTaskSettings.NAME,
|
||||||
|
AmazonBedrockEmbeddingsTaskSettings::new
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
namedWriteables.add(
|
namedWriteables.add(
|
||||||
new NamedWriteableRegistry.Entry(
|
new NamedWriteableRegistry.Entry(
|
||||||
|
|
|
@ -19,6 +19,8 @@ public class AmazonBedrockConstants {
|
||||||
public static final String TOP_K_FIELD = "top_k";
|
public static final String TOP_K_FIELD = "top_k";
|
||||||
public static final String MAX_NEW_TOKENS_FIELD = "max_new_tokens";
|
public static final String MAX_NEW_TOKENS_FIELD = "max_new_tokens";
|
||||||
|
|
||||||
|
public static final String TRUNCATE_FIELD = "truncate";
|
||||||
|
|
||||||
public static final Double MIN_TEMPERATURE_TOP_P_TOP_K_VALUE = 0.0;
|
public static final Double MIN_TEMPERATURE_TOP_P_TOP_K_VALUE = 0.0;
|
||||||
public static final Double MAX_TEMPERATURE_TOP_P_TOP_K_VALUE = 1.0;
|
public static final Double MAX_TEMPERATURE_TOP_P_TOP_K_VALUE = 1.0;
|
||||||
|
|
||||||
|
|
|
@ -303,6 +303,7 @@ public class AmazonBedrockService extends SenderService {
|
||||||
context
|
context
|
||||||
);
|
);
|
||||||
checkProviderForTask(TaskType.TEXT_EMBEDDING, model.provider());
|
checkProviderForTask(TaskType.TEXT_EMBEDDING, model.provider());
|
||||||
|
checkTaskSettingsForTextEmbeddingModel(model);
|
||||||
return model;
|
return model;
|
||||||
}
|
}
|
||||||
case COMPLETION -> {
|
case COMPLETION -> {
|
||||||
|
@ -368,6 +369,17 @@ public class AmazonBedrockService extends SenderService {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static void checkTaskSettingsForTextEmbeddingModel(AmazonBedrockEmbeddingsModel model) {
|
||||||
|
if (model.provider() != AmazonBedrockProvider.COHERE && model.getTaskSettings().cohereTruncation() != null) {
|
||||||
|
throw new ElasticsearchStatusException(
|
||||||
|
"The [{}] task type for provider [{}] does not allow [truncate] field",
|
||||||
|
RestStatus.BAD_REQUEST,
|
||||||
|
TaskType.TEXT_EMBEDDING,
|
||||||
|
model.provider()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private static void checkChatCompletionProviderForTopKParameter(AmazonBedrockChatCompletionModel model) {
|
private static void checkChatCompletionProviderForTopKParameter(AmazonBedrockChatCompletionModel model) {
|
||||||
var taskSettings = model.getTaskSettings();
|
var taskSettings = model.getTaskSettings();
|
||||||
if (taskSettings.topK() != null) {
|
if (taskSettings.topK() != null) {
|
||||||
|
|
|
@ -7,14 +7,11 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;
|
package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;
|
||||||
|
|
||||||
import org.elasticsearch.common.ValidationException;
|
|
||||||
import org.elasticsearch.inference.ChunkingSettings;
|
import org.elasticsearch.inference.ChunkingSettings;
|
||||||
import org.elasticsearch.inference.EmptyTaskSettings;
|
|
||||||
import org.elasticsearch.inference.Model;
|
import org.elasticsearch.inference.Model;
|
||||||
import org.elasticsearch.inference.ModelConfigurations;
|
import org.elasticsearch.inference.ModelConfigurations;
|
||||||
import org.elasticsearch.inference.ModelSecrets;
|
import org.elasticsearch.inference.ModelSecrets;
|
||||||
import org.elasticsearch.inference.ServiceSettings;
|
import org.elasticsearch.inference.ServiceSettings;
|
||||||
import org.elasticsearch.inference.TaskSettings;
|
|
||||||
import org.elasticsearch.inference.TaskType;
|
import org.elasticsearch.inference.TaskType;
|
||||||
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
|
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
|
||||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||||
|
@ -28,10 +25,8 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel {
|
||||||
|
|
||||||
public static AmazonBedrockEmbeddingsModel of(AmazonBedrockEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings) {
|
public static AmazonBedrockEmbeddingsModel of(AmazonBedrockEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings) {
|
||||||
if (taskSettings != null && taskSettings.isEmpty() == false) {
|
if (taskSettings != null && taskSettings.isEmpty() == false) {
|
||||||
// no task settings allowed
|
var updatedTaskSettings = embeddingsModel.getTaskSettings().updatedTaskSettings(taskSettings);
|
||||||
var validationException = new ValidationException();
|
return new AmazonBedrockEmbeddingsModel(embeddingsModel, updatedTaskSettings);
|
||||||
validationException.addValidationError("Amazon Bedrock embeddings model cannot have task settings");
|
|
||||||
throw validationException;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return embeddingsModel;
|
return embeddingsModel;
|
||||||
|
@ -52,7 +47,7 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel {
|
||||||
taskType,
|
taskType,
|
||||||
service,
|
service,
|
||||||
AmazonBedrockEmbeddingsServiceSettings.fromMap(serviceSettings, context),
|
AmazonBedrockEmbeddingsServiceSettings.fromMap(serviceSettings, context),
|
||||||
new EmptyTaskSettings(),
|
AmazonBedrockEmbeddingsTaskSettings.fromMap(taskSettings),
|
||||||
chunkingSettings,
|
chunkingSettings,
|
||||||
AwsSecretSettings.fromMap(secretSettings)
|
AwsSecretSettings.fromMap(secretSettings)
|
||||||
);
|
);
|
||||||
|
@ -63,12 +58,12 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel {
|
||||||
TaskType taskType,
|
TaskType taskType,
|
||||||
String service,
|
String service,
|
||||||
AmazonBedrockEmbeddingsServiceSettings serviceSettings,
|
AmazonBedrockEmbeddingsServiceSettings serviceSettings,
|
||||||
TaskSettings taskSettings,
|
AmazonBedrockEmbeddingsTaskSettings taskSettings,
|
||||||
ChunkingSettings chunkingSettings,
|
ChunkingSettings chunkingSettings,
|
||||||
AwsSecretSettings secrets
|
AwsSecretSettings secrets
|
||||||
) {
|
) {
|
||||||
super(
|
super(
|
||||||
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings(), chunkingSettings),
|
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
|
||||||
new ModelSecrets(secrets)
|
new ModelSecrets(secrets)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -77,6 +72,10 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel {
|
||||||
super(model, serviceSettings);
|
super(model, serviceSettings);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public AmazonBedrockEmbeddingsModel(Model model, AmazonBedrockEmbeddingsTaskSettings taskSettings) {
|
||||||
|
super(model, taskSettings);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ExecutableAction accept(AmazonBedrockActionVisitor creator, Map<String, Object> taskSettings) {
|
public ExecutableAction accept(AmazonBedrockActionVisitor creator, Map<String, Object> taskSettings) {
|
||||||
return creator.create(this, taskSettings);
|
return creator.create(this, taskSettings);
|
||||||
|
@ -86,4 +85,9 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel {
|
||||||
public AmazonBedrockEmbeddingsServiceSettings getServiceSettings() {
|
public AmazonBedrockEmbeddingsServiceSettings getServiceSettings() {
|
||||||
return (AmazonBedrockEmbeddingsServiceSettings) super.getServiceSettings();
|
return (AmazonBedrockEmbeddingsServiceSettings) super.getServiceSettings();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public AmazonBedrockEmbeddingsTaskSettings getTaskSettings() {
|
||||||
|
return (AmazonBedrockEmbeddingsTaskSettings) super.getTaskSettings();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,98 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;
|
||||||
|
|
||||||
|
import org.elasticsearch.TransportVersion;
|
||||||
|
import org.elasticsearch.TransportVersions;
|
||||||
|
import org.elasticsearch.common.ValidationException;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamInput;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
|
import org.elasticsearch.inference.ModelConfigurations;
|
||||||
|
import org.elasticsearch.inference.TaskSettings;
|
||||||
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
|
||||||
|
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TRUNCATE_FIELD;
|
||||||
|
|
||||||
|
public record AmazonBedrockEmbeddingsTaskSettings(@Nullable CohereTruncation cohereTruncation) implements TaskSettings {
|
||||||
|
public static final AmazonBedrockEmbeddingsTaskSettings EMPTY = new AmazonBedrockEmbeddingsTaskSettings((CohereTruncation) null);
|
||||||
|
public static final String NAME = "amazon_bedrock_embeddings_task_settings";
|
||||||
|
|
||||||
|
public static AmazonBedrockEmbeddingsTaskSettings fromMap(Map<String, Object> map) {
|
||||||
|
if (map == null || map.isEmpty()) {
|
||||||
|
return EMPTY;
|
||||||
|
}
|
||||||
|
|
||||||
|
ValidationException validationException = new ValidationException();
|
||||||
|
|
||||||
|
var cohereTruncation = extractOptionalEnum(
|
||||||
|
map,
|
||||||
|
TRUNCATE_FIELD,
|
||||||
|
ModelConfigurations.TASK_SETTINGS,
|
||||||
|
CohereTruncation::fromString,
|
||||||
|
CohereTruncation.ALL,
|
||||||
|
validationException
|
||||||
|
);
|
||||||
|
|
||||||
|
if (validationException.validationErrors().isEmpty() == false) {
|
||||||
|
throw validationException;
|
||||||
|
}
|
||||||
|
|
||||||
|
return new AmazonBedrockEmbeddingsTaskSettings(cohereTruncation);
|
||||||
|
}
|
||||||
|
|
||||||
|
public AmazonBedrockEmbeddingsTaskSettings(StreamInput in) throws IOException {
|
||||||
|
this(in.readOptionalEnum(CohereTruncation.class));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isEmpty() {
|
||||||
|
return cohereTruncation() == null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public AmazonBedrockEmbeddingsTaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
|
||||||
|
var newTaskSettings = fromMap(new HashMap<>(newSettings));
|
||||||
|
|
||||||
|
return new AmazonBedrockEmbeddingsTaskSettings(firstNonNullOrNull(newTaskSettings.cohereTruncation(), cohereTruncation()));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static <T> T firstNonNullOrNull(T first, T second) {
|
||||||
|
return first != null ? first : second;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getWriteableName() {
|
||||||
|
return NAME;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TransportVersion getMinimalSupportedVersion() {
|
||||||
|
return TransportVersions.AMAZON_BEDROCK_TASK_SETTINGS;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void writeTo(StreamOutput out) throws IOException {
|
||||||
|
out.writeOptionalEnum(cohereTruncation());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
builder.startObject();
|
||||||
|
if (cohereTruncation != null) {
|
||||||
|
builder.field(TRUNCATE_FIELD, cohereTruncation);
|
||||||
|
}
|
||||||
|
return builder.endObject();
|
||||||
|
}
|
||||||
|
}
|
|
@ -11,6 +11,7 @@ import org.elasticsearch.core.Nullable;
|
||||||
import org.elasticsearch.inference.InputType;
|
import org.elasticsearch.inference.InputType;
|
||||||
import org.elasticsearch.xcontent.ToXContentObject;
|
import org.elasticsearch.xcontent.ToXContentObject;
|
||||||
import org.elasticsearch.xcontent.XContentBuilder;
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettings;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -18,7 +19,11 @@ import java.util.Objects;
|
||||||
|
|
||||||
import static org.elasticsearch.inference.InputType.invalidInputTypeMessage;
|
import static org.elasticsearch.inference.InputType.invalidInputTypeMessage;
|
||||||
|
|
||||||
public record AmazonBedrockCohereEmbeddingsRequestEntity(List<String> input, @Nullable InputType inputType) implements ToXContentObject {
|
public record AmazonBedrockCohereEmbeddingsRequestEntity(
|
||||||
|
List<String> input,
|
||||||
|
@Nullable InputType inputType,
|
||||||
|
AmazonBedrockEmbeddingsTaskSettings taskSettings
|
||||||
|
) implements ToXContentObject {
|
||||||
|
|
||||||
private static final String TEXTS_FIELD = "texts";
|
private static final String TEXTS_FIELD = "texts";
|
||||||
private static final String INPUT_TYPE_FIELD = "input_type";
|
private static final String INPUT_TYPE_FIELD = "input_type";
|
||||||
|
@ -26,9 +31,11 @@ public record AmazonBedrockCohereEmbeddingsRequestEntity(List<String> input, @Nu
|
||||||
private static final String SEARCH_QUERY = "search_query";
|
private static final String SEARCH_QUERY = "search_query";
|
||||||
private static final String CLUSTERING = "clustering";
|
private static final String CLUSTERING = "clustering";
|
||||||
private static final String CLASSIFICATION = "classification";
|
private static final String CLASSIFICATION = "classification";
|
||||||
|
private static final String TRUNCATE = "truncate";
|
||||||
|
|
||||||
public AmazonBedrockCohereEmbeddingsRequestEntity {
|
public AmazonBedrockCohereEmbeddingsRequestEntity {
|
||||||
Objects.requireNonNull(input);
|
Objects.requireNonNull(input);
|
||||||
|
Objects.requireNonNull(taskSettings);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -43,6 +50,10 @@ public record AmazonBedrockCohereEmbeddingsRequestEntity(List<String> input, @Nu
|
||||||
builder.field(INPUT_TYPE_FIELD, SEARCH_DOCUMENT);
|
builder.field(INPUT_TYPE_FIELD, SEARCH_DOCUMENT);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (taskSettings.cohereTruncation() != null) {
|
||||||
|
builder.field(TRUNCATE, taskSettings.cohereTruncation().name());
|
||||||
|
}
|
||||||
|
|
||||||
builder.endObject();
|
builder.endObject();
|
||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,7 +39,7 @@ public final class AmazonBedrockEmbeddingsEntityFactory {
|
||||||
return new AmazonBedrockTitanEmbeddingsRequestEntity(truncatedInput.get(0));
|
return new AmazonBedrockTitanEmbeddingsRequestEntity(truncatedInput.get(0));
|
||||||
}
|
}
|
||||||
case COHERE -> {
|
case COHERE -> {
|
||||||
return new AmazonBedrockCohereEmbeddingsRequestEntity(truncatedInput, inputType);
|
return new AmazonBedrockCohereEmbeddingsRequestEntity(truncatedInput, inputType, model.getTaskSettings());
|
||||||
}
|
}
|
||||||
default -> {
|
default -> {
|
||||||
return null;
|
return null;
|
||||||
|
|
|
@ -76,6 +76,9 @@ public class AmazonBedrockEmbeddingsRequest extends AmazonBedrockRequest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Request truncate() {
|
public Request truncate() {
|
||||||
|
if (provider == AmazonBedrockProvider.COHERE) {
|
||||||
|
return this; // Cohere has its own truncation logic
|
||||||
|
}
|
||||||
var truncatedInput = truncator.truncate(truncationResult.input());
|
var truncatedInput = truncator.truncate(truncationResult.input());
|
||||||
return new AmazonBedrockEmbeddingsRequest(truncator, truncatedInput, embeddingsModel, requestEntity, timeout);
|
return new AmazonBedrockEmbeddingsRequest(truncator, truncatedInput, embeddingsModel, requestEntity, timeout);
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.services.cohere;
|
package org.elasticsearch.xpack.inference.services.cohere;
|
||||||
|
|
||||||
|
import java.util.EnumSet;
|
||||||
import java.util.Locale;
|
import java.util.Locale;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -31,6 +32,8 @@ public enum CohereTruncation {
|
||||||
*/
|
*/
|
||||||
END;
|
END;
|
||||||
|
|
||||||
|
public static final EnumSet<CohereTruncation> ALL = EnumSet.allOf(CohereTruncation.class);
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString() {
|
public String toString() {
|
||||||
return name().toLowerCase(Locale.ROOT);
|
return name().toLowerCase(Locale.ROOT);
|
||||||
|
|
|
@ -20,7 +20,6 @@ import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
|
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.EnumSet;
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
@ -63,7 +62,7 @@ public class CohereEmbeddingsTaskSettings implements TaskSettings {
|
||||||
TRUNCATE,
|
TRUNCATE,
|
||||||
ModelConfigurations.TASK_SETTINGS,
|
ModelConfigurations.TASK_SETTINGS,
|
||||||
CohereTruncation::fromString,
|
CohereTruncation::fromString,
|
||||||
EnumSet.allOf(CohereTruncation.class),
|
CohereTruncation.ALL,
|
||||||
validationException
|
validationException
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@ import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeExcept
|
||||||
import org.elasticsearch.ElasticsearchException;
|
import org.elasticsearch.ElasticsearchException;
|
||||||
import org.elasticsearch.ElasticsearchStatusException;
|
import org.elasticsearch.ElasticsearchStatusException;
|
||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
import org.elasticsearch.action.support.ActionTestUtils;
|
||||||
import org.elasticsearch.action.support.PlainActionFuture;
|
import org.elasticsearch.action.support.PlainActionFuture;
|
||||||
import org.elasticsearch.common.ValidationException;
|
import org.elasticsearch.common.ValidationException;
|
||||||
import org.elasticsearch.common.bytes.BytesArray;
|
import org.elasticsearch.common.bytes.BytesArray;
|
||||||
|
@ -51,6 +52,8 @@ import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.Amazo
|
||||||
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModel;
|
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModel;
|
||||||
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModelTests;
|
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModelTests;
|
||||||
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings;
|
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettingsTests;
|
||||||
|
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
|
||||||
import org.hamcrest.CoreMatchers;
|
import org.hamcrest.CoreMatchers;
|
||||||
import org.hamcrest.Matchers;
|
import org.hamcrest.Matchers;
|
||||||
|
@ -105,7 +108,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
|
|
||||||
public void testParseRequestConfig_CreatesAnAmazonBedrockModel() throws IOException {
|
public void testParseRequestConfig_CreatesAnAmazonBedrockModel() throws IOException {
|
||||||
try (var service = createAmazonBedrockService()) {
|
try (var service = createAmazonBedrockService()) {
|
||||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(model -> {
|
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> {
|
||||||
assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class));
|
assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class));
|
||||||
|
|
||||||
var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings();
|
var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings();
|
||||||
|
@ -115,7 +118,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
var secretSettings = (AwsSecretSettings) model.getSecretSettings();
|
var secretSettings = (AwsSecretSettings) model.getSecretSettings();
|
||||||
assertThat(secretSettings.accessKey().toString(), is("access"));
|
assertThat(secretSettings.accessKey().toString(), is("access"));
|
||||||
assertThat(secretSettings.secretKey().toString(), is("secret"));
|
assertThat(secretSettings.secretKey().toString(), is("secret"));
|
||||||
}, exception -> fail("Unexpected exception: " + exception));
|
});
|
||||||
|
|
||||||
service.parseRequestConfig(
|
service.parseRequestConfig(
|
||||||
"id",
|
"id",
|
||||||
|
@ -130,15 +133,62 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testParseRequestConfig_CreatesACohereModel() throws IOException {
|
||||||
|
try (var service = createAmazonBedrockService()) {
|
||||||
|
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> {
|
||||||
|
assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class));
|
||||||
|
|
||||||
|
var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings();
|
||||||
|
assertThat(settings.region(), is("region"));
|
||||||
|
assertThat(settings.modelId(), is("model"));
|
||||||
|
assertThat(settings.provider(), is(AmazonBedrockProvider.COHERE));
|
||||||
|
var secretSettings = (AwsSecretSettings) model.getSecretSettings();
|
||||||
|
assertThat(secretSettings.accessKey().toString(), is("access"));
|
||||||
|
assertThat(secretSettings.secretKey().toString(), is("secret"));
|
||||||
|
});
|
||||||
|
|
||||||
|
service.parseRequestConfig(
|
||||||
|
"id",
|
||||||
|
TaskType.TEXT_EMBEDDING,
|
||||||
|
getRequestConfigMap(
|
||||||
|
createEmbeddingsRequestSettingsMap("region", "model", "cohere", null, null, null, null),
|
||||||
|
AmazonBedrockEmbeddingsTaskSettingsTests.mutableMap("truncate", CohereTruncation.START),
|
||||||
|
getAmazonBedrockSecretSettingsMap("access", "secret")
|
||||||
|
),
|
||||||
|
modelVerificationListener
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testParseRequestConfig_CohereSettingsWithNoCohereModel() throws IOException {
|
||||||
|
try (var service = createAmazonBedrockService()) {
|
||||||
|
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> {
|
||||||
|
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
|
||||||
|
assertThat(
|
||||||
|
exception.getMessage(),
|
||||||
|
is("The [text_embedding] task type for provider [amazontitan] does not allow [truncate] field")
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
service.parseRequestConfig(
|
||||||
|
"id",
|
||||||
|
TaskType.TEXT_EMBEDDING,
|
||||||
|
getRequestConfigMap(
|
||||||
|
createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, null, null, null),
|
||||||
|
AmazonBedrockEmbeddingsTaskSettingsTests.mutableMap("truncate", CohereTruncation.START),
|
||||||
|
getAmazonBedrockSecretSettingsMap("access", "secret")
|
||||||
|
),
|
||||||
|
modelVerificationListener
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException {
|
public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException {
|
||||||
try (var service = createAmazonBedrockService()) {
|
try (var service = createAmazonBedrockService()) {
|
||||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> {
|
||||||
model -> fail("Expected exception, but got model: " + model),
|
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
|
||||||
exception -> {
|
assertThat(exception.getMessage(), is("The [amazonbedrock] service does not support task type [sparse_embedding]"));
|
||||||
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
|
});
|
||||||
assertThat(exception.getMessage(), is("The [amazonbedrock] service does not support task type [sparse_embedding]"));
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
service.parseRequestConfig(
|
service.parseRequestConfig(
|
||||||
"id",
|
"id",
|
||||||
|
@ -247,13 +297,10 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
|
|
||||||
public void testCreateModel_ForEmbeddingsTask_InvalidProvider() throws IOException {
|
public void testCreateModel_ForEmbeddingsTask_InvalidProvider() throws IOException {
|
||||||
try (var service = createAmazonBedrockService()) {
|
try (var service = createAmazonBedrockService()) {
|
||||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> {
|
||||||
model -> fail("Expected exception, but got model: " + model),
|
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
|
||||||
exception -> {
|
assertThat(exception.getMessage(), is("The [text_embedding] task type for provider [anthropic] is not available"));
|
||||||
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
|
});
|
||||||
assertThat(exception.getMessage(), is("The [text_embedding] task type for provider [anthropic] is not available"));
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
service.parseRequestConfig(
|
service.parseRequestConfig(
|
||||||
"id",
|
"id",
|
||||||
|
@ -270,13 +317,10 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
|
|
||||||
public void testCreateModel_TopKParameter_NotAvailable() throws IOException {
|
public void testCreateModel_TopKParameter_NotAvailable() throws IOException {
|
||||||
try (var service = createAmazonBedrockService()) {
|
try (var service = createAmazonBedrockService()) {
|
||||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> {
|
||||||
model -> fail("Expected exception, but got model: " + model),
|
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
|
||||||
exception -> {
|
assertThat(exception.getMessage(), is("The [top_k] task parameter is not available for provider [amazontitan]"));
|
||||||
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
|
});
|
||||||
assertThat(exception.getMessage(), is("The [top_k] task parameter is not available for provider [amazontitan]"));
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
service.parseRequestConfig(
|
service.parseRequestConfig(
|
||||||
"id",
|
"id",
|
||||||
|
@ -301,16 +345,13 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
|
|
||||||
config.put("extra_key", "value");
|
config.put("extra_key", "value");
|
||||||
|
|
||||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> {
|
||||||
model -> fail("Expected exception, but got model: " + model),
|
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
|
||||||
exception -> {
|
assertThat(
|
||||||
assertThat(exception, instanceOf(ElasticsearchStatusException.class));
|
exception.getMessage(),
|
||||||
assertThat(
|
is("Model configuration contains settings [{extra_key=value}] unknown to the [amazonbedrock] service")
|
||||||
exception.getMessage(),
|
);
|
||||||
is("Model configuration contains settings [{extra_key=value}] unknown to the [amazonbedrock] service")
|
});
|
||||||
);
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener);
|
service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener);
|
||||||
}
|
}
|
||||||
|
@ -323,9 +364,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
|
|
||||||
var config = getRequestConfigMap(serviceSettings, Map.of(), getAmazonBedrockSecretSettingsMap("access", "secret"));
|
var config = getRequestConfigMap(serviceSettings, Map.of(), getAmazonBedrockSecretSettingsMap("access", "secret"));
|
||||||
|
|
||||||
ActionListener<Model> modelVerificationListener = ActionListener.<Model>wrap((model) -> {
|
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(e -> {
|
||||||
fail("Expected exception, but got model: " + model);
|
|
||||||
}, e -> {
|
|
||||||
assertThat(e, instanceOf(ElasticsearchStatusException.class));
|
assertThat(e, instanceOf(ElasticsearchStatusException.class));
|
||||||
assertThat(
|
assertThat(
|
||||||
e.getMessage(),
|
e.getMessage(),
|
||||||
|
@ -347,9 +386,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
|
|
||||||
var config = getRequestConfigMap(settingsMap, taskSettingsMap, secretSettingsMap);
|
var config = getRequestConfigMap(settingsMap, taskSettingsMap, secretSettingsMap);
|
||||||
|
|
||||||
ActionListener<Model> modelVerificationListener = ActionListener.<Model>wrap((model) -> {
|
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(e -> {
|
||||||
fail("Expected exception, but got model: " + model);
|
|
||||||
}, e -> {
|
|
||||||
assertThat(e, instanceOf(ElasticsearchStatusException.class));
|
assertThat(e, instanceOf(ElasticsearchStatusException.class));
|
||||||
assertThat(
|
assertThat(
|
||||||
e.getMessage(),
|
e.getMessage(),
|
||||||
|
@ -371,9 +408,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
|
|
||||||
var config = getRequestConfigMap(settingsMap, taskSettingsMap, secretSettingsMap);
|
var config = getRequestConfigMap(settingsMap, taskSettingsMap, secretSettingsMap);
|
||||||
|
|
||||||
ActionListener<Model> modelVerificationListener = ActionListener.<Model>wrap((model) -> {
|
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(e -> {
|
||||||
fail("Expected exception, but got model: " + model);
|
|
||||||
}, e -> {
|
|
||||||
assertThat(e, instanceOf(ElasticsearchStatusException.class));
|
assertThat(e, instanceOf(ElasticsearchStatusException.class));
|
||||||
assertThat(
|
assertThat(
|
||||||
e.getMessage(),
|
e.getMessage(),
|
||||||
|
@ -387,7 +422,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
|
|
||||||
public void testParseRequestConfig_MovesModel() throws IOException {
|
public void testParseRequestConfig_MovesModel() throws IOException {
|
||||||
try (var service = createAmazonBedrockService()) {
|
try (var service = createAmazonBedrockService()) {
|
||||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(model -> {
|
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> {
|
||||||
assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class));
|
assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class));
|
||||||
|
|
||||||
var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings();
|
var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings();
|
||||||
|
@ -397,7 +432,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
var secretSettings = (AwsSecretSettings) model.getSecretSettings();
|
var secretSettings = (AwsSecretSettings) model.getSecretSettings();
|
||||||
assertThat(secretSettings.accessKey().toString(), is("access"));
|
assertThat(secretSettings.accessKey().toString(), is("access"));
|
||||||
assertThat(secretSettings.secretKey().toString(), is("secret"));
|
assertThat(secretSettings.secretKey().toString(), is("secret"));
|
||||||
}, exception -> fail("Unexpected exception: " + exception));
|
});
|
||||||
|
|
||||||
service.parseRequestConfig(
|
service.parseRequestConfig(
|
||||||
"id",
|
"id",
|
||||||
|
@ -414,7 +449,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
|
|
||||||
public void testParseRequestConfig_CreatesAnAmazonBedrockEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
|
public void testParseRequestConfig_CreatesAnAmazonBedrockEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
|
||||||
try (var service = createAmazonBedrockService()) {
|
try (var service = createAmazonBedrockService()) {
|
||||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(model -> {
|
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> {
|
||||||
assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class));
|
assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class));
|
||||||
|
|
||||||
var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings();
|
var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings();
|
||||||
|
@ -425,7 +460,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
assertThat(secretSettings.accessKey().toString(), is("access"));
|
assertThat(secretSettings.accessKey().toString(), is("access"));
|
||||||
assertThat(secretSettings.secretKey().toString(), is("secret"));
|
assertThat(secretSettings.secretKey().toString(), is("secret"));
|
||||||
assertThat(model.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
|
assertThat(model.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
|
||||||
}, exception -> fail("Unexpected exception: " + exception));
|
});
|
||||||
|
|
||||||
service.parseRequestConfig(
|
service.parseRequestConfig(
|
||||||
"id",
|
"id",
|
||||||
|
@ -443,7 +478,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
|
|
||||||
public void testParseRequestConfig_CreatesAnAmazonBedrockEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException {
|
public void testParseRequestConfig_CreatesAnAmazonBedrockEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException {
|
||||||
try (var service = createAmazonBedrockService()) {
|
try (var service = createAmazonBedrockService()) {
|
||||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(model -> {
|
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> {
|
||||||
assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class));
|
assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class));
|
||||||
|
|
||||||
var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings();
|
var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings();
|
||||||
|
@ -454,7 +489,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
assertThat(secretSettings.accessKey().toString(), is("access"));
|
assertThat(secretSettings.accessKey().toString(), is("access"));
|
||||||
assertThat(secretSettings.secretKey().toString(), is("secret"));
|
assertThat(secretSettings.secretKey().toString(), is("secret"));
|
||||||
assertThat(model.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
|
assertThat(model.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
|
||||||
}, exception -> fail("Unexpected exception: " + exception));
|
});
|
||||||
|
|
||||||
service.parseRequestConfig(
|
service.parseRequestConfig(
|
||||||
"id",
|
"id",
|
||||||
|
@ -471,13 +506,10 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
|
|
||||||
public void testCreateModel_ForEmbeddingsTask_DimensionsIsNotAllowed() throws IOException {
|
public void testCreateModel_ForEmbeddingsTask_DimensionsIsNotAllowed() throws IOException {
|
||||||
try (var service = createAmazonBedrockService()) {
|
try (var service = createAmazonBedrockService()) {
|
||||||
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
ActionListener<Model> modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> {
|
||||||
model -> fail("Expected exception, but got model: " + model),
|
assertThat(exception, instanceOf(ValidationException.class));
|
||||||
exception -> {
|
assertThat(exception.getMessage(), containsString("[service_settings] does not allow the setting [dimensions]"));
|
||||||
assertThat(exception, instanceOf(ValidationException.class));
|
});
|
||||||
assertThat(exception.getMessage(), containsString("[service_settings] does not allow the setting [dimensions]"));
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
service.parseRequestConfig(
|
service.parseRequestConfig(
|
||||||
"id",
|
"id",
|
||||||
|
@ -497,7 +529,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
|
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
|
||||||
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
|
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
|
||||||
|
|
||||||
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap);
|
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
|
||||||
|
|
||||||
var model = service.parsePersistedConfigWithSecrets(
|
var model = service.parsePersistedConfigWithSecrets(
|
||||||
"id",
|
"id",
|
||||||
|
@ -525,7 +557,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
|
|
||||||
var persistedConfig = getPersistedConfigMap(
|
var persistedConfig = getPersistedConfigMap(
|
||||||
settingsMap,
|
settingsMap,
|
||||||
new HashMap<String, Object>(Map.of()),
|
new HashMap<>(Map.of()),
|
||||||
createRandomChunkingSettingsMap(),
|
createRandomChunkingSettingsMap(),
|
||||||
secretSettingsMap
|
secretSettingsMap
|
||||||
);
|
);
|
||||||
|
@ -607,7 +639,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
|
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
|
||||||
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
|
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
|
||||||
|
|
||||||
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap);
|
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
|
||||||
persistedConfig.config().put("extra_key", "value");
|
persistedConfig.config().put("extra_key", "value");
|
||||||
|
|
||||||
var model = service.parsePersistedConfigWithSecrets(
|
var model = service.parsePersistedConfigWithSecrets(
|
||||||
|
@ -635,7 +667,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
|
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
|
||||||
secretSettingsMap.put("extra_key", "value");
|
secretSettingsMap.put("extra_key", "value");
|
||||||
|
|
||||||
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap);
|
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
|
||||||
|
|
||||||
var model = service.parsePersistedConfigWithSecrets(
|
var model = service.parsePersistedConfigWithSecrets(
|
||||||
"id",
|
"id",
|
||||||
|
@ -661,7 +693,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
|
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
|
||||||
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
|
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
|
||||||
|
|
||||||
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap);
|
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
|
||||||
persistedConfig.secrets().put("extra_key", "value");
|
persistedConfig.secrets().put("extra_key", "value");
|
||||||
|
|
||||||
var model = service.parsePersistedConfigWithSecrets(
|
var model = service.parsePersistedConfigWithSecrets(
|
||||||
|
@ -689,7 +721,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
settingsMap.put("extra_key", "value");
|
settingsMap.put("extra_key", "value");
|
||||||
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
|
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
|
||||||
|
|
||||||
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap);
|
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
|
||||||
|
|
||||||
var model = service.parsePersistedConfigWithSecrets(
|
var model = service.parsePersistedConfigWithSecrets(
|
||||||
"id",
|
"id",
|
||||||
|
@ -769,7 +801,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
|
|
||||||
var persistedConfig = getPersistedConfigMap(
|
var persistedConfig = getPersistedConfigMap(
|
||||||
settingsMap,
|
settingsMap,
|
||||||
new HashMap<String, Object>(Map.of()),
|
new HashMap<>(Map.of()),
|
||||||
createRandomChunkingSettingsMap(),
|
createRandomChunkingSettingsMap(),
|
||||||
secretSettingsMap
|
secretSettingsMap
|
||||||
);
|
);
|
||||||
|
@ -792,7 +824,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
|
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
|
||||||
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
|
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
|
||||||
|
|
||||||
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap);
|
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
|
||||||
|
|
||||||
var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
|
var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
|
||||||
|
|
||||||
|
@ -836,7 +868,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
|
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
|
||||||
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
|
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
|
||||||
|
|
||||||
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap);
|
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
|
||||||
|
|
||||||
var thrownException = expectThrows(
|
var thrownException = expectThrows(
|
||||||
ElasticsearchStatusException.class,
|
ElasticsearchStatusException.class,
|
||||||
|
@ -855,7 +887,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
|
var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null);
|
||||||
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
|
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
|
||||||
|
|
||||||
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap);
|
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
|
||||||
persistedConfig.config().put("extra_key", "value");
|
persistedConfig.config().put("extra_key", "value");
|
||||||
|
|
||||||
var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
|
var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
|
||||||
|
@ -876,7 +908,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
settingsMap.put("extra_key", "value");
|
settingsMap.put("extra_key", "value");
|
||||||
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
|
var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret");
|
||||||
|
|
||||||
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<String, Object>(Map.of()), secretSettingsMap);
|
var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap);
|
||||||
persistedConfig.config().put("extra_key", "value");
|
persistedConfig.config().put("extra_key", "value");
|
||||||
|
|
||||||
var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
|
var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
|
||||||
|
|
|
@ -11,6 +11,7 @@ import org.elasticsearch.inference.InputType;
|
||||||
import org.elasticsearch.test.ESTestCase;
|
import org.elasticsearch.test.ESTestCase;
|
||||||
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.AmazonBedrockJsonBuilder;
|
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.AmazonBedrockJsonBuilder;
|
||||||
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.embeddings.AmazonBedrockCohereEmbeddingsRequestEntity;
|
import org.elasticsearch.xpack.inference.services.amazonbedrock.request.embeddings.AmazonBedrockCohereEmbeddingsRequestEntity;
|
||||||
|
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -19,23 +20,46 @@ import static org.hamcrest.Matchers.is;
|
||||||
|
|
||||||
public class AmazonBedrockCohereEmbeddingsRequestEntityTests extends ESTestCase {
|
public class AmazonBedrockCohereEmbeddingsRequestEntityTests extends ESTestCase {
|
||||||
public void testRequestEntity_GeneratesExpectedJsonBody() throws IOException {
|
public void testRequestEntity_GeneratesExpectedJsonBody() throws IOException {
|
||||||
var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(List.of("test input"), InputType.CLASSIFICATION);
|
var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(
|
||||||
|
List.of("test input"),
|
||||||
|
InputType.CLASSIFICATION,
|
||||||
|
AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings()
|
||||||
|
);
|
||||||
var builder = new AmazonBedrockJsonBuilder(entity);
|
var builder = new AmazonBedrockJsonBuilder(entity);
|
||||||
var result = builder.getStringContent();
|
var result = builder.getStringContent();
|
||||||
assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"classification\"}"));
|
assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"classification\"}"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testRequestEntity_GeneratesExpectedJsonBody_WithInternalInputType() throws IOException {
|
public void testRequestEntity_GeneratesExpectedJsonBody_WithInternalInputType() throws IOException {
|
||||||
var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(List.of("test input"), InputType.INTERNAL_SEARCH);
|
var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(
|
||||||
|
List.of("test input"),
|
||||||
|
InputType.INTERNAL_SEARCH,
|
||||||
|
AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings()
|
||||||
|
);
|
||||||
var builder = new AmazonBedrockJsonBuilder(entity);
|
var builder = new AmazonBedrockJsonBuilder(entity);
|
||||||
var result = builder.getStringContent();
|
var result = builder.getStringContent();
|
||||||
assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"search_query\"}"));
|
assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"search_query\"}"));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testRequestEntity_GeneratesExpectedJsonBody_WithoutInputType() throws IOException {
|
public void testRequestEntity_GeneratesExpectedJsonBody_WithoutInputType() throws IOException {
|
||||||
var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(List.of("test input"), null);
|
var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(
|
||||||
|
List.of("test input"),
|
||||||
|
null,
|
||||||
|
AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings()
|
||||||
|
);
|
||||||
var builder = new AmazonBedrockJsonBuilder(entity);
|
var builder = new AmazonBedrockJsonBuilder(entity);
|
||||||
var result = builder.getStringContent();
|
var result = builder.getStringContent();
|
||||||
assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"search_document\"}"));
|
assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"search_document\"}"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testRequestEntity_GeneratesExpectedJsonBody_WithCohereTruncation() throws IOException {
|
||||||
|
var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(
|
||||||
|
List.of("test input"),
|
||||||
|
null,
|
||||||
|
new AmazonBedrockEmbeddingsTaskSettings(CohereTruncation.START)
|
||||||
|
);
|
||||||
|
var builder = new AmazonBedrockJsonBuilder(entity);
|
||||||
|
var result = builder.getStringContent();
|
||||||
|
assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"search_document\",\"truncate\":\"START\"}"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,12 +7,9 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;
|
package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;
|
||||||
|
|
||||||
import org.elasticsearch.common.ValidationException;
|
|
||||||
import org.elasticsearch.common.settings.SecureString;
|
import org.elasticsearch.common.settings.SecureString;
|
||||||
import org.elasticsearch.core.Nullable;
|
import org.elasticsearch.core.Nullable;
|
||||||
import org.elasticsearch.inference.ChunkingSettings;
|
import org.elasticsearch.inference.ChunkingSettings;
|
||||||
import org.elasticsearch.inference.EmptyTaskSettings;
|
|
||||||
import org.elasticsearch.inference.InputType;
|
|
||||||
import org.elasticsearch.inference.SimilarityMeasure;
|
import org.elasticsearch.inference.SimilarityMeasure;
|
||||||
import org.elasticsearch.inference.TaskType;
|
import org.elasticsearch.inference.TaskType;
|
||||||
import org.elasticsearch.test.ESTestCase;
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
@ -20,19 +17,23 @@ import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider;
|
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||||
|
|
||||||
import java.util.Map;
|
import java.io.IOException;
|
||||||
|
|
||||||
import static org.hamcrest.Matchers.containsString;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
import static org.hamcrest.Matchers.not;
|
||||||
|
|
||||||
public class AmazonBedrockEmbeddingsModelTests extends ESTestCase {
|
public class AmazonBedrockEmbeddingsModelTests extends ESTestCase {
|
||||||
|
|
||||||
public void testCreateModel_withTaskSettings_shouldFail() {
|
public void testCreateModel_withTaskSettingsOverride() throws IOException {
|
||||||
var baseModel = createModel("id", "region", "model", AmazonBedrockProvider.AMAZONTITAN, "accesskey", "secretkey");
|
var baseTaskSettings = AmazonBedrockEmbeddingsTaskSettingsTests.randomTaskSettings();
|
||||||
var thrownException = assertThrows(
|
var baseModel = createModel("id", "region", "model", AmazonBedrockProvider.AMAZONTITAN, "accesskey", "secretkey", baseTaskSettings);
|
||||||
ValidationException.class,
|
|
||||||
() -> AmazonBedrockEmbeddingsModel.of(baseModel, Map.of("testkey", "testvalue"))
|
var overrideTaskSettings = AmazonBedrockEmbeddingsTaskSettingsTests.mutateTaskSettings(baseTaskSettings);
|
||||||
);
|
var overrideTaskSettingsMap = AmazonBedrockEmbeddingsTaskSettingsTests.toMap(overrideTaskSettings);
|
||||||
assertThat(thrownException.getMessage(), containsString("Amazon Bedrock embeddings model cannot have task settings"));
|
|
||||||
|
var overriddenModel = AmazonBedrockEmbeddingsModel.of(baseModel, overrideTaskSettingsMap);
|
||||||
|
assertThat(overriddenModel.getTaskSettings(), equalTo(overrideTaskSettings));
|
||||||
|
assertThat(overriddenModel.getTaskSettings(), not(equalTo(baseTaskSettings)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// model creation only - no tests to define, but we want to have the public createModel
|
// model creation only - no tests to define, but we want to have the public createModel
|
||||||
|
@ -46,7 +47,15 @@ public class AmazonBedrockEmbeddingsModelTests extends ESTestCase {
|
||||||
String accessKey,
|
String accessKey,
|
||||||
String secretKey
|
String secretKey
|
||||||
) {
|
) {
|
||||||
return createModel(inferenceId, region, model, provider, null, false, null, null, new RateLimitSettings(240), accessKey, secretKey);
|
return createModel(
|
||||||
|
inferenceId,
|
||||||
|
region,
|
||||||
|
model,
|
||||||
|
provider,
|
||||||
|
accessKey,
|
||||||
|
secretKey,
|
||||||
|
AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static AmazonBedrockEmbeddingsModel createModel(
|
public static AmazonBedrockEmbeddingsModel createModel(
|
||||||
|
@ -56,9 +65,22 @@ public class AmazonBedrockEmbeddingsModelTests extends ESTestCase {
|
||||||
AmazonBedrockProvider provider,
|
AmazonBedrockProvider provider,
|
||||||
String accessKey,
|
String accessKey,
|
||||||
String secretKey,
|
String secretKey,
|
||||||
InputType inputType
|
AmazonBedrockEmbeddingsTaskSettings taskSettings
|
||||||
) {
|
) {
|
||||||
return createModel(inferenceId, region, model, provider, null, false, null, null, new RateLimitSettings(240), accessKey, secretKey);
|
return createModel(
|
||||||
|
inferenceId,
|
||||||
|
region,
|
||||||
|
model,
|
||||||
|
provider,
|
||||||
|
null,
|
||||||
|
false,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
new RateLimitSettings(240),
|
||||||
|
accessKey,
|
||||||
|
secretKey,
|
||||||
|
taskSettings
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static AmazonBedrockEmbeddingsModel createModel(
|
public static AmazonBedrockEmbeddingsModel createModel(
|
||||||
|
@ -114,7 +136,7 @@ public class AmazonBedrockEmbeddingsModelTests extends ESTestCase {
|
||||||
similarity,
|
similarity,
|
||||||
rateLimitSettings
|
rateLimitSettings
|
||||||
),
|
),
|
||||||
new EmptyTaskSettings(),
|
AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings(),
|
||||||
chunkingSettings,
|
chunkingSettings,
|
||||||
new AwsSecretSettings(new SecureString(accessKey), new SecureString(secretKey))
|
new AwsSecretSettings(new SecureString(accessKey), new SecureString(secretKey))
|
||||||
);
|
);
|
||||||
|
@ -132,6 +154,36 @@ public class AmazonBedrockEmbeddingsModelTests extends ESTestCase {
|
||||||
RateLimitSettings rateLimitSettings,
|
RateLimitSettings rateLimitSettings,
|
||||||
String accessKey,
|
String accessKey,
|
||||||
String secretKey
|
String secretKey
|
||||||
|
) {
|
||||||
|
return createModel(
|
||||||
|
inferenceId,
|
||||||
|
region,
|
||||||
|
model,
|
||||||
|
provider,
|
||||||
|
dimensions,
|
||||||
|
dimensionsSetByUser,
|
||||||
|
maxTokens,
|
||||||
|
similarity,
|
||||||
|
rateLimitSettings,
|
||||||
|
accessKey,
|
||||||
|
secretKey,
|
||||||
|
AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static AmazonBedrockEmbeddingsModel createModel(
|
||||||
|
String inferenceId,
|
||||||
|
String region,
|
||||||
|
String model,
|
||||||
|
AmazonBedrockProvider provider,
|
||||||
|
@Nullable Integer dimensions,
|
||||||
|
boolean dimensionsSetByUser,
|
||||||
|
@Nullable Integer maxTokens,
|
||||||
|
@Nullable SimilarityMeasure similarity,
|
||||||
|
RateLimitSettings rateLimitSettings,
|
||||||
|
String accessKey,
|
||||||
|
String secretKey,
|
||||||
|
AmazonBedrockEmbeddingsTaskSettings taskSettings
|
||||||
) {
|
) {
|
||||||
return new AmazonBedrockEmbeddingsModel(
|
return new AmazonBedrockEmbeddingsModel(
|
||||||
inferenceId,
|
inferenceId,
|
||||||
|
@ -147,7 +199,7 @@ public class AmazonBedrockEmbeddingsModelTests extends ESTestCase {
|
||||||
similarity,
|
similarity,
|
||||||
rateLimitSettings
|
rateLimitSettings
|
||||||
),
|
),
|
||||||
new EmptyTaskSettings(),
|
taskSettings,
|
||||||
null,
|
null,
|
||||||
new AwsSecretSettings(new SecureString(accessKey), new SecureString(secretKey))
|
new AwsSecretSettings(new SecureString(accessKey), new SecureString(secretKey))
|
||||||
);
|
);
|
||||||
|
|
|
@ -0,0 +1,112 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;
|
||||||
|
|
||||||
|
import org.elasticsearch.TransportVersion;
|
||||||
|
import org.elasticsearch.common.Strings;
|
||||||
|
import org.elasticsearch.common.ValidationException;
|
||||||
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
|
import org.elasticsearch.xcontent.ToXContent;
|
||||||
|
import org.elasticsearch.xcontent.XContentParserConfiguration;
|
||||||
|
import org.elasticsearch.xcontent.json.JsonXContent;
|
||||||
|
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
|
||||||
|
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.inference.InputTypeTests.randomWithoutUnspecified;
|
||||||
|
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TRUNCATE_FIELD;
|
||||||
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
|
||||||
|
public class AmazonBedrockEmbeddingsTaskSettingsTests extends AbstractBWCWireSerializationTestCase<AmazonBedrockEmbeddingsTaskSettings> {
|
||||||
|
|
||||||
|
public static AmazonBedrockEmbeddingsTaskSettings emptyTaskSettings() {
|
||||||
|
return AmazonBedrockEmbeddingsTaskSettings.EMPTY;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static AmazonBedrockEmbeddingsTaskSettings randomTaskSettings() {
|
||||||
|
var inputType = randomBoolean() ? randomWithoutUnspecified() : null;
|
||||||
|
var truncation = randomBoolean() ? randomFrom(CohereTruncation.values()) : null;
|
||||||
|
return new AmazonBedrockEmbeddingsTaskSettings(truncation);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static AmazonBedrockEmbeddingsTaskSettings mutateTaskSettings(AmazonBedrockEmbeddingsTaskSettings instance) {
|
||||||
|
return randomValueOtherThanMany(
|
||||||
|
v -> Objects.equals(instance, v) || (instance.cohereTruncation() != null && v.cohereTruncation() == null),
|
||||||
|
AmazonBedrockEmbeddingsTaskSettingsTests::randomTaskSettings
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected AmazonBedrockEmbeddingsTaskSettings mutateInstanceForVersion(
|
||||||
|
AmazonBedrockEmbeddingsTaskSettings instance,
|
||||||
|
TransportVersion version
|
||||||
|
) {
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Writeable.Reader<AmazonBedrockEmbeddingsTaskSettings> instanceReader() {
|
||||||
|
return AmazonBedrockEmbeddingsTaskSettings::new;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected AmazonBedrockEmbeddingsTaskSettings createTestInstance() {
|
||||||
|
return randomTaskSettings();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected AmazonBedrockEmbeddingsTaskSettings mutateInstance(AmazonBedrockEmbeddingsTaskSettings instance) throws IOException {
|
||||||
|
return mutateTaskSettings(instance);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testEmpty() {
|
||||||
|
assertTrue(emptyTaskSettings().isEmpty());
|
||||||
|
assertTrue(AmazonBedrockEmbeddingsTaskSettings.fromMap(null).isEmpty());
|
||||||
|
assertTrue(AmazonBedrockEmbeddingsTaskSettings.fromMap(Map.of()).isEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Map<String, Object> mutableMap(String key, Enum<?> value) {
|
||||||
|
return new HashMap<>(Map.of(key, value.toString()));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testValidCohereTruncations() {
|
||||||
|
for (var expectedCohereTruncation : CohereTruncation.ALL) {
|
||||||
|
var map = mutableMap(TRUNCATE_FIELD, expectedCohereTruncation);
|
||||||
|
var taskSettings = AmazonBedrockEmbeddingsTaskSettings.fromMap(map);
|
||||||
|
assertFalse(taskSettings.isEmpty());
|
||||||
|
assertThat(taskSettings.cohereTruncation(), equalTo(expectedCohereTruncation));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testGarbageCohereTruncations() {
|
||||||
|
var map = new HashMap<String, Object>(Map.of(TRUNCATE_FIELD, "oiuesoirtuoawoeirha"));
|
||||||
|
assertThrows(ValidationException.class, () -> AmazonBedrockEmbeddingsTaskSettings.fromMap(map));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testXContent() throws IOException {
|
||||||
|
var taskSettings = randomTaskSettings();
|
||||||
|
var taskSettingsAsMap = toMap(taskSettings);
|
||||||
|
var roundTripTaskSettings = AmazonBedrockEmbeddingsTaskSettings.fromMap(new HashMap<>(taskSettingsAsMap));
|
||||||
|
assertThat(roundTripTaskSettings, equalTo(taskSettings));
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Map<String, Object> toMap(AmazonBedrockEmbeddingsTaskSettings taskSettings) throws IOException {
|
||||||
|
try (var builder = JsonXContent.contentBuilder()) {
|
||||||
|
taskSettings.toXContent(builder, ToXContent.EMPTY_PARAMS);
|
||||||
|
var taskSettingsBytes = Strings.toString(builder).getBytes(StandardCharsets.UTF_8);
|
||||||
|
try (var parser = JsonXContent.jsonXContent.createParser(XContentParserConfiguration.EMPTY, taskSettingsBytes)) {
|
||||||
|
return parser.map();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue