From d870f42c907582e7666779dca3dcaa8c64b54f41 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Fri, 18 Apr 2025 12:38:59 -0400 Subject: [PATCH] [ML] Allow InputType for Bedrock Titan (#127021) Semantic Search can now send InputType as part of the request to non-Cohere Bedrock models. Fix #126709 --- .../amazonbedrock/AmazonBedrockService.java | 23 ++------- .../AmazonBedrockServiceTests.java | 48 +++++++++---------- 2 files changed, 26 insertions(+), 45 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index b0b4b7eed1a7..591607953ea1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -53,7 +53,6 @@ import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Set; import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; @@ -81,15 +80,14 @@ public class AmazonBedrockService extends SenderService { private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION); - private static final AmazonBedrockProvider PROVIDER_WITH_TASK_TYPE = AmazonBedrockProvider.COHERE; - private static final EnumSet VALID_INPUT_TYPE_VALUES = EnumSet.of( InputType.INGEST, InputType.SEARCH, InputType.CLASSIFICATION, InputType.CLUSTERING, InputType.INTERNAL_INGEST, - InputType.INTERNAL_SEARCH + InputType.INTERNAL_SEARCH, + InputType.UNSPECIFIED ); public AmazonBedrockService( @@ -130,21 +128,8 @@ public class AmazonBedrockService extends SenderService { @Override protected void validateInputType(InputType inputType, Model model, ValidationException validationException) { - if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) { - // inputType is only allowed when provider=cohere for text embeddings - var provider = baseAmazonBedrockModel.provider(); - - if (Objects.equals(provider, PROVIDER_WITH_TASK_TYPE)) { - // input type parameter allowed, so verify it is valid if specified - ServiceUtils.validateInputTypeAgainstAllowlist(inputType, VALID_INPUT_TYPE_VALUES, SERVICE_NAME, validationException); - } else { - // input type parameter not allowed so throw validation error if it is specified and not internal - ServiceUtils.validateInputTypeIsUnspecifiedOrInternal( - inputType, - validationException, - Strings.format("Invalid value [%s] received. [%s] is not allowed for provider [%s]", inputType, "input_type", provider) - ); - } + if (model instanceof AmazonBedrockModel) { + ServiceUtils.validateInputTypeAgainstAllowlist(inputType, VALID_INPUT_TYPE_VALUES, SERVICE_NAME, validationException); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index d34e8b3408fe..07ee7397504a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -988,7 +988,7 @@ public class AmazonBedrockServiceTests extends ESTestCase { verifyNoMoreInteractions(sender); } - public void testInfer_ThrowsValidationErrorWhenInputTypeIsSpecifiedForProviderThatDoesNotAcceptTaskType() throws IOException { + public void testInfer_SendsRequest_ForTitanEmbeddingsModel() throws IOException { var sender = mock(Sender.class); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); @@ -1006,37 +1006,33 @@ public class AmazonBedrockServiceTests extends ESTestCase { "secret" ); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try ( + var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool)); + var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender() + ) { + var results = new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F }))); + requestSender.enqueue(results); PlainActionFuture listener = new PlainActionFuture<>(); - var thrownException = expectThrows( - ValidationException.class, - () -> service.infer( - model, - null, - null, - null, - List.of(""), - false, - new HashMap<>(), - InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ) - ); - assertThat( - thrownException.getMessage(), - is("Validation Failed: 1: Invalid value [ingest] received. [input_type] is not allowed for provider [amazontitan];") + service.infer( + model, + null, + null, + null, + List.of("abc"), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener ); - verify(factory, times(1)).createSender(); - verify(sender, times(1)).start(); + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), Matchers.is(buildExpectationFloat(List.of(new float[] { 0.123F, 0.678F })))); } - verify(sender, times(1)).close(); - verifyNoMoreInteractions(factory); - verifyNoMoreInteractions(sender); } - public void testInfer_SendsRequest_ForEmbeddingsModel() throws IOException { + public void testInfer_SendsRequest_ForCohereEmbeddingsModel() throws IOException { var sender = mock(Sender.class); var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender);