mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 09:28:55 -04:00
[ML] Allow InputType for Bedrock Titan (#127021)
Semantic Search can now send InputType as part of the request to non-Cohere Bedrock models. Fix #126709
This commit is contained in:
parent
f461f90d48
commit
d870f42c90
2 changed files with 26 additions and 45 deletions
|
@ -53,7 +53,6 @@ import java.util.EnumSet;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
|
import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
|
||||||
|
@ -81,15 +80,14 @@ public class AmazonBedrockService extends SenderService {
|
||||||
|
|
||||||
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);
|
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);
|
||||||
|
|
||||||
private static final AmazonBedrockProvider PROVIDER_WITH_TASK_TYPE = AmazonBedrockProvider.COHERE;
|
|
||||||
|
|
||||||
private static final EnumSet<InputType> VALID_INPUT_TYPE_VALUES = EnumSet.of(
|
private static final EnumSet<InputType> VALID_INPUT_TYPE_VALUES = EnumSet.of(
|
||||||
InputType.INGEST,
|
InputType.INGEST,
|
||||||
InputType.SEARCH,
|
InputType.SEARCH,
|
||||||
InputType.CLASSIFICATION,
|
InputType.CLASSIFICATION,
|
||||||
InputType.CLUSTERING,
|
InputType.CLUSTERING,
|
||||||
InputType.INTERNAL_INGEST,
|
InputType.INTERNAL_INGEST,
|
||||||
InputType.INTERNAL_SEARCH
|
InputType.INTERNAL_SEARCH,
|
||||||
|
InputType.UNSPECIFIED
|
||||||
);
|
);
|
||||||
|
|
||||||
public AmazonBedrockService(
|
public AmazonBedrockService(
|
||||||
|
@ -130,21 +128,8 @@ public class AmazonBedrockService extends SenderService {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void validateInputType(InputType inputType, Model model, ValidationException validationException) {
|
protected void validateInputType(InputType inputType, Model model, ValidationException validationException) {
|
||||||
if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) {
|
if (model instanceof AmazonBedrockModel) {
|
||||||
// inputType is only allowed when provider=cohere for text embeddings
|
ServiceUtils.validateInputTypeAgainstAllowlist(inputType, VALID_INPUT_TYPE_VALUES, SERVICE_NAME, validationException);
|
||||||
var provider = baseAmazonBedrockModel.provider();
|
|
||||||
|
|
||||||
if (Objects.equals(provider, PROVIDER_WITH_TASK_TYPE)) {
|
|
||||||
// input type parameter allowed, so verify it is valid if specified
|
|
||||||
ServiceUtils.validateInputTypeAgainstAllowlist(inputType, VALID_INPUT_TYPE_VALUES, SERVICE_NAME, validationException);
|
|
||||||
} else {
|
|
||||||
// input type parameter not allowed so throw validation error if it is specified and not internal
|
|
||||||
ServiceUtils.validateInputTypeIsUnspecifiedOrInternal(
|
|
||||||
inputType,
|
|
||||||
validationException,
|
|
||||||
Strings.format("Invalid value [%s] received. [%s] is not allowed for provider [%s]", inputType, "input_type", provider)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -988,7 +988,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
verifyNoMoreInteractions(sender);
|
verifyNoMoreInteractions(sender);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testInfer_ThrowsValidationErrorWhenInputTypeIsSpecifiedForProviderThatDoesNotAcceptTaskType() throws IOException {
|
public void testInfer_SendsRequest_ForTitanEmbeddingsModel() throws IOException {
|
||||||
var sender = mock(Sender.class);
|
var sender = mock(Sender.class);
|
||||||
var factory = mock(HttpRequestSender.Factory.class);
|
var factory = mock(HttpRequestSender.Factory.class);
|
||||||
when(factory.createSender()).thenReturn(sender);
|
when(factory.createSender()).thenReturn(sender);
|
||||||
|
@ -1006,37 +1006,33 @@ public class AmazonBedrockServiceTests extends ESTestCase {
|
||||||
"secret"
|
"secret"
|
||||||
);
|
);
|
||||||
|
|
||||||
try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) {
|
try (
|
||||||
|
var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool));
|
||||||
|
var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()
|
||||||
|
) {
|
||||||
|
var results = new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F })));
|
||||||
|
requestSender.enqueue(results);
|
||||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
var thrownException = expectThrows(
|
service.infer(
|
||||||
ValidationException.class,
|
model,
|
||||||
() -> service.infer(
|
null,
|
||||||
model,
|
null,
|
||||||
null,
|
null,
|
||||||
null,
|
List.of("abc"),
|
||||||
null,
|
false,
|
||||||
List.of(""),
|
new HashMap<>(),
|
||||||
false,
|
InputType.INGEST,
|
||||||
new HashMap<>(),
|
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||||
InputType.INGEST,
|
listener
|
||||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
|
||||||
listener
|
|
||||||
)
|
|
||||||
);
|
|
||||||
assertThat(
|
|
||||||
thrownException.getMessage(),
|
|
||||||
is("Validation Failed: 1: Invalid value [ingest] received. [input_type] is not allowed for provider [amazontitan];")
|
|
||||||
);
|
);
|
||||||
|
|
||||||
verify(factory, times(1)).createSender();
|
var result = listener.actionGet(TIMEOUT);
|
||||||
verify(sender, times(1)).start();
|
|
||||||
|
assertThat(result.asMap(), Matchers.is(buildExpectationFloat(List.of(new float[] { 0.123F, 0.678F }))));
|
||||||
}
|
}
|
||||||
verify(sender, times(1)).close();
|
|
||||||
verifyNoMoreInteractions(factory);
|
|
||||||
verifyNoMoreInteractions(sender);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testInfer_SendsRequest_ForEmbeddingsModel() throws IOException {
|
public void testInfer_SendsRequest_ForCohereEmbeddingsModel() throws IOException {
|
||||||
var sender = mock(Sender.class);
|
var sender = mock(Sender.class);
|
||||||
var factory = mock(HttpRequestSender.Factory.class);
|
var factory = mock(HttpRequestSender.Factory.class);
|
||||||
when(factory.createSender()).thenReturn(sender);
|
when(factory.createSender()).thenReturn(sender);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue