mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 09:28:55 -04:00
[ML] Allow InputType for Bedrock Titan (#127021)
Semantic Search can now send InputType as part of the request to non-Cohere Bedrock models. Fix #126709
This commit is contained in:
parent
f461f90d48
commit
d870f42c90
2 changed files with 26 additions and 45 deletions
|
@ -53,7 +53,6 @@ import java.util.EnumSet;
|
|||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
|
||||
|
@ -81,15 +80,14 @@ public class AmazonBedrockService extends SenderService {
|
|||
|
||||
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);
|
||||
|
||||
private static final AmazonBedrockProvider PROVIDER_WITH_TASK_TYPE = AmazonBedrockProvider.COHERE;
|
||||
|
||||
private static final EnumSet<InputType> VALID_INPUT_TYPE_VALUES = EnumSet.of(
|
||||
InputType.INGEST,
|
||||
InputType.SEARCH,
|
||||
InputType.CLASSIFICATION,
|
||||
InputType.CLUSTERING,
|
||||
InputType.INTERNAL_INGEST,
|
||||
InputType.INTERNAL_SEARCH
|
||||
InputType.INTERNAL_SEARCH,
|
||||
InputType.UNSPECIFIED
|
||||
);
|
||||
|
||||
public AmazonBedrockService(
|
||||
|
@ -130,21 +128,8 @@ public class AmazonBedrockService extends SenderService {
|
|||
|
||||
@Override
|
||||
protected void validateInputType(InputType inputType, Model model, ValidationException validationException) {
|
||||
if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) {
|
||||
// inputType is only allowed when provider=cohere for text embeddings
|
||||
var provider = baseAmazonBedrockModel.provider();
|
||||
|
||||
if (Objects.equals(provider, PROVIDER_WITH_TASK_TYPE)) {
|
||||
// input type parameter allowed, so verify it is valid if specified
|
||||
if (model instanceof AmazonBedrockModel) {
|
||||
ServiceUtils.validateInputTypeAgainstAllowlist(inputType, VALID_INPUT_TYPE_VALUES, SERVICE_NAME, validationException);
|
||||
} else {
|
||||
// input type parameter not allowed so throw validation error if it is specified and not internal
|
||||
ServiceUtils.validateInputTypeIsUnspecifiedOrInternal(
|
||||
inputType,
|
||||
validationException,
|
||||
Strings.format("Invalid value [%s] received. [%s] is not allowed for provider [%s]", inputType, "input_type", provider)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -988,7 +988,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
verifyNoMoreInteractions(sender);
|
||||
}
|
||||
|
||||
public void testInfer_ThrowsValidationErrorWhenInputTypeIsSpecifiedForProviderThatDoesNotAcceptTaskType() throws IOException {
|
||||
public void testInfer_SendsRequest_ForTitanEmbeddingsModel() throws IOException {
|
||||
var sender = mock(Sender.class);
|
||||
var factory = mock(HttpRequestSender.Factory.class);
|
||||
when(factory.createSender()).thenReturn(sender);
|
||||
|
@ -1006,37 +1006,33 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
|||
"secret"
|
||||
);
|
||||
|
||||
try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) {
|
||||
try (
|
||||
var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool));
|
||||
var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()
|
||||
) {
|
||||
var results = new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F })));
|
||||
requestSender.enqueue(results);
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
var thrownException = expectThrows(
|
||||
ValidationException.class,
|
||||
() -> service.infer(
|
||||
service.infer(
|
||||
model,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
List.of(""),
|
||||
List.of("abc"),
|
||||
false,
|
||||
new HashMap<>(),
|
||||
InputType.INGEST,
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||
listener
|
||||
)
|
||||
);
|
||||
assertThat(
|
||||
thrownException.getMessage(),
|
||||
is("Validation Failed: 1: Invalid value [ingest] received. [input_type] is not allowed for provider [amazontitan];")
|
||||
);
|
||||
|
||||
verify(factory, times(1)).createSender();
|
||||
verify(sender, times(1)).start();
|
||||
var result = listener.actionGet(TIMEOUT);
|
||||
|
||||
assertThat(result.asMap(), Matchers.is(buildExpectationFloat(List.of(new float[] { 0.123F, 0.678F }))));
|
||||
}
|
||||
verify(sender, times(1)).close();
|
||||
verifyNoMoreInteractions(factory);
|
||||
verifyNoMoreInteractions(sender);
|
||||
}
|
||||
|
||||
public void testInfer_SendsRequest_ForEmbeddingsModel() throws IOException {
|
||||
public void testInfer_SendsRequest_ForCohereEmbeddingsModel() throws IOException {
|
||||
var sender = mock(Sender.class);
|
||||
var factory = mock(HttpRequestSender.Factory.class);
|
||||
when(factory.createSender()).thenReturn(sender);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue