mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-29 01:44:36 -04:00
Add Hugging Face Rerank support (#127966)
* Add Hugging Face Rerank support * Address comments * Add transport version * Add transport version * Add to inference service and crud IT rerank tests * Refactor slightly / error message * correct 'testGetConfiguration' test case * apply suggestions * fix tests * apply suggestions * [CI] Auto commit changes from spotless * add changelog information --------- Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
This commit is contained in:
parent
f6e4a26480
commit
c7cf8507a2
31 changed files with 1485 additions and 85 deletions
5
docs/changelog/127966.yaml
Normal file
5
docs/changelog/127966.yaml
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
pr: 127966
|
||||||
|
summary: "[ML] Add Rerank support to the Inference Plugin"
|
||||||
|
area: Machine Learning
|
||||||
|
type: enhancement
|
||||||
|
issues: []
|
|
@ -179,6 +179,7 @@ public class TransportVersions {
|
||||||
public static final TransportVersion V_8_19_FIELD_CAPS_ADD_CLUSTER_ALIAS = def(8_841_0_32);
|
public static final TransportVersion V_8_19_FIELD_CAPS_ADD_CLUSTER_ALIAS = def(8_841_0_32);
|
||||||
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME_8_19 = def(8_841_0_34);
|
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME_8_19 = def(8_841_0_34);
|
||||||
public static final TransportVersion RERANKER_FAILURES_ALLOWED_8_19 = def(8_841_0_35);
|
public static final TransportVersion RERANKER_FAILURES_ALLOWED_8_19 = def(8_841_0_35);
|
||||||
|
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED_8_19 = def(8_841_0_36);
|
||||||
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
|
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
|
||||||
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
|
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
|
||||||
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
|
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
|
||||||
|
@ -261,7 +262,7 @@ public class TransportVersions {
|
||||||
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME = def(9_077_0_00);
|
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME = def(9_077_0_00);
|
||||||
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED = def(9_078_0_00);
|
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED = def(9_078_0_00);
|
||||||
public static final TransportVersion NODES_STATS_SUPPORTS_MULTI_PROJECT = def(9_079_0_00);
|
public static final TransportVersion NODES_STATS_SUPPORTS_MULTI_PROJECT = def(9_079_0_00);
|
||||||
|
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED = def(9_080_0_00);
|
||||||
/*
|
/*
|
||||||
* STOP! READ THIS FIRST! No, really,
|
* STOP! READ THIS FIRST! No, really,
|
||||||
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
|
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
|
||||||
|
|
|
@ -20,9 +20,8 @@ import static org.elasticsearch.test.ESTestCase.randomInt;
|
||||||
public class SettingsConfigurationTestUtils {
|
public class SettingsConfigurationTestUtils {
|
||||||
|
|
||||||
public static SettingsConfiguration getRandomSettingsConfigurationField() {
|
public static SettingsConfiguration getRandomSettingsConfigurationField() {
|
||||||
return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)).setDefaultValue(
|
return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK))
|
||||||
randomAlphaOfLength(10)
|
.setDefaultValue(randomAlphaOfLength(10))
|
||||||
)
|
|
||||||
.setDescription(randomAlphaOfLength(10))
|
.setDescription(randomAlphaOfLength(10))
|
||||||
.setLabel(randomAlphaOfLength(10))
|
.setLabel(randomAlphaOfLength(10))
|
||||||
.setRequired(randomBoolean())
|
.setRequired(randomBoolean())
|
||||||
|
|
|
@ -82,4 +82,13 @@ public class RankedDocsResultsTests extends AbstractChunkedBWCSerializationTestC
|
||||||
protected RankedDocsResults doParseInstance(XContentParser parser) throws IOException {
|
protected RankedDocsResults doParseInstance(XContentParser parser) throws IOException {
|
||||||
return RankedDocsResults.createParser(true).apply(parser, null);
|
return RankedDocsResults.createParser(true).apply(parser, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public record RerankExpectation(Map<String, Object> rankedDocFields) {}
|
||||||
|
|
||||||
|
public static Map<String, Object> buildExpectationRerank(List<RerankExpectation> rerank) {
|
||||||
|
return Map.of(
|
||||||
|
RankedDocsResults.RERANK,
|
||||||
|
rerank.stream().map(rerankExpectation -> Map.of(RankedDocsResults.RankedDoc.NAME, rerankExpectation.rankedDocFields)).toList()
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -171,6 +171,20 @@ public class InferenceBaseRestTest extends ESRestTestCase {
|
||||||
""";
|
""";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static String mockRerankServiceModelConfig() {
|
||||||
|
return """
|
||||||
|
{
|
||||||
|
"service": "test_reranking_service",
|
||||||
|
"service_settings": {
|
||||||
|
"model_id": "my_model",
|
||||||
|
"api_key": "abc64"
|
||||||
|
},
|
||||||
|
"task_settings": {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""";
|
||||||
|
}
|
||||||
|
|
||||||
static void deleteModel(String modelId) throws IOException {
|
static void deleteModel(String modelId) throws IOException {
|
||||||
var request = new Request("DELETE", "_inference/" + modelId);
|
var request = new Request("DELETE", "_inference/" + modelId);
|
||||||
var response = client().performRequest(request);
|
var response = client().performRequest(request);
|
||||||
|
@ -484,6 +498,10 @@ public class InferenceBaseRestTest extends ESRestTestCase {
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
protected void assertNonEmptyInferenceResults(Map<String, Object> resultMap, int expectedNumberOfResults, TaskType taskType) {
|
protected void assertNonEmptyInferenceResults(Map<String, Object> resultMap, int expectedNumberOfResults, TaskType taskType) {
|
||||||
switch (taskType) {
|
switch (taskType) {
|
||||||
|
case RERANK -> {
|
||||||
|
var results = (List<Map<String, Object>>) resultMap.get(TaskType.RERANK.toString());
|
||||||
|
assertThat(results, hasSize(expectedNumberOfResults));
|
||||||
|
}
|
||||||
case SPARSE_EMBEDDING -> {
|
case SPARSE_EMBEDDING -> {
|
||||||
var results = (List<Map<String, Object>>) resultMap.get(TaskType.SPARSE_EMBEDDING.toString());
|
var results = (List<Map<String, Object>>) resultMap.get(TaskType.SPARSE_EMBEDDING.toString());
|
||||||
assertThat(results, hasSize(expectedNumberOfResults));
|
assertThat(results, hasSize(expectedNumberOfResults));
|
||||||
|
|
|
@ -53,9 +53,12 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
|
||||||
for (int i = 0; i < 4; i++) {
|
for (int i = 0; i < 4; i++) {
|
||||||
putModel("te_model_" + i, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING);
|
putModel("te_model_" + i, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING);
|
||||||
}
|
}
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
putModel("re-model-" + i, mockRerankServiceModelConfig(), TaskType.RERANK);
|
||||||
|
}
|
||||||
|
|
||||||
var getAllModels = getAllModels();
|
var getAllModels = getAllModels();
|
||||||
int numModels = 12;
|
int numModels = 15;
|
||||||
assertThat(getAllModels, hasSize(numModels));
|
assertThat(getAllModels, hasSize(numModels));
|
||||||
|
|
||||||
var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING);
|
var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING);
|
||||||
|
@ -71,6 +74,13 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
|
||||||
for (var denseModel : getDenseModels) {
|
for (var denseModel : getDenseModels) {
|
||||||
assertEquals("text_embedding", denseModel.get("task_type"));
|
assertEquals("text_embedding", denseModel.get("task_type"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var getRerankModels = getModels("_all", TaskType.RERANK);
|
||||||
|
int numRerankModels = 4;
|
||||||
|
assertThat(getRerankModels, hasSize(numRerankModels));
|
||||||
|
for (var denseModel : getRerankModels) {
|
||||||
|
assertEquals("rerank", denseModel.get("task_type"));
|
||||||
|
}
|
||||||
String oldApiKey;
|
String oldApiKey;
|
||||||
{
|
{
|
||||||
var singleModel = getModels("se_model_1", TaskType.SPARSE_EMBEDDING);
|
var singleModel = getModels("se_model_1", TaskType.SPARSE_EMBEDDING);
|
||||||
|
@ -100,6 +110,9 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
|
||||||
for (int i = 0; i < 4; i++) {
|
for (int i = 0; i < 4; i++) {
|
||||||
deleteModel("te_model_" + i, TaskType.TEXT_EMBEDDING);
|
deleteModel("te_model_" + i, TaskType.TEXT_EMBEDDING);
|
||||||
}
|
}
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
deleteModel("re-model-" + i, TaskType.RERANK);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testGetModelWithWrongTaskType() throws IOException {
|
public void testGetModelWithWrongTaskType() throws IOException {
|
||||||
|
|
|
@ -101,7 +101,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
||||||
|
|
||||||
public void testGetServicesWithRerankTaskType() throws IOException {
|
public void testGetServicesWithRerankTaskType() throws IOException {
|
||||||
List<Object> services = getServices(TaskType.RERANK);
|
List<Object> services = getServices(TaskType.RERANK);
|
||||||
assertThat(services.size(), equalTo(7));
|
assertThat(services.size(), equalTo(8));
|
||||||
|
|
||||||
var providers = providers(services);
|
var providers = providers(services);
|
||||||
|
|
||||||
|
@ -115,7 +115,8 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
||||||
"googlevertexai",
|
"googlevertexai",
|
||||||
"jinaai",
|
"jinaai",
|
||||||
"test_reranking_service",
|
"test_reranking_service",
|
||||||
"voyageai"
|
"voyageai",
|
||||||
|
"hugging_face"
|
||||||
).toArray()
|
).toArray()
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
|
@ -0,0 +1,67 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference;
|
||||||
|
|
||||||
|
import org.elasticsearch.inference.TaskType;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class MockRerankInferenceServiceIT extends InferenceBaseRestTest {
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public void testMockService() throws IOException {
|
||||||
|
String inferenceEntityId = "test-mock";
|
||||||
|
var putModel = putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK);
|
||||||
|
var model = getModels(inferenceEntityId, TaskType.RERANK).get(0);
|
||||||
|
|
||||||
|
for (var modelMap : List.of(putModel, model)) {
|
||||||
|
assertEquals(inferenceEntityId, modelMap.get("inference_id"));
|
||||||
|
assertEquals(TaskType.RERANK, TaskType.fromString((String) modelMap.get("task_type")));
|
||||||
|
assertEquals("test_reranking_service", modelMap.get("service"));
|
||||||
|
}
|
||||||
|
|
||||||
|
List<String> input = List.of(randomAlphaOfLength(10));
|
||||||
|
var inference = infer(inferenceEntityId, input);
|
||||||
|
assertNonEmptyInferenceResults(inference, 1, TaskType.RERANK);
|
||||||
|
assertEquals(inference, infer(inferenceEntityId, input));
|
||||||
|
assertNotEquals(inference, infer(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10)))));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMockServiceWithMultipleInputs() throws IOException {
|
||||||
|
String inferenceEntityId = "test-mock-with-multi-inputs";
|
||||||
|
putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK);
|
||||||
|
var queryParams = Map.of("timeout", "120s");
|
||||||
|
|
||||||
|
var inference = infer(
|
||||||
|
inferenceEntityId,
|
||||||
|
TaskType.RERANK,
|
||||||
|
List.of(randomAlphaOfLength(5), randomAlphaOfLength(10)),
|
||||||
|
"What if?",
|
||||||
|
queryParams
|
||||||
|
);
|
||||||
|
|
||||||
|
assertNonEmptyInferenceResults(inference, 2, TaskType.RERANK);
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOException {
|
||||||
|
String inferenceEntityId = "test-mock";
|
||||||
|
var putModel = putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK);
|
||||||
|
var model = getModels(inferenceEntityId, TaskType.RERANK).get(0);
|
||||||
|
|
||||||
|
var serviceSettings = (Map<String, Object>) model.get("service_settings");
|
||||||
|
assertNull(serviceSettings.get("api_key"));
|
||||||
|
assertNotNull(serviceSettings.get("model_id"));
|
||||||
|
|
||||||
|
var putServiceSettings = (Map<String, Object>) putModel.get("service_settings");
|
||||||
|
assertNull(putServiceSettings.get("api_key"));
|
||||||
|
assertNotNull(putServiceSettings.get("model_id"));
|
||||||
|
}
|
||||||
|
}
|
|
@ -80,6 +80,8 @@ import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVe
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
|
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionServiceSettings;
|
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionServiceSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
|
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankServiceSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
|
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
|
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
|
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
|
||||||
|
@ -365,6 +367,16 @@ public class InferenceNamedWriteablesProvider {
|
||||||
HuggingFaceChatCompletionServiceSettings::new
|
HuggingFaceChatCompletionServiceSettings::new
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
namedWriteables.add(
|
||||||
|
new NamedWriteableRegistry.Entry(TaskSettings.class, HuggingFaceRerankTaskSettings.NAME, HuggingFaceRerankTaskSettings::new)
|
||||||
|
);
|
||||||
|
namedWriteables.add(
|
||||||
|
new NamedWriteableRegistry.Entry(
|
||||||
|
ServiceSettings.class,
|
||||||
|
HuggingFaceRerankServiceSettings.NAME,
|
||||||
|
HuggingFaceRerankServiceSettings::new
|
||||||
|
)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void addGoogleAiStudioNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {
|
private static void addGoogleAiStudioNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {
|
||||||
|
|
|
@ -19,6 +19,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||||
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
|
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
|
||||||
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
|
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
|
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
@ -94,6 +95,9 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContext extends RankFe
|
||||||
} else if (r.getEndpoints().isEmpty() == false
|
} else if (r.getEndpoints().isEmpty() == false
|
||||||
&& r.getEndpoints().get(0).getTaskSettings() instanceof GoogleVertexAiRerankTaskSettings googleVertexAiTaskSettings) {
|
&& r.getEndpoints().get(0).getTaskSettings() instanceof GoogleVertexAiRerankTaskSettings googleVertexAiTaskSettings) {
|
||||||
configuredTopN = googleVertexAiTaskSettings.topN();
|
configuredTopN = googleVertexAiTaskSettings.topN();
|
||||||
|
} else if (r.getEndpoints().isEmpty() == false
|
||||||
|
&& r.getEndpoints().get(0).getTaskSettings() instanceof HuggingFaceRerankTaskSettings huggingFaceRerankTaskSettings) {
|
||||||
|
configuredTopN = huggingFaceRerankTaskSettings.getTopNDocumentsOnly();
|
||||||
}
|
}
|
||||||
if (configuredTopN != null && configuredTopN < rankWindowSize) {
|
if (configuredTopN != null && configuredTopN < rankWindowSize) {
|
||||||
l.onFailure(
|
l.onFailure(
|
||||||
|
|
|
@ -57,6 +57,7 @@ public abstract class HuggingFaceBaseService extends SenderService {
|
||||||
) {
|
) {
|
||||||
try {
|
try {
|
||||||
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
|
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
|
||||||
|
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
|
||||||
|
|
||||||
ChunkingSettings chunkingSettings = null;
|
ChunkingSettings chunkingSettings = null;
|
||||||
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
|
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
|
||||||
|
@ -66,17 +67,21 @@ public abstract class HuggingFaceBaseService extends SenderService {
|
||||||
}
|
}
|
||||||
|
|
||||||
var model = createModel(
|
var model = createModel(
|
||||||
|
new HuggingFaceModelParameters(
|
||||||
inferenceEntityId,
|
inferenceEntityId,
|
||||||
taskType,
|
taskType,
|
||||||
serviceSettingsMap,
|
serviceSettingsMap,
|
||||||
|
taskSettingsMap,
|
||||||
chunkingSettings,
|
chunkingSettings,
|
||||||
serviceSettingsMap,
|
serviceSettingsMap,
|
||||||
TaskType.unsupportedTaskTypeErrorMsg(taskType, name()),
|
TaskType.unsupportedTaskTypeErrorMsg(taskType, name()),
|
||||||
ConfigurationParseContext.REQUEST
|
ConfigurationParseContext.REQUEST
|
||||||
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
throwIfNotEmptyMap(config, name());
|
throwIfNotEmptyMap(config, name());
|
||||||
throwIfNotEmptyMap(serviceSettingsMap, name());
|
throwIfNotEmptyMap(serviceSettingsMap, name());
|
||||||
|
throwIfNotEmptyMap(taskSettingsMap, name());
|
||||||
|
|
||||||
parsedModelListener.onResponse(model);
|
parsedModelListener.onResponse(model);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
|
@ -92,6 +97,7 @@ public abstract class HuggingFaceBaseService extends SenderService {
|
||||||
Map<String, Object> secrets
|
Map<String, Object> secrets
|
||||||
) {
|
) {
|
||||||
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
|
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
|
||||||
|
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
|
||||||
Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);
|
Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);
|
||||||
|
|
||||||
ChunkingSettings chunkingSettings = null;
|
ChunkingSettings chunkingSettings = null;
|
||||||
|
@ -100,19 +106,23 @@ public abstract class HuggingFaceBaseService extends SenderService {
|
||||||
}
|
}
|
||||||
|
|
||||||
return createModel(
|
return createModel(
|
||||||
|
new HuggingFaceModelParameters(
|
||||||
inferenceEntityId,
|
inferenceEntityId,
|
||||||
taskType,
|
taskType,
|
||||||
serviceSettingsMap,
|
serviceSettingsMap,
|
||||||
|
taskSettingsMap,
|
||||||
chunkingSettings,
|
chunkingSettings,
|
||||||
secretSettingsMap,
|
secretSettingsMap,
|
||||||
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
|
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
|
||||||
ConfigurationParseContext.PERSISTENT
|
ConfigurationParseContext.PERSISTENT
|
||||||
|
)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
|
public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
|
||||||
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
|
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
|
||||||
|
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
|
||||||
|
|
||||||
ChunkingSettings chunkingSettings = null;
|
ChunkingSettings chunkingSettings = null;
|
||||||
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
|
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
|
||||||
|
@ -120,25 +130,20 @@ public abstract class HuggingFaceBaseService extends SenderService {
|
||||||
}
|
}
|
||||||
|
|
||||||
return createModel(
|
return createModel(
|
||||||
|
new HuggingFaceModelParameters(
|
||||||
inferenceEntityId,
|
inferenceEntityId,
|
||||||
taskType,
|
taskType,
|
||||||
serviceSettingsMap,
|
serviceSettingsMap,
|
||||||
|
taskSettingsMap,
|
||||||
chunkingSettings,
|
chunkingSettings,
|
||||||
null,
|
null,
|
||||||
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
|
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
|
||||||
ConfigurationParseContext.PERSISTENT
|
ConfigurationParseContext.PERSISTENT
|
||||||
|
)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract HuggingFaceModel createModel(
|
protected abstract HuggingFaceModel createModel(HuggingFaceModelParameters input);
|
||||||
String inferenceEntityId,
|
|
||||||
TaskType taskType,
|
|
||||||
Map<String, Object> serviceSettings,
|
|
||||||
ChunkingSettings chunkingSettings,
|
|
||||||
Map<String, Object> secretSettings,
|
|
||||||
String failureMessage,
|
|
||||||
ConfigurationParseContext context
|
|
||||||
);
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void doInfer(
|
public void doInfer(
|
||||||
|
|
|
@ -11,6 +11,7 @@ import org.elasticsearch.common.settings.SecureString;
|
||||||
import org.elasticsearch.core.Nullable;
|
import org.elasticsearch.core.Nullable;
|
||||||
import org.elasticsearch.inference.ModelConfigurations;
|
import org.elasticsearch.inference.ModelConfigurations;
|
||||||
import org.elasticsearch.inference.ModelSecrets;
|
import org.elasticsearch.inference.ModelSecrets;
|
||||||
|
import org.elasticsearch.inference.TaskSettings;
|
||||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||||
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
|
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
|
||||||
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
||||||
|
@ -35,6 +36,13 @@ public abstract class HuggingFaceModel extends RateLimitGroupingModel {
|
||||||
apiKey = ServiceUtils.apiKey(apiKeySecrets);
|
apiKey = ServiceUtils.apiKey(apiKeySecrets);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected HuggingFaceModel(HuggingFaceModel model, TaskSettings taskSettings) {
|
||||||
|
super(model, taskSettings);
|
||||||
|
|
||||||
|
rateLimitServiceSettings = model.rateLimitServiceSettings();
|
||||||
|
apiKey = model.apiKey();
|
||||||
|
}
|
||||||
|
|
||||||
public HuggingFaceRateLimitServiceSettings rateLimitServiceSettings() {
|
public HuggingFaceRateLimitServiceSettings rateLimitServiceSettings() {
|
||||||
return rateLimitServiceSettings;
|
return rateLimitServiceSettings;
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.huggingface;
|
||||||
|
|
||||||
|
import org.elasticsearch.inference.ChunkingSettings;
|
||||||
|
import org.elasticsearch.inference.TaskType;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public record HuggingFaceModelParameters(
|
||||||
|
String inferenceEntityId,
|
||||||
|
TaskType taskType,
|
||||||
|
Map<String, Object> serviceSettings,
|
||||||
|
Map<String, Object> taskSettings,
|
||||||
|
ChunkingSettings chunkingSettings,
|
||||||
|
Map<String, Object> secretSettings,
|
||||||
|
String failureMessage,
|
||||||
|
ConfigurationParseContext context
|
||||||
|
) {}
|
|
@ -12,10 +12,8 @@ import org.elasticsearch.TransportVersion;
|
||||||
import org.elasticsearch.TransportVersions;
|
import org.elasticsearch.TransportVersions;
|
||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
import org.elasticsearch.common.util.LazyInitializable;
|
import org.elasticsearch.common.util.LazyInitializable;
|
||||||
import org.elasticsearch.core.Nullable;
|
|
||||||
import org.elasticsearch.core.TimeValue;
|
import org.elasticsearch.core.TimeValue;
|
||||||
import org.elasticsearch.inference.ChunkedInference;
|
import org.elasticsearch.inference.ChunkedInference;
|
||||||
import org.elasticsearch.inference.ChunkingSettings;
|
|
||||||
import org.elasticsearch.inference.InferenceServiceConfiguration;
|
import org.elasticsearch.inference.InferenceServiceConfiguration;
|
||||||
import org.elasticsearch.inference.InferenceServiceResults;
|
import org.elasticsearch.inference.InferenceServiceResults;
|
||||||
import org.elasticsearch.inference.InputType;
|
import org.elasticsearch.inference.InputType;
|
||||||
|
@ -32,7 +30,6 @@ import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
|
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
|
||||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||||
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator;
|
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator;
|
||||||
|
@ -40,6 +37,7 @@ import org.elasticsearch.xpack.inference.services.huggingface.completion.Hugging
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
|
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
|
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest;
|
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModel;
|
||||||
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
|
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||||
|
@ -62,6 +60,7 @@ public class HuggingFaceService extends HuggingFaceBaseService {
|
||||||
|
|
||||||
private static final String SERVICE_NAME = "Hugging Face";
|
private static final String SERVICE_NAME = "Hugging Face";
|
||||||
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES = EnumSet.of(
|
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES = EnumSet.of(
|
||||||
|
TaskType.RERANK,
|
||||||
TaskType.TEXT_EMBEDDING,
|
TaskType.TEXT_EMBEDDING,
|
||||||
TaskType.SPARSE_EMBEDDING,
|
TaskType.SPARSE_EMBEDDING,
|
||||||
TaskType.COMPLETION,
|
TaskType.COMPLETION,
|
||||||
|
@ -77,35 +76,43 @@ public class HuggingFaceService extends HuggingFaceBaseService {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected HuggingFaceModel createModel(
|
protected HuggingFaceModel createModel(HuggingFaceModelParameters params) {
|
||||||
String inferenceEntityId,
|
return switch (params.taskType()) {
|
||||||
TaskType taskType,
|
case RERANK -> new HuggingFaceRerankModel(
|
||||||
Map<String, Object> serviceSettings,
|
params.inferenceEntityId(),
|
||||||
ChunkingSettings chunkingSettings,
|
params.taskType(),
|
||||||
@Nullable Map<String, Object> secretSettings,
|
NAME,
|
||||||
String failureMessage,
|
params.serviceSettings(),
|
||||||
ConfigurationParseContext context
|
params.taskSettings(),
|
||||||
) {
|
params.secretSettings(),
|
||||||
return switch (taskType) {
|
params.context()
|
||||||
|
);
|
||||||
case TEXT_EMBEDDING -> new HuggingFaceEmbeddingsModel(
|
case TEXT_EMBEDDING -> new HuggingFaceEmbeddingsModel(
|
||||||
inferenceEntityId,
|
params.inferenceEntityId(),
|
||||||
taskType,
|
params.taskType(),
|
||||||
NAME,
|
NAME,
|
||||||
serviceSettings,
|
params.serviceSettings(),
|
||||||
chunkingSettings,
|
params.chunkingSettings(),
|
||||||
secretSettings,
|
params.secretSettings(),
|
||||||
context
|
params.context()
|
||||||
|
);
|
||||||
|
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(
|
||||||
|
params.inferenceEntityId(),
|
||||||
|
params.taskType(),
|
||||||
|
NAME,
|
||||||
|
params.serviceSettings(),
|
||||||
|
params.secretSettings(),
|
||||||
|
params.context()
|
||||||
);
|
);
|
||||||
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context);
|
|
||||||
case CHAT_COMPLETION, COMPLETION -> new HuggingFaceChatCompletionModel(
|
case CHAT_COMPLETION, COMPLETION -> new HuggingFaceChatCompletionModel(
|
||||||
inferenceEntityId,
|
params.inferenceEntityId(),
|
||||||
taskType,
|
params.taskType(),
|
||||||
NAME,
|
NAME,
|
||||||
serviceSettings,
|
params.serviceSettings(),
|
||||||
secretSettings,
|
params.secretSettings(),
|
||||||
context
|
params.context()
|
||||||
);
|
);
|
||||||
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
default -> throw new ElasticsearchStatusException(params.failureMessage(), RestStatus.BAD_REQUEST);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecut
|
||||||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
|
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
|
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
|
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||||
|
@ -23,8 +24,11 @@ import org.elasticsearch.xpack.inference.services.huggingface.completion.Hugging
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
|
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
|
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest;
|
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.request.rerank.HuggingFaceRerankRequest;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModel;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceElserResponseEntity;
|
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceElserResponseEntity;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity;
|
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceRerankResponseEntity;
|
||||||
import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler;
|
import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler;
|
||||||
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
|
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
|
||||||
|
|
||||||
|
@ -37,12 +41,27 @@ import static org.elasticsearch.core.Strings.format;
|
||||||
*/
|
*/
|
||||||
public class HuggingFaceActionCreator implements HuggingFaceActionVisitor {
|
public class HuggingFaceActionCreator implements HuggingFaceActionVisitor {
|
||||||
|
|
||||||
|
private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE =
|
||||||
|
"Failed to send Hugging Face %s request from inference entity id [%s]";
|
||||||
|
private static final String INVALID_REQUEST_TYPE_MESSAGE = "Invalid request type: expected HuggingFace %s request but got %s";
|
||||||
public static final String COMPLETION_ERROR_PREFIX = "Hugging Face completions";
|
public static final String COMPLETION_ERROR_PREFIX = "Hugging Face completions";
|
||||||
static final String USER_ROLE = "user";
|
static final String USER_ROLE = "user";
|
||||||
static final ResponseHandler COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler(
|
static final ResponseHandler COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler(
|
||||||
"hugging face completion",
|
"hugging face completion",
|
||||||
OpenAiChatCompletionResponseEntity::fromResponse
|
OpenAiChatCompletionResponseEntity::fromResponse
|
||||||
);
|
);
|
||||||
|
private static final ResponseHandler RERANK_HANDLER = new HuggingFaceResponseHandler("hugging face rerank", (request, response) -> {
|
||||||
|
if ((request instanceof HuggingFaceRerankRequest) == false) {
|
||||||
|
var errorMessage = format(
|
||||||
|
INVALID_REQUEST_TYPE_MESSAGE,
|
||||||
|
"RERANK",
|
||||||
|
request != null ? request.getClass().getSimpleName() : "null"
|
||||||
|
);
|
||||||
|
throw new IllegalArgumentException(errorMessage);
|
||||||
|
}
|
||||||
|
return HuggingFaceRerankResponseEntity.fromResponse((HuggingFaceRerankRequest) request, response);
|
||||||
|
});
|
||||||
|
|
||||||
private final Sender sender;
|
private final Sender sender;
|
||||||
private final ServiceComponents serviceComponents;
|
private final ServiceComponents serviceComponents;
|
||||||
|
|
||||||
|
@ -51,6 +70,26 @@ public class HuggingFaceActionCreator implements HuggingFaceActionVisitor {
|
||||||
this.serviceComponents = Objects.requireNonNull(serviceComponents);
|
this.serviceComponents = Objects.requireNonNull(serviceComponents);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ExecutableAction create(HuggingFaceRerankModel model) {
|
||||||
|
var overriddenModel = HuggingFaceRerankModel.of(model, model.getTaskSettings());
|
||||||
|
var manager = new GenericRequestManager<>(
|
||||||
|
serviceComponents.threadPool(),
|
||||||
|
overriddenModel,
|
||||||
|
RERANK_HANDLER,
|
||||||
|
inputs -> new HuggingFaceRerankRequest(
|
||||||
|
inputs.getQuery(),
|
||||||
|
inputs.getChunks(),
|
||||||
|
inputs.getReturnDocuments(),
|
||||||
|
inputs.getTopN(),
|
||||||
|
model
|
||||||
|
),
|
||||||
|
QueryAndDocsInputs.class
|
||||||
|
);
|
||||||
|
var errorMessage = buildErrorMessage(TaskType.RERANK, model.getInferenceEntityId());
|
||||||
|
return new SenderExecutableAction(sender, manager, errorMessage);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ExecutableAction create(HuggingFaceEmbeddingsModel model) {
|
public ExecutableAction create(HuggingFaceEmbeddingsModel model) {
|
||||||
var responseHandler = new HuggingFaceResponseHandler(
|
var responseHandler = new HuggingFaceResponseHandler(
|
||||||
|
@ -95,6 +134,6 @@ public class HuggingFaceActionCreator implements HuggingFaceActionVisitor {
|
||||||
}
|
}
|
||||||
|
|
||||||
public static String buildErrorMessage(TaskType requestType, String inferenceId) {
|
public static String buildErrorMessage(TaskType requestType, String inferenceId) {
|
||||||
return format("Failed to send Hugging Face %s request from inference entity id [%s]", requestType.toString(), inferenceId);
|
return format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, requestType.toString(), inferenceId);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,8 +11,11 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
|
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
|
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
|
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModel;
|
||||||
|
|
||||||
public interface HuggingFaceActionVisitor {
|
public interface HuggingFaceActionVisitor {
|
||||||
|
ExecutableAction create(HuggingFaceRerankModel model);
|
||||||
|
|
||||||
ExecutableAction create(HuggingFaceEmbeddingsModel model);
|
ExecutableAction create(HuggingFaceEmbeddingsModel model);
|
||||||
|
|
||||||
ExecutableAction create(HuggingFaceElserModel model);
|
ExecutableAction create(HuggingFaceElserModel model);
|
||||||
|
|
|
@ -13,11 +13,9 @@ import org.elasticsearch.TransportVersions;
|
||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
import org.elasticsearch.common.Strings;
|
import org.elasticsearch.common.Strings;
|
||||||
import org.elasticsearch.common.util.LazyInitializable;
|
import org.elasticsearch.common.util.LazyInitializable;
|
||||||
import org.elasticsearch.core.Nullable;
|
|
||||||
import org.elasticsearch.core.TimeValue;
|
import org.elasticsearch.core.TimeValue;
|
||||||
import org.elasticsearch.inference.ChunkInferenceInput;
|
import org.elasticsearch.inference.ChunkInferenceInput;
|
||||||
import org.elasticsearch.inference.ChunkedInference;
|
import org.elasticsearch.inference.ChunkedInference;
|
||||||
import org.elasticsearch.inference.ChunkingSettings;
|
|
||||||
import org.elasticsearch.inference.InferenceServiceConfiguration;
|
import org.elasticsearch.inference.InferenceServiceConfiguration;
|
||||||
import org.elasticsearch.inference.InferenceServiceResults;
|
import org.elasticsearch.inference.InferenceServiceResults;
|
||||||
import org.elasticsearch.inference.InputType;
|
import org.elasticsearch.inference.InputType;
|
||||||
|
@ -35,10 +33,10 @@ import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
|
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
|
||||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceBaseService;
|
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceBaseService;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
|
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModelParameters;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||||
|
|
||||||
|
@ -69,18 +67,17 @@ public class HuggingFaceElserService extends HuggingFaceBaseService {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected HuggingFaceModel createModel(
|
protected HuggingFaceModel createModel(HuggingFaceModelParameters input) {
|
||||||
String inferenceEntityId,
|
return switch (input.taskType()) {
|
||||||
TaskType taskType,
|
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(
|
||||||
Map<String, Object> serviceSettings,
|
input.inferenceEntityId(),
|
||||||
ChunkingSettings chunkingSettings,
|
input.taskType(),
|
||||||
@Nullable Map<String, Object> secretSettings,
|
NAME,
|
||||||
String failureMessage,
|
input.serviceSettings(),
|
||||||
ConfigurationParseContext context
|
input.secretSettings(),
|
||||||
) {
|
input.context()
|
||||||
return switch (taskType) {
|
);
|
||||||
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context);
|
default -> throw new ElasticsearchStatusException(input.failureMessage(), RestStatus.BAD_REQUEST);
|
||||||
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,99 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.huggingface.request.rerank;
|
||||||
|
|
||||||
|
import org.apache.http.HttpHeaders;
|
||||||
|
import org.apache.http.client.methods.HttpPost;
|
||||||
|
import org.apache.http.entity.ByteArrayEntity;
|
||||||
|
import org.elasticsearch.common.Strings;
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
|
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
|
||||||
|
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceAccount;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModel;
|
||||||
|
|
||||||
|
import java.net.URI;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
|
||||||
|
|
||||||
|
public class HuggingFaceRerankRequest implements Request {
|
||||||
|
|
||||||
|
private final HuggingFaceAccount account;
|
||||||
|
private final String query;
|
||||||
|
private final List<String> input;
|
||||||
|
private final Boolean returnDocuments;
|
||||||
|
private final Integer topN;
|
||||||
|
private final HuggingFaceRerankModel model;
|
||||||
|
|
||||||
|
public HuggingFaceRerankRequest(
|
||||||
|
String query,
|
||||||
|
List<String> input,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer topN,
|
||||||
|
HuggingFaceRerankModel model
|
||||||
|
) {
|
||||||
|
Objects.requireNonNull(model);
|
||||||
|
|
||||||
|
this.account = HuggingFaceAccount.of(model);
|
||||||
|
this.input = Objects.requireNonNull(input);
|
||||||
|
this.query = Objects.requireNonNull(query);
|
||||||
|
this.returnDocuments = returnDocuments;
|
||||||
|
this.topN = topN;
|
||||||
|
this.model = model;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public HttpRequest createHttpRequest() {
|
||||||
|
HttpPost httpPost = new HttpPost(account.uri());
|
||||||
|
|
||||||
|
ByteArrayEntity byteEntity = new ByteArrayEntity(
|
||||||
|
Strings.toString(new HuggingFaceRerankRequestEntity(query, input, returnDocuments, getTopN(), model.getTaskSettings()))
|
||||||
|
.getBytes(StandardCharsets.UTF_8)
|
||||||
|
);
|
||||||
|
httpPost.setEntity(byteEntity);
|
||||||
|
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters());
|
||||||
|
|
||||||
|
decorateWithAuth(httpPost);
|
||||||
|
|
||||||
|
return new HttpRequest(httpPost, getInferenceEntityId());
|
||||||
|
}
|
||||||
|
|
||||||
|
void decorateWithAuth(HttpPost httpPost) {
|
||||||
|
httpPost.setHeader(createAuthBearerHeader(model.apiKey()));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getInferenceEntityId() {
|
||||||
|
return model.getInferenceEntityId();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public URI getURI() {
|
||||||
|
return account.uri();
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getTopN() {
|
||||||
|
return topN != null ? topN : model.getTaskSettings().getTopNDocumentsOnly();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Request truncate() {
|
||||||
|
// Not applicable for rerank, only used in text embedding requests
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean[] getTruncationInfo() {
|
||||||
|
// Not applicable for rerank, only used in text embedding requests
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,60 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.huggingface.request.rerank;
|
||||||
|
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
|
import org.elasticsearch.xcontent.ToXContentObject;
|
||||||
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
public record HuggingFaceRerankRequestEntity(
|
||||||
|
String query,
|
||||||
|
List<String> documents,
|
||||||
|
@Nullable Boolean returnDocuments,
|
||||||
|
@Nullable Integer topN,
|
||||||
|
HuggingFaceRerankTaskSettings taskSettings
|
||||||
|
) implements ToXContentObject {
|
||||||
|
|
||||||
|
private static final String RETURN_TEXT = "return_text";
|
||||||
|
private static final String DOCUMENTS_FIELD = "texts";
|
||||||
|
private static final String QUERY_FIELD = "query";
|
||||||
|
|
||||||
|
public HuggingFaceRerankRequestEntity {
|
||||||
|
Objects.requireNonNull(query);
|
||||||
|
Objects.requireNonNull(documents);
|
||||||
|
Objects.requireNonNull(taskSettings);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
builder.startObject();
|
||||||
|
|
||||||
|
builder.field(DOCUMENTS_FIELD, documents);
|
||||||
|
builder.field(QUERY_FIELD, query);
|
||||||
|
|
||||||
|
// prefer the root level return_documents over task settings
|
||||||
|
if (returnDocuments != null) {
|
||||||
|
builder.field(RETURN_TEXT, returnDocuments);
|
||||||
|
} else if (taskSettings.getReturnDocuments() != null) {
|
||||||
|
builder.field(RETURN_TEXT, taskSettings.getReturnDocuments());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (topN != null) {
|
||||||
|
builder.field(HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, topN);
|
||||||
|
} else if (taskSettings.getTopNDocumentsOnly() != null) {
|
||||||
|
builder.field(HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, taskSettings.getTopNDocumentsOnly());
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.endObject();
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,91 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.huggingface.rerank;
|
||||||
|
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
|
import org.elasticsearch.inference.ModelConfigurations;
|
||||||
|
import org.elasticsearch.inference.ModelSecrets;
|
||||||
|
import org.elasticsearch.inference.TaskType;
|
||||||
|
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionVisitor;
|
||||||
|
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class HuggingFaceRerankModel extends HuggingFaceModel {
|
||||||
|
public static HuggingFaceRerankModel of(HuggingFaceRerankModel model, HuggingFaceRerankTaskSettings taskSettings) {
|
||||||
|
return new HuggingFaceRerankModel(model, HuggingFaceRerankTaskSettings.of(model.getTaskSettings(), taskSettings));
|
||||||
|
}
|
||||||
|
|
||||||
|
public HuggingFaceRerankModel(
|
||||||
|
String inferenceEntityId,
|
||||||
|
TaskType taskType,
|
||||||
|
String service,
|
||||||
|
Map<String, Object> serviceSettings,
|
||||||
|
Map<String, Object> taskSettings,
|
||||||
|
@Nullable Map<String, Object> secrets,
|
||||||
|
ConfigurationParseContext context
|
||||||
|
) {
|
||||||
|
this(
|
||||||
|
inferenceEntityId,
|
||||||
|
taskType,
|
||||||
|
service,
|
||||||
|
HuggingFaceRerankServiceSettings.fromMap(serviceSettings, context),
|
||||||
|
HuggingFaceRerankTaskSettings.fromMap(taskSettings),
|
||||||
|
DefaultSecretSettings.fromMap(secrets)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should only be used directly for testing
|
||||||
|
HuggingFaceRerankModel(
|
||||||
|
String inferenceEntityId,
|
||||||
|
TaskType taskType,
|
||||||
|
String service,
|
||||||
|
HuggingFaceRerankServiceSettings serviceSettings,
|
||||||
|
HuggingFaceRerankTaskSettings taskSettings,
|
||||||
|
@Nullable DefaultSecretSettings secrets
|
||||||
|
) {
|
||||||
|
super(
|
||||||
|
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings),
|
||||||
|
new ModelSecrets(secrets),
|
||||||
|
serviceSettings,
|
||||||
|
secrets
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private HuggingFaceRerankModel(HuggingFaceRerankModel model, HuggingFaceRerankTaskSettings taskSettings) {
|
||||||
|
super(model, taskSettings);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public HuggingFaceRerankServiceSettings getServiceSettings() {
|
||||||
|
return (HuggingFaceRerankServiceSettings) super.getServiceSettings();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public HuggingFaceRerankTaskSettings getTaskSettings() {
|
||||||
|
return (HuggingFaceRerankTaskSettings) super.getTaskSettings();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DefaultSecretSettings getSecretSettings() {
|
||||||
|
return (DefaultSecretSettings) super.getSecretSettings();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Integer getTokenLimit() {
|
||||||
|
throw new UnsupportedOperationException("Token Limit for rerank is sent in request and not retrieved from the model");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ExecutableAction accept(HuggingFaceActionVisitor creator) {
|
||||||
|
return creator.create(this);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,139 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.huggingface.rerank;
|
||||||
|
|
||||||
|
import org.elasticsearch.TransportVersion;
|
||||||
|
import org.elasticsearch.TransportVersions;
|
||||||
|
import org.elasticsearch.common.ValidationException;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamInput;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
|
import org.elasticsearch.inference.ServiceSettings;
|
||||||
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRateLimitServiceSettings;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService;
|
||||||
|
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||||
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.net.URI;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri;
|
||||||
|
import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri;
|
||||||
|
|
||||||
|
public class HuggingFaceRerankServiceSettings extends FilteredXContentObject
|
||||||
|
implements
|
||||||
|
ServiceSettings,
|
||||||
|
HuggingFaceRateLimitServiceSettings {
|
||||||
|
|
||||||
|
public static final String NAME = "hugging_face_rerank_service_settings";
|
||||||
|
public static final String URL = "url";
|
||||||
|
|
||||||
|
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000);
|
||||||
|
|
||||||
|
public static HuggingFaceRerankServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
|
||||||
|
ValidationException validationException = new ValidationException();
|
||||||
|
var uri = extractUri(map, URL, validationException);
|
||||||
|
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||||
|
map,
|
||||||
|
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||||
|
validationException,
|
||||||
|
HuggingFaceService.NAME,
|
||||||
|
context
|
||||||
|
);
|
||||||
|
|
||||||
|
if (validationException.validationErrors().isEmpty() == false) {
|
||||||
|
throw validationException;
|
||||||
|
}
|
||||||
|
return new HuggingFaceRerankServiceSettings(uri, rateLimitSettings);
|
||||||
|
}
|
||||||
|
|
||||||
|
private final URI uri;
|
||||||
|
private final RateLimitSettings rateLimitSettings;
|
||||||
|
|
||||||
|
public HuggingFaceRerankServiceSettings(String url) {
|
||||||
|
uri = createUri(url);
|
||||||
|
rateLimitSettings = DEFAULT_RATE_LIMIT_SETTINGS;
|
||||||
|
}
|
||||||
|
|
||||||
|
HuggingFaceRerankServiceSettings(URI uri, @Nullable RateLimitSettings rateLimitSettings) {
|
||||||
|
this.uri = Objects.requireNonNull(uri);
|
||||||
|
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
|
||||||
|
}
|
||||||
|
|
||||||
|
public HuggingFaceRerankServiceSettings(StreamInput in) throws IOException {
|
||||||
|
uri = createUri(in.readString());
|
||||||
|
rateLimitSettings = new RateLimitSettings(in);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RateLimitSettings rateLimitSettings() {
|
||||||
|
return rateLimitSettings;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public URI uri() {
|
||||||
|
return uri;
|
||||||
|
}
|
||||||
|
|
||||||
|
// model is not defined in the service settings.
|
||||||
|
// since hugging face requires that the model be chosen when initializing a deployment within their service.
|
||||||
|
@Override
|
||||||
|
public String modelId() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
builder.startObject();
|
||||||
|
toXContentFragmentOfExposedFields(builder, params);
|
||||||
|
builder.endObject();
|
||||||
|
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
builder.field(URL, uri.toString());
|
||||||
|
rateLimitSettings.toXContent(builder, params);
|
||||||
|
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getWriteableName() {
|
||||||
|
return NAME;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TransportVersion getMinimalSupportedVersion() {
|
||||||
|
return TransportVersions.ML_INFERENCE_HUGGING_FACE_RERANK_ADDED;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void writeTo(StreamOutput out) throws IOException {
|
||||||
|
out.writeString(uri.toString());
|
||||||
|
rateLimitSettings.writeTo(out);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) return true;
|
||||||
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
|
HuggingFaceRerankServiceSettings that = (HuggingFaceRerankServiceSettings) o;
|
||||||
|
return Objects.equals(uri, that.uri) && Objects.equals(rateLimitSettings, that.rateLimitSettings);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(uri, rateLimitSettings);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,156 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.huggingface.rerank;
|
||||||
|
|
||||||
|
import org.elasticsearch.TransportVersion;
|
||||||
|
import org.elasticsearch.TransportVersions;
|
||||||
|
import org.elasticsearch.common.ValidationException;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamInput;
|
||||||
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
|
import org.elasticsearch.inference.ModelConfigurations;
|
||||||
|
import org.elasticsearch.inference.TaskSettings;
|
||||||
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
|
||||||
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
|
||||||
|
|
||||||
|
public class HuggingFaceRerankTaskSettings implements TaskSettings {
|
||||||
|
|
||||||
|
public static final String NAME = "hugging_face_rerank_task_settings";
|
||||||
|
public static final String RETURN_DOCUMENTS = "return_documents";
|
||||||
|
public static final String TOP_N_DOCS_ONLY = "top_n";
|
||||||
|
|
||||||
|
static final HuggingFaceRerankTaskSettings EMPTY_SETTINGS = new HuggingFaceRerankTaskSettings(null, null);
|
||||||
|
|
||||||
|
public static HuggingFaceRerankTaskSettings fromMap(Map<String, Object> map) {
|
||||||
|
ValidationException validationException = new ValidationException();
|
||||||
|
|
||||||
|
if (map == null || map.isEmpty()) {
|
||||||
|
return EMPTY_SETTINGS;
|
||||||
|
}
|
||||||
|
|
||||||
|
Boolean returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS, validationException);
|
||||||
|
Integer topNDocumentsOnly = extractOptionalPositiveInteger(
|
||||||
|
map,
|
||||||
|
TOP_N_DOCS_ONLY,
|
||||||
|
ModelConfigurations.TASK_SETTINGS,
|
||||||
|
validationException
|
||||||
|
);
|
||||||
|
|
||||||
|
if (validationException.validationErrors().isEmpty() == false) {
|
||||||
|
throw validationException;
|
||||||
|
}
|
||||||
|
|
||||||
|
return of(topNDocumentsOnly, returnDocuments);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new {@link HuggingFaceRerankTaskSettings}
|
||||||
|
* by preferring non-null fields from the request settings over the original settings.
|
||||||
|
*
|
||||||
|
* @param originalSettings the settings stored as part of the inference entity configuration
|
||||||
|
* @param requestTaskSettings the settings passed in within the task_settings field of the request
|
||||||
|
* @return a constructed {@link HuggingFaceRerankTaskSettings}
|
||||||
|
*/
|
||||||
|
public static HuggingFaceRerankTaskSettings of(
|
||||||
|
HuggingFaceRerankTaskSettings originalSettings,
|
||||||
|
HuggingFaceRerankTaskSettings requestTaskSettings
|
||||||
|
) {
|
||||||
|
return new HuggingFaceRerankTaskSettings(
|
||||||
|
requestTaskSettings.getTopNDocumentsOnly() != null
|
||||||
|
? requestTaskSettings.getTopNDocumentsOnly()
|
||||||
|
: originalSettings.getTopNDocumentsOnly(),
|
||||||
|
requestTaskSettings.getReturnDocuments() != null
|
||||||
|
? requestTaskSettings.getReturnDocuments()
|
||||||
|
: originalSettings.getReturnDocuments()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static HuggingFaceRerankTaskSettings of(Integer topNDocumentsOnly, Boolean returnDocuments) {
|
||||||
|
return new HuggingFaceRerankTaskSettings(topNDocumentsOnly, returnDocuments);
|
||||||
|
}
|
||||||
|
|
||||||
|
private final Integer topNDocumentsOnly;
|
||||||
|
private final Boolean returnDocuments;
|
||||||
|
|
||||||
|
public HuggingFaceRerankTaskSettings(StreamInput in) throws IOException {
|
||||||
|
this(in.readOptionalVInt(), in.readOptionalBoolean());
|
||||||
|
}
|
||||||
|
|
||||||
|
public HuggingFaceRerankTaskSettings(@Nullable Integer topNDocumentsOnly, @Nullable Boolean doReturnDocuments) {
|
||||||
|
this.topNDocumentsOnly = topNDocumentsOnly;
|
||||||
|
this.returnDocuments = doReturnDocuments;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isEmpty() {
|
||||||
|
return topNDocumentsOnly == null && returnDocuments == null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
builder.startObject();
|
||||||
|
if (topNDocumentsOnly != null) {
|
||||||
|
builder.field(TOP_N_DOCS_ONLY, topNDocumentsOnly);
|
||||||
|
}
|
||||||
|
if (returnDocuments != null) {
|
||||||
|
builder.field(RETURN_DOCUMENTS, returnDocuments);
|
||||||
|
}
|
||||||
|
builder.endObject();
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getWriteableName() {
|
||||||
|
return NAME;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TransportVersion getMinimalSupportedVersion() {
|
||||||
|
return TransportVersions.ML_INFERENCE_HUGGING_FACE_RERANK_ADDED;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void writeTo(StreamOutput out) throws IOException {
|
||||||
|
out.writeOptionalVInt(topNDocumentsOnly);
|
||||||
|
out.writeOptionalBoolean(returnDocuments);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) return true;
|
||||||
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
|
HuggingFaceRerankTaskSettings that = (HuggingFaceRerankTaskSettings) o;
|
||||||
|
return Objects.equals(returnDocuments, that.returnDocuments) && Objects.equals(topNDocumentsOnly, that.topNDocumentsOnly);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(returnDocuments, topNDocumentsOnly);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getTopNDocumentsOnly() {
|
||||||
|
return topNDocumentsOnly;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Boolean getReturnDocuments() {
|
||||||
|
return returnDocuments;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
|
||||||
|
HuggingFaceRerankTaskSettings updatedSettings = HuggingFaceRerankTaskSettings.fromMap(new HashMap<>(newSettings));
|
||||||
|
return HuggingFaceRerankTaskSettings.of(this, updatedSettings);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,110 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.huggingface.response;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
|
import org.elasticsearch.xcontent.ConstructingObjectParser;
|
||||||
|
import org.elasticsearch.xcontent.ParseField;
|
||||||
|
import org.elasticsearch.xcontent.XContentFactory;
|
||||||
|
import org.elasticsearch.xcontent.XContentParser;
|
||||||
|
import org.elasticsearch.xcontent.XContentParserConfiguration;
|
||||||
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
|
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.request.rerank.HuggingFaceRerankRequest;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList;
|
||||||
|
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
|
||||||
|
|
||||||
|
public class HuggingFaceRerankResponseEntity {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parses the Hugging Face rerank response.
|
||||||
|
|
||||||
|
* For a request like:
|
||||||
|
*
|
||||||
|
* <pre>
|
||||||
|
* <code>
|
||||||
|
* {
|
||||||
|
* "texts": ["luke", "leia"],
|
||||||
|
* "query": "star wars main character",
|
||||||
|
* "return_text": true
|
||||||
|
* }
|
||||||
|
* </code>
|
||||||
|
* </pre>
|
||||||
|
|
||||||
|
* The response would look like:
|
||||||
|
|
||||||
|
* <pre>
|
||||||
|
* <code>
|
||||||
|
* [
|
||||||
|
* {
|
||||||
|
* "index": 0,
|
||||||
|
* "score": -0.07996220886707306,
|
||||||
|
* "text": "luke"
|
||||||
|
* },
|
||||||
|
* {
|
||||||
|
* "index": 1,
|
||||||
|
* "score": -0.08393221348524094,
|
||||||
|
* "text": "leia"
|
||||||
|
* }
|
||||||
|
* ]
|
||||||
|
* </code>
|
||||||
|
* </pre>
|
||||||
|
*/
|
||||||
|
|
||||||
|
public static RankedDocsResults fromResponse(HuggingFaceRerankRequest request, HttpResult response) throws IOException {
|
||||||
|
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
|
||||||
|
|
||||||
|
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
|
||||||
|
moveToFirstToken(jsonParser);
|
||||||
|
var rankedDocs = doParse(jsonParser);
|
||||||
|
var rankedDocsByRelevanceStream = rankedDocs.stream()
|
||||||
|
.sorted(Comparator.comparingDouble(RankedDocsResults.RankedDoc::relevanceScore).reversed());
|
||||||
|
var rankedDocStreamTopN = request.getTopN() == null
|
||||||
|
? rankedDocsByRelevanceStream
|
||||||
|
: rankedDocsByRelevanceStream.limit(request.getTopN());
|
||||||
|
return new RankedDocsResults(rankedDocStreamTopN.toList());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<RankedDocsResults.RankedDoc> doParse(XContentParser parser) throws IOException {
|
||||||
|
return parseList(parser, (listParser, index) -> {
|
||||||
|
var parsedRankedDoc = HuggingFaceRerankResponseEntity.RankedDocEntry.parse(parser);
|
||||||
|
return new RankedDocsResults.RankedDoc(parsedRankedDoc.index, parsedRankedDoc.score, parsedRankedDoc.text);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private record RankedDocEntry(Integer index, Float score, @Nullable String text) {
|
||||||
|
|
||||||
|
private static final ParseField TEXT = new ParseField("text");
|
||||||
|
private static final ParseField SCORE = new ParseField("score");
|
||||||
|
private static final ParseField INDEX = new ParseField("index");
|
||||||
|
private static final ConstructingObjectParser<HuggingFaceRerankResponseEntity.RankedDocEntry, Void> PARSER =
|
||||||
|
new ConstructingObjectParser<>(
|
||||||
|
"hugging_face_rerank_response",
|
||||||
|
true,
|
||||||
|
args -> new HuggingFaceRerankResponseEntity.RankedDocEntry((int) args[0], (float) args[1], (String) args[2])
|
||||||
|
);
|
||||||
|
|
||||||
|
static {
|
||||||
|
PARSER.declareInt(ConstructingObjectParser.constructorArg(), INDEX);
|
||||||
|
PARSER.declareFloat(ConstructingObjectParser.constructorArg(), SCORE);
|
||||||
|
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), TEXT);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static RankedDocEntry parse(XContentParser parser) {
|
||||||
|
return PARSER.apply(parser, null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -425,7 +425,7 @@ public abstract class BaseTransportInferenceActionTestCase<Request extends BaseI
|
||||||
doAnswer(ans -> {
|
doAnswer(ans -> {
|
||||||
listenerAction.accept(ans.getArgument(9));
|
listenerAction.accept(ans.getArgument(9));
|
||||||
return null;
|
return null;
|
||||||
}).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any());
|
}).when(service).infer(any(), any(), anyBoolean(), any(), any(), anyBoolean(), any(), any(), any(), any());
|
||||||
doAnswer(ans -> {
|
doAnswer(ans -> {
|
||||||
listenerAction.accept(ans.getArgument(3));
|
listenerAction.accept(ans.getArgument(3));
|
||||||
return null;
|
return null;
|
||||||
|
|
|
@ -1292,7 +1292,7 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
||||||
{
|
{
|
||||||
"service": "hugging_face",
|
"service": "hugging_face",
|
||||||
"name": "Hugging Face",
|
"name": "Hugging Face",
|
||||||
"task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"],
|
"task_types": ["text_embedding", "sparse_embedding", "rerank", "completion", "chat_completion"],
|
||||||
"configurations": {
|
"configurations": {
|
||||||
"api_key": {
|
"api_key": {
|
||||||
"description": "API Key for the provider you're connecting to.",
|
"description": "API Key for the provider you're connecting to.",
|
||||||
|
@ -1301,7 +1301,7 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
||||||
"sensitive": true,
|
"sensitive": true,
|
||||||
"updatable": true,
|
"updatable": true,
|
||||||
"type": "str",
|
"type": "str",
|
||||||
"supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"]
|
"supported_task_types": ["text_embedding", "sparse_embedding", "rerank", "completion", "chat_completion"]
|
||||||
},
|
},
|
||||||
"rate_limit.requests_per_minute": {
|
"rate_limit.requests_per_minute": {
|
||||||
"description": "Minimize the number of rate limit errors.",
|
"description": "Minimize the number of rate limit errors.",
|
||||||
|
@ -1310,7 +1310,7 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
||||||
"sensitive": false,
|
"sensitive": false,
|
||||||
"updatable": false,
|
"updatable": false,
|
||||||
"type": "int",
|
"type": "int",
|
||||||
"supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"]
|
"supported_task_types": ["text_embedding", "sparse_embedding", "rerank", "completion", "chat_completion"]
|
||||||
},
|
},
|
||||||
"url": {
|
"url": {
|
||||||
"description": "The URL endpoint to use for the requests.",
|
"description": "The URL endpoint to use for the requests.",
|
||||||
|
@ -1319,7 +1319,7 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
||||||
"sensitive": false,
|
"sensitive": false,
|
||||||
"updatable": false,
|
"updatable": false,
|
||||||
"type": "str",
|
"type": "str",
|
||||||
"supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"]
|
"supported_task_types": ["text_embedding", "sparse_embedding", "rerank", "completion", "chat_completion"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,12 +27,14 @@ import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
|
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
|
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
|
||||||
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
||||||
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
||||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests;
|
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests;
|
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests;
|
||||||
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModelTests;
|
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModelTests;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModelTests;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
|
|
||||||
|
@ -311,6 +313,66 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testSend_FailsFromInvalidResponseFormat_ForRerankAction() throws IOException {
|
||||||
|
var settings = buildSettingsWithRetryFields(
|
||||||
|
TimeValue.timeValueMillis(1),
|
||||||
|
TimeValue.timeValueMinutes(1),
|
||||||
|
TimeValue.timeValueSeconds(0)
|
||||||
|
);
|
||||||
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
|
||||||
|
|
||||||
|
try (var sender = createSender(senderFactory)) {
|
||||||
|
sender.start();
|
||||||
|
|
||||||
|
String responseJson = """
|
||||||
|
{
|
||||||
|
"rerank": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"relevance_score": -0.07996031,
|
||||||
|
"text": "luke"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
""";
|
||||||
|
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||||
|
|
||||||
|
var model = HuggingFaceRerankModelTests.createModel(getUrl(webServer), "secret", "model", 8, true);
|
||||||
|
var actionCreator = new HuggingFaceActionCreator(
|
||||||
|
sender,
|
||||||
|
new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator())
|
||||||
|
);
|
||||||
|
var action = actionCreator.create(model);
|
||||||
|
|
||||||
|
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||||
|
action.execute(new QueryAndDocsInputs("popular name", List.of("Luke")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||||
|
|
||||||
|
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
|
||||||
|
assertThat(
|
||||||
|
thrownException.getMessage(),
|
||||||
|
is("Failed to parse object: expecting token of type [START_ARRAY] but found [START_OBJECT]")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
assertRerankActionCreator(List.of("Luke"), "popular name", 8, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assertRerankActionCreator(List<String> documents, String query, int topN, boolean returnText) throws IOException {
|
||||||
|
assertThat(webServer.requests(), hasSize(1));
|
||||||
|
assertNull(webServer.requests().get(0).getUri().getQuery());
|
||||||
|
assertThat(
|
||||||
|
webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),
|
||||||
|
equalTo(XContentType.JSON.mediaTypeWithoutParameters())
|
||||||
|
);
|
||||||
|
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
|
||||||
|
|
||||||
|
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
||||||
|
assertThat(requestMap.size(), is(4));
|
||||||
|
assertThat(requestMap.get("texts"), is(documents));
|
||||||
|
assertThat(requestMap.get("query"), is(query));
|
||||||
|
assertThat(requestMap.get("top_n"), is(topN));
|
||||||
|
assertThat(requestMap.get("return_text"), is(returnText));
|
||||||
|
}
|
||||||
|
|
||||||
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOException {
|
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.huggingface.request.rerank;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.Strings;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.xcontent.ToXContent;
|
||||||
|
import org.elasticsearch.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.xcontent.XContentFactory;
|
||||||
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.elasticsearch.common.xcontent.XContentHelper.stripWhitespace;
|
||||||
|
|
||||||
|
public class HuggingFaceRerankRequestEntityTests extends ESTestCase {
|
||||||
|
private static final String INPUT = "texts";
|
||||||
|
private static final String QUERY = "query";
|
||||||
|
private static final Integer TOP_N = 8;
|
||||||
|
private static final Boolean RETURN_DOCUMENTS = false;
|
||||||
|
|
||||||
|
public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException {
|
||||||
|
var entity = new HuggingFaceRerankRequestEntity(
|
||||||
|
QUERY,
|
||||||
|
List.of(INPUT),
|
||||||
|
Boolean.TRUE,
|
||||||
|
TOP_N,
|
||||||
|
new HuggingFaceRerankTaskSettings(TOP_N, RETURN_DOCUMENTS)
|
||||||
|
);
|
||||||
|
|
||||||
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
|
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
|
||||||
|
String xContentResult = Strings.toString(builder);
|
||||||
|
String expected = """
|
||||||
|
{"texts":["texts"],
|
||||||
|
"query":"query",
|
||||||
|
"return_text":true,
|
||||||
|
"top_n":8}""";
|
||||||
|
assertEquals(stripWhitespace(expected), xContentResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testXContent_WritesMinimalFields() throws IOException {
|
||||||
|
var entity = new HuggingFaceRerankRequestEntity(QUERY, List.of(INPUT), null, null, new HuggingFaceRerankTaskSettings(null, null));
|
||||||
|
|
||||||
|
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||||
|
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
|
||||||
|
String xContentResult = Strings.toString(builder);
|
||||||
|
String expected = """
|
||||||
|
{"texts":["texts"],"query":"query"}""";
|
||||||
|
assertEquals(stripWhitespace(expected), xContentResult);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,99 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.huggingface.request.rerank;
|
||||||
|
|
||||||
|
import org.apache.http.HttpHeaders;
|
||||||
|
import org.apache.http.client.methods.HttpPost;
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.xcontent.XContentType;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModel;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModelTests;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||||
|
import static org.hamcrest.Matchers.aMapWithSize;
|
||||||
|
import static org.hamcrest.Matchers.instanceOf;
|
||||||
|
import static org.hamcrest.Matchers.is;
|
||||||
|
|
||||||
|
public class HuggingFaceRerankRequestTests extends ESTestCase {
|
||||||
|
private static final String INPUT = "texts";
|
||||||
|
private static final String QUERY = "query";
|
||||||
|
private static final String INFERENCE_ID = "model";
|
||||||
|
private static final Integer TOP_N = 8;
|
||||||
|
private static final Boolean RETURN_TEXT = false;
|
||||||
|
|
||||||
|
private static final String AUTH_HEADER_VALUE = "foo";
|
||||||
|
|
||||||
|
public void testCreateRequest_WithMinimalFieldsSet() throws IOException {
|
||||||
|
testCreateRequest(null, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testCreateRequest_WithTopN() throws IOException {
|
||||||
|
testCreateRequest(TOP_N, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testCreateRequest_WithReturnDocuments() throws IOException {
|
||||||
|
testCreateRequest(null, RETURN_TEXT);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void testCreateRequest(Integer topN, Boolean returnDocuments) throws IOException {
|
||||||
|
var request = createRequest(topN, returnDocuments);
|
||||||
|
var httpRequest = request.createHttpRequest();
|
||||||
|
|
||||||
|
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||||
|
var httpPost = (HttpPost) httpRequest.httpRequestBase();
|
||||||
|
|
||||||
|
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaTypeWithoutParameters()));
|
||||||
|
assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE));
|
||||||
|
|
||||||
|
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||||
|
|
||||||
|
assertThat(requestMap.get(INPUT), is(List.of(INPUT)));
|
||||||
|
assertThat(requestMap.get(QUERY), is(QUERY));
|
||||||
|
// input and query must exist
|
||||||
|
int itemsCount = 2;
|
||||||
|
if (topN != null) {
|
||||||
|
assertThat(requestMap.get("top_n"), is(topN));
|
||||||
|
itemsCount++;
|
||||||
|
}
|
||||||
|
if (returnDocuments != null) {
|
||||||
|
assertThat(requestMap.get("return_text"), is(returnDocuments));
|
||||||
|
itemsCount++;
|
||||||
|
}
|
||||||
|
assertThat(requestMap, aMapWithSize(itemsCount));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static HuggingFaceRerankRequest createRequest(@Nullable Integer topN, @Nullable Boolean returnDocuments) {
|
||||||
|
var rerankModel = HuggingFaceRerankModelTests.createModel(randomAlphaOfLength(10), "secret", INFERENCE_ID, topN, returnDocuments);
|
||||||
|
|
||||||
|
return new HuggingFaceRerankWithoutAuthRequest(QUERY, List.of(INPUT), rerankModel, topN, returnDocuments);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* We use this class to fake the auth implementation to avoid static mocking of {@link HuggingFaceRerankRequest}
|
||||||
|
*/
|
||||||
|
private static class HuggingFaceRerankWithoutAuthRequest extends HuggingFaceRerankRequest {
|
||||||
|
HuggingFaceRerankWithoutAuthRequest(
|
||||||
|
String query,
|
||||||
|
List<String> input,
|
||||||
|
HuggingFaceRerankModel model,
|
||||||
|
@Nullable Integer topN,
|
||||||
|
@Nullable Boolean returnDocuments
|
||||||
|
) {
|
||||||
|
super(query, input, returnDocuments, topN, model);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void decorateWithAuth(HttpPost httpPost) {
|
||||||
|
httpPost.setHeader(HttpHeaders.AUTHORIZATION, AUTH_HEADER_VALUE);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,41 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.huggingface.rerank;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.settings.SecureString;
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
|
import org.elasticsearch.inference.TaskType;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.containsString;
|
||||||
|
|
||||||
|
public class HuggingFaceRerankModelTests extends ESTestCase {
|
||||||
|
|
||||||
|
public void testThrowsURISyntaxException_ForInvalidUrl() {
|
||||||
|
var thrownException = expectThrows(IllegalArgumentException.class, () -> createModel("^^", "secret", "model", 8, false));
|
||||||
|
assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public static HuggingFaceRerankModel createModel(
|
||||||
|
String url,
|
||||||
|
String apiKey,
|
||||||
|
String modelId,
|
||||||
|
@Nullable Integer topN,
|
||||||
|
@Nullable Boolean returnDocuments
|
||||||
|
) {
|
||||||
|
return new HuggingFaceRerankModel(
|
||||||
|
modelId,
|
||||||
|
TaskType.RERANK,
|
||||||
|
"service",
|
||||||
|
new HuggingFaceRerankServiceSettings(url),
|
||||||
|
new HuggingFaceRerankTaskSettings(topN, returnDocuments),
|
||||||
|
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,124 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.huggingface.rerank;
|
||||||
|
|
||||||
|
import org.elasticsearch.TransportVersion;
|
||||||
|
import org.elasticsearch.common.ValidationException;
|
||||||
|
import org.elasticsearch.common.io.stream.Writeable;
|
||||||
|
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.containsString;
|
||||||
|
|
||||||
|
public class HuggingFaceRerankTaskSettingsTests extends AbstractBWCWireSerializationTestCase<HuggingFaceRerankTaskSettings> {
|
||||||
|
|
||||||
|
public static HuggingFaceRerankTaskSettings createRandom() {
|
||||||
|
var returnDocuments = randomBoolean() ? randomBoolean() : null;
|
||||||
|
var topNDocsOnly = randomBoolean() ? randomIntBetween(1, 10) : null;
|
||||||
|
|
||||||
|
return new HuggingFaceRerankTaskSettings(topNDocsOnly, returnDocuments);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testFromMap_WithValidValues_ReturnsSettings() {
|
||||||
|
Map<String, Object> taskMap = Map.of(
|
||||||
|
HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS,
|
||||||
|
true,
|
||||||
|
HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY,
|
||||||
|
5
|
||||||
|
);
|
||||||
|
var settings = HuggingFaceRerankTaskSettings.fromMap(new HashMap<>(taskMap));
|
||||||
|
assertTrue(settings.getReturnDocuments());
|
||||||
|
assertEquals(5, settings.getTopNDocumentsOnly().intValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testFromMap_WithNullValues_ReturnsSettingsWithNulls() {
|
||||||
|
var settings = HuggingFaceRerankTaskSettings.fromMap(Map.of());
|
||||||
|
assertNull(settings.getReturnDocuments());
|
||||||
|
assertNull(settings.getTopNDocumentsOnly());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testFromMap_WithInvalidReturnDocuments_ThrowsValidationException() {
|
||||||
|
Map<String, Object> taskMap = Map.of(
|
||||||
|
HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS,
|
||||||
|
"invalid",
|
||||||
|
HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY,
|
||||||
|
5
|
||||||
|
);
|
||||||
|
var thrownException = expectThrows(ValidationException.class, () -> HuggingFaceRerankTaskSettings.fromMap(new HashMap<>(taskMap)));
|
||||||
|
assertThat(thrownException.getMessage(), containsString("field [return_documents] is not of the expected type"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testFromMap_WithInvalidTopNDocsOnly_ThrowsValidationException() {
|
||||||
|
Map<String, Object> taskMap = Map.of(
|
||||||
|
HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS,
|
||||||
|
true,
|
||||||
|
HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY,
|
||||||
|
"invalid"
|
||||||
|
);
|
||||||
|
var thrownException = expectThrows(ValidationException.class, () -> HuggingFaceRerankTaskSettings.fromMap(new HashMap<>(taskMap)));
|
||||||
|
assertThat(thrownException.getMessage(), containsString("field [top_n] is not of the expected type"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void UpdatedTaskSettings_WithEmptyMap_ReturnsSameSettings() {
|
||||||
|
var initialSettings = new HuggingFaceRerankTaskSettings(5, true);
|
||||||
|
HuggingFaceRerankTaskSettings updatedSettings = (HuggingFaceRerankTaskSettings) initialSettings.updatedTaskSettings(Map.of());
|
||||||
|
assertEquals(initialSettings, updatedSettings);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testUpdatedTaskSettings_WithNewReturnDocuments_ReturnsUpdatedSettings() {
|
||||||
|
var initialSettings = new HuggingFaceRerankTaskSettings(5, true);
|
||||||
|
Map<String, Object> newSettings = Map.of(HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS, false);
|
||||||
|
HuggingFaceRerankTaskSettings updatedSettings = (HuggingFaceRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings);
|
||||||
|
assertFalse(updatedSettings.getReturnDocuments());
|
||||||
|
assertEquals(initialSettings.getTopNDocumentsOnly(), updatedSettings.getTopNDocumentsOnly());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testUpdatedTaskSettings_WithNewTopNDocsOnly_ReturnsUpdatedSettings() {
|
||||||
|
var initialSettings = new HuggingFaceRerankTaskSettings(5, true);
|
||||||
|
Map<String, Object> newSettings = Map.of(HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, 7);
|
||||||
|
HuggingFaceRerankTaskSettings updatedSettings = (HuggingFaceRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings);
|
||||||
|
assertEquals(7, updatedSettings.getTopNDocumentsOnly().intValue());
|
||||||
|
assertEquals(initialSettings.getReturnDocuments(), updatedSettings.getReturnDocuments());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testUpdatedTaskSettings_WithMultipleNewValues_ReturnsUpdatedSettings() {
|
||||||
|
var initialSettings = new HuggingFaceRerankTaskSettings(5, true);
|
||||||
|
Map<String, Object> newSettings = Map.of(
|
||||||
|
HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS,
|
||||||
|
false,
|
||||||
|
HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY,
|
||||||
|
7
|
||||||
|
);
|
||||||
|
HuggingFaceRerankTaskSettings updatedSettings = (HuggingFaceRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings);
|
||||||
|
assertFalse(updatedSettings.getReturnDocuments());
|
||||||
|
assertEquals(7, updatedSettings.getTopNDocumentsOnly().intValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Writeable.Reader<HuggingFaceRerankTaskSettings> instanceReader() {
|
||||||
|
return HuggingFaceRerankTaskSettings::new;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected HuggingFaceRerankTaskSettings createTestInstance() {
|
||||||
|
return createRandom();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected HuggingFaceRerankTaskSettings mutateInstance(HuggingFaceRerankTaskSettings instance) throws IOException {
|
||||||
|
return randomValueOtherThan(instance, HuggingFaceRerankTaskSettingsTests::createRandom);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected HuggingFaceRerankTaskSettings mutateInstanceForVersion(HuggingFaceRerankTaskSettings instance, TransportVersion version) {
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,147 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.services.huggingface.response;
|
||||||
|
|
||||||
|
import org.apache.http.HttpResponse;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
|
||||||
|
import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests;
|
||||||
|
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||||
|
import org.elasticsearch.xpack.inference.services.huggingface.request.rerank.HuggingFaceRerankRequest;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests.buildExpectationRerank;
|
||||||
|
import static org.hamcrest.Matchers.is;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
public class HuggingFaceRerankResponseEntityTests extends ESTestCase {
|
||||||
|
private static final String MISSED_FIELD_INDEX = "index";
|
||||||
|
private static final String MISSED_FIELD_SCORE = "score";
|
||||||
|
private static final String RESPONSE_JSON_TWO_DOCS = """
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"index": 4,
|
||||||
|
"score": -0.22222222222222222,
|
||||||
|
"text": "ranked second"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"index": 1,
|
||||||
|
"score": 1.11111111111111111,
|
||||||
|
"text": "ranked first"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
""";
|
||||||
|
private static final List<RankedDocsResultsTests.RerankExpectation> EXPECTED_TWO_DOCS = List.of(
|
||||||
|
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 1, "relevance_score", 1.11111111111111111F, "text", "ranked first")),
|
||||||
|
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 4, "relevance_score", -0.22222222222222222F, "text", "ranked second"))
|
||||||
|
);
|
||||||
|
|
||||||
|
private static final String RESPONSE_JSON_FIVE_DOCS = """
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"index": 1,
|
||||||
|
"score": 1.11111111111111111,
|
||||||
|
"text": "ranked first"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"index": 3,
|
||||||
|
"score": -0.33333333333333333,
|
||||||
|
"text": "ranked third"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"score": -0.55555555555555555,
|
||||||
|
"text": "ranked fifth"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"index": 2,
|
||||||
|
"score": -0.44444444444444444,
|
||||||
|
"text": "ranked fourth"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"index": 4,
|
||||||
|
"score": -0.22222222222222222,
|
||||||
|
"text": "ranked second"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
""";
|
||||||
|
private static final List<RankedDocsResultsTests.RerankExpectation> EXPECTED_FIVE_DOCS = List.of(
|
||||||
|
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 1, "relevance_score", 1.11111111111111111F, "text", "ranked first")),
|
||||||
|
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 4, "relevance_score", -0.22222222222222222F, "text", "ranked second")),
|
||||||
|
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 3, "relevance_score", -0.33333333333333333F, "text", "ranked third")),
|
||||||
|
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 2, "relevance_score", -0.44444444444444444F, "text", "ranked fourth")),
|
||||||
|
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 0, "relevance_score", -0.55555555555555555F, "text", "ranked fifth"))
|
||||||
|
);
|
||||||
|
|
||||||
|
private static final HuggingFaceRerankRequest REQUEST_MOCK = mock(HuggingFaceRerankRequest.class);
|
||||||
|
|
||||||
|
public void testFromResponse_CreatesRankedDocsResults_TopNNull_FiveDocs_NoLimit() throws IOException {
|
||||||
|
assertTopNLimit(null, RESPONSE_JSON_FIVE_DOCS, EXPECTED_FIVE_DOCS);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testFromResponse_CreatesRankedDocsResults_TopN5_TwoDocs_NoLimit() throws IOException {
|
||||||
|
assertTopNLimit(5, RESPONSE_JSON_TWO_DOCS, EXPECTED_TWO_DOCS);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testFromResponse_CreatesRankedDocsResults_TopN2_FiveDocs_Limits() throws IOException {
|
||||||
|
assertTopNLimit(2, RESPONSE_JSON_FIVE_DOCS, EXPECTED_TWO_DOCS);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testFails_CreateRankedDocsResults_IndexFieldNull() {
|
||||||
|
String responseJson = """
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"score": 1.11111111111111111,
|
||||||
|
"text": "ranked first"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
""";
|
||||||
|
assertMissingFieldThrowsIllegalArgumentException(responseJson, MISSED_FIELD_INDEX);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testFails_CreateRankedDocsResults_ScoreFieldNull() {
|
||||||
|
String responseJson = """
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"index": 1,
|
||||||
|
"text": "ranked first"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
""";
|
||||||
|
assertMissingFieldThrowsIllegalArgumentException(responseJson, MISSED_FIELD_SCORE);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assertMissingFieldThrowsIllegalArgumentException(String responseJson, String missingField) {
|
||||||
|
when(REQUEST_MOCK.getTopN()).thenReturn(1);
|
||||||
|
|
||||||
|
var thrownException = expectThrows(
|
||||||
|
IllegalArgumentException.class,
|
||||||
|
() -> HuggingFaceRerankResponseEntity.fromResponse(
|
||||||
|
REQUEST_MOCK,
|
||||||
|
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
||||||
|
)
|
||||||
|
);
|
||||||
|
assertThat(thrownException.getMessage(), is("Required [" + missingField + "]"));
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assertTopNLimit(
|
||||||
|
Integer topN, String responseJson, List<RankedDocsResultsTests.RerankExpectation> expectation) throws IOException {
|
||||||
|
when(REQUEST_MOCK.getTopN()).thenReturn(topN);
|
||||||
|
|
||||||
|
RankedDocsResults parsedResults = HuggingFaceRerankResponseEntity.fromResponse(
|
||||||
|
REQUEST_MOCK,
|
||||||
|
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
||||||
|
);
|
||||||
|
assertThat(parsedResults.asMap(), is(buildExpectationRerank(expectation)));
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue