Add Hugging Face Rerank support (#127966)

* Add Hugging Face Rerank support

* Address comments

* Add transport version

* Add transport version

* Add to inference service and crud IT rerank tests

* Refactor slightly / error message

* correct 'testGetConfiguration' test case

* apply suggestions

* fix tests

* apply suggestions

* [CI] Auto commit changes from spotless

* add changelog information

---------

Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
This commit is contained in:
Evgenii-Kazannik 2025-05-22 21:47:41 +02:00 committed by GitHub
parent f6e4a26480
commit c7cf8507a2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
31 changed files with 1485 additions and 85 deletions

View file

@ -0,0 +1,5 @@
pr: 127966
summary: "[ML] Add Rerank support to the Inference Plugin"
area: Machine Learning
type: enhancement
issues: []

View file

@ -179,6 +179,7 @@ public class TransportVersions {
public static final TransportVersion V_8_19_FIELD_CAPS_ADD_CLUSTER_ALIAS = def(8_841_0_32);
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME_8_19 = def(8_841_0_34);
public static final TransportVersion RERANKER_FAILURES_ALLOWED_8_19 = def(8_841_0_35);
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED_8_19 = def(8_841_0_36);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
@ -261,7 +262,7 @@ public class TransportVersions {
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME = def(9_077_0_00);
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED = def(9_078_0_00);
public static final TransportVersion NODES_STATS_SUPPORTS_MULTI_PROJECT = def(9_079_0_00);
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED = def(9_080_0_00);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _

View file

@ -20,9 +20,8 @@ import static org.elasticsearch.test.ESTestCase.randomInt;
public class SettingsConfigurationTestUtils {
public static SettingsConfiguration getRandomSettingsConfigurationField() {
return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)).setDefaultValue(
randomAlphaOfLength(10)
)
return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK))
.setDefaultValue(randomAlphaOfLength(10))
.setDescription(randomAlphaOfLength(10))
.setLabel(randomAlphaOfLength(10))
.setRequired(randomBoolean())

View file

@ -82,4 +82,13 @@ public class RankedDocsResultsTests extends AbstractChunkedBWCSerializationTestC
protected RankedDocsResults doParseInstance(XContentParser parser) throws IOException {
return RankedDocsResults.createParser(true).apply(parser, null);
}
public record RerankExpectation(Map<String, Object> rankedDocFields) {}
public static Map<String, Object> buildExpectationRerank(List<RerankExpectation> rerank) {
return Map.of(
RankedDocsResults.RERANK,
rerank.stream().map(rerankExpectation -> Map.of(RankedDocsResults.RankedDoc.NAME, rerankExpectation.rankedDocFields)).toList()
);
}
}

View file

@ -171,6 +171,20 @@ public class InferenceBaseRestTest extends ESRestTestCase {
""";
}
static String mockRerankServiceModelConfig() {
return """
{
"service": "test_reranking_service",
"service_settings": {
"model_id": "my_model",
"api_key": "abc64"
},
"task_settings": {
}
}
""";
}
static void deleteModel(String modelId) throws IOException {
var request = new Request("DELETE", "_inference/" + modelId);
var response = client().performRequest(request);
@ -484,6 +498,10 @@ public class InferenceBaseRestTest extends ESRestTestCase {
@SuppressWarnings("unchecked")
protected void assertNonEmptyInferenceResults(Map<String, Object> resultMap, int expectedNumberOfResults, TaskType taskType) {
switch (taskType) {
case RERANK -> {
var results = (List<Map<String, Object>>) resultMap.get(TaskType.RERANK.toString());
assertThat(results, hasSize(expectedNumberOfResults));
}
case SPARSE_EMBEDDING -> {
var results = (List<Map<String, Object>>) resultMap.get(TaskType.SPARSE_EMBEDDING.toString());
assertThat(results, hasSize(expectedNumberOfResults));

View file

@ -53,9 +53,12 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
for (int i = 0; i < 4; i++) {
putModel("te_model_" + i, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING);
}
for (int i = 0; i < 3; i++) {
putModel("re-model-" + i, mockRerankServiceModelConfig(), TaskType.RERANK);
}
var getAllModels = getAllModels();
int numModels = 12;
int numModels = 15;
assertThat(getAllModels, hasSize(numModels));
var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING);
@ -71,6 +74,13 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
for (var denseModel : getDenseModels) {
assertEquals("text_embedding", denseModel.get("task_type"));
}
var getRerankModels = getModels("_all", TaskType.RERANK);
int numRerankModels = 4;
assertThat(getRerankModels, hasSize(numRerankModels));
for (var denseModel : getRerankModels) {
assertEquals("rerank", denseModel.get("task_type"));
}
String oldApiKey;
{
var singleModel = getModels("se_model_1", TaskType.SPARSE_EMBEDDING);
@ -100,6 +110,9 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
for (int i = 0; i < 4; i++) {
deleteModel("te_model_" + i, TaskType.TEXT_EMBEDDING);
}
for (int i = 0; i < 3; i++) {
deleteModel("re-model-" + i, TaskType.RERANK);
}
}
public void testGetModelWithWrongTaskType() throws IOException {

View file

@ -101,7 +101,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
public void testGetServicesWithRerankTaskType() throws IOException {
List<Object> services = getServices(TaskType.RERANK);
assertThat(services.size(), equalTo(7));
assertThat(services.size(), equalTo(8));
var providers = providers(services);
@ -115,7 +115,8 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
"googlevertexai",
"jinaai",
"test_reranking_service",
"voyageai"
"voyageai",
"hugging_face"
).toArray()
)
);

View file

@ -0,0 +1,67 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference;
import org.elasticsearch.inference.TaskType;
import java.io.IOException;
import java.util.List;
import java.util.Map;
public class MockRerankInferenceServiceIT extends InferenceBaseRestTest {
@SuppressWarnings("unchecked")
public void testMockService() throws IOException {
String inferenceEntityId = "test-mock";
var putModel = putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK);
var model = getModels(inferenceEntityId, TaskType.RERANK).get(0);
for (var modelMap : List.of(putModel, model)) {
assertEquals(inferenceEntityId, modelMap.get("inference_id"));
assertEquals(TaskType.RERANK, TaskType.fromString((String) modelMap.get("task_type")));
assertEquals("test_reranking_service", modelMap.get("service"));
}
List<String> input = List.of(randomAlphaOfLength(10));
var inference = infer(inferenceEntityId, input);
assertNonEmptyInferenceResults(inference, 1, TaskType.RERANK);
assertEquals(inference, infer(inferenceEntityId, input));
assertNotEquals(inference, infer(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10)))));
}
public void testMockServiceWithMultipleInputs() throws IOException {
String inferenceEntityId = "test-mock-with-multi-inputs";
putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK);
var queryParams = Map.of("timeout", "120s");
var inference = infer(
inferenceEntityId,
TaskType.RERANK,
List.of(randomAlphaOfLength(5), randomAlphaOfLength(10)),
"What if?",
queryParams
);
assertNonEmptyInferenceResults(inference, 2, TaskType.RERANK);
}
@SuppressWarnings("unchecked")
public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOException {
String inferenceEntityId = "test-mock";
var putModel = putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK);
var model = getModels(inferenceEntityId, TaskType.RERANK).get(0);
var serviceSettings = (Map<String, Object>) model.get("service_settings");
assertNull(serviceSettings.get("api_key"));
assertNotNull(serviceSettings.get("model_id"));
var putServiceSettings = (Map<String, Object>) putModel.get("service_settings");
assertNull(putServiceSettings.get("api_key"));
assertNotNull(putServiceSettings.get("model_id"));
}
}

View file

@ -80,6 +80,8 @@ import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVe
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
@ -365,6 +367,16 @@ public class InferenceNamedWriteablesProvider {
HuggingFaceChatCompletionServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, HuggingFaceRerankTaskSettings.NAME, HuggingFaceRerankTaskSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
HuggingFaceRerankServiceSettings.NAME,
HuggingFaceRerankServiceSettings::new
)
);
}
private static void addGoogleAiStudioNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {

View file

@ -19,6 +19,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
import java.util.ArrayList;
import java.util.Arrays;
@ -94,7 +95,10 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContext extends RankFe
} else if (r.getEndpoints().isEmpty() == false
&& r.getEndpoints().get(0).getTaskSettings() instanceof GoogleVertexAiRerankTaskSettings googleVertexAiTaskSettings) {
configuredTopN = googleVertexAiTaskSettings.topN();
}
} else if (r.getEndpoints().isEmpty() == false
&& r.getEndpoints().get(0).getTaskSettings() instanceof HuggingFaceRerankTaskSettings huggingFaceRerankTaskSettings) {
configuredTopN = huggingFaceRerankTaskSettings.getTopNDocumentsOnly();
}
if (configuredTopN != null && configuredTopN < rankWindowSize) {
l.onFailure(
new IllegalArgumentException(

View file

@ -57,6 +57,7 @@ public abstract class HuggingFaceBaseService extends SenderService {
) {
try {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
ChunkingSettings chunkingSettings = null;
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
@ -66,17 +67,21 @@ public abstract class HuggingFaceBaseService extends SenderService {
}
var model = createModel(
inferenceEntityId,
taskType,
serviceSettingsMap,
chunkingSettings,
serviceSettingsMap,
TaskType.unsupportedTaskTypeErrorMsg(taskType, name()),
ConfigurationParseContext.REQUEST
new HuggingFaceModelParameters(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
serviceSettingsMap,
TaskType.unsupportedTaskTypeErrorMsg(taskType, name()),
ConfigurationParseContext.REQUEST
)
);
throwIfNotEmptyMap(config, name());
throwIfNotEmptyMap(serviceSettingsMap, name());
throwIfNotEmptyMap(taskSettingsMap, name());
parsedModelListener.onResponse(model);
} catch (Exception e) {
@ -92,6 +97,7 @@ public abstract class HuggingFaceBaseService extends SenderService {
Map<String, Object> secrets
) {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);
ChunkingSettings chunkingSettings = null;
@ -100,19 +106,23 @@ public abstract class HuggingFaceBaseService extends SenderService {
}
return createModel(
inferenceEntityId,
taskType,
serviceSettingsMap,
chunkingSettings,
secretSettingsMap,
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
ConfigurationParseContext.PERSISTENT
new HuggingFaceModelParameters(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
secretSettingsMap,
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
ConfigurationParseContext.PERSISTENT
)
);
}
@Override
public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
ChunkingSettings chunkingSettings = null;
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
@ -120,25 +130,20 @@ public abstract class HuggingFaceBaseService extends SenderService {
}
return createModel(
inferenceEntityId,
taskType,
serviceSettingsMap,
chunkingSettings,
null,
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
ConfigurationParseContext.PERSISTENT
new HuggingFaceModelParameters(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
chunkingSettings,
null,
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
ConfigurationParseContext.PERSISTENT
)
);
}
protected abstract HuggingFaceModel createModel(
String inferenceEntityId,
TaskType taskType,
Map<String, Object> serviceSettings,
ChunkingSettings chunkingSettings,
Map<String, Object> secretSettings,
String failureMessage,
ConfigurationParseContext context
);
protected abstract HuggingFaceModel createModel(HuggingFaceModelParameters input);
@Override
public void doInfer(

View file

@ -11,6 +11,7 @@ import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
@ -35,6 +36,13 @@ public abstract class HuggingFaceModel extends RateLimitGroupingModel {
apiKey = ServiceUtils.apiKey(apiKeySecrets);
}
protected HuggingFaceModel(HuggingFaceModel model, TaskSettings taskSettings) {
super(model, taskSettings);
rateLimitServiceSettings = model.rateLimitServiceSettings();
apiKey = model.apiKey();
}
public HuggingFaceRateLimitServiceSettings rateLimitServiceSettings() {
return rateLimitServiceSettings;
}

View file

@ -0,0 +1,25 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.huggingface;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import java.util.Map;
public record HuggingFaceModelParameters(
String inferenceEntityId,
TaskType taskType,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
ChunkingSettings chunkingSettings,
Map<String, Object> secretSettings,
String failureMessage,
ConfigurationParseContext context
) {}

View file

@ -12,10 +12,8 @@ import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
@ -32,7 +30,6 @@ import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator;
@ -40,6 +37,7 @@ import org.elasticsearch.xpack.inference.services.huggingface.completion.Hugging
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModel;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -62,6 +60,7 @@ public class HuggingFaceService extends HuggingFaceBaseService {
private static final String SERVICE_NAME = "Hugging Face";
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES = EnumSet.of(
TaskType.RERANK,
TaskType.TEXT_EMBEDDING,
TaskType.SPARSE_EMBEDDING,
TaskType.COMPLETION,
@ -77,35 +76,43 @@ public class HuggingFaceService extends HuggingFaceBaseService {
}
@Override
protected HuggingFaceModel createModel(
String inferenceEntityId,
TaskType taskType,
Map<String, Object> serviceSettings,
ChunkingSettings chunkingSettings,
@Nullable Map<String, Object> secretSettings,
String failureMessage,
ConfigurationParseContext context
) {
return switch (taskType) {
protected HuggingFaceModel createModel(HuggingFaceModelParameters params) {
return switch (params.taskType()) {
case RERANK -> new HuggingFaceRerankModel(
params.inferenceEntityId(),
params.taskType(),
NAME,
params.serviceSettings(),
params.taskSettings(),
params.secretSettings(),
params.context()
);
case TEXT_EMBEDDING -> new HuggingFaceEmbeddingsModel(
inferenceEntityId,
taskType,
params.inferenceEntityId(),
params.taskType(),
NAME,
serviceSettings,
chunkingSettings,
secretSettings,
context
params.serviceSettings(),
params.chunkingSettings(),
params.secretSettings(),
params.context()
);
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(
params.inferenceEntityId(),
params.taskType(),
NAME,
params.serviceSettings(),
params.secretSettings(),
params.context()
);
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context);
case CHAT_COMPLETION, COMPLETION -> new HuggingFaceChatCompletionModel(
inferenceEntityId,
taskType,
params.inferenceEntityId(),
params.taskType(),
NAME,
serviceSettings,
secretSettings,
context
params.serviceSettings(),
params.secretSettings(),
params.context()
);
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
default -> throw new ElasticsearchStatusException(params.failureMessage(), RestStatus.BAD_REQUEST);
};
}

View file

@ -14,6 +14,7 @@ import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecut
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
@ -23,8 +24,11 @@ import org.elasticsearch.xpack.inference.services.huggingface.completion.Hugging
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.huggingface.request.rerank.HuggingFaceRerankRequest;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModel;
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceElserResponseEntity;
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity;
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceRerankResponseEntity;
import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
@ -37,12 +41,27 @@ import static org.elasticsearch.core.Strings.format;
*/
public class HuggingFaceActionCreator implements HuggingFaceActionVisitor {
private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE =
"Failed to send Hugging Face %s request from inference entity id [%s]";
private static final String INVALID_REQUEST_TYPE_MESSAGE = "Invalid request type: expected HuggingFace %s request but got %s";
public static final String COMPLETION_ERROR_PREFIX = "Hugging Face completions";
static final String USER_ROLE = "user";
static final ResponseHandler COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler(
"hugging face completion",
OpenAiChatCompletionResponseEntity::fromResponse
);
private static final ResponseHandler RERANK_HANDLER = new HuggingFaceResponseHandler("hugging face rerank", (request, response) -> {
if ((request instanceof HuggingFaceRerankRequest) == false) {
var errorMessage = format(
INVALID_REQUEST_TYPE_MESSAGE,
"RERANK",
request != null ? request.getClass().getSimpleName() : "null"
);
throw new IllegalArgumentException(errorMessage);
}
return HuggingFaceRerankResponseEntity.fromResponse((HuggingFaceRerankRequest) request, response);
});
private final Sender sender;
private final ServiceComponents serviceComponents;
@ -51,6 +70,26 @@ public class HuggingFaceActionCreator implements HuggingFaceActionVisitor {
this.serviceComponents = Objects.requireNonNull(serviceComponents);
}
@Override
public ExecutableAction create(HuggingFaceRerankModel model) {
var overriddenModel = HuggingFaceRerankModel.of(model, model.getTaskSettings());
var manager = new GenericRequestManager<>(
serviceComponents.threadPool(),
overriddenModel,
RERANK_HANDLER,
inputs -> new HuggingFaceRerankRequest(
inputs.getQuery(),
inputs.getChunks(),
inputs.getReturnDocuments(),
inputs.getTopN(),
model
),
QueryAndDocsInputs.class
);
var errorMessage = buildErrorMessage(TaskType.RERANK, model.getInferenceEntityId());
return new SenderExecutableAction(sender, manager, errorMessage);
}
@Override
public ExecutableAction create(HuggingFaceEmbeddingsModel model) {
var responseHandler = new HuggingFaceResponseHandler(
@ -95,6 +134,6 @@ public class HuggingFaceActionCreator implements HuggingFaceActionVisitor {
}
public static String buildErrorMessage(TaskType requestType, String inferenceId) {
return format("Failed to send Hugging Face %s request from inference entity id [%s]", requestType.toString(), inferenceId);
return format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, requestType.toString(), inferenceId);
}
}

View file

@ -11,8 +11,11 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModel;
public interface HuggingFaceActionVisitor {
ExecutableAction create(HuggingFaceRerankModel model);
ExecutableAction create(HuggingFaceEmbeddingsModel model);
ExecutableAction create(HuggingFaceElserModel model);

View file

@ -13,11 +13,9 @@ import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
@ -35,10 +33,10 @@ import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceBaseService;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModelParameters;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@ -69,18 +67,17 @@ public class HuggingFaceElserService extends HuggingFaceBaseService {
}
@Override
protected HuggingFaceModel createModel(
String inferenceEntityId,
TaskType taskType,
Map<String, Object> serviceSettings,
ChunkingSettings chunkingSettings,
@Nullable Map<String, Object> secretSettings,
String failureMessage,
ConfigurationParseContext context
) {
return switch (taskType) {
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context);
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
protected HuggingFaceModel createModel(HuggingFaceModelParameters input) {
return switch (input.taskType()) {
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(
input.inferenceEntityId(),
input.taskType(),
NAME,
input.serviceSettings(),
input.secretSettings(),
input.context()
);
default -> throw new ElasticsearchStatusException(input.failureMessage(), RestStatus.BAD_REQUEST);
};
}

View file

@ -0,0 +1,99 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.huggingface.request.rerank;
import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.common.Strings;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceAccount;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModel;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
public class HuggingFaceRerankRequest implements Request {
private final HuggingFaceAccount account;
private final String query;
private final List<String> input;
private final Boolean returnDocuments;
private final Integer topN;
private final HuggingFaceRerankModel model;
public HuggingFaceRerankRequest(
String query,
List<String> input,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
HuggingFaceRerankModel model
) {
Objects.requireNonNull(model);
this.account = HuggingFaceAccount.of(model);
this.input = Objects.requireNonNull(input);
this.query = Objects.requireNonNull(query);
this.returnDocuments = returnDocuments;
this.topN = topN;
this.model = model;
}
@Override
public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(account.uri());
ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(new HuggingFaceRerankRequestEntity(query, input, returnDocuments, getTopN(), model.getTaskSettings()))
.getBytes(StandardCharsets.UTF_8)
);
httpPost.setEntity(byteEntity);
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters());
decorateWithAuth(httpPost);
return new HttpRequest(httpPost, getInferenceEntityId());
}
void decorateWithAuth(HttpPost httpPost) {
httpPost.setHeader(createAuthBearerHeader(model.apiKey()));
}
@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
}
@Override
public URI getURI() {
return account.uri();
}
public Integer getTopN() {
return topN != null ? topN : model.getTaskSettings().getTopNDocumentsOnly();
}
@Override
public Request truncate() {
// Not applicable for rerank, only used in text embedding requests
return this;
}
@Override
public boolean[] getTruncationInfo() {
// Not applicable for rerank, only used in text embedding requests
return null;
}
}

View file

@ -0,0 +1,60 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.huggingface.request.rerank;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
public record HuggingFaceRerankRequestEntity(
String query,
List<String> documents,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
HuggingFaceRerankTaskSettings taskSettings
) implements ToXContentObject {
private static final String RETURN_TEXT = "return_text";
private static final String DOCUMENTS_FIELD = "texts";
private static final String QUERY_FIELD = "query";
public HuggingFaceRerankRequestEntity {
Objects.requireNonNull(query);
Objects.requireNonNull(documents);
Objects.requireNonNull(taskSettings);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(DOCUMENTS_FIELD, documents);
builder.field(QUERY_FIELD, query);
// prefer the root level return_documents over task settings
if (returnDocuments != null) {
builder.field(RETURN_TEXT, returnDocuments);
} else if (taskSettings.getReturnDocuments() != null) {
builder.field(RETURN_TEXT, taskSettings.getReturnDocuments());
}
if (topN != null) {
builder.field(HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, topN);
} else if (taskSettings.getTopNDocumentsOnly() != null) {
builder.field(HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, taskSettings.getTopNDocumentsOnly());
}
builder.endObject();
return builder;
}
}

View file

@ -0,0 +1,91 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.huggingface.rerank;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionVisitor;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import java.util.Map;
public class HuggingFaceRerankModel extends HuggingFaceModel {
public static HuggingFaceRerankModel of(HuggingFaceRerankModel model, HuggingFaceRerankTaskSettings taskSettings) {
return new HuggingFaceRerankModel(model, HuggingFaceRerankTaskSettings.of(model.getTaskSettings(), taskSettings));
}
public HuggingFaceRerankModel(
String inferenceEntityId,
TaskType taskType,
String service,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
@Nullable Map<String, Object> secrets,
ConfigurationParseContext context
) {
this(
inferenceEntityId,
taskType,
service,
HuggingFaceRerankServiceSettings.fromMap(serviceSettings, context),
HuggingFaceRerankTaskSettings.fromMap(taskSettings),
DefaultSecretSettings.fromMap(secrets)
);
}
// Should only be used directly for testing
HuggingFaceRerankModel(
String inferenceEntityId,
TaskType taskType,
String service,
HuggingFaceRerankServiceSettings serviceSettings,
HuggingFaceRerankTaskSettings taskSettings,
@Nullable DefaultSecretSettings secrets
) {
super(
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings),
new ModelSecrets(secrets),
serviceSettings,
secrets
);
}
private HuggingFaceRerankModel(HuggingFaceRerankModel model, HuggingFaceRerankTaskSettings taskSettings) {
super(model, taskSettings);
}
@Override
public HuggingFaceRerankServiceSettings getServiceSettings() {
return (HuggingFaceRerankServiceSettings) super.getServiceSettings();
}
@Override
public HuggingFaceRerankTaskSettings getTaskSettings() {
return (HuggingFaceRerankTaskSettings) super.getTaskSettings();
}
@Override
public DefaultSecretSettings getSecretSettings() {
return (DefaultSecretSettings) super.getSecretSettings();
}
@Override
public Integer getTokenLimit() {
throw new UnsupportedOperationException("Token Limit for rerank is sent in request and not retrieved from the model");
}
@Override
public ExecutableAction accept(HuggingFaceActionVisitor creator) {
return creator.create(this);
}
}

View file

@ -0,0 +1,139 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.huggingface.rerank;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRateLimitServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import java.io.IOException;
import java.net.URI;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri;
import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri;
public class HuggingFaceRerankServiceSettings extends FilteredXContentObject
implements
ServiceSettings,
HuggingFaceRateLimitServiceSettings {
public static final String NAME = "hugging_face_rerank_service_settings";
public static final String URL = "url";
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000);
public static HuggingFaceRerankServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException();
var uri = extractUri(map, URL, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
HuggingFaceService.NAME,
context
);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
return new HuggingFaceRerankServiceSettings(uri, rateLimitSettings);
}
private final URI uri;
private final RateLimitSettings rateLimitSettings;
public HuggingFaceRerankServiceSettings(String url) {
uri = createUri(url);
rateLimitSettings = DEFAULT_RATE_LIMIT_SETTINGS;
}
HuggingFaceRerankServiceSettings(URI uri, @Nullable RateLimitSettings rateLimitSettings) {
this.uri = Objects.requireNonNull(uri);
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
}
public HuggingFaceRerankServiceSettings(StreamInput in) throws IOException {
uri = createUri(in.readString());
rateLimitSettings = new RateLimitSettings(in);
}
@Override
public RateLimitSettings rateLimitSettings() {
return rateLimitSettings;
}
@Override
public URI uri() {
return uri;
}
// model is not defined in the service settings.
// since hugging face requires that the model be chosen when initializing a deployment within their service.
@Override
public String modelId() {
return null;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
toXContentFragmentOfExposedFields(builder, params);
builder.endObject();
return builder;
}
@Override
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
builder.field(URL, uri.toString());
rateLimitSettings.toXContent(builder, params);
return builder;
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_HUGGING_FACE_RERANK_ADDED;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(uri.toString());
rateLimitSettings.writeTo(out);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
HuggingFaceRerankServiceSettings that = (HuggingFaceRerankServiceSettings) o;
return Objects.equals(uri, that.uri) && Objects.equals(rateLimitSettings, that.rateLimitSettings);
}
@Override
public int hashCode() {
return Objects.hash(uri, rateLimitSettings);
}
}

View file

@ -0,0 +1,156 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.huggingface.rerank;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
public class HuggingFaceRerankTaskSettings implements TaskSettings {
public static final String NAME = "hugging_face_rerank_task_settings";
public static final String RETURN_DOCUMENTS = "return_documents";
public static final String TOP_N_DOCS_ONLY = "top_n";
static final HuggingFaceRerankTaskSettings EMPTY_SETTINGS = new HuggingFaceRerankTaskSettings(null, null);
public static HuggingFaceRerankTaskSettings fromMap(Map<String, Object> map) {
ValidationException validationException = new ValidationException();
if (map == null || map.isEmpty()) {
return EMPTY_SETTINGS;
}
Boolean returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS, validationException);
Integer topNDocumentsOnly = extractOptionalPositiveInteger(
map,
TOP_N_DOCS_ONLY,
ModelConfigurations.TASK_SETTINGS,
validationException
);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
return of(topNDocumentsOnly, returnDocuments);
}
/**
* Creates a new {@link HuggingFaceRerankTaskSettings}
* by preferring non-null fields from the request settings over the original settings.
*
* @param originalSettings the settings stored as part of the inference entity configuration
* @param requestTaskSettings the settings passed in within the task_settings field of the request
* @return a constructed {@link HuggingFaceRerankTaskSettings}
*/
public static HuggingFaceRerankTaskSettings of(
HuggingFaceRerankTaskSettings originalSettings,
HuggingFaceRerankTaskSettings requestTaskSettings
) {
return new HuggingFaceRerankTaskSettings(
requestTaskSettings.getTopNDocumentsOnly() != null
? requestTaskSettings.getTopNDocumentsOnly()
: originalSettings.getTopNDocumentsOnly(),
requestTaskSettings.getReturnDocuments() != null
? requestTaskSettings.getReturnDocuments()
: originalSettings.getReturnDocuments()
);
}
public static HuggingFaceRerankTaskSettings of(Integer topNDocumentsOnly, Boolean returnDocuments) {
return new HuggingFaceRerankTaskSettings(topNDocumentsOnly, returnDocuments);
}
private final Integer topNDocumentsOnly;
private final Boolean returnDocuments;
public HuggingFaceRerankTaskSettings(StreamInput in) throws IOException {
this(in.readOptionalVInt(), in.readOptionalBoolean());
}
public HuggingFaceRerankTaskSettings(@Nullable Integer topNDocumentsOnly, @Nullable Boolean doReturnDocuments) {
this.topNDocumentsOnly = topNDocumentsOnly;
this.returnDocuments = doReturnDocuments;
}
@Override
public boolean isEmpty() {
return topNDocumentsOnly == null && returnDocuments == null;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (topNDocumentsOnly != null) {
builder.field(TOP_N_DOCS_ONLY, topNDocumentsOnly);
}
if (returnDocuments != null) {
builder.field(RETURN_DOCUMENTS, returnDocuments);
}
builder.endObject();
return builder;
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_HUGGING_FACE_RERANK_ADDED;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalVInt(topNDocumentsOnly);
out.writeOptionalBoolean(returnDocuments);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
HuggingFaceRerankTaskSettings that = (HuggingFaceRerankTaskSettings) o;
return Objects.equals(returnDocuments, that.returnDocuments) && Objects.equals(topNDocumentsOnly, that.topNDocumentsOnly);
}
@Override
public int hashCode() {
return Objects.hash(returnDocuments, topNDocumentsOnly);
}
public Integer getTopNDocumentsOnly() {
return topNDocumentsOnly;
}
public Boolean getReturnDocuments() {
return returnDocuments;
}
@Override
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
HuggingFaceRerankTaskSettings updatedSettings = HuggingFaceRerankTaskSettings.fromMap(new HashMap<>(newSettings));
return HuggingFaceRerankTaskSettings.of(this, updatedSettings);
}
}

View file

@ -0,0 +1,110 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.huggingface.response;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.services.huggingface.request.rerank.HuggingFaceRerankRequest;
import java.io.IOException;
import java.util.Comparator;
import java.util.List;
import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
public class HuggingFaceRerankResponseEntity {
/**
* Parses the Hugging Face rerank response.
* For a request like:
*
* <pre>
* <code>
* {
* "texts": ["luke", "leia"],
* "query": "star wars main character",
* "return_text": true
* }
* </code>
* </pre>
* The response would look like:
* <pre>
* <code>
* [
* {
* "index": 0,
* "score": -0.07996220886707306,
* "text": "luke"
* },
* {
* "index": 1,
* "score": -0.08393221348524094,
* "text": "leia"
* }
* ]
* </code>
* </pre>
*/
public static RankedDocsResults fromResponse(HuggingFaceRerankRequest request, HttpResult response) throws IOException {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
moveToFirstToken(jsonParser);
var rankedDocs = doParse(jsonParser);
var rankedDocsByRelevanceStream = rankedDocs.stream()
.sorted(Comparator.comparingDouble(RankedDocsResults.RankedDoc::relevanceScore).reversed());
var rankedDocStreamTopN = request.getTopN() == null
? rankedDocsByRelevanceStream
: rankedDocsByRelevanceStream.limit(request.getTopN());
return new RankedDocsResults(rankedDocStreamTopN.toList());
}
}
private static List<RankedDocsResults.RankedDoc> doParse(XContentParser parser) throws IOException {
return parseList(parser, (listParser, index) -> {
var parsedRankedDoc = HuggingFaceRerankResponseEntity.RankedDocEntry.parse(parser);
return new RankedDocsResults.RankedDoc(parsedRankedDoc.index, parsedRankedDoc.score, parsedRankedDoc.text);
});
}
private record RankedDocEntry(Integer index, Float score, @Nullable String text) {
private static final ParseField TEXT = new ParseField("text");
private static final ParseField SCORE = new ParseField("score");
private static final ParseField INDEX = new ParseField("index");
private static final ConstructingObjectParser<HuggingFaceRerankResponseEntity.RankedDocEntry, Void> PARSER =
new ConstructingObjectParser<>(
"hugging_face_rerank_response",
true,
args -> new HuggingFaceRerankResponseEntity.RankedDocEntry((int) args[0], (float) args[1], (String) args[2])
);
static {
PARSER.declareInt(ConstructingObjectParser.constructorArg(), INDEX);
PARSER.declareFloat(ConstructingObjectParser.constructorArg(), SCORE);
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), TEXT);
}
public static RankedDocEntry parse(XContentParser parser) {
return PARSER.apply(parser, null);
}
}
}

View file

@ -425,7 +425,7 @@ public abstract class BaseTransportInferenceActionTestCase<Request extends BaseI
doAnswer(ans -> {
listenerAction.accept(ans.getArgument(9));
return null;
}).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any());
}).when(service).infer(any(), any(), anyBoolean(), any(), any(), anyBoolean(), any(), any(), any(), any());
doAnswer(ans -> {
listenerAction.accept(ans.getArgument(3));
return null;

View file

@ -1292,7 +1292,7 @@ public class HuggingFaceServiceTests extends ESTestCase {
{
"service": "hugging_face",
"name": "Hugging Face",
"task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"],
"task_types": ["text_embedding", "sparse_embedding", "rerank", "completion", "chat_completion"],
"configurations": {
"api_key": {
"description": "API Key for the provider you're connecting to.",
@ -1301,7 +1301,7 @@ public class HuggingFaceServiceTests extends ESTestCase {
"sensitive": true,
"updatable": true,
"type": "str",
"supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"]
"supported_task_types": ["text_embedding", "sparse_embedding", "rerank", "completion", "chat_completion"]
},
"rate_limit.requests_per_minute": {
"description": "Minimize the number of rate limit errors.",
@ -1310,7 +1310,7 @@ public class HuggingFaceServiceTests extends ESTestCase {
"sensitive": false,
"updatable": false,
"type": "int",
"supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"]
"supported_task_types": ["text_embedding", "sparse_embedding", "rerank", "completion", "chat_completion"]
},
"url": {
"description": "The URL endpoint to use for the requests.",
@ -1319,7 +1319,7 @@ public class HuggingFaceServiceTests extends ESTestCase {
"sensitive": false,
"updatable": false,
"type": "str",
"supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"]
"supported_task_types": ["text_embedding", "sparse_embedding", "rerank", "completion", "chat_completion"]
}
}
}

View file

@ -27,12 +27,14 @@ import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests;
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModelTests;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModelTests;
import org.junit.After;
import org.junit.Before;
@ -311,6 +313,66 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
}
}
public void testSend_FailsFromInvalidResponseFormat_ForRerankAction() throws IOException {
var settings = buildSettingsWithRetryFields(
TimeValue.timeValueMillis(1),
TimeValue.timeValueMinutes(1),
TimeValue.timeValueSeconds(0)
);
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
{
"rerank": [
{
"index": 0,
"relevance_score": -0.07996031,
"text": "luke"
}
]
}
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
var model = HuggingFaceRerankModelTests.createModel(getUrl(webServer), "secret", "model", 8, true);
var actionCreator = new HuggingFaceActionCreator(
sender,
new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator())
);
var action = actionCreator.create(model);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new QueryAndDocsInputs("popular name", List.of("Luke")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
assertThat(
thrownException.getMessage(),
is("Failed to parse object: expecting token of type [START_ARRAY] but found [START_OBJECT]")
);
}
assertRerankActionCreator(List.of("Luke"), "popular name", 8, true);
}
private void assertRerankActionCreator(List<String> documents, String query, int topN, boolean returnText) throws IOException {
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
assertThat(
webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),
equalTo(XContentType.JSON.mediaTypeWithoutParameters())
);
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
assertThat(requestMap.size(), is(4));
assertThat(requestMap.get("texts"), is(documents));
assertThat(requestMap.get("query"), is(query));
assertThat(requestMap.get("top_n"), is(topN));
assertThat(requestMap.get("return_text"), is(returnText));
}
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);

View file

@ -0,0 +1,59 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.huggingface.request.rerank;
import org.elasticsearch.common.Strings;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
import java.io.IOException;
import java.util.List;
import static org.elasticsearch.common.xcontent.XContentHelper.stripWhitespace;
public class HuggingFaceRerankRequestEntityTests extends ESTestCase {
private static final String INPUT = "texts";
private static final String QUERY = "query";
private static final Integer TOP_N = 8;
private static final Boolean RETURN_DOCUMENTS = false;
public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException {
var entity = new HuggingFaceRerankRequestEntity(
QUERY,
List.of(INPUT),
Boolean.TRUE,
TOP_N,
new HuggingFaceRerankTaskSettings(TOP_N, RETURN_DOCUMENTS)
);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
String xContentResult = Strings.toString(builder);
String expected = """
{"texts":["texts"],
"query":"query",
"return_text":true,
"top_n":8}""";
assertEquals(stripWhitespace(expected), xContentResult);
}
public void testXContent_WritesMinimalFields() throws IOException {
var entity = new HuggingFaceRerankRequestEntity(QUERY, List.of(INPUT), null, null, new HuggingFaceRerankTaskSettings(null, null));
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
String xContentResult = Strings.toString(builder);
String expected = """
{"texts":["texts"],"query":"query"}""";
assertEquals(stripWhitespace(expected), xContentResult);
}
}

View file

@ -0,0 +1,99 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.huggingface.request.rerank;
import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModel;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModelTests;
import java.io.IOException;
import java.util.List;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.hamcrest.Matchers.aMapWithSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
public class HuggingFaceRerankRequestTests extends ESTestCase {
private static final String INPUT = "texts";
private static final String QUERY = "query";
private static final String INFERENCE_ID = "model";
private static final Integer TOP_N = 8;
private static final Boolean RETURN_TEXT = false;
private static final String AUTH_HEADER_VALUE = "foo";
public void testCreateRequest_WithMinimalFieldsSet() throws IOException {
testCreateRequest(null, null);
}
public void testCreateRequest_WithTopN() throws IOException {
testCreateRequest(TOP_N, null);
}
public void testCreateRequest_WithReturnDocuments() throws IOException {
testCreateRequest(null, RETURN_TEXT);
}
private void testCreateRequest(Integer topN, Boolean returnDocuments) throws IOException {
var request = createRequest(topN, returnDocuments);
var httpRequest = request.createHttpRequest();
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
var httpPost = (HttpPost) httpRequest.httpRequestBase();
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaTypeWithoutParameters()));
assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE));
var requestMap = entityAsMap(httpPost.getEntity().getContent());
assertThat(requestMap.get(INPUT), is(List.of(INPUT)));
assertThat(requestMap.get(QUERY), is(QUERY));
// input and query must exist
int itemsCount = 2;
if (topN != null) {
assertThat(requestMap.get("top_n"), is(topN));
itemsCount++;
}
if (returnDocuments != null) {
assertThat(requestMap.get("return_text"), is(returnDocuments));
itemsCount++;
}
assertThat(requestMap, aMapWithSize(itemsCount));
}
private static HuggingFaceRerankRequest createRequest(@Nullable Integer topN, @Nullable Boolean returnDocuments) {
var rerankModel = HuggingFaceRerankModelTests.createModel(randomAlphaOfLength(10), "secret", INFERENCE_ID, topN, returnDocuments);
return new HuggingFaceRerankWithoutAuthRequest(QUERY, List.of(INPUT), rerankModel, topN, returnDocuments);
}
/**
* We use this class to fake the auth implementation to avoid static mocking of {@link HuggingFaceRerankRequest}
*/
private static class HuggingFaceRerankWithoutAuthRequest extends HuggingFaceRerankRequest {
HuggingFaceRerankWithoutAuthRequest(
String query,
List<String> input,
HuggingFaceRerankModel model,
@Nullable Integer topN,
@Nullable Boolean returnDocuments
) {
super(query, input, returnDocuments, topN, model);
}
@Override
public void decorateWithAuth(HttpPost httpPost) {
httpPost.setHeader(HttpHeaders.AUTHORIZATION, AUTH_HEADER_VALUE);
}
}
}

View file

@ -0,0 +1,41 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.huggingface.rerank;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import static org.hamcrest.Matchers.containsString;
public class HuggingFaceRerankModelTests extends ESTestCase {
public void testThrowsURISyntaxException_ForInvalidUrl() {
var thrownException = expectThrows(IllegalArgumentException.class, () -> createModel("^^", "secret", "model", 8, false));
assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]"));
}
public static HuggingFaceRerankModel createModel(
String url,
String apiKey,
String modelId,
@Nullable Integer topN,
@Nullable Boolean returnDocuments
) {
return new HuggingFaceRerankModel(
modelId,
TaskType.RERANK,
"service",
new HuggingFaceRerankServiceSettings(url),
new HuggingFaceRerankTaskSettings(topN, returnDocuments),
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
);
}
}

View file

@ -0,0 +1,124 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.huggingface.rerank;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import static org.hamcrest.Matchers.containsString;
public class HuggingFaceRerankTaskSettingsTests extends AbstractBWCWireSerializationTestCase<HuggingFaceRerankTaskSettings> {
public static HuggingFaceRerankTaskSettings createRandom() {
var returnDocuments = randomBoolean() ? randomBoolean() : null;
var topNDocsOnly = randomBoolean() ? randomIntBetween(1, 10) : null;
return new HuggingFaceRerankTaskSettings(topNDocsOnly, returnDocuments);
}
public void testFromMap_WithValidValues_ReturnsSettings() {
Map<String, Object> taskMap = Map.of(
HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS,
true,
HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY,
5
);
var settings = HuggingFaceRerankTaskSettings.fromMap(new HashMap<>(taskMap));
assertTrue(settings.getReturnDocuments());
assertEquals(5, settings.getTopNDocumentsOnly().intValue());
}
public void testFromMap_WithNullValues_ReturnsSettingsWithNulls() {
var settings = HuggingFaceRerankTaskSettings.fromMap(Map.of());
assertNull(settings.getReturnDocuments());
assertNull(settings.getTopNDocumentsOnly());
}
public void testFromMap_WithInvalidReturnDocuments_ThrowsValidationException() {
Map<String, Object> taskMap = Map.of(
HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS,
"invalid",
HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY,
5
);
var thrownException = expectThrows(ValidationException.class, () -> HuggingFaceRerankTaskSettings.fromMap(new HashMap<>(taskMap)));
assertThat(thrownException.getMessage(), containsString("field [return_documents] is not of the expected type"));
}
public void testFromMap_WithInvalidTopNDocsOnly_ThrowsValidationException() {
Map<String, Object> taskMap = Map.of(
HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS,
true,
HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY,
"invalid"
);
var thrownException = expectThrows(ValidationException.class, () -> HuggingFaceRerankTaskSettings.fromMap(new HashMap<>(taskMap)));
assertThat(thrownException.getMessage(), containsString("field [top_n] is not of the expected type"));
}
public void UpdatedTaskSettings_WithEmptyMap_ReturnsSameSettings() {
var initialSettings = new HuggingFaceRerankTaskSettings(5, true);
HuggingFaceRerankTaskSettings updatedSettings = (HuggingFaceRerankTaskSettings) initialSettings.updatedTaskSettings(Map.of());
assertEquals(initialSettings, updatedSettings);
}
public void testUpdatedTaskSettings_WithNewReturnDocuments_ReturnsUpdatedSettings() {
var initialSettings = new HuggingFaceRerankTaskSettings(5, true);
Map<String, Object> newSettings = Map.of(HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS, false);
HuggingFaceRerankTaskSettings updatedSettings = (HuggingFaceRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings);
assertFalse(updatedSettings.getReturnDocuments());
assertEquals(initialSettings.getTopNDocumentsOnly(), updatedSettings.getTopNDocumentsOnly());
}
public void testUpdatedTaskSettings_WithNewTopNDocsOnly_ReturnsUpdatedSettings() {
var initialSettings = new HuggingFaceRerankTaskSettings(5, true);
Map<String, Object> newSettings = Map.of(HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, 7);
HuggingFaceRerankTaskSettings updatedSettings = (HuggingFaceRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings);
assertEquals(7, updatedSettings.getTopNDocumentsOnly().intValue());
assertEquals(initialSettings.getReturnDocuments(), updatedSettings.getReturnDocuments());
}
public void testUpdatedTaskSettings_WithMultipleNewValues_ReturnsUpdatedSettings() {
var initialSettings = new HuggingFaceRerankTaskSettings(5, true);
Map<String, Object> newSettings = Map.of(
HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS,
false,
HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY,
7
);
HuggingFaceRerankTaskSettings updatedSettings = (HuggingFaceRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings);
assertFalse(updatedSettings.getReturnDocuments());
assertEquals(7, updatedSettings.getTopNDocumentsOnly().intValue());
}
@Override
protected Writeable.Reader<HuggingFaceRerankTaskSettings> instanceReader() {
return HuggingFaceRerankTaskSettings::new;
}
@Override
protected HuggingFaceRerankTaskSettings createTestInstance() {
return createRandom();
}
@Override
protected HuggingFaceRerankTaskSettings mutateInstance(HuggingFaceRerankTaskSettings instance) throws IOException {
return randomValueOtherThan(instance, HuggingFaceRerankTaskSettingsTests::createRandom);
}
@Override
protected HuggingFaceRerankTaskSettings mutateInstanceForVersion(HuggingFaceRerankTaskSettings instance, TransportVersion version) {
return instance;
}
}

View file

@ -0,0 +1,147 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.huggingface.response;
import org.apache.http.HttpResponse;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.services.huggingface.request.rerank.HuggingFaceRerankRequest;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import static org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests.buildExpectationRerank;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class HuggingFaceRerankResponseEntityTests extends ESTestCase {
private static final String MISSED_FIELD_INDEX = "index";
private static final String MISSED_FIELD_SCORE = "score";
private static final String RESPONSE_JSON_TWO_DOCS = """
[
{
"index": 4,
"score": -0.22222222222222222,
"text": "ranked second"
},
{
"index": 1,
"score": 1.11111111111111111,
"text": "ranked first"
}
]
""";
private static final List<RankedDocsResultsTests.RerankExpectation> EXPECTED_TWO_DOCS = List.of(
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 1, "relevance_score", 1.11111111111111111F, "text", "ranked first")),
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 4, "relevance_score", -0.22222222222222222F, "text", "ranked second"))
);
private static final String RESPONSE_JSON_FIVE_DOCS = """
[
{
"index": 1,
"score": 1.11111111111111111,
"text": "ranked first"
},
{
"index": 3,
"score": -0.33333333333333333,
"text": "ranked third"
},
{
"index": 0,
"score": -0.55555555555555555,
"text": "ranked fifth"
},
{
"index": 2,
"score": -0.44444444444444444,
"text": "ranked fourth"
},
{
"index": 4,
"score": -0.22222222222222222,
"text": "ranked second"
}
]
""";
private static final List<RankedDocsResultsTests.RerankExpectation> EXPECTED_FIVE_DOCS = List.of(
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 1, "relevance_score", 1.11111111111111111F, "text", "ranked first")),
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 4, "relevance_score", -0.22222222222222222F, "text", "ranked second")),
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 3, "relevance_score", -0.33333333333333333F, "text", "ranked third")),
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 2, "relevance_score", -0.44444444444444444F, "text", "ranked fourth")),
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 0, "relevance_score", -0.55555555555555555F, "text", "ranked fifth"))
);
private static final HuggingFaceRerankRequest REQUEST_MOCK = mock(HuggingFaceRerankRequest.class);
public void testFromResponse_CreatesRankedDocsResults_TopNNull_FiveDocs_NoLimit() throws IOException {
assertTopNLimit(null, RESPONSE_JSON_FIVE_DOCS, EXPECTED_FIVE_DOCS);
}
public void testFromResponse_CreatesRankedDocsResults_TopN5_TwoDocs_NoLimit() throws IOException {
assertTopNLimit(5, RESPONSE_JSON_TWO_DOCS, EXPECTED_TWO_DOCS);
}
public void testFromResponse_CreatesRankedDocsResults_TopN2_FiveDocs_Limits() throws IOException {
assertTopNLimit(2, RESPONSE_JSON_FIVE_DOCS, EXPECTED_TWO_DOCS);
}
public void testFails_CreateRankedDocsResults_IndexFieldNull() {
String responseJson = """
[
{
"score": 1.11111111111111111,
"text": "ranked first"
}
]
""";
assertMissingFieldThrowsIllegalArgumentException(responseJson, MISSED_FIELD_INDEX);
}
public void testFails_CreateRankedDocsResults_ScoreFieldNull() {
String responseJson = """
[
{
"index": 1,
"text": "ranked first"
}
]
""";
assertMissingFieldThrowsIllegalArgumentException(responseJson, MISSED_FIELD_SCORE);
}
private void assertMissingFieldThrowsIllegalArgumentException(String responseJson, String missingField) {
when(REQUEST_MOCK.getTopN()).thenReturn(1);
var thrownException = expectThrows(
IllegalArgumentException.class,
() -> HuggingFaceRerankResponseEntity.fromResponse(
REQUEST_MOCK,
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
)
);
assertThat(thrownException.getMessage(), is("Required [" + missingField + "]"));
}
private void assertTopNLimit(
Integer topN, String responseJson, List<RankedDocsResultsTests.RerankExpectation> expectation) throws IOException {
when(REQUEST_MOCK.getTopN()).thenReturn(topN);
RankedDocsResults parsedResults = HuggingFaceRerankResponseEntity.fromResponse(
REQUEST_MOCK,
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
);
assertThat(parsedResults.asMap(), is(buildExpectationRerank(expectation)));
}
}