[ML] Allow InputType for Bedrock Titan (#127021)

Semantic Search can now send InputType as part of the request to
non-Cohere Bedrock models.

Fix #126709
This commit is contained in:
Pat Whelan 2025-04-18 12:38:59 -04:00 committed by GitHub
parent f461f90d48
commit d870f42c90
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 26 additions and 45 deletions

View file

@ -53,7 +53,6 @@ import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
@ -81,15 +80,14 @@ public class AmazonBedrockService extends SenderService {
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);
private static final AmazonBedrockProvider PROVIDER_WITH_TASK_TYPE = AmazonBedrockProvider.COHERE;
private static final EnumSet<InputType> VALID_INPUT_TYPE_VALUES = EnumSet.of(
InputType.INGEST,
InputType.SEARCH,
InputType.CLASSIFICATION,
InputType.CLUSTERING,
InputType.INTERNAL_INGEST,
InputType.INTERNAL_SEARCH
InputType.INTERNAL_SEARCH,
InputType.UNSPECIFIED
);
public AmazonBedrockService(
@ -130,21 +128,8 @@ public class AmazonBedrockService extends SenderService {
@Override
protected void validateInputType(InputType inputType, Model model, ValidationException validationException) {
if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) {
// inputType is only allowed when provider=cohere for text embeddings
var provider = baseAmazonBedrockModel.provider();
if (Objects.equals(provider, PROVIDER_WITH_TASK_TYPE)) {
// input type parameter allowed, so verify it is valid if specified
if (model instanceof AmazonBedrockModel) {
ServiceUtils.validateInputTypeAgainstAllowlist(inputType, VALID_INPUT_TYPE_VALUES, SERVICE_NAME, validationException);
} else {
// input type parameter not allowed so throw validation error if it is specified and not internal
ServiceUtils.validateInputTypeIsUnspecifiedOrInternal(
inputType,
validationException,
Strings.format("Invalid value [%s] received. [%s] is not allowed for provider [%s]", inputType, "input_type", provider)
);
}
}
}

View file

@ -988,7 +988,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
verifyNoMoreInteractions(sender);
}
public void testInfer_ThrowsValidationErrorWhenInputTypeIsSpecifiedForProviderThatDoesNotAcceptTaskType() throws IOException {
public void testInfer_SendsRequest_ForTitanEmbeddingsModel() throws IOException {
var sender = mock(Sender.class);
var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender()).thenReturn(sender);
@ -1006,37 +1006,33 @@ public class AmazonBedrockServiceTests extends ESTestCase {
"secret"
);
try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) {
try (
var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool));
var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()
) {
var results = new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F })));
requestSender.enqueue(results);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
var thrownException = expectThrows(
ValidationException.class,
() -> service.infer(
service.infer(
model,
null,
null,
null,
List.of(""),
List.of("abc"),
false,
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
)
);
assertThat(
thrownException.getMessage(),
is("Validation Failed: 1: Invalid value [ingest] received. [input_type] is not allowed for provider [amazontitan];")
);
verify(factory, times(1)).createSender();
verify(sender, times(1)).start();
var result = listener.actionGet(TIMEOUT);
assertThat(result.asMap(), Matchers.is(buildExpectationFloat(List.of(new float[] { 0.123F, 0.678F }))));
}
verify(sender, times(1)).close();
verifyNoMoreInteractions(factory);
verifyNoMoreInteractions(sender);
}
public void testInfer_SendsRequest_ForEmbeddingsModel() throws IOException {
public void testInfer_SendsRequest_ForCohereEmbeddingsModel() throws IOException {
var sender = mock(Sender.class);
var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender()).thenReturn(sender);