mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 01:22:26 -04:00
[EIS] Dense Text Embedding task type integration (#129847)
This commit is contained in:
parent
0e2362432c
commit
3b51dd568c
26 changed files with 1860 additions and 266 deletions
|
@ -206,6 +206,7 @@ public class TransportVersions {
|
||||||
public static final TransportVersion RANDOM_SAMPLER_QUERY_BUILDER_8_19 = def(8_841_0_56);
|
public static final TransportVersion RANDOM_SAMPLER_QUERY_BUILDER_8_19 = def(8_841_0_56);
|
||||||
public static final TransportVersion ML_INFERENCE_SAGEMAKER_ELASTIC_8_19 = def(8_841_0_57);
|
public static final TransportVersion ML_INFERENCE_SAGEMAKER_ELASTIC_8_19 = def(8_841_0_57);
|
||||||
public static final TransportVersion SPARSE_VECTOR_FIELD_PRUNING_OPTIONS_8_19 = def(8_841_0_58);
|
public static final TransportVersion SPARSE_VECTOR_FIELD_PRUNING_OPTIONS_8_19 = def(8_841_0_58);
|
||||||
|
public static final TransportVersion ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED_8_19 = def(8_841_0_59);
|
||||||
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
|
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
|
||||||
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
|
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
|
||||||
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
|
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
|
||||||
|
@ -318,6 +319,7 @@ public class TransportVersions {
|
||||||
public static final TransportVersion ML_INFERENCE_SAGEMAKER_ELASTIC = def(9_106_0_00);
|
public static final TransportVersion ML_INFERENCE_SAGEMAKER_ELASTIC = def(9_106_0_00);
|
||||||
public static final TransportVersion SPARSE_VECTOR_FIELD_PRUNING_OPTIONS = def(9_107_0_00);
|
public static final TransportVersion SPARSE_VECTOR_FIELD_PRUNING_OPTIONS = def(9_107_0_00);
|
||||||
public static final TransportVersion CLUSTER_STATE_PROJECTS_SETTINGS = def(9_108_0_00);
|
public static final TransportVersion CLUSTER_STATE_PROJECTS_SETTINGS = def(9_108_0_00);
|
||||||
|
public static final TransportVersion ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED = def(9_109_00_0);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* STOP! READ THIS FIRST! No, really,
|
* STOP! READ THIS FIRST! No, really,
|
||||||
|
|
|
@ -33,7 +33,7 @@ public class InferenceGetModelsWithElasticInferenceServiceIT extends BaseMockEIS
|
||||||
var allModels = getAllModels();
|
var allModels = getAllModels();
|
||||||
var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION);
|
var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION);
|
||||||
|
|
||||||
assertThat(allModels, hasSize(6));
|
assertThat(allModels, hasSize(7));
|
||||||
assertThat(chatCompletionModels, hasSize(1));
|
assertThat(chatCompletionModels, hasSize(1));
|
||||||
|
|
||||||
for (var model : chatCompletionModels) {
|
for (var model : chatCompletionModels) {
|
||||||
|
@ -42,6 +42,7 @@ public class InferenceGetModelsWithElasticInferenceServiceIT extends BaseMockEIS
|
||||||
|
|
||||||
assertInferenceIdTaskType(allModels, ".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION);
|
assertInferenceIdTaskType(allModels, ".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION);
|
||||||
assertInferenceIdTaskType(allModels, ".elser-v2-elastic", TaskType.SPARSE_EMBEDDING);
|
assertInferenceIdTaskType(allModels, ".elser-v2-elastic", TaskType.SPARSE_EMBEDDING);
|
||||||
|
assertInferenceIdTaskType(allModels, ".multilingual-embed-v1-elastic", TaskType.TEXT_EMBEDDING);
|
||||||
assertInferenceIdTaskType(allModels, ".rerank-v1-elastic", TaskType.RERANK);
|
assertInferenceIdTaskType(allModels, ".rerank-v1-elastic", TaskType.RERANK);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,7 @@ import java.util.Map;
|
||||||
|
|
||||||
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated;
|
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated;
|
||||||
import static org.hamcrest.Matchers.containsInAnyOrder;
|
import static org.hamcrest.Matchers.containsInAnyOrder;
|
||||||
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
|
||||||
public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
||||||
|
|
||||||
|
@ -76,16 +77,21 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
|
public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
|
||||||
|
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
|
||||||
|
assertThat(services.size(), equalTo(18));
|
||||||
|
|
||||||
assertThat(
|
assertThat(
|
||||||
providersFor(TaskType.TEXT_EMBEDDING),
|
providersFor(TaskType.TEXT_EMBEDDING),
|
||||||
containsInAnyOrder(
|
containsInAnyOrder(
|
||||||
List.of(
|
List.of(
|
||||||
"alibabacloud-ai-search",
|
"alibabacloud-ai-search",
|
||||||
"amazonbedrock",
|
"amazonbedrock",
|
||||||
|
"amazon_sagemaker",
|
||||||
"azureaistudio",
|
"azureaistudio",
|
||||||
"azureopenai",
|
"azureopenai",
|
||||||
"cohere",
|
"cohere",
|
||||||
"custom",
|
"custom",
|
||||||
|
"elastic",
|
||||||
"elasticsearch",
|
"elasticsearch",
|
||||||
"googleaistudio",
|
"googleaistudio",
|
||||||
"googlevertexai",
|
"googlevertexai",
|
||||||
|
@ -95,8 +101,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
||||||
"openai",
|
"openai",
|
||||||
"text_embedding_test_service",
|
"text_embedding_test_service",
|
||||||
"voyageai",
|
"voyageai",
|
||||||
"watsonxai",
|
"watsonxai"
|
||||||
"amazon_sagemaker"
|
|
||||||
).toArray()
|
).toArray()
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
|
@ -43,6 +43,10 @@ public class MockElasticInferenceServiceAuthorizationServer implements TestRule
|
||||||
"task_types": ["embed/text/sparse"]
|
"task_types": ["embed/text/sparse"]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"model_name": "multilingual-embed-v1",
|
||||||
|
"task_types": ["embed/text/dense"]
|
||||||
|
},
|
||||||
|
{
|
||||||
"model_name": "rerank-v1",
|
"model_name": "rerank-v1",
|
||||||
"task_types": ["rerank/text/text-similarity"]
|
"task_types": ["rerank/text/text-similarity"]
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,6 +11,7 @@ import org.elasticsearch.ResourceNotFoundException;
|
||||||
import org.elasticsearch.action.support.PlainActionFuture;
|
import org.elasticsearch.action.support.PlainActionFuture;
|
||||||
import org.elasticsearch.common.settings.Settings;
|
import org.elasticsearch.common.settings.Settings;
|
||||||
import org.elasticsearch.core.TimeValue;
|
import org.elasticsearch.core.TimeValue;
|
||||||
|
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
|
||||||
import org.elasticsearch.inference.InferenceService;
|
import org.elasticsearch.inference.InferenceService;
|
||||||
import org.elasticsearch.inference.MinimalServiceSettings;
|
import org.elasticsearch.inference.MinimalServiceSettings;
|
||||||
import org.elasticsearch.inference.Model;
|
import org.elasticsearch.inference.Model;
|
||||||
|
@ -43,6 +44,7 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||||
import static org.hamcrest.CoreMatchers.is;
|
import static org.hamcrest.CoreMatchers.is;
|
||||||
|
import static org.hamcrest.Matchers.containsInAnyOrder;
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
|
|
||||||
public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
|
public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
|
||||||
|
@ -190,13 +192,17 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
{
|
{
|
||||||
"models": [
|
"models": [
|
||||||
|
{
|
||||||
|
"model_name": "elser-v2",
|
||||||
|
"task_types": ["embed/text/sparse"]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"model_name": "rainbow-sprinkles",
|
"model_name": "rainbow-sprinkles",
|
||||||
"task_types": ["chat"]
|
"task_types": ["chat"]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"model_name": "elser-v2",
|
"model_name": "multilingual-embed-v1",
|
||||||
"task_types": ["embed/text/sparse"]
|
"task_types": ["embed/text/dense"]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"model_name": "rerank-v1",
|
"model_name": "rerank-v1",
|
||||||
|
@ -214,36 +220,48 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
|
||||||
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
|
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
|
||||||
assertThat(
|
assertThat(
|
||||||
service.defaultConfigIds(),
|
service.defaultConfigIds(),
|
||||||
is(
|
containsInAnyOrder(
|
||||||
List.of(
|
new InferenceService.DefaultConfigId(
|
||||||
new InferenceService.DefaultConfigId(
|
".elser-v2-elastic",
|
||||||
".elser-v2-elastic",
|
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
|
||||||
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
|
service
|
||||||
service
|
),
|
||||||
|
new InferenceService.DefaultConfigId(
|
||||||
|
".rainbow-sprinkles-elastic",
|
||||||
|
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
|
||||||
|
service
|
||||||
|
),
|
||||||
|
new InferenceService.DefaultConfigId(
|
||||||
|
".multilingual-embed-v1-elastic",
|
||||||
|
MinimalServiceSettings.textEmbedding(
|
||||||
|
ElasticInferenceService.NAME,
|
||||||
|
ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
|
||||||
|
ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(),
|
||||||
|
DenseVectorFieldMapper.ElementType.FLOAT
|
||||||
),
|
),
|
||||||
new InferenceService.DefaultConfigId(
|
service
|
||||||
".rainbow-sprinkles-elastic",
|
),
|
||||||
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
|
new InferenceService.DefaultConfigId(
|
||||||
service
|
".rerank-v1-elastic",
|
||||||
),
|
MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
|
||||||
new InferenceService.DefaultConfigId(
|
service
|
||||||
".rerank-v1-elastic",
|
|
||||||
MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
|
|
||||||
service
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
assertThat(
|
assertThat(
|
||||||
service.supportedTaskTypes(),
|
service.supportedTaskTypes(),
|
||||||
is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK))
|
is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK, TaskType.TEXT_EMBEDDING))
|
||||||
);
|
);
|
||||||
|
|
||||||
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
|
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
|
||||||
service.defaultConfigs(listener);
|
service.defaultConfigs(listener);
|
||||||
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic"));
|
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic"));
|
||||||
assertThat(listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
|
assertThat(
|
||||||
assertThat(listener.actionGet(TIMEOUT).get(2).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic"));
|
listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(),
|
||||||
|
is(".multilingual-embed-v1-elastic")
|
||||||
|
);
|
||||||
|
assertThat(listener.actionGet(TIMEOUT).get(2).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
|
||||||
|
assertThat(listener.actionGet(TIMEOUT).get(3).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic"));
|
||||||
|
|
||||||
var getModelListener = new PlainActionFuture<UnparsedModel>();
|
var getModelListener = new PlainActionFuture<UnparsedModel>();
|
||||||
// persists the default endpoints
|
// persists the default endpoints
|
||||||
|
@ -265,6 +283,10 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
|
||||||
{
|
{
|
||||||
"model_name": "rerank-v1",
|
"model_name": "rerank-v1",
|
||||||
"task_types": ["rerank/text/text-similarity"]
|
"task_types": ["rerank/text/text-similarity"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "multilingual-embed-v1",
|
||||||
|
"task_types": ["embed/text/dense"]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
@ -278,22 +300,33 @@ public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
|
||||||
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
|
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
|
||||||
assertThat(
|
assertThat(
|
||||||
service.defaultConfigIds(),
|
service.defaultConfigIds(),
|
||||||
is(
|
containsInAnyOrder(
|
||||||
List.of(
|
new InferenceService.DefaultConfigId(
|
||||||
new InferenceService.DefaultConfigId(
|
".elser-v2-elastic",
|
||||||
".elser-v2-elastic",
|
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
|
||||||
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
|
service
|
||||||
service
|
),
|
||||||
|
new InferenceService.DefaultConfigId(
|
||||||
|
".multilingual-embed-v1-elastic",
|
||||||
|
MinimalServiceSettings.textEmbedding(
|
||||||
|
ElasticInferenceService.NAME,
|
||||||
|
ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
|
||||||
|
ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(),
|
||||||
|
DenseVectorFieldMapper.ElementType.FLOAT
|
||||||
),
|
),
|
||||||
new InferenceService.DefaultConfigId(
|
service
|
||||||
".rerank-v1-elastic",
|
),
|
||||||
MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
|
new InferenceService.DefaultConfigId(
|
||||||
service
|
".rerank-v1-elastic",
|
||||||
)
|
MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
|
||||||
|
service
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.RERANK)));
|
assertThat(
|
||||||
|
service.supportedTaskTypes(),
|
||||||
|
is(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK))
|
||||||
|
);
|
||||||
|
|
||||||
var getModelListener = new PlainActionFuture<UnparsedModel>();
|
var getModelListener = new PlainActionFuture<UnparsedModel>();
|
||||||
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);
|
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);
|
||||||
|
|
|
@ -0,0 +1,103 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.external.response.elastic;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.xcontent.XContentParserUtils;
|
||||||
|
import org.elasticsearch.xcontent.ConstructingObjectParser;
|
||||||
|
import org.elasticsearch.xcontent.ParseField;
|
||||||
|
import org.elasticsearch.xcontent.XContentFactory;
|
||||||
|
import org.elasticsearch.xcontent.XContentParserConfiguration;
|
||||||
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
|
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||||
|
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
|
||||||
|
|
||||||
|
public class ElasticInferenceServiceDenseTextEmbeddingsResponseEntity {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parses the Elastic Inference Service Dense Text Embeddings response.
|
||||||
|
*
|
||||||
|
* For a request like:
|
||||||
|
*
|
||||||
|
* <pre>
|
||||||
|
* <code>
|
||||||
|
* {
|
||||||
|
* "inputs": ["Embed this text", "Embed this text, too"]
|
||||||
|
* }
|
||||||
|
* </code>
|
||||||
|
* </pre>
|
||||||
|
*
|
||||||
|
* The response would look like:
|
||||||
|
*
|
||||||
|
* <pre>
|
||||||
|
* <code>
|
||||||
|
* {
|
||||||
|
* "data": [
|
||||||
|
* [
|
||||||
|
* 2.1259406,
|
||||||
|
* 1.7073475,
|
||||||
|
* 0.9020516
|
||||||
|
* ],
|
||||||
|
* (...)
|
||||||
|
* ],
|
||||||
|
* "meta": {
|
||||||
|
* "usage": {...}
|
||||||
|
* }
|
||||||
|
* }
|
||||||
|
* </code>
|
||||||
|
* </pre>
|
||||||
|
*/
|
||||||
|
public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
|
||||||
|
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) {
|
||||||
|
return EmbeddingFloatResult.PARSER.apply(p, null).toTextEmbeddingFloatResults();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public record EmbeddingFloatResult(List<EmbeddingFloatResultEntry> embeddingResults) {
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public static final ConstructingObjectParser<EmbeddingFloatResult, Void> PARSER = new ConstructingObjectParser<>(
|
||||||
|
EmbeddingFloatResult.class.getSimpleName(),
|
||||||
|
true,
|
||||||
|
args -> new EmbeddingFloatResult((List<EmbeddingFloatResultEntry>) args[0])
|
||||||
|
);
|
||||||
|
|
||||||
|
static {
|
||||||
|
// Custom field declaration to handle array of arrays format
|
||||||
|
PARSER.declareField(constructorArg(), (parser, context) -> {
|
||||||
|
return XContentParserUtils.parseList(parser, (p, index) -> {
|
||||||
|
List<Float> embedding = XContentParserUtils.parseList(p, (innerParser, innerIndex) -> innerParser.floatValue());
|
||||||
|
return EmbeddingFloatResultEntry.fromFloatArray(embedding);
|
||||||
|
});
|
||||||
|
}, new ParseField("data"), org.elasticsearch.xcontent.ObjectParser.ValueType.OBJECT_ARRAY);
|
||||||
|
}
|
||||||
|
|
||||||
|
public TextEmbeddingFloatResults toTextEmbeddingFloatResults() {
|
||||||
|
return new TextEmbeddingFloatResults(
|
||||||
|
embeddingResults.stream().map(entry -> TextEmbeddingFloatResults.Embedding.of(entry.embedding)).toList()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents a single embedding entry in the response.
|
||||||
|
* For the Elastic Inference Service, each entry is just an array of floats (no wrapper object).
|
||||||
|
* This is a simpler wrapper that just holds the float array.
|
||||||
|
*/
|
||||||
|
public record EmbeddingFloatResultEntry(List<Float> embedding) {
|
||||||
|
public static EmbeddingFloatResultEntry fromFloatArray(List<Float> floats) {
|
||||||
|
return new EmbeddingFloatResultEntry(floats);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private ElasticInferenceServiceDenseTextEmbeddingsResponseEntity() {}
|
||||||
|
}
|
|
@ -16,6 +16,7 @@ import org.elasticsearch.common.ValidationException;
|
||||||
import org.elasticsearch.common.util.LazyInitializable;
|
import org.elasticsearch.common.util.LazyInitializable;
|
||||||
import org.elasticsearch.core.Nullable;
|
import org.elasticsearch.core.Nullable;
|
||||||
import org.elasticsearch.core.TimeValue;
|
import org.elasticsearch.core.TimeValue;
|
||||||
|
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
|
||||||
import org.elasticsearch.inference.ChunkedInference;
|
import org.elasticsearch.inference.ChunkedInference;
|
||||||
import org.elasticsearch.inference.ChunkingSettings;
|
import org.elasticsearch.inference.ChunkingSettings;
|
||||||
import org.elasticsearch.inference.EmptySecretSettings;
|
import org.elasticsearch.inference.EmptySecretSettings;
|
||||||
|
@ -28,6 +29,7 @@ import org.elasticsearch.inference.Model;
|
||||||
import org.elasticsearch.inference.ModelConfigurations;
|
import org.elasticsearch.inference.ModelConfigurations;
|
||||||
import org.elasticsearch.inference.ModelSecrets;
|
import org.elasticsearch.inference.ModelSecrets;
|
||||||
import org.elasticsearch.inference.SettingsConfiguration;
|
import org.elasticsearch.inference.SettingsConfiguration;
|
||||||
|
import org.elasticsearch.inference.SimilarityMeasure;
|
||||||
import org.elasticsearch.inference.TaskType;
|
import org.elasticsearch.inference.TaskType;
|
||||||
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
|
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
|
||||||
import org.elasticsearch.rest.RestStatus;
|
import org.elasticsearch.rest.RestStatus;
|
||||||
|
@ -54,6 +56,8 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticI
|
||||||
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;
|
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;
|
||||||
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
|
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
|
||||||
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
|
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel;
|
||||||
|
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsServiceSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
|
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
|
||||||
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings;
|
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel;
|
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel;
|
||||||
|
@ -84,15 +88,20 @@ public class ElasticInferenceService extends SenderService {
|
||||||
|
|
||||||
public static final String NAME = "elastic";
|
public static final String NAME = "elastic";
|
||||||
public static final String ELASTIC_INFERENCE_SERVICE_IDENTIFIER = "Elastic Inference Service";
|
public static final String ELASTIC_INFERENCE_SERVICE_IDENTIFIER = "Elastic Inference Service";
|
||||||
public static final int SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE = 512;
|
public static final Integer DENSE_TEXT_EMBEDDINGS_DIMENSIONS = 1024;
|
||||||
|
public static final Integer SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE = 512;
|
||||||
|
|
||||||
private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(
|
private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(
|
||||||
TaskType.SPARSE_EMBEDDING,
|
TaskType.SPARSE_EMBEDDING,
|
||||||
TaskType.CHAT_COMPLETION,
|
TaskType.CHAT_COMPLETION,
|
||||||
TaskType.RERANK
|
TaskType.RERANK,
|
||||||
|
TaskType.TEXT_EMBEDDING
|
||||||
);
|
);
|
||||||
private static final String SERVICE_NAME = "Elastic";
|
private static final String SERVICE_NAME = "Elastic";
|
||||||
|
|
||||||
|
// TODO: check with team, what makes the most sense
|
||||||
|
private static final Integer DENSE_TEXT_EMBEDDINGS_MAX_BATCH_SIZE = 32;
|
||||||
|
|
||||||
// rainbow-sprinkles
|
// rainbow-sprinkles
|
||||||
static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles";
|
static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles";
|
||||||
static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1);
|
static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1);
|
||||||
|
@ -101,6 +110,10 @@ public class ElasticInferenceService extends SenderService {
|
||||||
static final String DEFAULT_ELSER_MODEL_ID_V2 = "elser-v2";
|
static final String DEFAULT_ELSER_MODEL_ID_V2 = "elser-v2";
|
||||||
static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId(DEFAULT_ELSER_MODEL_ID_V2);
|
static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId(DEFAULT_ELSER_MODEL_ID_V2);
|
||||||
|
|
||||||
|
// multilingual-text-embed
|
||||||
|
static final String DEFAULT_MULTILINGUAL_EMBED_MODEL_ID = "multilingual-embed-v1";
|
||||||
|
static final String DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID = defaultEndpointId(DEFAULT_MULTILINGUAL_EMBED_MODEL_ID);
|
||||||
|
|
||||||
// rerank-v1
|
// rerank-v1
|
||||||
static final String DEFAULT_RERANK_MODEL_ID_V1 = "rerank-v1";
|
static final String DEFAULT_RERANK_MODEL_ID_V1 = "rerank-v1";
|
||||||
static final String DEFAULT_RERANK_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_RERANK_MODEL_ID_V1);
|
static final String DEFAULT_RERANK_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_RERANK_MODEL_ID_V1);
|
||||||
|
@ -108,7 +121,11 @@ public class ElasticInferenceService extends SenderService {
|
||||||
/**
|
/**
|
||||||
* The task types that the {@link InferenceAction.Request} can accept.
|
* The task types that the {@link InferenceAction.Request} can accept.
|
||||||
*/
|
*/
|
||||||
private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.RERANK);
|
private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(
|
||||||
|
TaskType.SPARSE_EMBEDDING,
|
||||||
|
TaskType.RERANK,
|
||||||
|
TaskType.TEXT_EMBEDDING
|
||||||
|
);
|
||||||
|
|
||||||
public static String defaultEndpointId(String modelId) {
|
public static String defaultEndpointId(String modelId) {
|
||||||
return Strings.format(".%s-elastic", modelId);
|
return Strings.format(".%s-elastic", modelId);
|
||||||
|
@ -171,6 +188,32 @@ public class ElasticInferenceService extends SenderService {
|
||||||
),
|
),
|
||||||
MinimalServiceSettings.sparseEmbedding(NAME)
|
MinimalServiceSettings.sparseEmbedding(NAME)
|
||||||
),
|
),
|
||||||
|
DEFAULT_MULTILINGUAL_EMBED_MODEL_ID,
|
||||||
|
new DefaultModelConfig(
|
||||||
|
new ElasticInferenceServiceDenseTextEmbeddingsModel(
|
||||||
|
DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID,
|
||||||
|
TaskType.TEXT_EMBEDDING,
|
||||||
|
NAME,
|
||||||
|
new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(
|
||||||
|
DEFAULT_MULTILINGUAL_EMBED_MODEL_ID,
|
||||||
|
defaultDenseTextEmbeddingsSimilarity(),
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS
|
||||||
|
),
|
||||||
|
EmptyTaskSettings.INSTANCE,
|
||||||
|
EmptySecretSettings.INSTANCE,
|
||||||
|
elasticInferenceServiceComponents,
|
||||||
|
ChunkingSettingsBuilder.DEFAULT_SETTINGS
|
||||||
|
),
|
||||||
|
MinimalServiceSettings.textEmbedding(
|
||||||
|
NAME,
|
||||||
|
DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
|
||||||
|
defaultDenseTextEmbeddingsSimilarity(),
|
||||||
|
DenseVectorFieldMapper.ElementType.FLOAT
|
||||||
|
)
|
||||||
|
),
|
||||||
|
|
||||||
DEFAULT_RERANK_MODEL_ID_V1,
|
DEFAULT_RERANK_MODEL_ID_V1,
|
||||||
new DefaultModelConfig(
|
new DefaultModelConfig(
|
||||||
new ElasticInferenceServiceRerankModel(
|
new ElasticInferenceServiceRerankModel(
|
||||||
|
@ -310,6 +353,23 @@ public class ElasticInferenceService extends SenderService {
|
||||||
TimeValue timeout,
|
TimeValue timeout,
|
||||||
ActionListener<List<ChunkedInference>> listener
|
ActionListener<List<ChunkedInference>> listener
|
||||||
) {
|
) {
|
||||||
|
if (model instanceof ElasticInferenceServiceDenseTextEmbeddingsModel denseTextEmbeddingsModel) {
|
||||||
|
var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), getCurrentTraceInfo());
|
||||||
|
|
||||||
|
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
|
||||||
|
inputs.getInputs(),
|
||||||
|
DENSE_TEXT_EMBEDDINGS_MAX_BATCH_SIZE,
|
||||||
|
denseTextEmbeddingsModel.getConfigurations().getChunkingSettings()
|
||||||
|
).batchRequestsWithListeners(listener);
|
||||||
|
|
||||||
|
for (var request : batchedRequests) {
|
||||||
|
var action = denseTextEmbeddingsModel.accept(actionCreator, taskSettings);
|
||||||
|
action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (model instanceof ElasticInferenceServiceSparseEmbeddingsModel sparseTextEmbeddingsModel) {
|
if (model instanceof ElasticInferenceServiceSparseEmbeddingsModel sparseTextEmbeddingsModel) {
|
||||||
var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), getCurrentTraceInfo());
|
var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), getCurrentTraceInfo());
|
||||||
|
|
||||||
|
@ -348,7 +408,7 @@ public class ElasticInferenceService extends SenderService {
|
||||||
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
|
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
|
||||||
|
|
||||||
ChunkingSettings chunkingSettings = null;
|
ChunkingSettings chunkingSettings = null;
|
||||||
if (TaskType.SPARSE_EMBEDDING.equals(taskType)) {
|
if (TaskType.SPARSE_EMBEDDING.equals(taskType) || TaskType.TEXT_EMBEDDING.equals(taskType)) {
|
||||||
chunkingSettings = ChunkingSettingsBuilder.fromMap(
|
chunkingSettings = ChunkingSettingsBuilder.fromMap(
|
||||||
removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
|
removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
|
||||||
);
|
);
|
||||||
|
@ -359,11 +419,11 @@ public class ElasticInferenceService extends SenderService {
|
||||||
taskType,
|
taskType,
|
||||||
serviceSettingsMap,
|
serviceSettingsMap,
|
||||||
taskSettingsMap,
|
taskSettingsMap,
|
||||||
|
chunkingSettings,
|
||||||
serviceSettingsMap,
|
serviceSettingsMap,
|
||||||
elasticInferenceServiceComponents,
|
elasticInferenceServiceComponents,
|
||||||
TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
|
TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
|
||||||
ConfigurationParseContext.REQUEST,
|
ConfigurationParseContext.REQUEST
|
||||||
chunkingSettings
|
|
||||||
);
|
);
|
||||||
|
|
||||||
throwIfNotEmptyMap(config, NAME);
|
throwIfNotEmptyMap(config, NAME);
|
||||||
|
@ -396,11 +456,11 @@ public class ElasticInferenceService extends SenderService {
|
||||||
TaskType taskType,
|
TaskType taskType,
|
||||||
Map<String, Object> serviceSettings,
|
Map<String, Object> serviceSettings,
|
||||||
Map<String, Object> taskSettings,
|
Map<String, Object> taskSettings,
|
||||||
|
ChunkingSettings chunkingSettings,
|
||||||
@Nullable Map<String, Object> secretSettings,
|
@Nullable Map<String, Object> secretSettings,
|
||||||
ElasticInferenceServiceComponents elasticInferenceServiceComponents,
|
ElasticInferenceServiceComponents elasticInferenceServiceComponents,
|
||||||
String failureMessage,
|
String failureMessage,
|
||||||
ConfigurationParseContext context,
|
ConfigurationParseContext context
|
||||||
ChunkingSettings chunkingSettings
|
|
||||||
) {
|
) {
|
||||||
return switch (taskType) {
|
return switch (taskType) {
|
||||||
case SPARSE_EMBEDDING -> new ElasticInferenceServiceSparseEmbeddingsModel(
|
case SPARSE_EMBEDDING -> new ElasticInferenceServiceSparseEmbeddingsModel(
|
||||||
|
@ -434,6 +494,17 @@ public class ElasticInferenceService extends SenderService {
|
||||||
elasticInferenceServiceComponents,
|
elasticInferenceServiceComponents,
|
||||||
context
|
context
|
||||||
);
|
);
|
||||||
|
case TEXT_EMBEDDING -> new ElasticInferenceServiceDenseTextEmbeddingsModel(
|
||||||
|
inferenceEntityId,
|
||||||
|
taskType,
|
||||||
|
NAME,
|
||||||
|
serviceSettings,
|
||||||
|
taskSettings,
|
||||||
|
secretSettings,
|
||||||
|
elasticInferenceServiceComponents,
|
||||||
|
context,
|
||||||
|
chunkingSettings
|
||||||
|
);
|
||||||
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -450,7 +521,7 @@ public class ElasticInferenceService extends SenderService {
|
||||||
Map<String, Object> secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);
|
Map<String, Object> secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);
|
||||||
|
|
||||||
ChunkingSettings chunkingSettings = null;
|
ChunkingSettings chunkingSettings = null;
|
||||||
if (TaskType.SPARSE_EMBEDDING.equals(taskType)) {
|
if (TaskType.SPARSE_EMBEDDING.equals(taskType) || TaskType.TEXT_EMBEDDING.equals(taskType)) {
|
||||||
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
|
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -459,9 +530,9 @@ public class ElasticInferenceService extends SenderService {
|
||||||
taskType,
|
taskType,
|
||||||
serviceSettingsMap,
|
serviceSettingsMap,
|
||||||
taskSettingsMap,
|
taskSettingsMap,
|
||||||
|
chunkingSettings,
|
||||||
secretSettingsMap,
|
secretSettingsMap,
|
||||||
parsePersistedConfigErrorMsg(inferenceEntityId, NAME),
|
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
|
||||||
chunkingSettings
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -471,7 +542,7 @@ public class ElasticInferenceService extends SenderService {
|
||||||
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
|
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
|
||||||
|
|
||||||
ChunkingSettings chunkingSettings = null;
|
ChunkingSettings chunkingSettings = null;
|
||||||
if (TaskType.SPARSE_EMBEDDING.equals(taskType)) {
|
if (TaskType.SPARSE_EMBEDDING.equals(taskType) || TaskType.TEXT_EMBEDDING.equals(taskType)) {
|
||||||
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
|
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -480,9 +551,9 @@ public class ElasticInferenceService extends SenderService {
|
||||||
taskType,
|
taskType,
|
||||||
serviceSettingsMap,
|
serviceSettingsMap,
|
||||||
taskSettingsMap,
|
taskSettingsMap,
|
||||||
|
chunkingSettings,
|
||||||
null,
|
null,
|
||||||
parsePersistedConfigErrorMsg(inferenceEntityId, NAME),
|
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
|
||||||
chunkingSettings
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -496,23 +567,51 @@ public class ElasticInferenceService extends SenderService {
|
||||||
TaskType taskType,
|
TaskType taskType,
|
||||||
Map<String, Object> serviceSettings,
|
Map<String, Object> serviceSettings,
|
||||||
Map<String, Object> taskSettings,
|
Map<String, Object> taskSettings,
|
||||||
|
ChunkingSettings chunkingSettings,
|
||||||
@Nullable Map<String, Object> secretSettings,
|
@Nullable Map<String, Object> secretSettings,
|
||||||
String failureMessage,
|
String failureMessage
|
||||||
ChunkingSettings chunkingSettings
|
|
||||||
) {
|
) {
|
||||||
return createModel(
|
return createModel(
|
||||||
inferenceEntityId,
|
inferenceEntityId,
|
||||||
taskType,
|
taskType,
|
||||||
serviceSettings,
|
serviceSettings,
|
||||||
taskSettings,
|
taskSettings,
|
||||||
|
chunkingSettings,
|
||||||
secretSettings,
|
secretSettings,
|
||||||
elasticInferenceServiceComponents,
|
elasticInferenceServiceComponents,
|
||||||
failureMessage,
|
failureMessage,
|
||||||
ConfigurationParseContext.PERSISTENT,
|
ConfigurationParseContext.PERSISTENT
|
||||||
chunkingSettings
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
|
||||||
|
if (model instanceof ElasticInferenceServiceDenseTextEmbeddingsModel embeddingsModel) {
|
||||||
|
var serviceSettings = embeddingsModel.getServiceSettings();
|
||||||
|
var modelId = serviceSettings.modelId();
|
||||||
|
var similarityFromModel = serviceSettings.similarity();
|
||||||
|
var similarityToUse = similarityFromModel == null ? defaultDenseTextEmbeddingsSimilarity() : similarityFromModel;
|
||||||
|
var maxInputTokens = serviceSettings.maxInputTokens();
|
||||||
|
|
||||||
|
var updateServiceSettings = new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(
|
||||||
|
modelId,
|
||||||
|
similarityToUse,
|
||||||
|
embeddingSize,
|
||||||
|
maxInputTokens,
|
||||||
|
serviceSettings.rateLimitSettings()
|
||||||
|
);
|
||||||
|
|
||||||
|
return new ElasticInferenceServiceDenseTextEmbeddingsModel(embeddingsModel, updateServiceSettings);
|
||||||
|
} else {
|
||||||
|
throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static SimilarityMeasure defaultDenseTextEmbeddingsSimilarity() {
|
||||||
|
// TODO: double-check
|
||||||
|
return SimilarityMeasure.COSINE;
|
||||||
|
}
|
||||||
|
|
||||||
private static List<ChunkedInference> translateToChunkedResults(InferenceInputs inputs, InferenceServiceResults inferenceResults) {
|
private static List<ChunkedInference> translateToChunkedResults(InferenceInputs inputs, InferenceServiceResults inferenceResults) {
|
||||||
if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) {
|
if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) {
|
||||||
var inputsAsList = EmbeddingsInput.of(inputs).getStringInputs();
|
var inputsAsList = EmbeddingsInput.of(inputs).getStringInputs();
|
||||||
|
@ -550,8 +649,9 @@ public class ElasticInferenceService extends SenderService {
|
||||||
|
|
||||||
configurationMap.put(
|
configurationMap.put(
|
||||||
MODEL_ID,
|
MODEL_ID,
|
||||||
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK))
|
new SettingsConfiguration.Builder(
|
||||||
.setDescription("The name of the model to use for the inference task.")
|
EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK, TaskType.TEXT_EMBEDDING)
|
||||||
|
).setDescription("The name of the model to use for the inference task.")
|
||||||
.setLabel("Model ID")
|
.setLabel("Model ID")
|
||||||
.setRequired(true)
|
.setRequired(true)
|
||||||
.setSensitive(false)
|
.setSensitive(false)
|
||||||
|
@ -562,7 +662,7 @@ public class ElasticInferenceService extends SenderService {
|
||||||
|
|
||||||
configurationMap.put(
|
configurationMap.put(
|
||||||
MAX_INPUT_TOKENS,
|
MAX_INPUT_TOKENS,
|
||||||
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING)).setDescription(
|
new SettingsConfiguration.Builder(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING)).setDescription(
|
||||||
"Allows you to specify the maximum number of tokens per input."
|
"Allows you to specify the maximum number of tokens per input."
|
||||||
)
|
)
|
||||||
.setLabel("Maximum Input Tokens")
|
.setLabel("Maximum Input Tokens")
|
||||||
|
@ -575,7 +675,7 @@ public class ElasticInferenceService extends SenderService {
|
||||||
|
|
||||||
configurationMap.putAll(
|
configurationMap.putAll(
|
||||||
RateLimitSettings.toSettingsConfiguration(
|
RateLimitSettings.toSettingsConfiguration(
|
||||||
EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK)
|
EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK, TaskType.TEXT_EMBEDDING)
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
@ -11,14 +11,18 @@ import org.elasticsearch.common.Strings;
|
||||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||||
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
|
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
|
||||||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
|
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
|
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
|
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
||||||
import org.elasticsearch.xpack.inference.external.request.elastic.rerank.ElasticInferenceServiceRerankRequest;
|
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceDenseTextEmbeddingsResponseEntity;
|
||||||
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceRerankResponseEntity;
|
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceRerankResponseEntity;
|
||||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||||
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceResponseHandler;
|
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceResponseHandler;
|
||||||
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsRequestManager;
|
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsRequestManager;
|
||||||
|
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel;
|
||||||
|
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceDenseTextEmbeddingsRequest;
|
||||||
|
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRerankRequest;
|
||||||
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
|
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
|
||||||
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel;
|
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel;
|
||||||
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
|
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
|
||||||
|
@ -31,17 +35,22 @@ import static org.elasticsearch.xpack.inference.services.elastic.request.Elastic
|
||||||
|
|
||||||
public class ElasticInferenceServiceActionCreator implements ElasticInferenceServiceActionVisitor {
|
public class ElasticInferenceServiceActionCreator implements ElasticInferenceServiceActionVisitor {
|
||||||
|
|
||||||
private final Sender sender;
|
static final ResponseHandler DENSE_TEXT_EMBEDDINGS_HANDLER = new ElasticInferenceServiceResponseHandler(
|
||||||
|
"elastic dense text embedding",
|
||||||
private final ServiceComponents serviceComponents;
|
ElasticInferenceServiceDenseTextEmbeddingsResponseEntity::fromResponse
|
||||||
|
);
|
||||||
private final TraceContext traceContext;
|
|
||||||
|
|
||||||
static final ResponseHandler RERANK_HANDLER = new ElasticInferenceServiceResponseHandler(
|
static final ResponseHandler RERANK_HANDLER = new ElasticInferenceServiceResponseHandler(
|
||||||
"elastic rerank",
|
"elastic rerank",
|
||||||
(request, response) -> ElasticInferenceServiceRerankResponseEntity.fromResponse(response)
|
(request, response) -> ElasticInferenceServiceRerankResponseEntity.fromResponse(response)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
private final Sender sender;
|
||||||
|
|
||||||
|
private final ServiceComponents serviceComponents;
|
||||||
|
|
||||||
|
private final TraceContext traceContext;
|
||||||
|
|
||||||
public ElasticInferenceServiceActionCreator(Sender sender, ServiceComponents serviceComponents, TraceContext traceContext) {
|
public ElasticInferenceServiceActionCreator(Sender sender, ServiceComponents serviceComponents, TraceContext traceContext) {
|
||||||
this.sender = Objects.requireNonNull(sender);
|
this.sender = Objects.requireNonNull(sender);
|
||||||
this.serviceComponents = Objects.requireNonNull(serviceComponents);
|
this.serviceComponents = Objects.requireNonNull(serviceComponents);
|
||||||
|
@ -77,4 +86,26 @@ public class ElasticInferenceServiceActionCreator implements ElasticInferenceSer
|
||||||
var errorMessage = constructFailedToSendRequestMessage(Strings.format("%s rerank", ELASTIC_INFERENCE_SERVICE_IDENTIFIER));
|
var errorMessage = constructFailedToSendRequestMessage(Strings.format("%s rerank", ELASTIC_INFERENCE_SERVICE_IDENTIFIER));
|
||||||
return new SenderExecutableAction(sender, requestManager, errorMessage);
|
return new SenderExecutableAction(sender, requestManager, errorMessage);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ExecutableAction create(ElasticInferenceServiceDenseTextEmbeddingsModel model) {
|
||||||
|
var threadPool = serviceComponents.threadPool();
|
||||||
|
|
||||||
|
var manager = new GenericRequestManager<>(
|
||||||
|
threadPool,
|
||||||
|
model,
|
||||||
|
DENSE_TEXT_EMBEDDINGS_HANDLER,
|
||||||
|
(embeddingsInput) -> new ElasticInferenceServiceDenseTextEmbeddingsRequest(
|
||||||
|
model,
|
||||||
|
embeddingsInput.getStringInputs(),
|
||||||
|
traceContext,
|
||||||
|
extractRequestMetadataFromThreadContext(threadPool.getThreadContext()),
|
||||||
|
embeddingsInput.getInputType()
|
||||||
|
),
|
||||||
|
EmbeddingsInput.class
|
||||||
|
);
|
||||||
|
|
||||||
|
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Elastic dense text embeddings");
|
||||||
|
return new SenderExecutableAction(sender, manager, failedToSendRequestErrorMessage);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
package org.elasticsearch.xpack.inference.services.elastic.action;
|
package org.elasticsearch.xpack.inference.services.elastic.action;
|
||||||
|
|
||||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||||
|
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel;
|
||||||
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
|
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
|
||||||
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel;
|
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel;
|
||||||
|
|
||||||
|
@ -17,4 +18,5 @@ public interface ElasticInferenceServiceActionVisitor {
|
||||||
|
|
||||||
ExecutableAction create(ElasticInferenceServiceRerankModel model);
|
ExecutableAction create(ElasticInferenceServiceRerankModel model);
|
||||||
|
|
||||||
|
ExecutableAction create(ElasticInferenceServiceDenseTextEmbeddingsModel model);
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,116 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.elastic.densetextembeddings;
|
||||||
|
|
||||||
|
import org.elasticsearch.ElasticsearchStatusException;
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
|
import org.elasticsearch.inference.ChunkingSettings;
|
||||||
|
import org.elasticsearch.inference.EmptySecretSettings;
|
||||||
|
import org.elasticsearch.inference.EmptyTaskSettings;
|
||||||
|
import org.elasticsearch.inference.ModelConfigurations;
|
||||||
|
import org.elasticsearch.inference.ModelSecrets;
|
||||||
|
import org.elasticsearch.inference.SecretSettings;
|
||||||
|
import org.elasticsearch.inference.TaskSettings;
|
||||||
|
import org.elasticsearch.inference.TaskType;
|
||||||
|
import org.elasticsearch.rest.RestStatus;
|
||||||
|
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
|
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
|
||||||
|
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceExecutableActionModel;
|
||||||
|
import org.elasticsearch.xpack.inference.services.elastic.action.ElasticInferenceServiceActionVisitor;
|
||||||
|
|
||||||
|
import java.net.URI;
|
||||||
|
import java.net.URISyntaxException;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class ElasticInferenceServiceDenseTextEmbeddingsModel extends ElasticInferenceServiceExecutableActionModel {
|
||||||
|
|
||||||
|
private final URI uri;
|
||||||
|
|
||||||
|
public ElasticInferenceServiceDenseTextEmbeddingsModel(
|
||||||
|
String inferenceEntityId,
|
||||||
|
TaskType taskType,
|
||||||
|
String service,
|
||||||
|
Map<String, Object> serviceSettings,
|
||||||
|
Map<String, Object> taskSettings,
|
||||||
|
Map<String, Object> secrets,
|
||||||
|
ElasticInferenceServiceComponents elasticInferenceServiceComponents,
|
||||||
|
ConfigurationParseContext context,
|
||||||
|
ChunkingSettings chunkingSettings
|
||||||
|
) {
|
||||||
|
this(
|
||||||
|
inferenceEntityId,
|
||||||
|
taskType,
|
||||||
|
service,
|
||||||
|
ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.fromMap(serviceSettings, context),
|
||||||
|
EmptyTaskSettings.INSTANCE,
|
||||||
|
EmptySecretSettings.INSTANCE,
|
||||||
|
elasticInferenceServiceComponents,
|
||||||
|
chunkingSettings
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ElasticInferenceServiceDenseTextEmbeddingsModel(
|
||||||
|
String inferenceEntityId,
|
||||||
|
TaskType taskType,
|
||||||
|
String service,
|
||||||
|
ElasticInferenceServiceDenseTextEmbeddingsServiceSettings serviceSettings,
|
||||||
|
@Nullable TaskSettings taskSettings,
|
||||||
|
@Nullable SecretSettings secretSettings,
|
||||||
|
ElasticInferenceServiceComponents elasticInferenceServiceComponents,
|
||||||
|
ChunkingSettings chunkingSettings
|
||||||
|
) {
|
||||||
|
super(
|
||||||
|
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
|
||||||
|
new ModelSecrets(secretSettings),
|
||||||
|
serviceSettings,
|
||||||
|
elasticInferenceServiceComponents
|
||||||
|
);
|
||||||
|
this.uri = createUri();
|
||||||
|
}
|
||||||
|
|
||||||
|
public ElasticInferenceServiceDenseTextEmbeddingsModel(
|
||||||
|
ElasticInferenceServiceDenseTextEmbeddingsModel model,
|
||||||
|
ElasticInferenceServiceDenseTextEmbeddingsServiceSettings serviceSettings
|
||||||
|
) {
|
||||||
|
super(model, serviceSettings);
|
||||||
|
this.uri = createUri();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ExecutableAction accept(ElasticInferenceServiceActionVisitor visitor, Map<String, Object> taskSettings) {
|
||||||
|
return visitor.create(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ElasticInferenceServiceDenseTextEmbeddingsServiceSettings getServiceSettings() {
|
||||||
|
return (ElasticInferenceServiceDenseTextEmbeddingsServiceSettings) super.getServiceSettings();
|
||||||
|
}
|
||||||
|
|
||||||
|
public URI uri() {
|
||||||
|
return uri;
|
||||||
|
}
|
||||||
|
|
||||||
|
private URI createUri() throws ElasticsearchStatusException {
|
||||||
|
try {
|
||||||
|
// TODO, consider transforming the base URL into a URI for better error handling.
|
||||||
|
return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/embed/text/dense");
|
||||||
|
} catch (URISyntaxException e) {
|
||||||
|
throw new ElasticsearchStatusException(
|
||||||
|
"Failed to create URI for service ["
|
||||||
|
+ this.getConfigurations().getService()
|
||||||
|
+ "] with taskType ["
|
||||||
|
+ this.getTaskType()
|
||||||
|
+ "]: "
|
||||||
|
+ e.getMessage(),
|
||||||
|
RestStatus.BAD_REQUEST,
|
||||||
|
e
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,236 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.elastic.densetextembeddings;
|
||||||
|
|
||||||
|
import org.elasticsearch.TransportVersion;
|
||||||
|
import org.elasticsearch.TransportVersions;
|
||||||
|
import org.elasticsearch.common.ValidationException;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamInput;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
|
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
|
||||||
|
import org.elasticsearch.inference.ModelConfigurations;
|
||||||
|
import org.elasticsearch.inference.ServiceSettings;
|
||||||
|
import org.elasticsearch.inference.SimilarityMeasure;
|
||||||
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
|
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
|
||||||
|
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceRateLimitServiceSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||||
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
|
||||||
|
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
|
||||||
|
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
|
||||||
|
import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
|
||||||
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
|
||||||
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
|
||||||
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType;
|
||||||
|
|
||||||
|
public class ElasticInferenceServiceDenseTextEmbeddingsServiceSettings extends FilteredXContentObject
|
||||||
|
implements
|
||||||
|
ServiceSettings,
|
||||||
|
ElasticInferenceServiceRateLimitServiceSettings {
|
||||||
|
|
||||||
|
public static final String NAME = "elastic_inference_service_dense_embeddings_service_settings";
|
||||||
|
|
||||||
|
public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000);
|
||||||
|
|
||||||
|
private final String modelId;
|
||||||
|
private final SimilarityMeasure similarity;
|
||||||
|
private final Integer dimensions;
|
||||||
|
private final Integer maxInputTokens;
|
||||||
|
private final RateLimitSettings rateLimitSettings;
|
||||||
|
|
||||||
|
public static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings fromMap(
|
||||||
|
Map<String, Object> map,
|
||||||
|
ConfigurationParseContext context
|
||||||
|
) {
|
||||||
|
return switch (context) {
|
||||||
|
case REQUEST -> fromRequestMap(map, context);
|
||||||
|
case PERSISTENT -> fromPersistentMap(map, context);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
private static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings fromRequestMap(
|
||||||
|
Map<String, Object> map,
|
||||||
|
ConfigurationParseContext context
|
||||||
|
) {
|
||||||
|
ValidationException validationException = new ValidationException();
|
||||||
|
|
||||||
|
String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||||
|
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||||
|
map,
|
||||||
|
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||||
|
validationException,
|
||||||
|
ElasticInferenceService.NAME,
|
||||||
|
context
|
||||||
|
);
|
||||||
|
|
||||||
|
SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||||
|
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
|
||||||
|
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
|
||||||
|
|
||||||
|
if (validationException.validationErrors().isEmpty() == false) {
|
||||||
|
throw validationException;
|
||||||
|
}
|
||||||
|
|
||||||
|
return new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(modelId, similarity, dims, maxInputTokens, rateLimitSettings);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings fromPersistentMap(
|
||||||
|
Map<String, Object> map,
|
||||||
|
ConfigurationParseContext context
|
||||||
|
) {
|
||||||
|
ValidationException validationException = new ValidationException();
|
||||||
|
|
||||||
|
String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||||
|
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||||
|
map,
|
||||||
|
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||||
|
validationException,
|
||||||
|
ElasticInferenceService.NAME,
|
||||||
|
context
|
||||||
|
);
|
||||||
|
|
||||||
|
SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||||
|
Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
|
||||||
|
Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
|
||||||
|
|
||||||
|
if (validationException.validationErrors().isEmpty() == false) {
|
||||||
|
throw validationException;
|
||||||
|
}
|
||||||
|
|
||||||
|
return new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(modelId, similarity, dims, maxInputTokens, rateLimitSettings);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(
|
||||||
|
String modelId,
|
||||||
|
@Nullable SimilarityMeasure similarity,
|
||||||
|
@Nullable Integer dimensions,
|
||||||
|
@Nullable Integer maxInputTokens,
|
||||||
|
RateLimitSettings rateLimitSettings
|
||||||
|
) {
|
||||||
|
this.modelId = modelId;
|
||||||
|
this.similarity = similarity;
|
||||||
|
this.dimensions = dimensions;
|
||||||
|
this.maxInputTokens = maxInputTokens;
|
||||||
|
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(StreamInput in) throws IOException {
|
||||||
|
this.modelId = in.readString();
|
||||||
|
this.similarity = in.readOptionalEnum(SimilarityMeasure.class);
|
||||||
|
this.dimensions = in.readOptionalVInt();
|
||||||
|
this.maxInputTokens = in.readOptionalVInt();
|
||||||
|
this.rateLimitSettings = new RateLimitSettings(in);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SimilarityMeasure similarity() {
|
||||||
|
return similarity;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Integer dimensions() {
|
||||||
|
return dimensions;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer maxInputTokens() {
|
||||||
|
return maxInputTokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String modelId() {
|
||||||
|
return modelId;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RateLimitSettings rateLimitSettings() {
|
||||||
|
return rateLimitSettings;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DenseVectorFieldMapper.ElementType elementType() {
|
||||||
|
return DenseVectorFieldMapper.ElementType.FLOAT;
|
||||||
|
}
|
||||||
|
|
||||||
|
public RateLimitSettings getRateLimitSettings() {
|
||||||
|
return rateLimitSettings;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getWriteableName() {
|
||||||
|
return NAME;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
builder.field(MODEL_ID, modelId);
|
||||||
|
rateLimitSettings.toXContent(builder, params);
|
||||||
|
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
builder.startObject();
|
||||||
|
|
||||||
|
if (similarity != null) {
|
||||||
|
builder.field(SIMILARITY, similarity);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dimensions != null) {
|
||||||
|
builder.field(DIMENSIONS, dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (maxInputTokens != null) {
|
||||||
|
builder.field(MAX_INPUT_TOKENS, maxInputTokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
toXContentFragmentOfExposedFields(builder, params);
|
||||||
|
|
||||||
|
builder.endObject();
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TransportVersion getMinimalSupportedVersion() {
|
||||||
|
return TransportVersions.ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void writeTo(StreamOutput out) throws IOException {
|
||||||
|
out.writeString(modelId);
|
||||||
|
out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion()));
|
||||||
|
out.writeOptionalVInt(dimensions);
|
||||||
|
out.writeOptionalVInt(maxInputTokens);
|
||||||
|
rateLimitSettings.writeTo(out);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) return true;
|
||||||
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
|
ElasticInferenceServiceDenseTextEmbeddingsServiceSettings that = (ElasticInferenceServiceDenseTextEmbeddingsServiceSettings) o;
|
||||||
|
return Objects.equals(modelId, that.modelId)
|
||||||
|
&& similarity == that.similarity
|
||||||
|
&& Objects.equals(dimensions, that.dimensions)
|
||||||
|
&& Objects.equals(maxInputTokens, that.maxInputTokens)
|
||||||
|
&& Objects.equals(rateLimitSettings, that.rateLimitSettings);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(modelId, similarity, dimensions, maxInputTokens, rateLimitSettings);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,93 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.elastic.request;
|
||||||
|
|
||||||
|
import org.apache.http.HttpHeaders;
|
||||||
|
import org.apache.http.client.methods.HttpPost;
|
||||||
|
import org.apache.http.client.methods.HttpRequestBase;
|
||||||
|
import org.apache.http.entity.ByteArrayEntity;
|
||||||
|
import org.apache.http.message.BasicHeader;
|
||||||
|
import org.elasticsearch.common.Strings;
|
||||||
|
import org.elasticsearch.inference.InputType;
|
||||||
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
|
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||||
|
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel;
|
||||||
|
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
|
||||||
|
import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler;
|
||||||
|
|
||||||
|
import java.net.URI;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceSparseEmbeddingsRequest.inputTypeToUsageContext;
|
||||||
|
|
||||||
|
public class ElasticInferenceServiceDenseTextEmbeddingsRequest extends ElasticInferenceServiceRequest {
|
||||||
|
|
||||||
|
private final URI uri;
|
||||||
|
private final ElasticInferenceServiceDenseTextEmbeddingsModel model;
|
||||||
|
private final List<String> inputs;
|
||||||
|
private final TraceContextHandler traceContextHandler;
|
||||||
|
private final InputType inputType;
|
||||||
|
|
||||||
|
public ElasticInferenceServiceDenseTextEmbeddingsRequest(
|
||||||
|
ElasticInferenceServiceDenseTextEmbeddingsModel model,
|
||||||
|
List<String> inputs,
|
||||||
|
TraceContext traceContext,
|
||||||
|
ElasticInferenceServiceRequestMetadata metadata,
|
||||||
|
InputType inputType
|
||||||
|
) {
|
||||||
|
super(metadata);
|
||||||
|
this.inputs = inputs;
|
||||||
|
this.model = Objects.requireNonNull(model);
|
||||||
|
this.uri = model.uri();
|
||||||
|
this.traceContextHandler = new TraceContextHandler(traceContext);
|
||||||
|
this.inputType = inputType;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public HttpRequestBase createHttpRequestBase() {
|
||||||
|
var httpPost = new HttpPost(uri);
|
||||||
|
var usageContext = inputTypeToUsageContext(inputType);
|
||||||
|
var requestEntity = Strings.toString(
|
||||||
|
new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(inputs, model.getServiceSettings().modelId(), usageContext)
|
||||||
|
);
|
||||||
|
|
||||||
|
ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
|
||||||
|
httpPost.setEntity(byteEntity);
|
||||||
|
|
||||||
|
traceContextHandler.propagateTraceContext(httpPost);
|
||||||
|
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));
|
||||||
|
|
||||||
|
return httpPost;
|
||||||
|
}
|
||||||
|
|
||||||
|
public TraceContext getTraceContext() {
|
||||||
|
return traceContextHandler.traceContext();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getInferenceEntityId() {
|
||||||
|
return model.getInferenceEntityId();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public URI getURI() {
|
||||||
|
return this.uri;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Request truncate() {
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean[] getTruncationInfo() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,57 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.elastic.request;
|
||||||
|
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
|
import org.elasticsearch.xcontent.ToXContentObject;
|
||||||
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceUsageContext;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
public record ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(
|
||||||
|
List<String> inputs,
|
||||||
|
String modelId,
|
||||||
|
@Nullable ElasticInferenceServiceUsageContext usageContext
|
||||||
|
) implements ToXContentObject {
|
||||||
|
|
||||||
|
private static final String INPUT_FIELD = "input";
|
||||||
|
private static final String MODEL_FIELD = "model";
|
||||||
|
private static final String USAGE_CONTEXT = "usage_context";
|
||||||
|
|
||||||
|
public ElasticInferenceServiceDenseTextEmbeddingsRequestEntity {
|
||||||
|
Objects.requireNonNull(inputs);
|
||||||
|
Objects.requireNonNull(modelId);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
builder.startObject();
|
||||||
|
builder.startArray(INPUT_FIELD);
|
||||||
|
|
||||||
|
for (String input : inputs) {
|
||||||
|
builder.value(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.endArray();
|
||||||
|
|
||||||
|
builder.field(MODEL_FIELD, modelId);
|
||||||
|
|
||||||
|
// optional field
|
||||||
|
if (Objects.nonNull(usageContext) && usageContext != ElasticInferenceServiceUsageContext.UNSPECIFIED) {
|
||||||
|
builder.field(USAGE_CONTEXT, usageContext);
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.endObject();
|
||||||
|
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -5,7 +5,7 @@
|
||||||
* 2.0.
|
* 2.0.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.external.request.elastic.rerank;
|
package org.elasticsearch.xpack.inference.services.elastic.request;
|
||||||
|
|
||||||
import org.apache.http.HttpHeaders;
|
import org.apache.http.HttpHeaders;
|
||||||
import org.apache.http.client.methods.HttpPost;
|
import org.apache.http.client.methods.HttpPost;
|
||||||
|
@ -15,8 +15,6 @@ import org.apache.http.message.BasicHeader;
|
||||||
import org.elasticsearch.common.Strings;
|
import org.elasticsearch.common.Strings;
|
||||||
import org.elasticsearch.xcontent.XContentType;
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
import org.elasticsearch.xpack.inference.external.request.Request;
|
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||||
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequest;
|
|
||||||
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestMetadata;
|
|
||||||
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
|
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
|
||||||
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
|
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
|
||||||
import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler;
|
import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler;
|
|
@ -5,7 +5,7 @@
|
||||||
* 2.0.
|
* 2.0.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.external.request.elastic.rerank;
|
package org.elasticsearch.xpack.inference.services.elastic.request;
|
||||||
|
|
||||||
import org.elasticsearch.core.Nullable;
|
import org.elasticsearch.core.Nullable;
|
||||||
import org.elasticsearch.xcontent.ToXContentObject;
|
import org.elasticsearch.xcontent.ToXContentObject;
|
|
@ -44,6 +44,8 @@ public class ElasticInferenceServiceAuthorizationResponseEntity implements Infer
|
||||||
TaskType.SPARSE_EMBEDDING,
|
TaskType.SPARSE_EMBEDDING,
|
||||||
"chat",
|
"chat",
|
||||||
TaskType.CHAT_COMPLETION,
|
TaskType.CHAT_COMPLETION,
|
||||||
|
"embed/text/dense",
|
||||||
|
TaskType.TEXT_EMBEDDING,
|
||||||
"rerank/text/text-similarity",
|
"rerank/text/text-similarity",
|
||||||
TaskType.RERANK
|
TaskType.RERANK
|
||||||
);
|
);
|
||||||
|
|
|
@ -51,11 +51,11 @@ public class ElasticInferenceServiceSparseEmbeddingsResponseEntity {
|
||||||
* <code>
|
* <code>
|
||||||
* {
|
* {
|
||||||
* "data": [
|
* "data": [
|
||||||
* {
|
* [
|
||||||
* "Embed": 2.1259406,
|
* 2.1259406,
|
||||||
* "this": 1.7073475,
|
* 1.7073475,
|
||||||
* "text": 0.9020516
|
* 0.9020516
|
||||||
* },
|
* ],
|
||||||
* (...)
|
* (...)
|
||||||
* ],
|
* ],
|
||||||
* "meta": {
|
* "meta": {
|
||||||
|
|
|
@ -12,7 +12,7 @@ import org.elasticsearch.test.ESTestCase;
|
||||||
import org.elasticsearch.xcontent.XContentBuilder;
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.xcontent.XContentFactory;
|
import org.elasticsearch.xcontent.XContentFactory;
|
||||||
import org.elasticsearch.xcontent.XContentType;
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
import org.elasticsearch.xpack.inference.external.request.elastic.rerank.ElasticInferenceServiceRerankRequestEntity;
|
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRerankRequestEntity;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
|
@ -10,7 +10,7 @@ package org.elasticsearch.xpack.inference.external.request.elastic;
|
||||||
import org.apache.http.client.methods.HttpPost;
|
import org.apache.http.client.methods.HttpPost;
|
||||||
import org.elasticsearch.tasks.Task;
|
import org.elasticsearch.tasks.Task;
|
||||||
import org.elasticsearch.test.ESTestCase;
|
import org.elasticsearch.test.ESTestCase;
|
||||||
import org.elasticsearch.xpack.inference.external.request.elastic.rerank.ElasticInferenceServiceRerankRequest;
|
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRerankRequest;
|
||||||
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModelTests;
|
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModelTests;
|
||||||
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
|
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ import org.elasticsearch.common.bytes.BytesReference;
|
||||||
import org.elasticsearch.common.settings.Settings;
|
import org.elasticsearch.common.settings.Settings;
|
||||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||||
import org.elasticsearch.core.TimeValue;
|
import org.elasticsearch.core.TimeValue;
|
||||||
|
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
|
||||||
import org.elasticsearch.inference.ChunkInferenceInput;
|
import org.elasticsearch.inference.ChunkInferenceInput;
|
||||||
import org.elasticsearch.inference.ChunkedInference;
|
import org.elasticsearch.inference.ChunkedInference;
|
||||||
import org.elasticsearch.inference.EmptySecretSettings;
|
import org.elasticsearch.inference.EmptySecretSettings;
|
||||||
|
@ -29,7 +30,6 @@ import org.elasticsearch.inference.MinimalServiceSettings;
|
||||||
import org.elasticsearch.inference.Model;
|
import org.elasticsearch.inference.Model;
|
||||||
import org.elasticsearch.inference.TaskType;
|
import org.elasticsearch.inference.TaskType;
|
||||||
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
||||||
import org.elasticsearch.inference.WeightedToken;
|
|
||||||
import org.elasticsearch.plugins.Plugin;
|
import org.elasticsearch.plugins.Plugin;
|
||||||
import org.elasticsearch.test.ESSingleNodeTestCase;
|
import org.elasticsearch.test.ESSingleNodeTestCase;
|
||||||
import org.elasticsearch.test.http.MockResponse;
|
import org.elasticsearch.test.http.MockResponse;
|
||||||
|
@ -40,9 +40,8 @@ import org.elasticsearch.xcontent.XContentFactory;
|
||||||
import org.elasticsearch.xcontent.XContentType;
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||||
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
|
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
|
||||||
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
|
|
||||||
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
|
|
||||||
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
|
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
|
||||||
|
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
|
||||||
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
|
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
|
||||||
import org.elasticsearch.xpack.inference.InferencePlugin;
|
import org.elasticsearch.xpack.inference.InferencePlugin;
|
||||||
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
|
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
|
||||||
|
@ -59,6 +58,7 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticI
|
||||||
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;
|
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;
|
||||||
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
|
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
|
||||||
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
|
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModelTests;
|
||||||
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
|
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
|
||||||
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModelTests;
|
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModelTests;
|
||||||
import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity;
|
import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity;
|
||||||
|
@ -421,47 +421,6 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException {
|
|
||||||
var sender = mock(Sender.class);
|
|
||||||
|
|
||||||
var factory = mock(HttpRequestSender.Factory.class);
|
|
||||||
when(factory.createSender()).thenReturn(sender);
|
|
||||||
|
|
||||||
var mockModel = getInvalidModel("model_id", "service_name", TaskType.TEXT_EMBEDDING);
|
|
||||||
|
|
||||||
try (var service = createService(factory)) {
|
|
||||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
|
||||||
service.infer(
|
|
||||||
mockModel,
|
|
||||||
null,
|
|
||||||
null,
|
|
||||||
null,
|
|
||||||
List.of(""),
|
|
||||||
false,
|
|
||||||
new HashMap<>(),
|
|
||||||
InputType.INGEST,
|
|
||||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
|
||||||
listener
|
|
||||||
);
|
|
||||||
|
|
||||||
var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
|
|
||||||
MatcherAssert.assertThat(
|
|
||||||
thrownException.getMessage(),
|
|
||||||
is(
|
|
||||||
"Inference entity [model_id] does not support task type [text_embedding] "
|
|
||||||
+ "for inference, the task type must be one of [sparse_embedding, rerank]."
|
|
||||||
)
|
|
||||||
);
|
|
||||||
|
|
||||||
verify(factory, times(1)).createSender();
|
|
||||||
verify(sender, times(1)).start();
|
|
||||||
}
|
|
||||||
|
|
||||||
verify(sender, times(1)).close();
|
|
||||||
verifyNoMoreInteractions(factory);
|
|
||||||
verifyNoMoreInteractions(sender);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws IOException {
|
public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws IOException {
|
||||||
var sender = mock(Sender.class);
|
var sender = mock(Sender.class);
|
||||||
|
|
||||||
|
@ -490,7 +449,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
||||||
thrownException.getMessage(),
|
thrownException.getMessage(),
|
||||||
is(
|
is(
|
||||||
"Inference entity [model_id] does not support task type [chat_completion] "
|
"Inference entity [model_id] does not support task type [chat_completion] "
|
||||||
+ "for inference, the task type must be one of [sparse_embedding, rerank]. "
|
+ "for inference, the task type must be one of [text_embedding, sparse_embedding, rerank]. "
|
||||||
+ "The task type for the inference entity is chat_completion, "
|
+ "The task type for the inference entity is chat_completion, "
|
||||||
+ "please use the _inference/chat_completion/model_id/_stream URL."
|
+ "please use the _inference/chat_completion/model_id/_stream URL."
|
||||||
)
|
)
|
||||||
|
@ -701,82 +660,6 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testChunkedInfer_PropagatesProductUseCaseHeader() throws IOException {
|
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
|
||||||
var elasticInferenceServiceURL = getUrl(webServer);
|
|
||||||
|
|
||||||
try (var service = createService(senderFactory, elasticInferenceServiceURL)) {
|
|
||||||
String responseJson = """
|
|
||||||
{
|
|
||||||
"data": [
|
|
||||||
{
|
|
||||||
"hello": 2.1259406,
|
|
||||||
"greet": 1.7073475
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
""";
|
|
||||||
|
|
||||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
|
||||||
|
|
||||||
// Set up the product use case in the thread context
|
|
||||||
String productUseCase = "test-product-use-case";
|
|
||||||
threadPool.getThreadContext().putHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, productUseCase);
|
|
||||||
|
|
||||||
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(elasticInferenceServiceURL, "my-model-id");
|
|
||||||
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
|
|
||||||
|
|
||||||
try {
|
|
||||||
service.chunkedInfer(
|
|
||||||
model,
|
|
||||||
null,
|
|
||||||
List.of(new ChunkInferenceInput("input text")),
|
|
||||||
new HashMap<>(),
|
|
||||||
InputType.INGEST,
|
|
||||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
|
||||||
listener
|
|
||||||
);
|
|
||||||
|
|
||||||
var results = listener.actionGet(TIMEOUT);
|
|
||||||
|
|
||||||
// Verify the response was processed correctly
|
|
||||||
ChunkedInference inferenceResult = results.getFirst();
|
|
||||||
assertThat(inferenceResult, instanceOf(ChunkedInferenceEmbedding.class));
|
|
||||||
var sparseResult = (ChunkedInferenceEmbedding) inferenceResult;
|
|
||||||
assertThat(
|
|
||||||
sparseResult.chunks(),
|
|
||||||
is(
|
|
||||||
List.of(
|
|
||||||
new EmbeddingResults.Chunk(
|
|
||||||
new SparseEmbeddingResults.Embedding(
|
|
||||||
List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)),
|
|
||||||
false
|
|
||||||
),
|
|
||||||
new ChunkedInference.TextOffset(0, "input text".length())
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
);
|
|
||||||
|
|
||||||
// Verify the request was sent and contains expected headers
|
|
||||||
MatcherAssert.assertThat(webServer.requests(), hasSize(1));
|
|
||||||
var request = webServer.requests().getFirst();
|
|
||||||
assertNull(request.getUri().getQuery());
|
|
||||||
MatcherAssert.assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
|
|
||||||
|
|
||||||
// Check that the product use case header was set correctly
|
|
||||||
assertThat(request.getHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER), is(productUseCase));
|
|
||||||
|
|
||||||
// Verify request body
|
|
||||||
var requestMap = entityAsMap(request.getBody());
|
|
||||||
assertThat(requestMap, is(Map.of("input", List.of("input text"), "model", "my-model-id", "usage_context", "ingest")));
|
|
||||||
} finally {
|
|
||||||
// Clean up the thread context
|
|
||||||
threadPool.getThreadContext().stashContext();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public void testUnifiedCompletionInfer_PropagatesProductUseCaseHeader() throws IOException {
|
public void testUnifiedCompletionInfer_PropagatesProductUseCaseHeader() throws IOException {
|
||||||
var elasticInferenceServiceURL = getUrl(webServer);
|
var elasticInferenceServiceURL = getUrl(webServer);
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
@ -835,30 +718,45 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testChunkedInfer() throws IOException {
|
public void testChunkedInfer_PropagatesProductUseCaseHeader() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
var elasticInferenceServiceURL = getUrl(webServer);
|
|
||||||
|
|
||||||
try (var service = createService(senderFactory, elasticInferenceServiceURL)) {
|
try (var service = createService(senderFactory, getUrl(webServer))) {
|
||||||
|
|
||||||
|
// Batching will call the service with 2 inputs
|
||||||
String responseJson = """
|
String responseJson = """
|
||||||
{
|
{
|
||||||
"data": [
|
"data": [
|
||||||
{
|
[
|
||||||
"hello": 2.1259406,
|
0.123,
|
||||||
"greet": 1.7073475
|
-0.456,
|
||||||
|
0.789
|
||||||
|
],
|
||||||
|
[
|
||||||
|
0.987,
|
||||||
|
-0.654,
|
||||||
|
0.321
|
||||||
|
]
|
||||||
|
],
|
||||||
|
"meta": {
|
||||||
|
"usage": {
|
||||||
|
"total_tokens": 10
|
||||||
}
|
}
|
||||||
]
|
}
|
||||||
}
|
}
|
||||||
""";
|
""";
|
||||||
|
|
||||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||||
|
var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id");
|
||||||
|
|
||||||
|
String productUseCase = "test-product-use-case";
|
||||||
|
threadPool.getThreadContext().putHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, productUseCase);
|
||||||
|
|
||||||
var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(elasticInferenceServiceURL, "my-model-id");
|
|
||||||
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
|
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
|
||||||
|
// 2 inputs
|
||||||
service.chunkedInfer(
|
service.chunkedInfer(
|
||||||
model,
|
model,
|
||||||
null,
|
null,
|
||||||
List.of(new ChunkInferenceInput("input text")),
|
List.of(new ChunkInferenceInput("hello world"), new ChunkInferenceInput("dense embedding")),
|
||||||
new HashMap<>(),
|
new HashMap<>(),
|
||||||
InputType.INGEST,
|
InputType.INGEST,
|
||||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||||
|
@ -866,32 +764,106 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
||||||
);
|
);
|
||||||
|
|
||||||
var results = listener.actionGet(TIMEOUT);
|
var results = listener.actionGet(TIMEOUT);
|
||||||
assertThat(results.get(0), instanceOf(ChunkedInferenceEmbedding.class));
|
assertThat(results, hasSize(2));
|
||||||
var sparseResult = (ChunkedInferenceEmbedding) results.get(0);
|
|
||||||
assertThat(
|
// Verify the response was processed correctly
|
||||||
sparseResult.chunks(),
|
ChunkedInference inferenceResult = results.getFirst();
|
||||||
is(
|
assertThat(inferenceResult, instanceOf(ChunkedInferenceEmbedding.class));
|
||||||
List.of(
|
|
||||||
new EmbeddingResults.Chunk(
|
// Verify the request was sent and contains expected headers
|
||||||
new SparseEmbeddingResults.Embedding(
|
assertThat(webServer.requests(), hasSize(1));
|
||||||
List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)),
|
var request = webServer.requests().getFirst();
|
||||||
false
|
assertNull(request.getUri().getQuery());
|
||||||
),
|
assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
|
||||||
new ChunkedInference.TextOffset(0, "input text".length())
|
|
||||||
)
|
// Check that the product use case header was set correctly
|
||||||
)
|
assertThat(request.getHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER), is(productUseCase));
|
||||||
)
|
|
||||||
|
} finally {
|
||||||
|
// Clean up the thread context
|
||||||
|
threadPool.getThreadContext().stashContext();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOException {
|
||||||
|
var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id");
|
||||||
|
|
||||||
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
|
try (var service = createService(senderFactory, getUrl(webServer))) {
|
||||||
|
|
||||||
|
// Batching will call the service with 2 inputs
|
||||||
|
String responseJson = """
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
[
|
||||||
|
0.123,
|
||||||
|
-0.456,
|
||||||
|
0.789
|
||||||
|
],
|
||||||
|
[
|
||||||
|
0.987,
|
||||||
|
-0.654,
|
||||||
|
0.321
|
||||||
|
]
|
||||||
|
],
|
||||||
|
"meta": {
|
||||||
|
"usage": {
|
||||||
|
"total_tokens": 10
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""";
|
||||||
|
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||||
|
|
||||||
|
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
|
||||||
|
// 2 inputs
|
||||||
|
service.chunkedInfer(
|
||||||
|
model,
|
||||||
|
null,
|
||||||
|
List.of(new ChunkInferenceInput("hello world"), new ChunkInferenceInput("dense embedding")),
|
||||||
|
new HashMap<>(),
|
||||||
|
InputType.INGEST,
|
||||||
|
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||||
|
listener
|
||||||
);
|
);
|
||||||
|
|
||||||
MatcherAssert.assertThat(webServer.requests(), hasSize(1));
|
var results = listener.actionGet(TIMEOUT);
|
||||||
assertNull(webServer.requests().get(0).getUri().getQuery());
|
assertThat(results, hasSize(2));
|
||||||
|
|
||||||
|
// First result
|
||||||
|
{
|
||||||
|
assertThat(results.getFirst(), instanceOf(ChunkedInferenceEmbedding.class));
|
||||||
|
var denseResult = (ChunkedInferenceEmbedding) results.getFirst();
|
||||||
|
assertThat(denseResult.chunks(), hasSize(1));
|
||||||
|
assertEquals(new ChunkedInference.TextOffset(0, "hello world".length()), denseResult.chunks().getFirst().offset());
|
||||||
|
assertThat(denseResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
||||||
|
|
||||||
|
var embedding = (TextEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding();
|
||||||
|
assertArrayEquals(new float[] { 0.123f, -0.456f, 0.789f }, embedding.values(), 0.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second result
|
||||||
|
{
|
||||||
|
assertThat(results.get(1), instanceOf(ChunkedInferenceEmbedding.class));
|
||||||
|
var denseResult = (ChunkedInferenceEmbedding) results.get(1);
|
||||||
|
assertThat(denseResult.chunks(), hasSize(1));
|
||||||
|
assertEquals(new ChunkedInference.TextOffset(0, "dense embedding".length()), denseResult.chunks().getFirst().offset());
|
||||||
|
assertThat(denseResult.chunks().getFirst().embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
||||||
|
|
||||||
|
var embedding = (TextEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding();
|
||||||
|
assertArrayEquals(new float[] { 0.987f, -0.654f, 0.321f }, embedding.values(), 0.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
assertThat(webServer.requests(), hasSize(1));
|
||||||
|
assertNull(webServer.requests().getFirst().getUri().getQuery());
|
||||||
|
assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
|
||||||
|
|
||||||
|
var requestMap = entityAsMap(webServer.requests().getFirst().getBody());
|
||||||
MatcherAssert.assertThat(
|
MatcherAssert.assertThat(
|
||||||
webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),
|
requestMap,
|
||||||
equalTo(XContentType.JSON.mediaType())
|
is(Map.of("input", List.of("hello world", "dense embedding"), "model", "my-dense-model-id", "usage_context", "ingest"))
|
||||||
);
|
);
|
||||||
|
|
||||||
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
|
||||||
assertThat(requestMap, is(Map.of("input", List.of("input text"), "model", "my-model-id", "usage_context", "ingest")));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -903,27 +875,6 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testHideFromConfigurationApi_ReturnsTrue_WithModelTaskTypesThatAreNotImplemented() throws Exception {
|
|
||||||
try (
|
|
||||||
var service = createServiceWithMockSender(
|
|
||||||
ElasticInferenceServiceAuthorizationModel.of(
|
|
||||||
new ElasticInferenceServiceAuthorizationResponseEntity(
|
|
||||||
List.of(
|
|
||||||
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
|
|
||||||
"model-1",
|
|
||||||
EnumSet.of(TaskType.TEXT_EMBEDDING)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
) {
|
|
||||||
ensureAuthorizationCallFinished(service);
|
|
||||||
|
|
||||||
assertTrue(service.hideFromConfigurationApi());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public void testHideFromConfigurationApi_ReturnsFalse_WithAvailableModels() throws Exception {
|
public void testHideFromConfigurationApi_ReturnsFalse_WithAvailableModels() throws Exception {
|
||||||
try (
|
try (
|
||||||
var service = createServiceWithMockSender(
|
var service = createServiceWithMockSender(
|
||||||
|
@ -953,7 +904,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
||||||
List.of(
|
List.of(
|
||||||
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
|
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
|
||||||
"model-1",
|
"model-1",
|
||||||
EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)
|
EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -966,7 +917,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
||||||
{
|
{
|
||||||
"service": "elastic",
|
"service": "elastic",
|
||||||
"name": "Elastic",
|
"name": "Elastic",
|
||||||
"task_types": ["sparse_embedding", "chat_completion"],
|
"task_types": ["sparse_embedding", "chat_completion", "text_embedding"],
|
||||||
"configurations": {
|
"configurations": {
|
||||||
"rate_limit.requests_per_minute": {
|
"rate_limit.requests_per_minute": {
|
||||||
"description": "Minimize the number of rate limit errors.",
|
"description": "Minimize the number of rate limit errors.",
|
||||||
|
@ -975,7 +926,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
||||||
"sensitive": false,
|
"sensitive": false,
|
||||||
"updatable": false,
|
"updatable": false,
|
||||||
"type": "int",
|
"type": "int",
|
||||||
"supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
|
"supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"]
|
||||||
},
|
},
|
||||||
"model_id": {
|
"model_id": {
|
||||||
"description": "The name of the model to use for the inference task.",
|
"description": "The name of the model to use for the inference task.",
|
||||||
|
@ -984,7 +935,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
||||||
"sensitive": false,
|
"sensitive": false,
|
||||||
"updatable": false,
|
"updatable": false,
|
||||||
"type": "str",
|
"type": "str",
|
||||||
"supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
|
"supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"]
|
||||||
},
|
},
|
||||||
"max_input_tokens": {
|
"max_input_tokens": {
|
||||||
"description": "Allows you to specify the maximum number of tokens per input.",
|
"description": "Allows you to specify the maximum number of tokens per input.",
|
||||||
|
@ -993,7 +944,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
||||||
"sensitive": false,
|
"sensitive": false,
|
||||||
"updatable": false,
|
"updatable": false,
|
||||||
"type": "int",
|
"type": "int",
|
||||||
"supported_task_types": ["sparse_embedding"]
|
"supported_task_types": ["text_embedding", "sparse_embedding"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1030,7 +981,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
||||||
"sensitive": false,
|
"sensitive": false,
|
||||||
"updatable": false,
|
"updatable": false,
|
||||||
"type": "int",
|
"type": "int",
|
||||||
"supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
|
"supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"]
|
||||||
},
|
},
|
||||||
"model_id": {
|
"model_id": {
|
||||||
"description": "The name of the model to use for the inference task.",
|
"description": "The name of the model to use for the inference task.",
|
||||||
|
@ -1039,7 +990,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
||||||
"sensitive": false,
|
"sensitive": false,
|
||||||
"updatable": false,
|
"updatable": false,
|
||||||
"type": "str",
|
"type": "str",
|
||||||
"supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
|
"supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"]
|
||||||
},
|
},
|
||||||
"max_input_tokens": {
|
"max_input_tokens": {
|
||||||
"description": "Allows you to specify the maximum number of tokens per input.",
|
"description": "Allows you to specify the maximum number of tokens per input.",
|
||||||
|
@ -1048,7 +999,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
||||||
"sensitive": false,
|
"sensitive": false,
|
||||||
"updatable": false,
|
"updatable": false,
|
||||||
"type": "int",
|
"type": "int",
|
||||||
"supported_task_types": ["sparse_embedding"]
|
"supported_task_types": ["text_embedding", "sparse_embedding"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1090,7 +1041,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
||||||
{
|
{
|
||||||
"service": "elastic",
|
"service": "elastic",
|
||||||
"name": "Elastic",
|
"name": "Elastic",
|
||||||
"task_types": [],
|
"task_types": ["text_embedding"],
|
||||||
"configurations": {
|
"configurations": {
|
||||||
"rate_limit.requests_per_minute": {
|
"rate_limit.requests_per_minute": {
|
||||||
"description": "Minimize the number of rate limit errors.",
|
"description": "Minimize the number of rate limit errors.",
|
||||||
|
@ -1099,7 +1050,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
||||||
"sensitive": false,
|
"sensitive": false,
|
||||||
"updatable": false,
|
"updatable": false,
|
||||||
"type": "int",
|
"type": "int",
|
||||||
"supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
|
"supported_task_types": ["text_embedding" , "sparse_embedding", "rerank", "chat_completion"]
|
||||||
},
|
},
|
||||||
"model_id": {
|
"model_id": {
|
||||||
"description": "The name of the model to use for the inference task.",
|
"description": "The name of the model to use for the inference task.",
|
||||||
|
@ -1108,7 +1059,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
||||||
"sensitive": false,
|
"sensitive": false,
|
||||||
"updatable": false,
|
"updatable": false,
|
||||||
"type": "str",
|
"type": "str",
|
||||||
"supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
|
"supported_task_types": ["text_embedding" , "sparse_embedding", "rerank", "chat_completion"]
|
||||||
},
|
},
|
||||||
"max_input_tokens": {
|
"max_input_tokens": {
|
||||||
"description": "Allows you to specify the maximum number of tokens per input.",
|
"description": "Allows you to specify the maximum number of tokens per input.",
|
||||||
|
@ -1117,7 +1068,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
||||||
"sensitive": false,
|
"sensitive": false,
|
||||||
"updatable": false,
|
"updatable": false,
|
||||||
"type": "int",
|
"type": "int",
|
||||||
"supported_task_types": ["sparse_embedding"]
|
"supported_task_types": ["text_embedding", "sparse_embedding"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1296,6 +1247,10 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
||||||
"task_types": ["embed/text/sparse"]
|
"task_types": ["embed/text/sparse"]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
"model_name": "multilingual-embed-v1",
|
||||||
|
"task_types": ["embed/text/dense"]
|
||||||
|
},
|
||||||
|
{
|
||||||
"model_name": "rerank-v1",
|
"model_name": "rerank-v1",
|
||||||
"task_types": ["rerank/text/text-similarity"]
|
"task_types": ["rerank/text/text-similarity"]
|
||||||
}
|
}
|
||||||
|
@ -1319,6 +1274,16 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
||||||
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
|
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
|
||||||
service
|
service
|
||||||
),
|
),
|
||||||
|
new InferenceService.DefaultConfigId(
|
||||||
|
".multilingual-embed-v1-elastic",
|
||||||
|
MinimalServiceSettings.textEmbedding(
|
||||||
|
ElasticInferenceService.NAME,
|
||||||
|
ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
|
||||||
|
ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(),
|
||||||
|
DenseVectorFieldMapper.ElementType.FLOAT
|
||||||
|
),
|
||||||
|
service
|
||||||
|
),
|
||||||
new InferenceService.DefaultConfigId(
|
new InferenceService.DefaultConfigId(
|
||||||
".rainbow-sprinkles-elastic",
|
".rainbow-sprinkles-elastic",
|
||||||
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
|
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
|
||||||
|
@ -1332,16 +1297,19 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK)));
|
assertThat(
|
||||||
|
service.supportedTaskTypes(),
|
||||||
|
is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK, TaskType.TEXT_EMBEDDING))
|
||||||
|
);
|
||||||
|
|
||||||
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
|
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
|
||||||
service.defaultConfigs(listener);
|
service.defaultConfigs(listener);
|
||||||
var models = listener.actionGet(TIMEOUT);
|
var models = listener.actionGet(TIMEOUT);
|
||||||
assertThat(models.size(), is(3));
|
assertThat(models.size(), is(4));
|
||||||
assertThat(models.get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic"));
|
assertThat(models.get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic"));
|
||||||
assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
|
assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".multilingual-embed-v1-elastic"));
|
||||||
assertThat(models.get(2).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic"));
|
assertThat(models.get(2).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
|
||||||
|
assertThat(models.get(3).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,12 +22,14 @@ import org.elasticsearch.xcontent.XContentType;
|
||||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||||
import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests;
|
import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests;
|
||||||
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
|
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
|
||||||
|
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
|
||||||
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
|
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
|
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
|
||||||
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
||||||
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModelTests;
|
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModelTests;
|
||||||
|
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModelTests;
|
||||||
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModelTests;
|
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModelTests;
|
||||||
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
|
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
|
@ -46,6 +48,7 @@ import static org.elasticsearch.xpack.inference.external.http.retry.RetrySetting
|
||||||
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||||
import static org.hamcrest.Matchers.contains;
|
import static org.hamcrest.Matchers.contains;
|
||||||
|
import static org.hamcrest.Matchers.containsString;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
import static org.hamcrest.Matchers.hasSize;
|
import static org.hamcrest.Matchers.hasSize;
|
||||||
import static org.hamcrest.Matchers.instanceOf;
|
import static org.hamcrest.Matchers.instanceOf;
|
||||||
|
@ -256,6 +259,206 @@ public class ElasticInferenceServiceActionCreatorTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction() throws IOException {
|
||||||
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
|
try (var sender = createSender(senderFactory)) {
|
||||||
|
sender.start();
|
||||||
|
|
||||||
|
String responseJson = """
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
[
|
||||||
|
2.1259406,
|
||||||
|
1.7073475,
|
||||||
|
0.9020516
|
||||||
|
],
|
||||||
|
[
|
||||||
|
1.8342123,
|
||||||
|
2.3456789,
|
||||||
|
0.7654321
|
||||||
|
]
|
||||||
|
]
|
||||||
|
}
|
||||||
|
""";
|
||||||
|
|
||||||
|
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||||
|
|
||||||
|
var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id");
|
||||||
|
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext());
|
||||||
|
var action = actionCreator.create(model);
|
||||||
|
|
||||||
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
|
action.execute(
|
||||||
|
new EmbeddingsInput(List.of("hello world", "second text"), null, InputType.UNSPECIFIED),
|
||||||
|
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||||
|
listener
|
||||||
|
);
|
||||||
|
|
||||||
|
var result = listener.actionGet(TIMEOUT);
|
||||||
|
|
||||||
|
assertThat(result, instanceOf(TextEmbeddingFloatResults.class));
|
||||||
|
var textEmbeddingResults = (TextEmbeddingFloatResults) result;
|
||||||
|
assertThat(textEmbeddingResults.embeddings(), hasSize(2));
|
||||||
|
|
||||||
|
var firstEmbedding = textEmbeddingResults.embeddings().get(0);
|
||||||
|
assertThat(firstEmbedding.values(), is(new float[] { 2.1259406f, 1.7073475f, 0.9020516f }));
|
||||||
|
|
||||||
|
var secondEmbedding = textEmbeddingResults.embeddings().get(1);
|
||||||
|
assertThat(secondEmbedding.values(), is(new float[] { 1.8342123f, 2.3456789f, 0.7654321f }));
|
||||||
|
|
||||||
|
assertThat(webServer.requests(), hasSize(1));
|
||||||
|
assertNull(webServer.requests().get(0).getUri().getQuery());
|
||||||
|
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
|
||||||
|
|
||||||
|
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
||||||
|
assertThat(requestMap.size(), is(2));
|
||||||
|
assertThat(requestMap.get("input"), instanceOf(List.class));
|
||||||
|
var inputList = (List<String>) requestMap.get("input");
|
||||||
|
assertThat(inputList, contains("hello world", "second text"));
|
||||||
|
assertThat(requestMap.get("model"), is("my-dense-model-id"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_WithUsageContext() throws IOException {
|
||||||
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
|
try (var sender = createSender(senderFactory)) {
|
||||||
|
sender.start();
|
||||||
|
|
||||||
|
String responseJson = """
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
[
|
||||||
|
0.1234567,
|
||||||
|
0.9876543
|
||||||
|
]
|
||||||
|
]
|
||||||
|
}
|
||||||
|
""";
|
||||||
|
|
||||||
|
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||||
|
|
||||||
|
var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id");
|
||||||
|
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext());
|
||||||
|
var action = actionCreator.create(model);
|
||||||
|
|
||||||
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
|
action.execute(
|
||||||
|
new EmbeddingsInput(List.of("search query"), null, InputType.SEARCH),
|
||||||
|
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||||
|
listener
|
||||||
|
);
|
||||||
|
|
||||||
|
var result = listener.actionGet(TIMEOUT);
|
||||||
|
|
||||||
|
assertThat(result, instanceOf(TextEmbeddingFloatResults.class));
|
||||||
|
var textEmbeddingResults = (TextEmbeddingFloatResults) result;
|
||||||
|
assertThat(textEmbeddingResults.embeddings(), hasSize(1));
|
||||||
|
|
||||||
|
var embedding = textEmbeddingResults.embeddings().get(0);
|
||||||
|
assertThat(embedding.values(), is(new float[] { 0.1234567f, 0.9876543f }));
|
||||||
|
|
||||||
|
assertThat(webServer.requests(), hasSize(1));
|
||||||
|
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
||||||
|
assertThat(requestMap.size(), is(3));
|
||||||
|
assertThat(requestMap.get("input"), instanceOf(List.class));
|
||||||
|
var inputList = (List<String>) requestMap.get("input");
|
||||||
|
assertThat(inputList, contains("search query"));
|
||||||
|
assertThat(requestMap.get("model"), is("my-dense-model-id"));
|
||||||
|
assertThat(requestMap.get("usage_context"), is("search"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public void testSend_FailsFromInvalidResponseFormat_ForDenseTextEmbeddingsAction() throws IOException {
|
||||||
|
// timeout as zero for no retries
|
||||||
|
var settings = buildSettingsWithRetryFields(
|
||||||
|
TimeValue.timeValueMillis(1),
|
||||||
|
TimeValue.timeValueMinutes(1),
|
||||||
|
TimeValue.timeValueSeconds(0)
|
||||||
|
);
|
||||||
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
|
||||||
|
|
||||||
|
try (var sender = createSender(senderFactory)) {
|
||||||
|
sender.start();
|
||||||
|
|
||||||
|
// This will fail because the expected output is {"data": [[...]]}
|
||||||
|
String responseJson = """
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"embedding": [2.1259406, 1.7073475]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""";
|
||||||
|
|
||||||
|
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||||
|
|
||||||
|
var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id");
|
||||||
|
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext());
|
||||||
|
var action = actionCreator.create(model);
|
||||||
|
|
||||||
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
|
action.execute(
|
||||||
|
new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED),
|
||||||
|
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||||
|
listener
|
||||||
|
);
|
||||||
|
|
||||||
|
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
|
||||||
|
assertThat(thrownException.getMessage(), containsString("[EmbeddingFloatResult] failed to parse field [data]"));
|
||||||
|
|
||||||
|
assertThat(webServer.requests(), hasSize(1));
|
||||||
|
assertNull(webServer.requests().get(0).getUri().getQuery());
|
||||||
|
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
|
||||||
|
|
||||||
|
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
||||||
|
assertThat(requestMap.size(), is(2));
|
||||||
|
assertThat(requestMap.get("input"), instanceOf(List.class));
|
||||||
|
var inputList = (List<String>) requestMap.get("input");
|
||||||
|
assertThat(inputList, contains("hello world"));
|
||||||
|
assertThat(requestMap.get("model"), is("my-dense-model-id"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_EmptyEmbeddings() throws IOException {
|
||||||
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
|
try (var sender = createSender(senderFactory)) {
|
||||||
|
sender.start();
|
||||||
|
|
||||||
|
String responseJson = """
|
||||||
|
{
|
||||||
|
"data": []
|
||||||
|
}
|
||||||
|
""";
|
||||||
|
|
||||||
|
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||||
|
|
||||||
|
var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id");
|
||||||
|
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext());
|
||||||
|
var action = actionCreator.create(model);
|
||||||
|
|
||||||
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
|
action.execute(new EmbeddingsInput(List.of(), null, InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||||
|
|
||||||
|
var result = listener.actionGet(TIMEOUT);
|
||||||
|
|
||||||
|
assertThat(result, instanceOf(TextEmbeddingFloatResults.class));
|
||||||
|
var textEmbeddingResults = (TextEmbeddingFloatResults) result;
|
||||||
|
assertThat(textEmbeddingResults.embeddings(), hasSize(0));
|
||||||
|
|
||||||
|
assertThat(webServer.requests(), hasSize(1));
|
||||||
|
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
||||||
|
assertThat(requestMap.get("input"), instanceOf(List.class));
|
||||||
|
var inputList = (List<String>) requestMap.get("input");
|
||||||
|
assertThat(inputList, hasSize(0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOException {
|
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.elastic.densetextembeddings;
|
||||||
|
|
||||||
|
import org.elasticsearch.inference.EmptySecretSettings;
|
||||||
|
import org.elasticsearch.inference.EmptyTaskSettings;
|
||||||
|
import org.elasticsearch.inference.SimilarityMeasure;
|
||||||
|
import org.elasticsearch.inference.TaskType;
|
||||||
|
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
|
||||||
|
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
|
||||||
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||||
|
|
||||||
|
public class ElasticInferenceServiceDenseTextEmbeddingsModelTests {
|
||||||
|
|
||||||
|
public static ElasticInferenceServiceDenseTextEmbeddingsModel createModel(String url, String modelId) {
|
||||||
|
return new ElasticInferenceServiceDenseTextEmbeddingsModel(
|
||||||
|
"id",
|
||||||
|
TaskType.TEXT_EMBEDDING,
|
||||||
|
"elastic",
|
||||||
|
new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(
|
||||||
|
modelId,
|
||||||
|
SimilarityMeasure.COSINE,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
new RateLimitSettings(1000L)
|
||||||
|
),
|
||||||
|
EmptyTaskSettings.INSTANCE,
|
||||||
|
EmptySecretSettings.INSTANCE,
|
||||||
|
ElasticInferenceServiceComponents.of(url),
|
||||||
|
ChunkingSettingsBuilder.DEFAULT_SETTINGS
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,165 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.elastic.densetextembeddings;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.Strings;
|
||||||
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
|
import org.elasticsearch.inference.SimilarityMeasure;
|
||||||
|
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||||
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xcontent.XContentFactory;
|
||||||
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ServiceFields;
|
||||||
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.is;
|
||||||
|
|
||||||
|
public class ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase<
|
||||||
|
ElasticInferenceServiceDenseTextEmbeddingsServiceSettings> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Writeable.Reader<ElasticInferenceServiceDenseTextEmbeddingsServiceSettings> instanceReader() {
|
||||||
|
return ElasticInferenceServiceDenseTextEmbeddingsServiceSettings::new;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected ElasticInferenceServiceDenseTextEmbeddingsServiceSettings createTestInstance() {
|
||||||
|
return createRandom();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected ElasticInferenceServiceDenseTextEmbeddingsServiceSettings mutateInstance(
|
||||||
|
ElasticInferenceServiceDenseTextEmbeddingsServiceSettings instance
|
||||||
|
) throws IOException {
|
||||||
|
return randomValueOtherThan(instance, ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests::createRandom);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testFromMap_Request_WithAllSettings() {
|
||||||
|
var modelId = "my-dense-model-id";
|
||||||
|
var similarity = SimilarityMeasure.COSINE;
|
||||||
|
var dimensions = 384;
|
||||||
|
var maxInputTokens = 512;
|
||||||
|
|
||||||
|
var serviceSettings = ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.fromMap(
|
||||||
|
new HashMap<>(
|
||||||
|
Map.of(
|
||||||
|
ServiceFields.MODEL_ID,
|
||||||
|
modelId,
|
||||||
|
ServiceFields.SIMILARITY,
|
||||||
|
similarity.toString(),
|
||||||
|
ServiceFields.DIMENSIONS,
|
||||||
|
dimensions,
|
||||||
|
ServiceFields.MAX_INPUT_TOKENS,
|
||||||
|
maxInputTokens
|
||||||
|
)
|
||||||
|
),
|
||||||
|
ConfigurationParseContext.REQUEST
|
||||||
|
);
|
||||||
|
|
||||||
|
assertThat(serviceSettings.modelId(), is(modelId));
|
||||||
|
assertThat(serviceSettings.similarity(), is(similarity));
|
||||||
|
assertThat(serviceSettings.dimensions(), is(dimensions));
|
||||||
|
assertThat(serviceSettings.maxInputTokens(), is(maxInputTokens));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testToXContent_WritesAllFields() throws IOException {
|
||||||
|
var modelId = "my-dense-model";
|
||||||
|
var similarity = SimilarityMeasure.DOT_PRODUCT;
|
||||||
|
var dimensions = 1024;
|
||||||
|
var maxInputTokens = 256;
|
||||||
|
var rateLimitSettings = new RateLimitSettings(5000);
|
||||||
|
|
||||||
|
var serviceSettings = new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(
|
||||||
|
modelId,
|
||||||
|
similarity,
|
||||||
|
dimensions,
|
||||||
|
maxInputTokens,
|
||||||
|
rateLimitSettings
|
||||||
|
);
|
||||||
|
|
||||||
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
|
serviceSettings.toXContent(builder, null);
|
||||||
|
String xContentResult = Strings.toString(builder);
|
||||||
|
|
||||||
|
String expectedResult = Strings.format(
|
||||||
|
"""
|
||||||
|
{"similarity":"%s","dimensions":%d,"max_input_tokens":%d,"model_id":"%s","rate_limit":{"requests_per_minute":%d}}""",
|
||||||
|
similarity,
|
||||||
|
dimensions,
|
||||||
|
maxInputTokens,
|
||||||
|
modelId,
|
||||||
|
rateLimitSettings.requestsPerTimeUnit()
|
||||||
|
);
|
||||||
|
|
||||||
|
assertThat(xContentResult, is(expectedResult));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testToXContent_WritesOnlyNonNullFields() throws IOException {
|
||||||
|
var modelId = "my-dense-model";
|
||||||
|
var rateLimitSettings = new RateLimitSettings(2000);
|
||||||
|
|
||||||
|
var serviceSettings = new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(
|
||||||
|
modelId,
|
||||||
|
null, // similarity
|
||||||
|
null, // dimensions
|
||||||
|
null, // maxInputTokens
|
||||||
|
rateLimitSettings
|
||||||
|
);
|
||||||
|
|
||||||
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
|
serviceSettings.toXContent(builder, null);
|
||||||
|
String xContentResult = Strings.toString(builder);
|
||||||
|
|
||||||
|
assertThat(xContentResult, is(Strings.format("""
|
||||||
|
{"model_id":"%s","rate_limit":{"requests_per_minute":%d}}""", modelId, rateLimitSettings.requestsPerTimeUnit())));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testToXContentFragmentOfExposedFields() throws IOException {
|
||||||
|
var modelId = "my-dense-model";
|
||||||
|
var rateLimitSettings = new RateLimitSettings(1500);
|
||||||
|
|
||||||
|
var serviceSettings = new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(
|
||||||
|
modelId,
|
||||||
|
SimilarityMeasure.COSINE,
|
||||||
|
512,
|
||||||
|
128,
|
||||||
|
rateLimitSettings
|
||||||
|
);
|
||||||
|
|
||||||
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
|
builder.startObject();
|
||||||
|
serviceSettings.toXContentFragmentOfExposedFields(builder, null);
|
||||||
|
builder.endObject();
|
||||||
|
String xContentResult = Strings.toString(builder);
|
||||||
|
|
||||||
|
// Only model_id and rate_limit should be in exposed fields
|
||||||
|
assertThat(xContentResult, is(Strings.format("""
|
||||||
|
{"model_id":"%s","rate_limit":{"requests_per_minute":%d}}""", modelId, rateLimitSettings.requestsPerTimeUnit())));
|
||||||
|
}
|
||||||
|
|
||||||
|
public static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings createRandom() {
|
||||||
|
var modelId = randomAlphaOfLength(10);
|
||||||
|
var similarity = SimilarityMeasure.COSINE;
|
||||||
|
var dimensions = randomBoolean() ? randomIntBetween(1, 1024) : null;
|
||||||
|
var maxInputTokens = randomBoolean() ? randomIntBetween(128, 256) : null;
|
||||||
|
var rateLimitSettings = randomBoolean() ? new RateLimitSettings(randomIntBetween(1, 10000)) : null;
|
||||||
|
|
||||||
|
return new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(
|
||||||
|
modelId,
|
||||||
|
similarity,
|
||||||
|
dimensions,
|
||||||
|
maxInputTokens,
|
||||||
|
rateLimitSettings
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,147 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.elastic.request;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.Strings;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xcontent.XContentFactory;
|
||||||
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
|
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceUsageContext;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString;
|
||||||
|
|
||||||
|
public class ElasticInferenceServiceDenseTextEmbeddingsRequestEntityTests extends ESTestCase {
|
||||||
|
|
||||||
|
public void testToXContent_SingleInput_UnspecifiedUsageContext() throws IOException {
|
||||||
|
var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(
|
||||||
|
List.of("abc"),
|
||||||
|
"my-model-id",
|
||||||
|
ElasticInferenceServiceUsageContext.UNSPECIFIED
|
||||||
|
);
|
||||||
|
String xContentString = xContentEntityToString(entity);
|
||||||
|
assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
|
||||||
|
{
|
||||||
|
"input": ["abc"],
|
||||||
|
"model": "my-model-id"
|
||||||
|
}"""));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testToXContent_MultipleInputs_UnspecifiedUsageContext() throws IOException {
|
||||||
|
var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(
|
||||||
|
List.of("abc", "def"),
|
||||||
|
"my-model-id",
|
||||||
|
ElasticInferenceServiceUsageContext.UNSPECIFIED
|
||||||
|
);
|
||||||
|
String xContentString = xContentEntityToString(entity);
|
||||||
|
assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
|
||||||
|
{
|
||||||
|
"input": [
|
||||||
|
"abc",
|
||||||
|
"def"
|
||||||
|
],
|
||||||
|
"model": "my-model-id"
|
||||||
|
}
|
||||||
|
"""));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testToXContent_SingleInput_SearchUsageContext() throws IOException {
|
||||||
|
var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(
|
||||||
|
List.of("abc"),
|
||||||
|
"my-model-id",
|
||||||
|
ElasticInferenceServiceUsageContext.SEARCH
|
||||||
|
);
|
||||||
|
String xContentString = xContentEntityToString(entity);
|
||||||
|
assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
|
||||||
|
{
|
||||||
|
"input": ["abc"],
|
||||||
|
"model": "my-model-id",
|
||||||
|
"usage_context": "search"
|
||||||
|
}
|
||||||
|
"""));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testToXContent_SingleInput_IngestUsageContext() throws IOException {
|
||||||
|
var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(
|
||||||
|
List.of("abc"),
|
||||||
|
"my-model-id",
|
||||||
|
ElasticInferenceServiceUsageContext.INGEST
|
||||||
|
);
|
||||||
|
String xContentString = xContentEntityToString(entity);
|
||||||
|
assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
|
||||||
|
{
|
||||||
|
"input": ["abc"],
|
||||||
|
"model": "my-model-id",
|
||||||
|
"usage_context": "ingest"
|
||||||
|
}
|
||||||
|
"""));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testToXContent_MultipleInputs_SearchUsageContext() throws IOException {
|
||||||
|
var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(
|
||||||
|
List.of("first input", "second input", "third input"),
|
||||||
|
"my-dense-model",
|
||||||
|
ElasticInferenceServiceUsageContext.SEARCH
|
||||||
|
);
|
||||||
|
String xContentString = xContentEntityToString(entity);
|
||||||
|
assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
|
||||||
|
{
|
||||||
|
"input": [
|
||||||
|
"first input",
|
||||||
|
"second input",
|
||||||
|
"third input"
|
||||||
|
],
|
||||||
|
"model": "my-dense-model",
|
||||||
|
"usage_context": "search"
|
||||||
|
}
|
||||||
|
"""));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testToXContent_MultipleInputs_IngestUsageContext() throws IOException {
|
||||||
|
var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(
|
||||||
|
List.of("document one", "document two"),
|
||||||
|
"embedding-model-v2",
|
||||||
|
ElasticInferenceServiceUsageContext.INGEST
|
||||||
|
);
|
||||||
|
String xContentString = xContentEntityToString(entity);
|
||||||
|
assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
|
||||||
|
{
|
||||||
|
"input": [
|
||||||
|
"document one",
|
||||||
|
"document two"
|
||||||
|
],
|
||||||
|
"model": "embedding-model-v2",
|
||||||
|
"usage_context": "ingest"
|
||||||
|
}
|
||||||
|
"""));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testToXContent_EmptyInput_UnspecifiedUsageContext() throws IOException {
|
||||||
|
var entity = new ElasticInferenceServiceDenseTextEmbeddingsRequestEntity(
|
||||||
|
List.of(""),
|
||||||
|
"my-model-id",
|
||||||
|
ElasticInferenceServiceUsageContext.UNSPECIFIED
|
||||||
|
);
|
||||||
|
String xContentString = xContentEntityToString(entity);
|
||||||
|
assertThat(xContentString, equalToIgnoringWhitespaceInJsonString("""
|
||||||
|
{
|
||||||
|
"input": [""],
|
||||||
|
"model": "my-model-id"
|
||||||
|
}
|
||||||
|
"""));
|
||||||
|
}
|
||||||
|
|
||||||
|
private String xContentEntityToString(ElasticInferenceServiceDenseTextEmbeddingsRequestEntity entity) throws IOException {
|
||||||
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
|
entity.toXContent(builder, null);
|
||||||
|
return Strings.toString(builder);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,165 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.elastic.request;
|
||||||
|
|
||||||
|
import org.apache.http.HttpHeaders;
|
||||||
|
import org.apache.http.client.methods.HttpPost;
|
||||||
|
import org.elasticsearch.inference.InputType;
|
||||||
|
import org.elasticsearch.tasks.Task;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
|
import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModelTests;
|
||||||
|
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||||
|
import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata;
|
||||||
|
import static org.hamcrest.Matchers.aMapWithSize;
|
||||||
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
import static org.hamcrest.Matchers.instanceOf;
|
||||||
|
import static org.hamcrest.Matchers.is;
|
||||||
|
import static org.hamcrest.Matchers.nullValue;
|
||||||
|
|
||||||
|
public class ElasticInferenceServiceDenseTextEmbeddingsRequestTests extends ESTestCase {
|
||||||
|
|
||||||
|
public void testCreateHttpRequest_UsageContextSearch() throws IOException {
|
||||||
|
var url = "http://eis-gateway.com";
|
||||||
|
var input = List.of("input text");
|
||||||
|
var modelId = "my-dense-model-id";
|
||||||
|
|
||||||
|
var request = createRequest(url, modelId, input, InputType.SEARCH);
|
||||||
|
var httpRequest = request.createHttpRequest();
|
||||||
|
|
||||||
|
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||||
|
var httpPost = (HttpPost) httpRequest.httpRequestBase();
|
||||||
|
|
||||||
|
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
|
||||||
|
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||||
|
assertThat(requestMap.size(), equalTo(3));
|
||||||
|
assertThat(requestMap.get("input"), is(input));
|
||||||
|
assertThat(requestMap.get("model"), is(modelId));
|
||||||
|
assertThat(requestMap.get("usage_context"), equalTo("search"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testCreateHttpRequest_UsageContextIngest() throws IOException {
|
||||||
|
var url = "http://eis-gateway.com";
|
||||||
|
var input = List.of("ingest text");
|
||||||
|
var modelId = "my-dense-model-id";
|
||||||
|
|
||||||
|
var request = createRequest(url, modelId, input, InputType.INGEST);
|
||||||
|
var httpRequest = request.createHttpRequest();
|
||||||
|
|
||||||
|
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||||
|
var httpPost = (HttpPost) httpRequest.httpRequestBase();
|
||||||
|
|
||||||
|
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
|
||||||
|
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||||
|
assertThat(requestMap.size(), equalTo(3));
|
||||||
|
assertThat(requestMap.get("input"), is(input));
|
||||||
|
assertThat(requestMap.get("model"), is(modelId));
|
||||||
|
assertThat(requestMap.get("usage_context"), equalTo("ingest"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testCreateHttpRequest_UsageContextUnspecified() throws IOException {
|
||||||
|
var url = "http://eis-gateway.com";
|
||||||
|
var input = List.of("unspecified text");
|
||||||
|
var modelId = "my-dense-model-id";
|
||||||
|
|
||||||
|
var request = createRequest(url, modelId, input, InputType.UNSPECIFIED);
|
||||||
|
var httpRequest = request.createHttpRequest();
|
||||||
|
|
||||||
|
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||||
|
var httpPost = (HttpPost) httpRequest.httpRequestBase();
|
||||||
|
|
||||||
|
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
|
||||||
|
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||||
|
assertThat(requestMap, aMapWithSize(2));
|
||||||
|
assertThat(requestMap.get("input"), is(input));
|
||||||
|
assertThat(requestMap.get("model"), is(modelId));
|
||||||
|
// usage_context should not be present for UNSPECIFIED
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testCreateHttpRequest_MultipleInputs() throws IOException {
|
||||||
|
var url = "http://eis-gateway.com";
|
||||||
|
var inputs = List.of("first input", "second input", "third input");
|
||||||
|
var modelId = "my-dense-model-id";
|
||||||
|
|
||||||
|
var request = createRequest(url, modelId, inputs, InputType.SEARCH);
|
||||||
|
var httpRequest = request.createHttpRequest();
|
||||||
|
|
||||||
|
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||||
|
var httpPost = (HttpPost) httpRequest.httpRequestBase();
|
||||||
|
|
||||||
|
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
|
||||||
|
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||||
|
assertThat(requestMap.size(), equalTo(3));
|
||||||
|
assertThat(requestMap.get("input"), is(inputs));
|
||||||
|
assertThat(requestMap.get("model"), is(modelId));
|
||||||
|
assertThat(requestMap.get("usage_context"), equalTo("search"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testTraceContextPropagatedThroughHTTPHeaders() {
|
||||||
|
var url = "http://eis-gateway.com";
|
||||||
|
var input = List.of("input text");
|
||||||
|
var modelId = "my-dense-model-id";
|
||||||
|
|
||||||
|
var request = createRequest(url, modelId, input, InputType.UNSPECIFIED);
|
||||||
|
var httpRequest = request.createHttpRequest();
|
||||||
|
|
||||||
|
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||||
|
var httpPost = (HttpPost) httpRequest.httpRequestBase();
|
||||||
|
|
||||||
|
var traceParent = request.getTraceContext().traceParent();
|
||||||
|
var traceState = request.getTraceContext().traceState();
|
||||||
|
|
||||||
|
assertThat(httpPost.getLastHeader(Task.TRACE_PARENT_HTTP_HEADER).getValue(), is(traceParent));
|
||||||
|
assertThat(httpPost.getLastHeader(Task.TRACE_STATE).getValue(), is(traceState));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testTruncate_ReturnsSameInstance() {
|
||||||
|
var url = "http://eis-gateway.com";
|
||||||
|
var input = List.of("input text");
|
||||||
|
var modelId = "my-dense-model-id";
|
||||||
|
|
||||||
|
var request = createRequest(url, modelId, input, InputType.UNSPECIFIED);
|
||||||
|
var truncatedRequest = request.truncate();
|
||||||
|
|
||||||
|
// Dense text embeddings request doesn't support truncation, should return same instance
|
||||||
|
assertThat(truncatedRequest, is(request));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testGetTruncationInfo_ReturnsNull() {
|
||||||
|
var url = "http://eis-gateway.com";
|
||||||
|
var input = List.of("input text");
|
||||||
|
var modelId = "my-dense-model-id";
|
||||||
|
|
||||||
|
var request = createRequest(url, modelId, input, InputType.UNSPECIFIED);
|
||||||
|
|
||||||
|
// Dense text embeddings request doesn't support truncation info
|
||||||
|
assertThat(request.getTruncationInfo(), is(nullValue()));
|
||||||
|
}
|
||||||
|
|
||||||
|
private ElasticInferenceServiceDenseTextEmbeddingsRequest createRequest(
|
||||||
|
String url,
|
||||||
|
String modelId,
|
||||||
|
List<String> inputs,
|
||||||
|
InputType inputType
|
||||||
|
) {
|
||||||
|
var embeddingsModel = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(url, modelId);
|
||||||
|
|
||||||
|
return new ElasticInferenceServiceDenseTextEmbeddingsRequest(
|
||||||
|
embeddingsModel,
|
||||||
|
inputs,
|
||||||
|
new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10)),
|
||||||
|
randomElasticInferenceServiceRequestMetadata(),
|
||||||
|
inputType
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,124 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.elastic.response;
|
||||||
|
|
||||||
|
import org.apache.http.HttpResponse;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||||
|
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||||
|
import org.elasticsearch.xpack.inference.external.response.elastic.ElasticInferenceServiceDenseTextEmbeddingsResponseEntity;
|
||||||
|
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
|
||||||
|
import static org.hamcrest.CoreMatchers.is;
|
||||||
|
import static org.hamcrest.Matchers.hasSize;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
|
||||||
|
public class ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests extends ESTestCase {
|
||||||
|
|
||||||
|
public void testDenseTextEmbeddingsResponse_SingleEmbeddingInData_NoMeta() throws Exception {
|
||||||
|
String responseJson = """
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
[
|
||||||
|
1.23,
|
||||||
|
4.56,
|
||||||
|
7.89
|
||||||
|
]
|
||||||
|
]
|
||||||
|
}
|
||||||
|
""";
|
||||||
|
|
||||||
|
TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse(
|
||||||
|
mock(Request.class),
|
||||||
|
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
||||||
|
);
|
||||||
|
|
||||||
|
assertThat(parsedResults.embeddings(), hasSize(1));
|
||||||
|
|
||||||
|
var embedding = parsedResults.embeddings().get(0);
|
||||||
|
assertThat(embedding.values(), is(new float[] { 1.23f, 4.56f, 7.89f }));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testDenseTextEmbeddingsResponse_MultipleEmbeddingsInData_NoMeta() throws Exception {
|
||||||
|
String responseJson = """
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
[
|
||||||
|
1.23,
|
||||||
|
4.56,
|
||||||
|
7.89
|
||||||
|
],
|
||||||
|
[
|
||||||
|
0.12,
|
||||||
|
0.34,
|
||||||
|
0.56
|
||||||
|
]
|
||||||
|
]
|
||||||
|
}
|
||||||
|
""";
|
||||||
|
|
||||||
|
TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse(
|
||||||
|
mock(Request.class),
|
||||||
|
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
||||||
|
);
|
||||||
|
|
||||||
|
assertThat(parsedResults.embeddings(), hasSize(2));
|
||||||
|
|
||||||
|
var firstEmbedding = parsedResults.embeddings().get(0);
|
||||||
|
assertThat(firstEmbedding.values(), is(new float[] { 1.23f, 4.56f, 7.89f }));
|
||||||
|
|
||||||
|
var secondEmbedding = parsedResults.embeddings().get(1);
|
||||||
|
assertThat(secondEmbedding.values(), is(new float[] { 0.12f, 0.34f, 0.56f }));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testDenseTextEmbeddingsResponse_EmptyData() throws Exception {
|
||||||
|
String responseJson = """
|
||||||
|
{
|
||||||
|
"data": []
|
||||||
|
}
|
||||||
|
""";
|
||||||
|
|
||||||
|
TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse(
|
||||||
|
mock(Request.class),
|
||||||
|
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
||||||
|
);
|
||||||
|
|
||||||
|
assertThat(parsedResults.embeddings(), hasSize(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testDenseTextEmbeddingsResponse_SingleEmbeddingInData_IgnoresMeta() throws Exception {
|
||||||
|
String responseJson = """
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
[
|
||||||
|
-1.0,
|
||||||
|
0.0,
|
||||||
|
1.0
|
||||||
|
]
|
||||||
|
],
|
||||||
|
"meta": {
|
||||||
|
"usage": {
|
||||||
|
"total_tokens": 5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""";
|
||||||
|
|
||||||
|
TextEmbeddingFloatResults parsedResults = ElasticInferenceServiceDenseTextEmbeddingsResponseEntity.fromResponse(
|
||||||
|
mock(Request.class),
|
||||||
|
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
||||||
|
);
|
||||||
|
|
||||||
|
assertThat(parsedResults.embeddings(), hasSize(1));
|
||||||
|
|
||||||
|
var embedding = parsedResults.embeddings().get(0);
|
||||||
|
assertThat(embedding.values(), is(new float[] { -1.0f, 0.0f, 1.0f }));
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue