Add working dense text embeddings integration with default endpoint. Some tests WIP

This commit is contained in:
Tim Grein 2025-06-23 13:19:47 +02:00
parent 6e4cb8142b
commit f054dca0b3
19 changed files with 1092 additions and 214 deletions

View file

@ -157,6 +157,7 @@ public class TransportVersions {
public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14); public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14);
public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15); public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15);
public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16); public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16);
public static final TransportVersion ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED_8_19 = def(8_841_0_17);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01); public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02); public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
@ -214,6 +215,7 @@ public class TransportVersions {
public static final TransportVersion ESQL_REMOVE_AGGREGATE_TYPE = def(9_045_0_00); public static final TransportVersion ESQL_REMOVE_AGGREGATE_TYPE = def(9_045_0_00);
public static final TransportVersion ADD_PROJECT_ID_TO_DSL_ERROR_INFO = def(9_046_0_00); public static final TransportVersion ADD_PROJECT_ID_TO_DSL_ERROR_INFO = def(9_046_0_00);
public static final TransportVersion SEMANTIC_TEXT_CHUNKING_CONFIG = def(9_047_00_0); public static final TransportVersion SEMANTIC_TEXT_CHUNKING_CONFIG = def(9_047_00_0);
public static final TransportVersion ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED = def(9_048_00_0);
/* /*
* STOP! READ THIS FIRST! No, really, * STOP! READ THIS FIRST! No, really,

View file

@ -26,7 +26,7 @@ public class InferenceGetModelsWithElasticInferenceServiceIT extends BaseMockEIS
var allModels = getAllModels(); var allModels = getAllModels();
var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION); var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION);
assertThat(allModels, hasSize(5)); assertThat(allModels, hasSize(6));
assertThat(chatCompletionModels, hasSize(1)); assertThat(chatCompletionModels, hasSize(1));
for (var model : chatCompletionModels) { for (var model : chatCompletionModels) {
@ -35,6 +35,7 @@ public class InferenceGetModelsWithElasticInferenceServiceIT extends BaseMockEIS
assertInferenceIdTaskType(allModels, ".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION); assertInferenceIdTaskType(allModels, ".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION);
assertInferenceIdTaskType(allModels, ".elser-v2-elastic", TaskType.SPARSE_EMBEDDING); assertInferenceIdTaskType(allModels, ".elser-v2-elastic", TaskType.SPARSE_EMBEDDING);
assertInferenceIdTaskType(allModels, ".multilingual-embed-elastic", TaskType.TEXT_EMBEDDING);
} }
private static void assertInferenceIdTaskType(List<Map<String, Object>> models, String inferenceId, TaskType taskType) { private static void assertInferenceIdTaskType(List<Map<String, Object>> models, String inferenceId, TaskType taskType) {

View file

@ -64,7 +64,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void testGetServicesWithTextEmbeddingTaskType() throws IOException { public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.TEXT_EMBEDDING); List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
assertThat(services.size(), equalTo(15)); assertThat(services.size(), equalTo(16));
String[] providers = new String[services.size()]; String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) { for (int i = 0; i < services.size(); i++) {
@ -79,6 +79,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
"azureaistudio", "azureaistudio",
"azureopenai", "azureopenai",
"cohere", "cohere",
"elastic",
"elasticsearch", "elasticsearch",
"googleaistudio", "googleaistudio",
"googlevertexai", "googlevertexai",

View file

@ -36,6 +36,10 @@ public class MockElasticInferenceServiceAuthorizationServer implements TestRule
{ {
"model_name": "elser-v2", "model_name": "elser-v2",
"task_types": ["embed/text/sparse"] "task_types": ["embed/text/sparse"]
},
{
"model_name": "multilingual-embed",
"task_types": ["embed/text/dense"]
} }
] ]
} }

View file

@ -11,6 +11,7 @@ import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.MinimalServiceSettings;
import org.elasticsearch.inference.Model; import org.elasticsearch.inference.Model;
@ -197,6 +198,10 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
{ {
"model_name": "elser-v2", "model_name": "elser-v2",
"task_types": ["embed/text/sparse"] "task_types": ["embed/text/sparse"]
},
{
"model_name": "multilingual-embed",
"task_types": ["embed/text/dense"]
} }
] ]
} }
@ -221,16 +226,33 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
".rainbow-sprinkles-elastic", ".rainbow-sprinkles-elastic",
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
service service
),
new InferenceService.DefaultConfigId(
".multilingual-embed-elastic",
MinimalServiceSettings.textEmbedding(
ElasticInferenceService.NAME,
ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(),
DenseVectorFieldMapper.ElementType.FLOAT
),
service
) )
) )
) )
); );
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING))); assertThat(
service.supportedTaskTypes(),
is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING))
);
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>(); PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
service.defaultConfigs(listener); service.defaultConfigs(listener);
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic")); assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic"));
assertThat(listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); assertThat(listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
assertThat(
listener.actionGet(TIMEOUT).get(2).getConfigurations().getInferenceEntityId(),
is(".multilingual-embed-elastic")
);
var getModelListener = new PlainActionFuture<UnparsedModel>(); var getModelListener = new PlainActionFuture<UnparsedModel>();
// persists the default endpoints // persists the default endpoints
@ -267,6 +289,16 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
".elser-v2-elastic", ".elser-v2-elastic",
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
service service
),
new InferenceService.DefaultConfigId(
".multilingual-embed-elastic",
MinimalServiceSettings.textEmbedding(
ElasticInferenceService.NAME,
ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(),
DenseVectorFieldMapper.ElementType.FLOAT
),
service
) )
) )
) )

View file

@ -0,0 +1,93 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.request.elastic;
import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpRequestBase;
import org.apache.http.entity.ByteArrayEntity;
import org.apache.http.message.BasicHeader;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceSparseEmbeddingsRequest.inputTypeToUsageContext;
public class ElasticInferenceServiceDenseTextEmbeddingsRequest extends ElasticInferenceServiceRequest {
private final URI uri;
private final ElasticInferenceServiceDenseTextEmbeddingsModel model;
private final List<String> inputs;
private final TraceContextHandler traceContextHandler;
private final InputType inputType;
public ElasticInferenceServiceDenseTextEmbeddingsRequest(
ElasticInferenceServiceDenseTextEmbeddingsModel model,
List<String> inputs,
TraceContext traceContext,
ElasticInferenceServiceRequestMetadata metadata,
InputType inputType
) {
super(metadata);
this.inputs = inputs;
this.model = Objects.requireNonNull(model);
this.uri = model.uri();
this.traceContextHandler = new TraceContextHandler(traceContext);
this.inputType = inputType;
}
@Override
public HttpRequestBase createHttpRequestBase() {
var httpPost = new HttpPost(uri);
var usageContext = inputTypeToUsageContext(inputType);
var requestEntity = Strings.toString(
new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(inputs, model.getServiceSettings().modelId(), usageContext)
);
ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
httpPost.setEntity(byteEntity);
traceContextHandler.propagateTraceContext(httpPost);
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));
return httpPost;
}
public TraceContext getTraceContext() {
return traceContextHandler.traceContext();
}
@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
}
@Override
public URI getURI() {
return this.uri;
}
@Override
public Request truncate() {
return this;
}
@Override
public boolean[] getTruncationInfo() {
return null;
}
}

View file

@ -0,0 +1,57 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.request.elastic;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceUsageContext;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
public record ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(
List<String> inputs,
String modelId,
@Nullable ElasticInferenceServiceUsageContext usageContext
) implements ToXContentObject {
private static final String INPUT_FIELD = "input";
private static final String MODEL_FIELD = "model";
private static final String USAGE_CONTEXT = "usage_context";
public ElasticInferenceServiceDenseTextEmbeddingsRequestEntity {
Objects.requireNonNull(inputs);
Objects.requireNonNull(modelId);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.startArray(INPUT_FIELD);
for (String input : inputs) {
builder.value(input);
}
builder.endArray();
builder.field(MODEL_FIELD, modelId);
// optional field
if ((usageContext == ElasticInferenceServiceUsageContext.UNSPECIFIED) == false) {
builder.field(USAGE_CONTEXT, usageContext);
}
builder.endObject();
return builder;
}
}

View file

@ -43,7 +43,9 @@ public class ElasticInferenceServiceAuthorizationResponseEntity implements Infer
"embed/text/sparse", "embed/text/sparse",
TaskType.SPARSE_EMBEDDING, TaskType.SPARSE_EMBEDDING,
"chat", "chat",
TaskType.CHAT_COMPLETION TaskType.CHAT_COMPLETION,
"embed/text/dense",
TaskType.TEXT_EMBEDDING
); );
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")

View file

@ -0,0 +1,107 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.response.elastic;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.XContentParserUtils;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField;
public class ElasticInferenceServiceDenseTextEmbeddingsResponseEntity {
private static final String FAILED_TO_FIND_FIELD_TEMPLATE =
"Failed to find required field [%s] in Elastic Inference Service dense text embeddings response";
/**
* Parses the Elastic Inference Service Dense Text Embeddings response.
*
* For a request like:
*
* <pre>
* <code>
* {
* "inputs": ["Embed this text", "Embed this text, too"]
* }
* </code>
* </pre>
*
* The response would look like:
*
* <pre>
* <code>
* {
* "data": [
* [
* 2.1259406,
* 1.7073475,
* 0.9020516
* ],
* (...)
* ],
* "meta": {
* "usage": {...}
* }
* }
* </code>
* </pre>
*/
public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
moveToFirstToken(jsonParser);
XContentParser.Token token = jsonParser.currentToken();
ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser);
positionParserAtTokenAfterField(jsonParser, "data", FAILED_TO_FIND_FIELD_TEMPLATE);
List<TextEmbeddingFloatResults.Embedding> parsedEmbeddings = parseList(
jsonParser,
(parser, index) -> ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.parseTextEmbeddingObject(parser)
);
if (parsedEmbeddings.isEmpty()) {
return new TextEmbeddingFloatResults(Collections.emptyList());
}
return new TextEmbeddingFloatResults(parsedEmbeddings);
}
}
private static TextEmbeddingFloatResults.Embedding parseTextEmbeddingObject(XContentParser parser) throws IOException {
List<Float> embeddingValueList = parseList(
parser,
ElasticInferenceServiceDenseTextEmbeddingsResponseEntity::parseEmbeddingFloatValueList
);
return TextEmbeddingFloatResults.Embedding.of(embeddingValueList);
}
private static float parseEmbeddingFloatValueList(XContentParser parser) throws IOException {
XContentParser.Token token = parser.currentToken();
XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser);
return parser.floatValue();
}
private ElasticInferenceServiceDenseTextEmbeddingsResponseEntity() {}
}

View file

@ -51,11 +51,11 @@ public class ElasticInferenceServiceSparseEmbeddingsResponseEntity {
* <code> * <code>
* { * {
* "data": [ * "data": [
* { * [
* "Embed": 2.1259406, * 2.1259406,
* "this": 1.7073475, * 1.7073475,
* "text": 0.9020516 * 0.9020516
* }, * ],
* (...) * (...)
* ], * ],
* "meta": { * "meta": {

View file

@ -16,7 +16,9 @@ import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptySecretSettings;
import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration;
@ -27,6 +29,7 @@ import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.SettingsConfiguration;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.RestStatus;
@ -36,6 +39,8 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@ -51,6 +56,8 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticI
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -70,6 +77,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
@ -79,10 +87,18 @@ public class ElasticInferenceService extends SenderService {
public static final String NAME = "elastic"; public static final String NAME = "elastic";
public static final String ELASTIC_INFERENCE_SERVICE_IDENTIFIER = "Elastic Inference Service"; public static final String ELASTIC_INFERENCE_SERVICE_IDENTIFIER = "Elastic Inference Service";
public static final Integer DENSE_TEXT_EMBEDDINGS_DIMENSIONS = 1024;
private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION); private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(
TaskType.SPARSE_EMBEDDING,
TaskType.CHAT_COMPLETION,
TaskType.TEXT_EMBEDDING
);
private static final String SERVICE_NAME = "Elastic"; private static final String SERVICE_NAME = "Elastic";
// TODO: check with team, what makes the most sense
private static final Integer DENSE_TEXT_EMBEDDINGS_MAX_BATCH_SIZE = 32;
// rainbow-sprinkles // rainbow-sprinkles
static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles"; static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles";
static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1);
@ -91,10 +107,17 @@ public class ElasticInferenceService extends SenderService {
static final String DEFAULT_ELSER_MODEL_ID_V2 = "elser-v2"; static final String DEFAULT_ELSER_MODEL_ID_V2 = "elser-v2";
static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId(DEFAULT_ELSER_MODEL_ID_V2); static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId(DEFAULT_ELSER_MODEL_ID_V2);
// multilingual-text-embed
static final String DEFAULT_MULTILINGUAL_EMBED_MODEL_ID = "multilingual-embed";
static final String DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID = defaultEndpointId(DEFAULT_MULTILINGUAL_EMBED_MODEL_ID);
/** /**
* The task types that the {@link InferenceAction.Request} can accept. * The task types that the {@link InferenceAction.Request} can accept.
*/ */
private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING); private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(
TaskType.SPARSE_EMBEDDING,
TaskType.TEXT_EMBEDDING
);
public static String defaultEndpointId(String modelId) { public static String defaultEndpointId(String modelId) {
return Strings.format(".%s-elastic", modelId); return Strings.format(".%s-elastic", modelId);
@ -155,6 +178,31 @@ public class ElasticInferenceService extends SenderService {
elasticInferenceServiceComponents elasticInferenceServiceComponents
), ),
MinimalServiceSettings.sparseEmbedding(NAME) MinimalServiceSettings.sparseEmbedding(NAME)
),
DEFAULT_MULTILINGUAL_EMBED_MODEL_ID,
new DefaultModelConfig(
new ElasticInferenceServiceDenseTextEmbeddingsModel(
DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID,
TaskType.TEXT_EMBEDDING,
NAME,
new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(
DEFAULT_MULTILINGUAL_EMBED_MODEL_ID,
defaultDenseTextEmbeddingsSimilarity(),
null,
null,
false,
ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS
),
EmptyTaskSettings.INSTANCE,
EmptySecretSettings.INSTANCE,
elasticInferenceServiceComponents
),
MinimalServiceSettings.textEmbedding(
NAME,
DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
defaultDenseTextEmbeddingsSimilarity(),
DenseVectorFieldMapper.ElementType.FLOAT
)
) )
); );
} }
@ -270,12 +318,26 @@ public class ElasticInferenceService extends SenderService {
TimeValue timeout, TimeValue timeout,
ActionListener<List<ChunkedInference>> listener ActionListener<List<ChunkedInference>> listener
) { ) {
// Pass-through without actually performing chunking (result will have a single chunk per input) // TODO: we probably want to allow chunked inference for both sparse and dense?
ActionListener<InferenceServiceResults> inferListener = listener.delegateFailureAndWrap( if (model instanceof ElasticInferenceServiceDenseTextEmbeddingsModel == false) {
(delegate, response) -> delegate.onResponse(translateToChunkedResults(inputs, response)) listener.onFailure(createInvalidModelException(model));
); return;
}
doInfer(model, inputs, taskSettings, timeout, inferListener); ElasticInferenceServiceDenseTextEmbeddingsModel elasticInferenceServiceModel =
(ElasticInferenceServiceDenseTextEmbeddingsModel) model;
var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), getCurrentTraceInfo());
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
inputs.getInputs(),
DENSE_TEXT_EMBEDDINGS_MAX_BATCH_SIZE,
elasticInferenceServiceModel.getConfigurations().getChunkingSettings()
).batchRequestsWithListeners(listener);
for (var request : batchedRequests) {
var action = elasticInferenceServiceModel.accept(actionCreator, taskSettings);
action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
}
} }
@Override @Override
@ -294,11 +356,19 @@ public class ElasticInferenceService extends SenderService {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
ChunkingSettings chunkingSettings = null;
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(
removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
);
}
ElasticInferenceServiceModel model = createModel( ElasticInferenceServiceModel model = createModel(
inferenceEntityId, inferenceEntityId,
taskType, taskType,
serviceSettingsMap, serviceSettingsMap,
taskSettingsMap, taskSettingsMap,
chunkingSettings,
serviceSettingsMap, serviceSettingsMap,
elasticInferenceServiceComponents, elasticInferenceServiceComponents,
TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
@ -335,6 +405,7 @@ public class ElasticInferenceService extends SenderService {
TaskType taskType, TaskType taskType,
Map<String, Object> serviceSettings, Map<String, Object> serviceSettings,
Map<String, Object> taskSettings, Map<String, Object> taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map<String, Object> secretSettings, @Nullable Map<String, Object> secretSettings,
ElasticInferenceServiceComponents eisServiceComponents, ElasticInferenceServiceComponents eisServiceComponents,
String failureMessage, String failureMessage,
@ -361,6 +432,16 @@ public class ElasticInferenceService extends SenderService {
eisServiceComponents, eisServiceComponents,
context context
); );
case TEXT_EMBEDDING -> new ElasticInferenceServiceDenseTextEmbeddingsModel(
inferenceEntityId,
taskType,
NAME,
serviceSettings,
taskSettings,
secretSettings,
eisServiceComponents,
context
);
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
}; };
} }
@ -376,11 +457,17 @@ public class ElasticInferenceService extends SenderService {
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
Map<String, Object> secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); Map<String, Object> secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);
ChunkingSettings chunkingSettings = null;
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}
return createModelFromPersistent( return createModelFromPersistent(
inferenceEntityId, inferenceEntityId,
taskType, taskType,
serviceSettingsMap, serviceSettingsMap,
taskSettingsMap, taskSettingsMap,
chunkingSettings,
secretSettingsMap, secretSettingsMap,
parsePersistedConfigErrorMsg(inferenceEntityId, NAME) parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
); );
@ -391,11 +478,17 @@ public class ElasticInferenceService extends SenderService {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
ChunkingSettings chunkingSettings = null;
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}
return createModelFromPersistent( return createModelFromPersistent(
inferenceEntityId, inferenceEntityId,
taskType, taskType,
serviceSettingsMap, serviceSettingsMap,
taskSettingsMap, taskSettingsMap,
chunkingSettings,
null, null,
parsePersistedConfigErrorMsg(inferenceEntityId, NAME) parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
); );
@ -411,6 +504,7 @@ public class ElasticInferenceService extends SenderService {
TaskType taskType, TaskType taskType,
Map<String, Object> serviceSettings, Map<String, Object> serviceSettings,
Map<String, Object> taskSettings, Map<String, Object> taskSettings,
ChunkingSettings chunkingSettings,
@Nullable Map<String, Object> secretSettings, @Nullable Map<String, Object> secretSettings,
String failureMessage String failureMessage
) { ) {
@ -419,6 +513,7 @@ public class ElasticInferenceService extends SenderService {
taskType, taskType,
serviceSettings, serviceSettings,
taskSettings, taskSettings,
chunkingSettings,
secretSettings, secretSettings,
elasticInferenceServiceComponents, elasticInferenceServiceComponents,
failureMessage, failureMessage,
@ -432,6 +527,36 @@ public class ElasticInferenceService extends SenderService {
ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
} }
@Override
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
if (model instanceof ElasticInferenceServiceDenseTextEmbeddingsModel embeddingsModel) {
var serviceSettings = embeddingsModel.getServiceSettings();
var modelId = serviceSettings.modelId();
var similarityFromModel = serviceSettings.similarity();
var similarityToUse = similarityFromModel == null ? defaultDenseTextEmbeddingsSimilarity() : similarityFromModel;
var maxInputTokens = serviceSettings.maxInputTokens();
var dimensionsSetByUser = serviceSettings.dimensionsSetByUser();
var updateServiceSettings = new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(
modelId,
similarityToUse,
embeddingSize,
maxInputTokens,
dimensionsSetByUser,
serviceSettings.rateLimitSettings()
);
return new ElasticInferenceServiceDenseTextEmbeddingsModel(embeddingsModel, updateServiceSettings);
} else {
throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
}
}
public static SimilarityMeasure defaultDenseTextEmbeddingsSimilarity() {
// TODO: double-check
return SimilarityMeasure.COSINE;
}
private static List<ChunkedInference> translateToChunkedResults(InferenceInputs inputs, InferenceServiceResults inferenceResults) { private static List<ChunkedInference> translateToChunkedResults(InferenceInputs inputs, InferenceServiceResults inferenceResults) {
if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) { if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) {
var inputsAsList = EmbeddingsInput.of(inputs).getStringInputs(); var inputsAsList = EmbeddingsInput.of(inputs).getStringInputs();
@ -469,9 +594,9 @@ public class ElasticInferenceService extends SenderService {
configurationMap.put( configurationMap.put(
MODEL_ID, MODEL_ID,
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)).setDescription( new SettingsConfiguration.Builder(
"The name of the model to use for the inference task." EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING)
) ).setDescription("The name of the model to use for the inference task.")
.setLabel("Model ID") .setLabel("Model ID")
.setRequired(true) .setRequired(true)
.setSensitive(false) .setSensitive(false)
@ -482,7 +607,7 @@ public class ElasticInferenceService extends SenderService {
configurationMap.put( configurationMap.put(
MAX_INPUT_TOKENS, MAX_INPUT_TOKENS,
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING)).setDescription( new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING)).setDescription(
"Allows you to specify the maximum number of tokens per input." "Allows you to specify the maximum number of tokens per input."
) )
.setLabel("Maximum Input Tokens") .setLabel("Maximum Input Tokens")
@ -494,7 +619,9 @@ public class ElasticInferenceService extends SenderService {
); );
configurationMap.putAll( configurationMap.putAll(
RateLimitSettings.toSettingsConfiguration(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)) RateLimitSettings.toSettingsConfiguration(
EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING)
)
); );
return new InferenceServiceConfiguration.Builder().setService(NAME) return new InferenceServiceConfiguration.Builder().setService(NAME)

View file

@ -7,14 +7,15 @@
package org.elasticsearch.xpack.inference.services.elastic; package org.elasticsearch.xpack.inference.services.elastic;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import java.util.Objects; import java.util.Objects;
public abstract class ElasticInferenceServiceModel extends Model { public abstract class ElasticInferenceServiceModel extends RateLimitGroupingModel {
private final ElasticInferenceServiceRateLimitServiceSettings rateLimitServiceSettings; private final ElasticInferenceServiceRateLimitServiceSettings rateLimitServiceSettings;
@ -35,12 +36,18 @@ public abstract class ElasticInferenceServiceModel extends Model {
public ElasticInferenceServiceModel(ElasticInferenceServiceModel model, ServiceSettings serviceSettings) { public ElasticInferenceServiceModel(ElasticInferenceServiceModel model, ServiceSettings serviceSettings) {
super(model, serviceSettings); super(model, serviceSettings);
this.rateLimitServiceSettings = model.rateLimitServiceSettings(); this.rateLimitServiceSettings = model.rateLimitServiceSettings;
this.elasticInferenceServiceComponents = model.elasticInferenceServiceComponents(); this.elasticInferenceServiceComponents = model.elasticInferenceServiceComponents();
} }
public ElasticInferenceServiceRateLimitServiceSettings rateLimitServiceSettings() { @Override
return rateLimitServiceSettings; public int rateLimitGroupingHash() {
// We only have one model for rerank
return Objects.hash(this.getServiceSettings().modelId());
}
public RateLimitSettings rateLimitSettings() {
return rateLimitServiceSettings.rateLimitSettings();
} }
public ElasticInferenceServiceComponents elasticInferenceServiceComponents() { public ElasticInferenceServiceComponents elasticInferenceServiceComponents() {

View file

@ -20,7 +20,7 @@ public abstract class ElasticInferenceServiceRequestManager extends BaseRequestM
private final ElasticInferenceServiceRequestMetadata requestMetadata; private final ElasticInferenceServiceRequestMetadata requestMetadata;
protected ElasticInferenceServiceRequestManager(ThreadPool threadPool, ElasticInferenceServiceModel model) { protected ElasticInferenceServiceRequestManager(ThreadPool threadPool, ElasticInferenceServiceModel model) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings()); super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitSettings());
this.requestMetadata = extractRequestMetadataFromThreadContext(threadPool.getThreadContext()); this.requestMetadata = extractRequestMetadataFromThreadContext(threadPool.getThreadContext());
} }
@ -32,7 +32,7 @@ public abstract class ElasticInferenceServiceRequestManager extends BaseRequestM
public static RateLimitGrouping of(ElasticInferenceServiceModel model) { public static RateLimitGrouping of(ElasticInferenceServiceModel model) {
Objects.requireNonNull(model); Objects.requireNonNull(model);
return new RateLimitGrouping(model.rateLimitServiceSettings().modelId().hashCode()); return new RateLimitGrouping(model.rateLimitGroupingHash());
} }
} }
} }

View file

@ -9,9 +9,16 @@ package org.elasticsearch.xpack.inference.services.elastic.action;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceDenseTextEmbeddingsRequest;
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceDenseTextEmbeddingsResponseEntity;
import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceResponseHandler;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsRequestManager; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsRequestManager;
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel;
import org.elasticsearch.xpack.inference.telemetry.TraceContext; import org.elasticsearch.xpack.inference.telemetry.TraceContext;
@ -19,10 +26,16 @@ import java.util.Locale;
import java.util.Objects; import java.util.Objects;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceRequest.extractRequestMetadataFromThreadContext;
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER; import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER;
public class ElasticInferenceServiceActionCreator implements ElasticInferenceServiceActionVisitor { public class ElasticInferenceServiceActionCreator implements ElasticInferenceServiceActionVisitor {
public static final ResponseHandler DENSE_TEXT_EMBEDDINGS_HANDLER = new ElasticInferenceServiceResponseHandler(
"elastic dense text embedding",
ElasticInferenceServiceDenseTextEmbeddingsResponseEntity::fromResponse
);
private final Sender sender; private final Sender sender;
private final ServiceComponents serviceComponents; private final ServiceComponents serviceComponents;
@ -43,4 +56,26 @@ public class ElasticInferenceServiceActionCreator implements ElasticInferenceSer
); );
return new SenderExecutableAction(sender, requestManager, errorMessage); return new SenderExecutableAction(sender, requestManager, errorMessage);
} }
@Override
public ExecutableAction create(ElasticInferenceServiceDenseTextEmbeddingsModel model) {
var threadPool = serviceComponents.threadPool();
var manager = new GenericRequestManager<>(
threadPool,
model,
DENSE_TEXT_EMBEDDINGS_HANDLER,
(embeddingsInput) -> new ElasticInferenceServiceDenseTextEmbeddingsRequest(
model,
embeddingsInput.getStringInputs(),
traceContext,
extractRequestMetadataFromThreadContext(threadPool.getThreadContext()),
embeddingsInput.getInputType()
),
EmbeddingsInput.class
);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Elastic dense text embeddings");
return new SenderExecutableAction(sender, manager, failedToSendRequestErrorMessage);
}
} }

View file

@ -8,10 +8,12 @@
package org.elasticsearch.xpack.inference.services.elastic.action; package org.elasticsearch.xpack.inference.services.elastic.action;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel;
public interface ElasticInferenceServiceActionVisitor { public interface ElasticInferenceServiceActionVisitor {
ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel model); ExecutableAction create(ElasticInferenceServiceSparseEmbeddingsModel model);
ExecutableAction create(ElasticInferenceServiceDenseTextEmbeddingsModel model);
} }

View file

@ -0,0 +1,114 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.elastic.densetextembeddings;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.EmptySecretSettings;
import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.SecretSettings;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceExecutableActionModel;
import org.elasticsearch.xpack.inference.services.elastic.action.ElasticInferenceServiceActionVisitor;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Map;
public class ElasticInferenceServiceDenseTextEmbeddingsModel extends ElasticInferenceServiceExecutableActionModel {
private final URI uri;
public ElasticInferenceServiceDenseTextEmbeddingsModel(
String inferenceEntityId,
TaskType taskType,
String service,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
Map<String, Object> secrets,
ElasticInferenceServiceComponents elasticInferenceServiceComponents,
ConfigurationParseContext context
) {
this(
inferenceEntityId,
taskType,
service,
ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.fromMap(serviceSettings, context),
// TODO: we probably want dense embeddings task settings
EmptyTaskSettings.INSTANCE,
EmptySecretSettings.INSTANCE,
elasticInferenceServiceComponents
);
}
public ElasticInferenceServiceDenseTextEmbeddingsModel(
String inferenceEntityId,
TaskType taskType,
String service,
ElasticInferenceServiceDenseTextEmbeddingsServiceSettings serviceSettings,
// TODO: we probably want dense embeddings task settings
@Nullable TaskSettings taskSettings,
@Nullable SecretSettings secretSettings,
ElasticInferenceServiceComponents elasticInferenceServiceComponents
) {
super(
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings),
new ModelSecrets(secretSettings),
serviceSettings,
elasticInferenceServiceComponents
);
this.uri = createUri();
}
public ElasticInferenceServiceDenseTextEmbeddingsModel(
ElasticInferenceServiceDenseTextEmbeddingsModel model,
ElasticInferenceServiceDenseTextEmbeddingsServiceSettings serviceSettings
) {
super(model, serviceSettings);
this.uri = createUri();
}
@Override
public ExecutableAction accept(ElasticInferenceServiceActionVisitor visitor, Map<String, Object> taskSettings) {
return visitor.create(this);
}
@Override
public ElasticInferenceServiceDenseTextEmbeddingsServiceSettings getServiceSettings() {
return (ElasticInferenceServiceDenseTextEmbeddingsServiceSettings) super.getServiceSettings();
}
public URI uri() {
return uri;
}
private URI createUri() throws ElasticsearchStatusException {
try {
// TODO, consider transforming the base URL into a URI for better error handling.
return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/embed/text/dense");
} catch (URISyntaxException e) {
throw new ElasticsearchStatusException(
"Failed to create URI for service ["
+ this.getConfigurations().getService()
+ "] with taskType ["
+ this.getTaskType()
+ "]: "
+ e.getMessage(),
RestStatus.BAD_REQUEST,
e
);
}
}
}

View file

@ -0,0 +1,263 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.elastic.densetextembeddings;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceRateLimitServiceSettings;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.xpack.inference.services.ServiceFields.*;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType;
public class ElasticInferenceServiceDenseTextEmbeddingsServiceSettings extends FilteredXContentObject
implements
ServiceSettings,
ElasticInferenceServiceRateLimitServiceSettings {
public static final String NAME = "elastic_inference_service_dense_embeddings_service_settings";
static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user";
public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000);
private final String modelId;
private final SimilarityMeasure similarity;
private final Integer dimensions;
private final Integer maxInputTokens;
private final boolean dimensionsSetByUser;
private final RateLimitSettings rateLimitSettings;
public static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings fromMap(
Map<String, Object> map,
ConfigurationParseContext context
) {
return switch (context) {
case REQUEST -> fromRequestMap(map, context);
case PERSISTENT -> fromPersistentMap(map, context);
};
}
private static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings fromRequestMap(
Map<String, Object> map,
ConfigurationParseContext context
) {
ValidationException validationException = new ValidationException();
String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
ElasticInferenceService.NAME,
context
);
SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
var dimensionsSetByUser = dims != null;
return new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(
modelId,
similarity,
dims,
maxInputTokens,
dimensionsSetByUser,
rateLimitSettings
);
}
private static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings fromPersistentMap(
Map<String, Object> map,
ConfigurationParseContext context
) {
ValidationException validationException = new ValidationException();
String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
ElasticInferenceService.NAME,
context
);
SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
Boolean dimensionsSetByUser = removeAsType(map, DIMENSIONS_SET_BY_USER, Boolean.class);
if (dimensionsSetByUser == null) {
dimensionsSetByUser = Boolean.FALSE;
}
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
return new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(
modelId,
similarity,
dims,
maxInputTokens,
dimensionsSetByUser,
rateLimitSettings
);
}
public ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(
String modelId,
@Nullable SimilarityMeasure similarity,
@Nullable Integer dimensions,
@Nullable Integer maxInputTokens,
boolean dimensionsSetByUser,
RateLimitSettings rateLimitSettings
) {
this.modelId = modelId;
this.similarity = similarity;
this.dimensions = dimensions;
this.maxInputTokens = maxInputTokens;
this.dimensionsSetByUser = dimensionsSetByUser;
this.rateLimitSettings = rateLimitSettings;
}
public ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(StreamInput in) throws IOException {
this.modelId = in.readString();
this.similarity = in.readOptionalEnum(SimilarityMeasure.class);
this.dimensions = in.readOptionalVInt();
this.maxInputTokens = in.readOptionalVInt();
this.dimensionsSetByUser = in.readBoolean();
this.rateLimitSettings = new RateLimitSettings(in);
}
@Override
public SimilarityMeasure similarity() {
return similarity;
}
@Override
public Integer dimensions() {
return dimensions;
}
public Integer maxInputTokens() {
return maxInputTokens;
}
@Override
public String modelId() {
return modelId;
}
@Override
public RateLimitSettings rateLimitSettings() {
return rateLimitSettings;
}
@Override
public Boolean dimensionsSetByUser() {
return dimensionsSetByUser;
}
@Override
public DenseVectorFieldMapper.ElementType elementType() {
return DenseVectorFieldMapper.ElementType.FLOAT;
}
public RateLimitSettings getRateLimitSettings() {
return rateLimitSettings;
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
builder.field(MODEL_ID, modelId);
rateLimitSettings.toXContent(builder, params);
return builder;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (similarity != null) {
builder.field(SIMILARITY, similarity);
}
if (dimensions != null) {
builder.field(DIMENSIONS, dimensions);
}
if (maxInputTokens != null) {
builder.field(MAX_INPUT_TOKENS, maxInputTokens);
}
toXContentFragmentOfExposedFields(builder, params);
builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser);
builder.endObject();
return builder;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion()));
out.writeOptionalVInt(dimensions);
out.writeOptionalVInt(maxInputTokens);
out.writeBoolean(dimensionsSetByUser);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ElasticInferenceServiceDenseTextEmbeddingsServiceSettings that = (ElasticInferenceServiceDenseTextEmbeddingsServiceSettings) o;
return Objects.equals(dimensionsSetByUser, that.dimensionsSetByUser)
&& Objects.equals(similarity, that.similarity)
&& Objects.equals(dimensions, that.dimensions)
&& Objects.equals(maxInputTokens, that.maxInputTokens);
}
@Override
public int hashCode() {
return Objects.hash(similarity, dimensions, maxInputTokens, dimensionsSetByUser);
}
}

View file

@ -16,6 +16,7 @@ import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptySecretSettings;
@ -38,11 +39,9 @@ import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
@ -59,6 +58,8 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticI
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModelTests;
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -85,6 +86,7 @@ import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap; import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
@ -394,47 +396,6 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
verifyNoMoreInteractions(sender); verifyNoMoreInteractions(sender);
} }
public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException {
var sender = mock(Sender.class);
var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender()).thenReturn(sender);
var mockModel = getInvalidModel("model_id", "service_name", TaskType.TEXT_EMBEDDING);
try (var service = createService(factory)) {
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
service.infer(
mockModel,
null,
null,
null,
List.of(""),
false,
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);
var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
MatcherAssert.assertThat(
thrownException.getMessage(),
is(
"Inference entity [model_id] does not support task type [text_embedding] "
+ "for inference, the task type must be one of [sparse_embedding]."
)
);
verify(factory, times(1)).createSender();
verify(sender, times(1)).start();
}
verify(sender, times(1)).close();
verifyNoMoreInteractions(factory);
verifyNoMoreInteractions(sender);
}
public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws IOException { public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws IOException {
var sender = mock(Sender.class); var sender = mock(Sender.class);
@ -463,7 +424,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
thrownException.getMessage(), thrownException.getMessage(),
is( is(
"Inference entity [model_id] does not support task type [chat_completion] " "Inference entity [model_id] does not support task type [chat_completion] "
+ "for inference, the task type must be one of [sparse_embedding]. " + "for inference, the task type must be one of [text_embedding, sparse_embedding]. "
+ "The task type for the inference entity is chat_completion, " + "The task type for the inference entity is chat_completion, "
+ "please use the _inference/chat_completion/model_id/_stream URL." + "please use the _inference/chat_completion/model_id/_stream URL."
) )
@ -604,82 +565,6 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
} }
} }
public void testChunkedInfer_PropagatesProductUseCaseHeader() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
var elasticInferenceServiceURL = getUrl(webServer);
try (var service = createService(senderFactory, elasticInferenceServiceURL)) {
String responseJson = """
{
"data": [
{
"hello": 2.1259406,
"greet": 1.7073475
}
]
}
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
// Set up the product use case in the thread context
String productUseCase = "test-product-use-case";
threadPool.getThreadContext().putHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, productUseCase);
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(elasticInferenceServiceURL, "my-model-id");
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
try {
service.chunkedInfer(
model,
null,
List.of(new ChunkInferenceInput("input text")),
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);
var results = listener.actionGet(TIMEOUT);
// Verify the response was processed correctly
ChunkedInference inferenceResult = results.getFirst();
assertThat(inferenceResult, instanceOf(ChunkedInferenceEmbedding.class));
var sparseResult = (ChunkedInferenceEmbedding) inferenceResult;
assertThat(
sparseResult.chunks(),
is(
List.of(
new EmbeddingResults.Chunk(
new SparseEmbeddingResults.Embedding(
List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)),
false
),
new ChunkedInference.TextOffset(0, "input text".length())
)
)
)
);
// Verify the request was sent and contains expected headers
MatcherAssert.assertThat(webServer.requests(), hasSize(1));
var request = webServer.requests().getFirst();
assertNull(request.getUri().getQuery());
MatcherAssert.assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
// Check that the product use case header was set correctly
assertThat(request.getHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER), is(productUseCase));
// Verify request body
var requestMap = entityAsMap(request.getBody());
assertThat(requestMap, is(Map.of("input", List.of("input text"), "model", "my-model-id", "usage_context", "ingest")));
} finally {
// Clean up the thread context
threadPool.getThreadContext().stashContext();
}
}
}
public void testUnifiedCompletionInfer_PropagatesProductUseCaseHeader() throws IOException { public void testUnifiedCompletionInfer_PropagatesProductUseCaseHeader() throws IOException {
var elasticInferenceServiceURL = getUrl(webServer); var elasticInferenceServiceURL = getUrl(webServer);
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
@ -738,30 +623,45 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
} }
} }
public void testChunkedInfer_PassesThrough() throws IOException { public void testChunkedInfer_PropagatesProductUseCaseHeader() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
var elasticInferenceServiceURL = getUrl(webServer);
try (var service = createService(senderFactory, elasticInferenceServiceURL)) { try (var service = createService(senderFactory, getUrl(webServer))) {
// Batching will call the service with 2 inputs
String responseJson = """ String responseJson = """
{ {
"data": [ "data": [
{ [
"hello": 2.1259406, 0.123,
"greet": 1.7073475 -0.456,
0.789
],
[
0.987,
-0.654,
0.321
]
],
"meta": {
"usage": {
"total_tokens": 10
} }
] }
} }
"""; """;
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id", null);
String productUseCase = "test-product-use-case";
threadPool.getThreadContext().putHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, productUseCase);
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(elasticInferenceServiceURL, "my-model-id");
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>(); PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
// 2 inputs
service.chunkedInfer( service.chunkedInfer(
model, model,
null, null,
List.of(new ChunkInferenceInput("input text")), List.of(new ChunkInferenceInput("hello world"), new ChunkInferenceInput("dense embedding")),
new HashMap<>(), new HashMap<>(),
InputType.INGEST, InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT, InferenceAction.Request.DEFAULT_TIMEOUT,
@ -769,32 +669,123 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
); );
var results = listener.actionGet(TIMEOUT); var results = listener.actionGet(TIMEOUT);
assertThat(results.get(0), instanceOf(ChunkedInferenceEmbedding.class)); assertThat(results, hasSize(2));
var sparseResult = (ChunkedInferenceEmbedding) results.get(0);
assertThat( // Verify the response was processed correctly
sparseResult.chunks(), ChunkedInference inferenceResult = results.getFirst();
is( assertThat(inferenceResult, instanceOf(ChunkedInferenceEmbedding.class));
List.of(
new EmbeddingResults.Chunk( // Verify the request was sent and contains expected headers
new SparseEmbeddingResults.Embedding( assertThat(webServer.requests(), hasSize(1));
List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)), var request = webServer.requests().getFirst();
false assertNull(request.getUri().getQuery());
), assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
new ChunkedInference.TextOffset(0, "input text".length())
) // Check that the product use case header was set correctly
) assertThat(request.getHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER), is(productUseCase));
)
} finally {
// Clean up the thread context
threadPool.getThreadContext().stashContext();
}
}
public void testChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOException {
var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(
getUrl(webServer),
"my-dense-model-id",
createRandomChunkingSettings()
);
testChunkedInfer_BatchesCalls(model);
}
public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id", null);
testChunkedInfer_BatchesCalls(model);
}
private void testChunkedInfer_BatchesCalls(ElasticInferenceServiceDenseTextEmbeddingsModel model) throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = createService(senderFactory, getUrl(webServer))) {
// Batching will call the service with 2 inputs
String responseJson = """
{
"data": [
[
0.123,
-0.456,
0.789
],
[
0.987,
-0.654,
0.321
]
],
"meta": {
"usage": {
"total_tokens": 10
}
}
}
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
// 2 inputs
service.chunkedInfer(
model,
null,
List.of(new ChunkInferenceInput("hello world"), new ChunkInferenceInput("dense embedding")),
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
); );
var results = listener.actionGet(TIMEOUT);
assertThat(results, hasSize(2));
// First result
{
assertThat(results.getFirst(), instanceOf(ChunkedInferenceEmbedding.class));
var denseResult = (ChunkedInferenceEmbedding) results.getFirst();
assertThat(denseResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, "hello world".length()), denseResult.chunks().getFirst().offset());
assertThat(denseResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
var embedding = (TextEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding();
assertArrayEquals(new float[] { 0.123f, -0.456f, 0.789f }, embedding.values(), 0.0f);
}
// Second result
{
assertThat(results.get(1), instanceOf(ChunkedInferenceEmbedding.class));
var denseResult = (ChunkedInferenceEmbedding) results.get(1);
assertThat(denseResult.chunks(), hasSize(1));
assertEquals(new ChunkedInference.TextOffset(0, "dense embedding".length()), denseResult.chunks().getFirst().offset());
assertThat(denseResult.chunks().getFirst().embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
var embedding = (TextEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding();
assertArrayEquals(new float[] { 0.987f, -0.654f, 0.321f }, embedding.values(), 0.0f);
}
MatcherAssert.assertThat(webServer.requests(), hasSize(1)); MatcherAssert.assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery()); assertNull(webServer.requests().getFirst().getUri().getQuery());
MatcherAssert.assertThat( MatcherAssert.assertThat(
webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE),
equalTo(XContentType.JSON.mediaType()) equalTo(XContentType.JSON.mediaType())
); );
var requestMap = entityAsMap(webServer.requests().get(0).getBody()); var requestMap = entityAsMap(webServer.requests().getFirst().getBody());
assertThat(requestMap, is(Map.of("input", List.of("input text"), "model", "my-model-id", "usage_context", "ingest"))); MatcherAssert.assertThat(
requestMap,
is(Map.of("input", List.of("hello world", "dense embedding"), "model", "my-dense-model-id", "usage_context", "ingest"))
);
} }
} }
@ -806,27 +797,6 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
} }
} }
public void testHideFromConfigurationApi_ReturnsTrue_WithModelTaskTypesThatAreNotImplemented() throws Exception {
try (
var service = createServiceWithMockSender(
ElasticInferenceServiceAuthorizationModel.of(
new ElasticInferenceServiceAuthorizationResponseEntity(
List.of(
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
"model-1",
EnumSet.of(TaskType.TEXT_EMBEDDING)
)
)
)
)
)
) {
ensureAuthorizationCallFinished(service);
assertTrue(service.hideFromConfigurationApi());
}
}
public void testHideFromConfigurationApi_ReturnsFalse_WithAvailableModels() throws Exception { public void testHideFromConfigurationApi_ReturnsFalse_WithAvailableModels() throws Exception {
try ( try (
var service = createServiceWithMockSender( var service = createServiceWithMockSender(
@ -856,7 +826,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
List.of( List.of(
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
"model-1", "model-1",
EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION) EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING)
) )
) )
) )
@ -869,7 +839,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
{ {
"service": "elastic", "service": "elastic",
"name": "Elastic", "name": "Elastic",
"task_types": ["sparse_embedding", "chat_completion"], "task_types": ["sparse_embedding", "chat_completion", "text_embedding"],
"configurations": { "configurations": {
"rate_limit.requests_per_minute": { "rate_limit.requests_per_minute": {
"description": "Minimize the number of rate limit errors.", "description": "Minimize the number of rate limit errors.",
@ -878,7 +848,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
"sensitive": false, "sensitive": false,
"updatable": false, "updatable": false,
"type": "int", "type": "int",
"supported_task_types": ["sparse_embedding" , "chat_completion"] "supported_task_types": ["text_embedding", "sparse_embedding" , "chat_completion"]
}, },
"model_id": { "model_id": {
"description": "The name of the model to use for the inference task.", "description": "The name of the model to use for the inference task.",
@ -887,7 +857,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
"sensitive": false, "sensitive": false,
"updatable": false, "updatable": false,
"type": "str", "type": "str",
"supported_task_types": ["sparse_embedding" , "chat_completion"] "supported_task_types": ["text_embedding", "sparse_embedding" , "chat_completion"]
}, },
"max_input_tokens": { "max_input_tokens": {
"description": "Allows you to specify the maximum number of tokens per input.", "description": "Allows you to specify the maximum number of tokens per input.",
@ -896,7 +866,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
"sensitive": false, "sensitive": false,
"updatable": false, "updatable": false,
"type": "int", "type": "int",
"supported_task_types": ["sparse_embedding"] "supported_task_types": ["text_embedding", "sparse_embedding"]
} }
} }
} }
@ -933,7 +903,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
"sensitive": false, "sensitive": false,
"updatable": false, "updatable": false,
"type": "int", "type": "int",
"supported_task_types": ["sparse_embedding" , "chat_completion"] "supported_task_types": ["text_embedding", "sparse_embedding" , "chat_completion"]
}, },
"model_id": { "model_id": {
"description": "The name of the model to use for the inference task.", "description": "The name of the model to use for the inference task.",
@ -942,7 +912,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
"sensitive": false, "sensitive": false,
"updatable": false, "updatable": false,
"type": "str", "type": "str",
"supported_task_types": ["sparse_embedding" , "chat_completion"] "supported_task_types": ["text_embedding", "sparse_embedding" , "chat_completion"]
}, },
"max_input_tokens": { "max_input_tokens": {
"description": "Allows you to specify the maximum number of tokens per input.", "description": "Allows you to specify the maximum number of tokens per input.",
@ -951,7 +921,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
"sensitive": false, "sensitive": false,
"updatable": false, "updatable": false,
"type": "int", "type": "int",
"supported_task_types": ["sparse_embedding"] "supported_task_types": ["text_embedding", "sparse_embedding"]
} }
} }
} }
@ -993,7 +963,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
{ {
"service": "elastic", "service": "elastic",
"name": "Elastic", "name": "Elastic",
"task_types": [], "task_types": ["text_embedding"],
"configurations": { "configurations": {
"rate_limit.requests_per_minute": { "rate_limit.requests_per_minute": {
"description": "Minimize the number of rate limit errors.", "description": "Minimize the number of rate limit errors.",
@ -1002,7 +972,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
"sensitive": false, "sensitive": false,
"updatable": false, "updatable": false,
"type": "int", "type": "int",
"supported_task_types": ["sparse_embedding" , "chat_completion"] "supported_task_types": ["text_embedding" , "sparse_embedding", "chat_completion"]
}, },
"model_id": { "model_id": {
"description": "The name of the model to use for the inference task.", "description": "The name of the model to use for the inference task.",
@ -1011,7 +981,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
"sensitive": false, "sensitive": false,
"updatable": false, "updatable": false,
"type": "str", "type": "str",
"supported_task_types": ["sparse_embedding" , "chat_completion"] "supported_task_types": ["text_embedding" , "sparse_embedding", "chat_completion"]
}, },
"max_input_tokens": { "max_input_tokens": {
"description": "Allows you to specify the maximum number of tokens per input.", "description": "Allows you to specify the maximum number of tokens per input.",
@ -1020,7 +990,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
"sensitive": false, "sensitive": false,
"updatable": false, "updatable": false,
"type": "int", "type": "int",
"supported_task_types": ["sparse_embedding"] "supported_task_types": ["text_embedding", "sparse_embedding"]
} }
} }
} }
@ -1197,6 +1167,10 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
{ {
"model_name": "elser-v2", "model_name": "elser-v2",
"task_types": ["embed/text/sparse"] "task_types": ["embed/text/sparse"]
},
{
"model_name": "multilingual-embed",
"task_types": ["embed/text/dense"]
} }
] ]
} }
@ -1218,6 +1192,16 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
service service
), ),
new InferenceService.DefaultConfigId(
".multilingual-embed-elastic",
MinimalServiceSettings.textEmbedding(
ElasticInferenceService.NAME,
ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(),
DenseVectorFieldMapper.ElementType.FLOAT
),
service
),
new InferenceService.DefaultConfigId( new InferenceService.DefaultConfigId(
".rainbow-sprinkles-elastic", ".rainbow-sprinkles-elastic",
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
@ -1226,14 +1210,18 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
) )
) )
); );
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING))); assertThat(
service.supportedTaskTypes(),
is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING))
);
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>(); PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
service.defaultConfigs(listener); service.defaultConfigs(listener);
var models = listener.actionGet(TIMEOUT); var models = listener.actionGet(TIMEOUT);
assertThat(models.size(), is(2)); assertThat(models.size(), is(3));
assertThat(models.get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic")); assertThat(models.get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic"));
assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".multilingual-embed-elastic"));
assertThat(models.get(2).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
} }
} }

View file

@ -0,0 +1,43 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.elastic.densetextembeddings;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.EmptySecretSettings;
import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
public class ElasticInferenceServiceDenseTextEmbeddingsModelTests {
public static ElasticInferenceServiceDenseTextEmbeddingsModel createModel(
String url,
String modelId,
ChunkingSettings chunkingSettings
) {
return new ElasticInferenceServiceDenseTextEmbeddingsModel(
"id",
TaskType.TEXT_EMBEDDING,
"elastic",
new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(
modelId,
SimilarityMeasure.COSINE,
null,
null,
false,
new RateLimitSettings(1000L)
),
EmptyTaskSettings.INSTANCE,
EmptySecretSettings.INSTANCE,
ElasticInferenceServiceComponents.of(url)
);
}
}