mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 09:28:55 -04:00
Add Hugging Face Rerank support (#127966)
* Add Hugging Face Rerank support * Address comments * Add transport version * Add transport version * Add to inference service and crud IT rerank tests * Refactor slightly / error message * correct 'testGetConfiguration' test case * apply suggestions * fix tests * apply suggestions * [CI] Auto commit changes from spotless * add changelog information --------- Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
This commit is contained in:
parent
f6e4a26480
commit
c7cf8507a2
31 changed files with 1485 additions and 85 deletions
5
docs/changelog/127966.yaml
Normal file
5
docs/changelog/127966.yaml
Normal file
|
@ -0,0 +1,5 @@
|
|||
pr: 127966
|
||||
summary: "[ML] Add Rerank support to the Inference Plugin"
|
||||
area: Machine Learning
|
||||
type: enhancement
|
||||
issues: []
|
|
@ -179,6 +179,7 @@ public class TransportVersions {
|
|||
public static final TransportVersion V_8_19_FIELD_CAPS_ADD_CLUSTER_ALIAS = def(8_841_0_32);
|
||||
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME_8_19 = def(8_841_0_34);
|
||||
public static final TransportVersion RERANKER_FAILURES_ALLOWED_8_19 = def(8_841_0_35);
|
||||
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED_8_19 = def(8_841_0_36);
|
||||
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
|
||||
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
|
||||
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
|
||||
|
@ -261,7 +262,7 @@ public class TransportVersions {
|
|||
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME = def(9_077_0_00);
|
||||
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED = def(9_078_0_00);
|
||||
public static final TransportVersion NODES_STATS_SUPPORTS_MULTI_PROJECT = def(9_079_0_00);
|
||||
|
||||
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED = def(9_080_0_00);
|
||||
/*
|
||||
* STOP! READ THIS FIRST! No, really,
|
||||
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
|
||||
|
|
|
@ -20,9 +20,8 @@ import static org.elasticsearch.test.ESTestCase.randomInt;
|
|||
public class SettingsConfigurationTestUtils {
|
||||
|
||||
public static SettingsConfiguration getRandomSettingsConfigurationField() {
|
||||
return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)).setDefaultValue(
|
||||
randomAlphaOfLength(10)
|
||||
)
|
||||
return new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK))
|
||||
.setDefaultValue(randomAlphaOfLength(10))
|
||||
.setDescription(randomAlphaOfLength(10))
|
||||
.setLabel(randomAlphaOfLength(10))
|
||||
.setRequired(randomBoolean())
|
||||
|
|
|
@ -82,4 +82,13 @@ public class RankedDocsResultsTests extends AbstractChunkedBWCSerializationTestC
|
|||
protected RankedDocsResults doParseInstance(XContentParser parser) throws IOException {
|
||||
return RankedDocsResults.createParser(true).apply(parser, null);
|
||||
}
|
||||
|
||||
public record RerankExpectation(Map<String, Object> rankedDocFields) {}
|
||||
|
||||
public static Map<String, Object> buildExpectationRerank(List<RerankExpectation> rerank) {
|
||||
return Map.of(
|
||||
RankedDocsResults.RERANK,
|
||||
rerank.stream().map(rerankExpectation -> Map.of(RankedDocsResults.RankedDoc.NAME, rerankExpectation.rankedDocFields)).toList()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -171,6 +171,20 @@ public class InferenceBaseRestTest extends ESRestTestCase {
|
|||
""";
|
||||
}
|
||||
|
||||
static String mockRerankServiceModelConfig() {
|
||||
return """
|
||||
{
|
||||
"service": "test_reranking_service",
|
||||
"service_settings": {
|
||||
"model_id": "my_model",
|
||||
"api_key": "abc64"
|
||||
},
|
||||
"task_settings": {
|
||||
}
|
||||
}
|
||||
""";
|
||||
}
|
||||
|
||||
static void deleteModel(String modelId) throws IOException {
|
||||
var request = new Request("DELETE", "_inference/" + modelId);
|
||||
var response = client().performRequest(request);
|
||||
|
@ -484,6 +498,10 @@ public class InferenceBaseRestTest extends ESRestTestCase {
|
|||
@SuppressWarnings("unchecked")
|
||||
protected void assertNonEmptyInferenceResults(Map<String, Object> resultMap, int expectedNumberOfResults, TaskType taskType) {
|
||||
switch (taskType) {
|
||||
case RERANK -> {
|
||||
var results = (List<Map<String, Object>>) resultMap.get(TaskType.RERANK.toString());
|
||||
assertThat(results, hasSize(expectedNumberOfResults));
|
||||
}
|
||||
case SPARSE_EMBEDDING -> {
|
||||
var results = (List<Map<String, Object>>) resultMap.get(TaskType.SPARSE_EMBEDDING.toString());
|
||||
assertThat(results, hasSize(expectedNumberOfResults));
|
||||
|
|
|
@ -53,9 +53,12 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
|
|||
for (int i = 0; i < 4; i++) {
|
||||
putModel("te_model_" + i, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING);
|
||||
}
|
||||
for (int i = 0; i < 3; i++) {
|
||||
putModel("re-model-" + i, mockRerankServiceModelConfig(), TaskType.RERANK);
|
||||
}
|
||||
|
||||
var getAllModels = getAllModels();
|
||||
int numModels = 12;
|
||||
int numModels = 15;
|
||||
assertThat(getAllModels, hasSize(numModels));
|
||||
|
||||
var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING);
|
||||
|
@ -71,6 +74,13 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
|
|||
for (var denseModel : getDenseModels) {
|
||||
assertEquals("text_embedding", denseModel.get("task_type"));
|
||||
}
|
||||
|
||||
var getRerankModels = getModels("_all", TaskType.RERANK);
|
||||
int numRerankModels = 4;
|
||||
assertThat(getRerankModels, hasSize(numRerankModels));
|
||||
for (var denseModel : getRerankModels) {
|
||||
assertEquals("rerank", denseModel.get("task_type"));
|
||||
}
|
||||
String oldApiKey;
|
||||
{
|
||||
var singleModel = getModels("se_model_1", TaskType.SPARSE_EMBEDDING);
|
||||
|
@ -100,6 +110,9 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
|
|||
for (int i = 0; i < 4; i++) {
|
||||
deleteModel("te_model_" + i, TaskType.TEXT_EMBEDDING);
|
||||
}
|
||||
for (int i = 0; i < 3; i++) {
|
||||
deleteModel("re-model-" + i, TaskType.RERANK);
|
||||
}
|
||||
}
|
||||
|
||||
public void testGetModelWithWrongTaskType() throws IOException {
|
||||
|
|
|
@ -101,7 +101,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
|||
|
||||
public void testGetServicesWithRerankTaskType() throws IOException {
|
||||
List<Object> services = getServices(TaskType.RERANK);
|
||||
assertThat(services.size(), equalTo(7));
|
||||
assertThat(services.size(), equalTo(8));
|
||||
|
||||
var providers = providers(services);
|
||||
|
||||
|
@ -115,7 +115,8 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
|||
"googlevertexai",
|
||||
"jinaai",
|
||||
"test_reranking_service",
|
||||
"voyageai"
|
||||
"voyageai",
|
||||
"hugging_face"
|
||||
).toArray()
|
||||
)
|
||||
);
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference;
|
||||
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class MockRerankInferenceServiceIT extends InferenceBaseRestTest {
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testMockService() throws IOException {
|
||||
String inferenceEntityId = "test-mock";
|
||||
var putModel = putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK);
|
||||
var model = getModels(inferenceEntityId, TaskType.RERANK).get(0);
|
||||
|
||||
for (var modelMap : List.of(putModel, model)) {
|
||||
assertEquals(inferenceEntityId, modelMap.get("inference_id"));
|
||||
assertEquals(TaskType.RERANK, TaskType.fromString((String) modelMap.get("task_type")));
|
||||
assertEquals("test_reranking_service", modelMap.get("service"));
|
||||
}
|
||||
|
||||
List<String> input = List.of(randomAlphaOfLength(10));
|
||||
var inference = infer(inferenceEntityId, input);
|
||||
assertNonEmptyInferenceResults(inference, 1, TaskType.RERANK);
|
||||
assertEquals(inference, infer(inferenceEntityId, input));
|
||||
assertNotEquals(inference, infer(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10)))));
|
||||
}
|
||||
|
||||
public void testMockServiceWithMultipleInputs() throws IOException {
|
||||
String inferenceEntityId = "test-mock-with-multi-inputs";
|
||||
putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK);
|
||||
var queryParams = Map.of("timeout", "120s");
|
||||
|
||||
var inference = infer(
|
||||
inferenceEntityId,
|
||||
TaskType.RERANK,
|
||||
List.of(randomAlphaOfLength(5), randomAlphaOfLength(10)),
|
||||
"What if?",
|
||||
queryParams
|
||||
);
|
||||
|
||||
assertNonEmptyInferenceResults(inference, 2, TaskType.RERANK);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOException {
|
||||
String inferenceEntityId = "test-mock";
|
||||
var putModel = putModel(inferenceEntityId, mockRerankServiceModelConfig(), TaskType.RERANK);
|
||||
var model = getModels(inferenceEntityId, TaskType.RERANK).get(0);
|
||||
|
||||
var serviceSettings = (Map<String, Object>) model.get("service_settings");
|
||||
assertNull(serviceSettings.get("api_key"));
|
||||
assertNotNull(serviceSettings.get("model_id"));
|
||||
|
||||
var putServiceSettings = (Map<String, Object>) putModel.get("service_settings");
|
||||
assertNull(putServiceSettings.get("api_key"));
|
||||
assertNotNull(putServiceSettings.get("model_id"));
|
||||
}
|
||||
}
|
|
@ -80,6 +80,8 @@ import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVe
|
|||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
|
||||
|
@ -365,6 +367,16 @@ public class InferenceNamedWriteablesProvider {
|
|||
HuggingFaceChatCompletionServiceSettings::new
|
||||
)
|
||||
);
|
||||
namedWriteables.add(
|
||||
new NamedWriteableRegistry.Entry(TaskSettings.class, HuggingFaceRerankTaskSettings.NAME, HuggingFaceRerankTaskSettings::new)
|
||||
);
|
||||
namedWriteables.add(
|
||||
new NamedWriteableRegistry.Entry(
|
||||
ServiceSettings.class,
|
||||
HuggingFaceRerankServiceSettings.NAME,
|
||||
HuggingFaceRerankServiceSettings::new
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
private static void addGoogleAiStudioNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {
|
||||
|
|
|
@ -19,6 +19,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
|||
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
|
||||
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
|
||||
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
@ -94,7 +95,10 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContext extends RankFe
|
|||
} else if (r.getEndpoints().isEmpty() == false
|
||||
&& r.getEndpoints().get(0).getTaskSettings() instanceof GoogleVertexAiRerankTaskSettings googleVertexAiTaskSettings) {
|
||||
configuredTopN = googleVertexAiTaskSettings.topN();
|
||||
}
|
||||
} else if (r.getEndpoints().isEmpty() == false
|
||||
&& r.getEndpoints().get(0).getTaskSettings() instanceof HuggingFaceRerankTaskSettings huggingFaceRerankTaskSettings) {
|
||||
configuredTopN = huggingFaceRerankTaskSettings.getTopNDocumentsOnly();
|
||||
}
|
||||
if (configuredTopN != null && configuredTopN < rankWindowSize) {
|
||||
l.onFailure(
|
||||
new IllegalArgumentException(
|
||||
|
|
|
@ -57,6 +57,7 @@ public abstract class HuggingFaceBaseService extends SenderService {
|
|||
) {
|
||||
try {
|
||||
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
|
||||
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
|
||||
|
||||
ChunkingSettings chunkingSettings = null;
|
||||
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
|
||||
|
@ -66,17 +67,21 @@ public abstract class HuggingFaceBaseService extends SenderService {
|
|||
}
|
||||
|
||||
var model = createModel(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
serviceSettingsMap,
|
||||
chunkingSettings,
|
||||
serviceSettingsMap,
|
||||
TaskType.unsupportedTaskTypeErrorMsg(taskType, name()),
|
||||
ConfigurationParseContext.REQUEST
|
||||
new HuggingFaceModelParameters(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
serviceSettingsMap,
|
||||
taskSettingsMap,
|
||||
chunkingSettings,
|
||||
serviceSettingsMap,
|
||||
TaskType.unsupportedTaskTypeErrorMsg(taskType, name()),
|
||||
ConfigurationParseContext.REQUEST
|
||||
)
|
||||
);
|
||||
|
||||
throwIfNotEmptyMap(config, name());
|
||||
throwIfNotEmptyMap(serviceSettingsMap, name());
|
||||
throwIfNotEmptyMap(taskSettingsMap, name());
|
||||
|
||||
parsedModelListener.onResponse(model);
|
||||
} catch (Exception e) {
|
||||
|
@ -92,6 +97,7 @@ public abstract class HuggingFaceBaseService extends SenderService {
|
|||
Map<String, Object> secrets
|
||||
) {
|
||||
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
|
||||
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
|
||||
Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);
|
||||
|
||||
ChunkingSettings chunkingSettings = null;
|
||||
|
@ -100,19 +106,23 @@ public abstract class HuggingFaceBaseService extends SenderService {
|
|||
}
|
||||
|
||||
return createModel(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
serviceSettingsMap,
|
||||
chunkingSettings,
|
||||
secretSettingsMap,
|
||||
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
new HuggingFaceModelParameters(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
serviceSettingsMap,
|
||||
taskSettingsMap,
|
||||
chunkingSettings,
|
||||
secretSettingsMap,
|
||||
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
|
||||
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
|
||||
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
|
||||
|
||||
ChunkingSettings chunkingSettings = null;
|
||||
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
|
||||
|
@ -120,25 +130,20 @@ public abstract class HuggingFaceBaseService extends SenderService {
|
|||
}
|
||||
|
||||
return createModel(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
serviceSettingsMap,
|
||||
chunkingSettings,
|
||||
null,
|
||||
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
new HuggingFaceModelParameters(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
serviceSettingsMap,
|
||||
taskSettingsMap,
|
||||
chunkingSettings,
|
||||
null,
|
||||
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
|
||||
ConfigurationParseContext.PERSISTENT
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
protected abstract HuggingFaceModel createModel(
|
||||
String inferenceEntityId,
|
||||
TaskType taskType,
|
||||
Map<String, Object> serviceSettings,
|
||||
ChunkingSettings chunkingSettings,
|
||||
Map<String, Object> secretSettings,
|
||||
String failureMessage,
|
||||
ConfigurationParseContext context
|
||||
);
|
||||
protected abstract HuggingFaceModel createModel(HuggingFaceModelParameters input);
|
||||
|
||||
@Override
|
||||
public void doInfer(
|
||||
|
|
|
@ -11,6 +11,7 @@ import org.elasticsearch.common.settings.SecureString;
|
|||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ModelSecrets;
|
||||
import org.elasticsearch.inference.TaskSettings;
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
||||
|
@ -35,6 +36,13 @@ public abstract class HuggingFaceModel extends RateLimitGroupingModel {
|
|||
apiKey = ServiceUtils.apiKey(apiKeySecrets);
|
||||
}
|
||||
|
||||
protected HuggingFaceModel(HuggingFaceModel model, TaskSettings taskSettings) {
|
||||
super(model, taskSettings);
|
||||
|
||||
rateLimitServiceSettings = model.rateLimitServiceSettings();
|
||||
apiKey = model.apiKey();
|
||||
}
|
||||
|
||||
public HuggingFaceRateLimitServiceSettings rateLimitServiceSettings() {
|
||||
return rateLimitServiceSettings;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.huggingface;
|
||||
|
||||
import org.elasticsearch.inference.ChunkingSettings;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
public record HuggingFaceModelParameters(
|
||||
String inferenceEntityId,
|
||||
TaskType taskType,
|
||||
Map<String, Object> serviceSettings,
|
||||
Map<String, Object> taskSettings,
|
||||
ChunkingSettings chunkingSettings,
|
||||
Map<String, Object> secretSettings,
|
||||
String failureMessage,
|
||||
ConfigurationParseContext context
|
||||
) {}
|
|
@ -12,10 +12,8 @@ import org.elasticsearch.TransportVersion;
|
|||
import org.elasticsearch.TransportVersions;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.common.util.LazyInitializable;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.core.TimeValue;
|
||||
import org.elasticsearch.inference.ChunkedInference;
|
||||
import org.elasticsearch.inference.ChunkingSettings;
|
||||
import org.elasticsearch.inference.InferenceServiceConfiguration;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.inference.InputType;
|
||||
|
@ -32,7 +30,6 @@ import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
|
|||
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator;
|
||||
|
@ -40,6 +37,7 @@ import org.elasticsearch.xpack.inference.services.huggingface.completion.Hugging
|
|||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModel;
|
||||
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
@ -62,6 +60,7 @@ public class HuggingFaceService extends HuggingFaceBaseService {
|
|||
|
||||
private static final String SERVICE_NAME = "Hugging Face";
|
||||
private static final EnumSet<TaskType> SUPPORTED_TASK_TYPES = EnumSet.of(
|
||||
TaskType.RERANK,
|
||||
TaskType.TEXT_EMBEDDING,
|
||||
TaskType.SPARSE_EMBEDDING,
|
||||
TaskType.COMPLETION,
|
||||
|
@ -77,35 +76,43 @@ public class HuggingFaceService extends HuggingFaceBaseService {
|
|||
}
|
||||
|
||||
@Override
|
||||
protected HuggingFaceModel createModel(
|
||||
String inferenceEntityId,
|
||||
TaskType taskType,
|
||||
Map<String, Object> serviceSettings,
|
||||
ChunkingSettings chunkingSettings,
|
||||
@Nullable Map<String, Object> secretSettings,
|
||||
String failureMessage,
|
||||
ConfigurationParseContext context
|
||||
) {
|
||||
return switch (taskType) {
|
||||
protected HuggingFaceModel createModel(HuggingFaceModelParameters params) {
|
||||
return switch (params.taskType()) {
|
||||
case RERANK -> new HuggingFaceRerankModel(
|
||||
params.inferenceEntityId(),
|
||||
params.taskType(),
|
||||
NAME,
|
||||
params.serviceSettings(),
|
||||
params.taskSettings(),
|
||||
params.secretSettings(),
|
||||
params.context()
|
||||
);
|
||||
case TEXT_EMBEDDING -> new HuggingFaceEmbeddingsModel(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
params.inferenceEntityId(),
|
||||
params.taskType(),
|
||||
NAME,
|
||||
serviceSettings,
|
||||
chunkingSettings,
|
||||
secretSettings,
|
||||
context
|
||||
params.serviceSettings(),
|
||||
params.chunkingSettings(),
|
||||
params.secretSettings(),
|
||||
params.context()
|
||||
);
|
||||
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(
|
||||
params.inferenceEntityId(),
|
||||
params.taskType(),
|
||||
NAME,
|
||||
params.serviceSettings(),
|
||||
params.secretSettings(),
|
||||
params.context()
|
||||
);
|
||||
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context);
|
||||
case CHAT_COMPLETION, COMPLETION -> new HuggingFaceChatCompletionModel(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
params.inferenceEntityId(),
|
||||
params.taskType(),
|
||||
NAME,
|
||||
serviceSettings,
|
||||
secretSettings,
|
||||
context
|
||||
params.serviceSettings(),
|
||||
params.secretSettings(),
|
||||
params.context()
|
||||
);
|
||||
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
||||
default -> throw new ElasticsearchStatusException(params.failureMessage(), RestStatus.BAD_REQUEST);
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecut
|
|||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||
|
@ -23,8 +24,11 @@ import org.elasticsearch.xpack.inference.services.huggingface.completion.Hugging
|
|||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.request.rerank.HuggingFaceRerankRequest;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModel;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceElserResponseEntity;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceRerankResponseEntity;
|
||||
import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler;
|
||||
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
|
||||
|
||||
|
@ -37,12 +41,27 @@ import static org.elasticsearch.core.Strings.format;
|
|||
*/
|
||||
public class HuggingFaceActionCreator implements HuggingFaceActionVisitor {
|
||||
|
||||
private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE =
|
||||
"Failed to send Hugging Face %s request from inference entity id [%s]";
|
||||
private static final String INVALID_REQUEST_TYPE_MESSAGE = "Invalid request type: expected HuggingFace %s request but got %s";
|
||||
public static final String COMPLETION_ERROR_PREFIX = "Hugging Face completions";
|
||||
static final String USER_ROLE = "user";
|
||||
static final ResponseHandler COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler(
|
||||
"hugging face completion",
|
||||
OpenAiChatCompletionResponseEntity::fromResponse
|
||||
);
|
||||
private static final ResponseHandler RERANK_HANDLER = new HuggingFaceResponseHandler("hugging face rerank", (request, response) -> {
|
||||
if ((request instanceof HuggingFaceRerankRequest) == false) {
|
||||
var errorMessage = format(
|
||||
INVALID_REQUEST_TYPE_MESSAGE,
|
||||
"RERANK",
|
||||
request != null ? request.getClass().getSimpleName() : "null"
|
||||
);
|
||||
throw new IllegalArgumentException(errorMessage);
|
||||
}
|
||||
return HuggingFaceRerankResponseEntity.fromResponse((HuggingFaceRerankRequest) request, response);
|
||||
});
|
||||
|
||||
private final Sender sender;
|
||||
private final ServiceComponents serviceComponents;
|
||||
|
||||
|
@ -51,6 +70,26 @@ public class HuggingFaceActionCreator implements HuggingFaceActionVisitor {
|
|||
this.serviceComponents = Objects.requireNonNull(serviceComponents);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ExecutableAction create(HuggingFaceRerankModel model) {
|
||||
var overriddenModel = HuggingFaceRerankModel.of(model, model.getTaskSettings());
|
||||
var manager = new GenericRequestManager<>(
|
||||
serviceComponents.threadPool(),
|
||||
overriddenModel,
|
||||
RERANK_HANDLER,
|
||||
inputs -> new HuggingFaceRerankRequest(
|
||||
inputs.getQuery(),
|
||||
inputs.getChunks(),
|
||||
inputs.getReturnDocuments(),
|
||||
inputs.getTopN(),
|
||||
model
|
||||
),
|
||||
QueryAndDocsInputs.class
|
||||
);
|
||||
var errorMessage = buildErrorMessage(TaskType.RERANK, model.getInferenceEntityId());
|
||||
return new SenderExecutableAction(sender, manager, errorMessage);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ExecutableAction create(HuggingFaceEmbeddingsModel model) {
|
||||
var responseHandler = new HuggingFaceResponseHandler(
|
||||
|
@ -95,6 +134,6 @@ public class HuggingFaceActionCreator implements HuggingFaceActionVisitor {
|
|||
}
|
||||
|
||||
public static String buildErrorMessage(TaskType requestType, String inferenceId) {
|
||||
return format("Failed to send Hugging Face %s request from inference entity id [%s]", requestType.toString(), inferenceId);
|
||||
return format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, requestType.toString(), inferenceId);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,8 +11,11 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
|||
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModel;
|
||||
|
||||
public interface HuggingFaceActionVisitor {
|
||||
ExecutableAction create(HuggingFaceRerankModel model);
|
||||
|
||||
ExecutableAction create(HuggingFaceEmbeddingsModel model);
|
||||
|
||||
ExecutableAction create(HuggingFaceElserModel model);
|
||||
|
|
|
@ -13,11 +13,9 @@ import org.elasticsearch.TransportVersions;
|
|||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.util.LazyInitializable;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.core.TimeValue;
|
||||
import org.elasticsearch.inference.ChunkInferenceInput;
|
||||
import org.elasticsearch.inference.ChunkedInference;
|
||||
import org.elasticsearch.inference.ChunkingSettings;
|
||||
import org.elasticsearch.inference.InferenceServiceConfiguration;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.inference.InputType;
|
||||
|
@ -35,10 +33,10 @@ import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
|
|||
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceBaseService;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModelParameters;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
|
@ -69,18 +67,17 @@ public class HuggingFaceElserService extends HuggingFaceBaseService {
|
|||
}
|
||||
|
||||
@Override
|
||||
protected HuggingFaceModel createModel(
|
||||
String inferenceEntityId,
|
||||
TaskType taskType,
|
||||
Map<String, Object> serviceSettings,
|
||||
ChunkingSettings chunkingSettings,
|
||||
@Nullable Map<String, Object> secretSettings,
|
||||
String failureMessage,
|
||||
ConfigurationParseContext context
|
||||
) {
|
||||
return switch (taskType) {
|
||||
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context);
|
||||
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
||||
protected HuggingFaceModel createModel(HuggingFaceModelParameters input) {
|
||||
return switch (input.taskType()) {
|
||||
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(
|
||||
input.inferenceEntityId(),
|
||||
input.taskType(),
|
||||
NAME,
|
||||
input.serviceSettings(),
|
||||
input.secretSettings(),
|
||||
input.context()
|
||||
);
|
||||
default -> throw new ElasticsearchStatusException(input.failureMessage(), RestStatus.BAD_REQUEST);
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.huggingface.request.rerank;
|
||||
|
||||
import org.apache.http.HttpHeaders;
|
||||
import org.apache.http.client.methods.HttpPost;
|
||||
import org.apache.http.entity.ByteArrayEntity;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
|
||||
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceAccount;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModel;
|
||||
|
||||
import java.net.URI;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
|
||||
|
||||
public class HuggingFaceRerankRequest implements Request {
|
||||
|
||||
private final HuggingFaceAccount account;
|
||||
private final String query;
|
||||
private final List<String> input;
|
||||
private final Boolean returnDocuments;
|
||||
private final Integer topN;
|
||||
private final HuggingFaceRerankModel model;
|
||||
|
||||
public HuggingFaceRerankRequest(
|
||||
String query,
|
||||
List<String> input,
|
||||
@Nullable Boolean returnDocuments,
|
||||
@Nullable Integer topN,
|
||||
HuggingFaceRerankModel model
|
||||
) {
|
||||
Objects.requireNonNull(model);
|
||||
|
||||
this.account = HuggingFaceAccount.of(model);
|
||||
this.input = Objects.requireNonNull(input);
|
||||
this.query = Objects.requireNonNull(query);
|
||||
this.returnDocuments = returnDocuments;
|
||||
this.topN = topN;
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
@Override
|
||||
public HttpRequest createHttpRequest() {
|
||||
HttpPost httpPost = new HttpPost(account.uri());
|
||||
|
||||
ByteArrayEntity byteEntity = new ByteArrayEntity(
|
||||
Strings.toString(new HuggingFaceRerankRequestEntity(query, input, returnDocuments, getTopN(), model.getTaskSettings()))
|
||||
.getBytes(StandardCharsets.UTF_8)
|
||||
);
|
||||
httpPost.setEntity(byteEntity);
|
||||
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters());
|
||||
|
||||
decorateWithAuth(httpPost);
|
||||
|
||||
return new HttpRequest(httpPost, getInferenceEntityId());
|
||||
}
|
||||
|
||||
void decorateWithAuth(HttpPost httpPost) {
|
||||
httpPost.setHeader(createAuthBearerHeader(model.apiKey()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getInferenceEntityId() {
|
||||
return model.getInferenceEntityId();
|
||||
}
|
||||
|
||||
@Override
|
||||
public URI getURI() {
|
||||
return account.uri();
|
||||
}
|
||||
|
||||
public Integer getTopN() {
|
||||
return topN != null ? topN : model.getTaskSettings().getTopNDocumentsOnly();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Request truncate() {
|
||||
// Not applicable for rerank, only used in text embedding requests
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean[] getTruncationInfo() {
|
||||
// Not applicable for rerank, only used in text embedding requests
|
||||
return null;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.huggingface.request.rerank;
|
||||
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
public record HuggingFaceRerankRequestEntity(
|
||||
String query,
|
||||
List<String> documents,
|
||||
@Nullable Boolean returnDocuments,
|
||||
@Nullable Integer topN,
|
||||
HuggingFaceRerankTaskSettings taskSettings
|
||||
) implements ToXContentObject {
|
||||
|
||||
private static final String RETURN_TEXT = "return_text";
|
||||
private static final String DOCUMENTS_FIELD = "texts";
|
||||
private static final String QUERY_FIELD = "query";
|
||||
|
||||
public HuggingFaceRerankRequestEntity {
|
||||
Objects.requireNonNull(query);
|
||||
Objects.requireNonNull(documents);
|
||||
Objects.requireNonNull(taskSettings);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
|
||||
builder.field(DOCUMENTS_FIELD, documents);
|
||||
builder.field(QUERY_FIELD, query);
|
||||
|
||||
// prefer the root level return_documents over task settings
|
||||
if (returnDocuments != null) {
|
||||
builder.field(RETURN_TEXT, returnDocuments);
|
||||
} else if (taskSettings.getReturnDocuments() != null) {
|
||||
builder.field(RETURN_TEXT, taskSettings.getReturnDocuments());
|
||||
}
|
||||
|
||||
if (topN != null) {
|
||||
builder.field(HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, topN);
|
||||
} else if (taskSettings.getTopNDocumentsOnly() != null) {
|
||||
builder.field(HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, taskSettings.getTopNDocumentsOnly());
|
||||
}
|
||||
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,91 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.huggingface.rerank;
|
||||
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ModelSecrets;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionVisitor;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
public class HuggingFaceRerankModel extends HuggingFaceModel {
|
||||
public static HuggingFaceRerankModel of(HuggingFaceRerankModel model, HuggingFaceRerankTaskSettings taskSettings) {
|
||||
return new HuggingFaceRerankModel(model, HuggingFaceRerankTaskSettings.of(model.getTaskSettings(), taskSettings));
|
||||
}
|
||||
|
||||
public HuggingFaceRerankModel(
|
||||
String inferenceEntityId,
|
||||
TaskType taskType,
|
||||
String service,
|
||||
Map<String, Object> serviceSettings,
|
||||
Map<String, Object> taskSettings,
|
||||
@Nullable Map<String, Object> secrets,
|
||||
ConfigurationParseContext context
|
||||
) {
|
||||
this(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
service,
|
||||
HuggingFaceRerankServiceSettings.fromMap(serviceSettings, context),
|
||||
HuggingFaceRerankTaskSettings.fromMap(taskSettings),
|
||||
DefaultSecretSettings.fromMap(secrets)
|
||||
);
|
||||
}
|
||||
|
||||
// Should only be used directly for testing
|
||||
HuggingFaceRerankModel(
|
||||
String inferenceEntityId,
|
||||
TaskType taskType,
|
||||
String service,
|
||||
HuggingFaceRerankServiceSettings serviceSettings,
|
||||
HuggingFaceRerankTaskSettings taskSettings,
|
||||
@Nullable DefaultSecretSettings secrets
|
||||
) {
|
||||
super(
|
||||
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings),
|
||||
new ModelSecrets(secrets),
|
||||
serviceSettings,
|
||||
secrets
|
||||
);
|
||||
}
|
||||
|
||||
private HuggingFaceRerankModel(HuggingFaceRerankModel model, HuggingFaceRerankTaskSettings taskSettings) {
|
||||
super(model, taskSettings);
|
||||
}
|
||||
|
||||
@Override
|
||||
public HuggingFaceRerankServiceSettings getServiceSettings() {
|
||||
return (HuggingFaceRerankServiceSettings) super.getServiceSettings();
|
||||
}
|
||||
|
||||
@Override
|
||||
public HuggingFaceRerankTaskSettings getTaskSettings() {
|
||||
return (HuggingFaceRerankTaskSettings) super.getTaskSettings();
|
||||
}
|
||||
|
||||
@Override
|
||||
public DefaultSecretSettings getSecretSettings() {
|
||||
return (DefaultSecretSettings) super.getSecretSettings();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getTokenLimit() {
|
||||
throw new UnsupportedOperationException("Token Limit for rerank is sent in request and not retrieved from the model");
|
||||
}
|
||||
|
||||
@Override
|
||||
public ExecutableAction accept(HuggingFaceActionVisitor creator) {
|
||||
return creator.create(this);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,139 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.huggingface.rerank;
|
||||
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.TransportVersions;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.ServiceSettings;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRateLimitServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService;
|
||||
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri;
|
||||
import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri;
|
||||
|
||||
public class HuggingFaceRerankServiceSettings extends FilteredXContentObject
|
||||
implements
|
||||
ServiceSettings,
|
||||
HuggingFaceRateLimitServiceSettings {
|
||||
|
||||
public static final String NAME = "hugging_face_rerank_service_settings";
|
||||
public static final String URL = "url";
|
||||
|
||||
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000);
|
||||
|
||||
public static HuggingFaceRerankServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
|
||||
ValidationException validationException = new ValidationException();
|
||||
var uri = extractUri(map, URL, validationException);
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||
map,
|
||||
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||
validationException,
|
||||
HuggingFaceService.NAME,
|
||||
context
|
||||
);
|
||||
|
||||
if (validationException.validationErrors().isEmpty() == false) {
|
||||
throw validationException;
|
||||
}
|
||||
return new HuggingFaceRerankServiceSettings(uri, rateLimitSettings);
|
||||
}
|
||||
|
||||
private final URI uri;
|
||||
private final RateLimitSettings rateLimitSettings;
|
||||
|
||||
public HuggingFaceRerankServiceSettings(String url) {
|
||||
uri = createUri(url);
|
||||
rateLimitSettings = DEFAULT_RATE_LIMIT_SETTINGS;
|
||||
}
|
||||
|
||||
HuggingFaceRerankServiceSettings(URI uri, @Nullable RateLimitSettings rateLimitSettings) {
|
||||
this.uri = Objects.requireNonNull(uri);
|
||||
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
|
||||
}
|
||||
|
||||
public HuggingFaceRerankServiceSettings(StreamInput in) throws IOException {
|
||||
uri = createUri(in.readString());
|
||||
rateLimitSettings = new RateLimitSettings(in);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RateLimitSettings rateLimitSettings() {
|
||||
return rateLimitSettings;
|
||||
}
|
||||
|
||||
@Override
|
||||
public URI uri() {
|
||||
return uri;
|
||||
}
|
||||
|
||||
// model is not defined in the service settings.
|
||||
// since hugging face requires that the model be chosen when initializing a deployment within their service.
|
||||
@Override
|
||||
public String modelId() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
toXContentFragmentOfExposedFields(builder, params);
|
||||
builder.endObject();
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.field(URL, uri.toString());
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
return TransportVersions.ML_INFERENCE_HUGGING_FACE_RERANK_ADDED;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(uri.toString());
|
||||
rateLimitSettings.writeTo(out);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
HuggingFaceRerankServiceSettings that = (HuggingFaceRerankServiceSettings) o;
|
||||
return Objects.equals(uri, that.uri) && Objects.equals(rateLimitSettings, that.rateLimitSettings);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(uri, rateLimitSettings);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,156 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.huggingface.rerank;
|
||||
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.TransportVersions;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.TaskSettings;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
|
||||
|
||||
public class HuggingFaceRerankTaskSettings implements TaskSettings {
|
||||
|
||||
public static final String NAME = "hugging_face_rerank_task_settings";
|
||||
public static final String RETURN_DOCUMENTS = "return_documents";
|
||||
public static final String TOP_N_DOCS_ONLY = "top_n";
|
||||
|
||||
static final HuggingFaceRerankTaskSettings EMPTY_SETTINGS = new HuggingFaceRerankTaskSettings(null, null);
|
||||
|
||||
public static HuggingFaceRerankTaskSettings fromMap(Map<String, Object> map) {
|
||||
ValidationException validationException = new ValidationException();
|
||||
|
||||
if (map == null || map.isEmpty()) {
|
||||
return EMPTY_SETTINGS;
|
||||
}
|
||||
|
||||
Boolean returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS, validationException);
|
||||
Integer topNDocumentsOnly = extractOptionalPositiveInteger(
|
||||
map,
|
||||
TOP_N_DOCS_ONLY,
|
||||
ModelConfigurations.TASK_SETTINGS,
|
||||
validationException
|
||||
);
|
||||
|
||||
if (validationException.validationErrors().isEmpty() == false) {
|
||||
throw validationException;
|
||||
}
|
||||
|
||||
return of(topNDocumentsOnly, returnDocuments);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new {@link HuggingFaceRerankTaskSettings}
|
||||
* by preferring non-null fields from the request settings over the original settings.
|
||||
*
|
||||
* @param originalSettings the settings stored as part of the inference entity configuration
|
||||
* @param requestTaskSettings the settings passed in within the task_settings field of the request
|
||||
* @return a constructed {@link HuggingFaceRerankTaskSettings}
|
||||
*/
|
||||
public static HuggingFaceRerankTaskSettings of(
|
||||
HuggingFaceRerankTaskSettings originalSettings,
|
||||
HuggingFaceRerankTaskSettings requestTaskSettings
|
||||
) {
|
||||
return new HuggingFaceRerankTaskSettings(
|
||||
requestTaskSettings.getTopNDocumentsOnly() != null
|
||||
? requestTaskSettings.getTopNDocumentsOnly()
|
||||
: originalSettings.getTopNDocumentsOnly(),
|
||||
requestTaskSettings.getReturnDocuments() != null
|
||||
? requestTaskSettings.getReturnDocuments()
|
||||
: originalSettings.getReturnDocuments()
|
||||
);
|
||||
}
|
||||
|
||||
public static HuggingFaceRerankTaskSettings of(Integer topNDocumentsOnly, Boolean returnDocuments) {
|
||||
return new HuggingFaceRerankTaskSettings(topNDocumentsOnly, returnDocuments);
|
||||
}
|
||||
|
||||
private final Integer topNDocumentsOnly;
|
||||
private final Boolean returnDocuments;
|
||||
|
||||
public HuggingFaceRerankTaskSettings(StreamInput in) throws IOException {
|
||||
this(in.readOptionalVInt(), in.readOptionalBoolean());
|
||||
}
|
||||
|
||||
public HuggingFaceRerankTaskSettings(@Nullable Integer topNDocumentsOnly, @Nullable Boolean doReturnDocuments) {
|
||||
this.topNDocumentsOnly = topNDocumentsOnly;
|
||||
this.returnDocuments = doReturnDocuments;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isEmpty() {
|
||||
return topNDocumentsOnly == null && returnDocuments == null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
if (topNDocumentsOnly != null) {
|
||||
builder.field(TOP_N_DOCS_ONLY, topNDocumentsOnly);
|
||||
}
|
||||
if (returnDocuments != null) {
|
||||
builder.field(RETURN_DOCUMENTS, returnDocuments);
|
||||
}
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
return TransportVersions.ML_INFERENCE_HUGGING_FACE_RERANK_ADDED;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeOptionalVInt(topNDocumentsOnly);
|
||||
out.writeOptionalBoolean(returnDocuments);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
HuggingFaceRerankTaskSettings that = (HuggingFaceRerankTaskSettings) o;
|
||||
return Objects.equals(returnDocuments, that.returnDocuments) && Objects.equals(topNDocumentsOnly, that.topNDocumentsOnly);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(returnDocuments, topNDocumentsOnly);
|
||||
}
|
||||
|
||||
public Integer getTopNDocumentsOnly() {
|
||||
return topNDocumentsOnly;
|
||||
}
|
||||
|
||||
public Boolean getReturnDocuments() {
|
||||
return returnDocuments;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
|
||||
HuggingFaceRerankTaskSettings updatedSettings = HuggingFaceRerankTaskSettings.fromMap(new HashMap<>(newSettings));
|
||||
return HuggingFaceRerankTaskSettings.of(this, updatedSettings);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,110 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.huggingface.response;
|
||||
|
||||
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.xcontent.ConstructingObjectParser;
|
||||
import org.elasticsearch.xcontent.ParseField;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentParser;
|
||||
import org.elasticsearch.xcontent.XContentParserConfiguration;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.request.rerank.HuggingFaceRerankRequest;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList;
|
||||
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
|
||||
|
||||
public class HuggingFaceRerankResponseEntity {
|
||||
|
||||
/**
|
||||
* Parses the Hugging Face rerank response.
|
||||
|
||||
* For a request like:
|
||||
*
|
||||
* <pre>
|
||||
* <code>
|
||||
* {
|
||||
* "texts": ["luke", "leia"],
|
||||
* "query": "star wars main character",
|
||||
* "return_text": true
|
||||
* }
|
||||
* </code>
|
||||
* </pre>
|
||||
|
||||
* The response would look like:
|
||||
|
||||
* <pre>
|
||||
* <code>
|
||||
* [
|
||||
* {
|
||||
* "index": 0,
|
||||
* "score": -0.07996220886707306,
|
||||
* "text": "luke"
|
||||
* },
|
||||
* {
|
||||
* "index": 1,
|
||||
* "score": -0.08393221348524094,
|
||||
* "text": "leia"
|
||||
* }
|
||||
* ]
|
||||
* </code>
|
||||
* </pre>
|
||||
*/
|
||||
|
||||
public static RankedDocsResults fromResponse(HuggingFaceRerankRequest request, HttpResult response) throws IOException {
|
||||
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
|
||||
|
||||
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
|
||||
moveToFirstToken(jsonParser);
|
||||
var rankedDocs = doParse(jsonParser);
|
||||
var rankedDocsByRelevanceStream = rankedDocs.stream()
|
||||
.sorted(Comparator.comparingDouble(RankedDocsResults.RankedDoc::relevanceScore).reversed());
|
||||
var rankedDocStreamTopN = request.getTopN() == null
|
||||
? rankedDocsByRelevanceStream
|
||||
: rankedDocsByRelevanceStream.limit(request.getTopN());
|
||||
return new RankedDocsResults(rankedDocStreamTopN.toList());
|
||||
}
|
||||
}
|
||||
|
||||
private static List<RankedDocsResults.RankedDoc> doParse(XContentParser parser) throws IOException {
|
||||
return parseList(parser, (listParser, index) -> {
|
||||
var parsedRankedDoc = HuggingFaceRerankResponseEntity.RankedDocEntry.parse(parser);
|
||||
return new RankedDocsResults.RankedDoc(parsedRankedDoc.index, parsedRankedDoc.score, parsedRankedDoc.text);
|
||||
});
|
||||
}
|
||||
|
||||
private record RankedDocEntry(Integer index, Float score, @Nullable String text) {
|
||||
|
||||
private static final ParseField TEXT = new ParseField("text");
|
||||
private static final ParseField SCORE = new ParseField("score");
|
||||
private static final ParseField INDEX = new ParseField("index");
|
||||
private static final ConstructingObjectParser<HuggingFaceRerankResponseEntity.RankedDocEntry, Void> PARSER =
|
||||
new ConstructingObjectParser<>(
|
||||
"hugging_face_rerank_response",
|
||||
true,
|
||||
args -> new HuggingFaceRerankResponseEntity.RankedDocEntry((int) args[0], (float) args[1], (String) args[2])
|
||||
);
|
||||
|
||||
static {
|
||||
PARSER.declareInt(ConstructingObjectParser.constructorArg(), INDEX);
|
||||
PARSER.declareFloat(ConstructingObjectParser.constructorArg(), SCORE);
|
||||
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), TEXT);
|
||||
}
|
||||
|
||||
public static RankedDocEntry parse(XContentParser parser) {
|
||||
return PARSER.apply(parser, null);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -425,7 +425,7 @@ public abstract class BaseTransportInferenceActionTestCase<Request extends BaseI
|
|||
doAnswer(ans -> {
|
||||
listenerAction.accept(ans.getArgument(9));
|
||||
return null;
|
||||
}).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any());
|
||||
}).when(service).infer(any(), any(), anyBoolean(), any(), any(), anyBoolean(), any(), any(), any(), any());
|
||||
doAnswer(ans -> {
|
||||
listenerAction.accept(ans.getArgument(3));
|
||||
return null;
|
||||
|
|
|
@ -1292,7 +1292,7 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
|||
{
|
||||
"service": "hugging_face",
|
||||
"name": "Hugging Face",
|
||||
"task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"],
|
||||
"task_types": ["text_embedding", "sparse_embedding", "rerank", "completion", "chat_completion"],
|
||||
"configurations": {
|
||||
"api_key": {
|
||||
"description": "API Key for the provider you're connecting to.",
|
||||
|
@ -1301,7 +1301,7 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
|||
"sensitive": true,
|
||||
"updatable": true,
|
||||
"type": "str",
|
||||
"supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"]
|
||||
"supported_task_types": ["text_embedding", "sparse_embedding", "rerank", "completion", "chat_completion"]
|
||||
},
|
||||
"rate_limit.requests_per_minute": {
|
||||
"description": "Minimize the number of rate limit errors.",
|
||||
|
@ -1310,7 +1310,7 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
|||
"sensitive": false,
|
||||
"updatable": false,
|
||||
"type": "int",
|
||||
"supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"]
|
||||
"supported_task_types": ["text_embedding", "sparse_embedding", "rerank", "completion", "chat_completion"]
|
||||
},
|
||||
"url": {
|
||||
"description": "The URL endpoint to use for the requests.",
|
||||
|
@ -1319,7 +1319,7 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
|||
"sensitive": false,
|
||||
"updatable": false,
|
||||
"type": "str",
|
||||
"supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"]
|
||||
"supported_task_types": ["text_embedding", "sparse_embedding", "rerank", "completion", "chat_completion"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,12 +27,14 @@ import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
|||
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
||||
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModelTests;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModelTests;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
|
@ -311,6 +313,66 @@ public class HuggingFaceActionCreatorTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testSend_FailsFromInvalidResponseFormat_ForRerankAction() throws IOException {
|
||||
var settings = buildSettingsWithRetryFields(
|
||||
TimeValue.timeValueMillis(1),
|
||||
TimeValue.timeValueMinutes(1),
|
||||
TimeValue.timeValueSeconds(0)
|
||||
);
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
|
||||
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
{
|
||||
"rerank": [
|
||||
{
|
||||
"index": 0,
|
||||
"relevance_score": -0.07996031,
|
||||
"text": "luke"
|
||||
}
|
||||
]
|
||||
}
|
||||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
|
||||
var model = HuggingFaceRerankModelTests.createModel(getUrl(webServer), "secret", "model", 8, true);
|
||||
var actionCreator = new HuggingFaceActionCreator(
|
||||
sender,
|
||||
new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator())
|
||||
);
|
||||
var action = actionCreator.create(model);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
action.execute(new QueryAndDocsInputs("popular name", List.of("Luke")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||
|
||||
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
|
||||
assertThat(
|
||||
thrownException.getMessage(),
|
||||
is("Failed to parse object: expecting token of type [START_ARRAY] but found [START_OBJECT]")
|
||||
);
|
||||
}
|
||||
assertRerankActionCreator(List.of("Luke"), "popular name", 8, true);
|
||||
}
|
||||
|
||||
private void assertRerankActionCreator(List<String> documents, String query, int topN, boolean returnText) throws IOException {
|
||||
assertThat(webServer.requests(), hasSize(1));
|
||||
assertNull(webServer.requests().get(0).getUri().getQuery());
|
||||
assertThat(
|
||||
webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),
|
||||
equalTo(XContentType.JSON.mediaTypeWithoutParameters())
|
||||
);
|
||||
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
|
||||
|
||||
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
||||
assertThat(requestMap.size(), is(4));
|
||||
assertThat(requestMap.get("texts"), is(documents));
|
||||
assertThat(requestMap.get("query"), is(query));
|
||||
assertThat(requestMap.get("top_n"), is(topN));
|
||||
assertThat(requestMap.get("return_text"), is(returnText));
|
||||
}
|
||||
|
||||
public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.huggingface.request.rerank;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xcontent.ToXContent;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.common.xcontent.XContentHelper.stripWhitespace;
|
||||
|
||||
public class HuggingFaceRerankRequestEntityTests extends ESTestCase {
|
||||
private static final String INPUT = "texts";
|
||||
private static final String QUERY = "query";
|
||||
private static final Integer TOP_N = 8;
|
||||
private static final Boolean RETURN_DOCUMENTS = false;
|
||||
|
||||
public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException {
|
||||
var entity = new HuggingFaceRerankRequestEntity(
|
||||
QUERY,
|
||||
List.of(INPUT),
|
||||
Boolean.TRUE,
|
||||
TOP_N,
|
||||
new HuggingFaceRerankTaskSettings(TOP_N, RETURN_DOCUMENTS)
|
||||
);
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
String expected = """
|
||||
{"texts":["texts"],
|
||||
"query":"query",
|
||||
"return_text":true,
|
||||
"top_n":8}""";
|
||||
assertEquals(stripWhitespace(expected), xContentResult);
|
||||
}
|
||||
|
||||
public void testXContent_WritesMinimalFields() throws IOException {
|
||||
var entity = new HuggingFaceRerankRequestEntity(QUERY, List.of(INPUT), null, null, new HuggingFaceRerankTaskSettings(null, null));
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
String expected = """
|
||||
{"texts":["texts"],"query":"query"}""";
|
||||
assertEquals(stripWhitespace(expected), xContentResult);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,99 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.huggingface.request.rerank;
|
||||
|
||||
import org.apache.http.HttpHeaders;
|
||||
import org.apache.http.client.methods.HttpPost;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModel;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankModelTests;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.hamcrest.Matchers.aMapWithSize;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class HuggingFaceRerankRequestTests extends ESTestCase {
|
||||
private static final String INPUT = "texts";
|
||||
private static final String QUERY = "query";
|
||||
private static final String INFERENCE_ID = "model";
|
||||
private static final Integer TOP_N = 8;
|
||||
private static final Boolean RETURN_TEXT = false;
|
||||
|
||||
private static final String AUTH_HEADER_VALUE = "foo";
|
||||
|
||||
public void testCreateRequest_WithMinimalFieldsSet() throws IOException {
|
||||
testCreateRequest(null, null);
|
||||
}
|
||||
|
||||
public void testCreateRequest_WithTopN() throws IOException {
|
||||
testCreateRequest(TOP_N, null);
|
||||
}
|
||||
|
||||
public void testCreateRequest_WithReturnDocuments() throws IOException {
|
||||
testCreateRequest(null, RETURN_TEXT);
|
||||
}
|
||||
|
||||
private void testCreateRequest(Integer topN, Boolean returnDocuments) throws IOException {
|
||||
var request = createRequest(topN, returnDocuments);
|
||||
var httpRequest = request.createHttpRequest();
|
||||
|
||||
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||
var httpPost = (HttpPost) httpRequest.httpRequestBase();
|
||||
|
||||
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaTypeWithoutParameters()));
|
||||
assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE));
|
||||
|
||||
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||
|
||||
assertThat(requestMap.get(INPUT), is(List.of(INPUT)));
|
||||
assertThat(requestMap.get(QUERY), is(QUERY));
|
||||
// input and query must exist
|
||||
int itemsCount = 2;
|
||||
if (topN != null) {
|
||||
assertThat(requestMap.get("top_n"), is(topN));
|
||||
itemsCount++;
|
||||
}
|
||||
if (returnDocuments != null) {
|
||||
assertThat(requestMap.get("return_text"), is(returnDocuments));
|
||||
itemsCount++;
|
||||
}
|
||||
assertThat(requestMap, aMapWithSize(itemsCount));
|
||||
}
|
||||
|
||||
private static HuggingFaceRerankRequest createRequest(@Nullable Integer topN, @Nullable Boolean returnDocuments) {
|
||||
var rerankModel = HuggingFaceRerankModelTests.createModel(randomAlphaOfLength(10), "secret", INFERENCE_ID, topN, returnDocuments);
|
||||
|
||||
return new HuggingFaceRerankWithoutAuthRequest(QUERY, List.of(INPUT), rerankModel, topN, returnDocuments);
|
||||
}
|
||||
|
||||
/**
|
||||
* We use this class to fake the auth implementation to avoid static mocking of {@link HuggingFaceRerankRequest}
|
||||
*/
|
||||
private static class HuggingFaceRerankWithoutAuthRequest extends HuggingFaceRerankRequest {
|
||||
HuggingFaceRerankWithoutAuthRequest(
|
||||
String query,
|
||||
List<String> input,
|
||||
HuggingFaceRerankModel model,
|
||||
@Nullable Integer topN,
|
||||
@Nullable Boolean returnDocuments
|
||||
) {
|
||||
super(query, input, returnDocuments, topN, model);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void decorateWithAuth(HttpPost httpPost) {
|
||||
httpPost.setHeader(HttpHeaders.AUTHORIZATION, AUTH_HEADER_VALUE);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.huggingface.rerank;
|
||||
|
||||
import org.elasticsearch.common.settings.SecureString;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
|
||||
public class HuggingFaceRerankModelTests extends ESTestCase {
|
||||
|
||||
public void testThrowsURISyntaxException_ForInvalidUrl() {
|
||||
var thrownException = expectThrows(IllegalArgumentException.class, () -> createModel("^^", "secret", "model", 8, false));
|
||||
assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]"));
|
||||
}
|
||||
|
||||
public static HuggingFaceRerankModel createModel(
|
||||
String url,
|
||||
String apiKey,
|
||||
String modelId,
|
||||
@Nullable Integer topN,
|
||||
@Nullable Boolean returnDocuments
|
||||
) {
|
||||
return new HuggingFaceRerankModel(
|
||||
modelId,
|
||||
TaskType.RERANK,
|
||||
"service",
|
||||
new HuggingFaceRerankServiceSettings(url),
|
||||
new HuggingFaceRerankTaskSettings(topN, returnDocuments),
|
||||
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
|
||||
);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,124 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.huggingface.rerank;
|
||||
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
|
||||
public class HuggingFaceRerankTaskSettingsTests extends AbstractBWCWireSerializationTestCase<HuggingFaceRerankTaskSettings> {
|
||||
|
||||
public static HuggingFaceRerankTaskSettings createRandom() {
|
||||
var returnDocuments = randomBoolean() ? randomBoolean() : null;
|
||||
var topNDocsOnly = randomBoolean() ? randomIntBetween(1, 10) : null;
|
||||
|
||||
return new HuggingFaceRerankTaskSettings(topNDocsOnly, returnDocuments);
|
||||
}
|
||||
|
||||
public void testFromMap_WithValidValues_ReturnsSettings() {
|
||||
Map<String, Object> taskMap = Map.of(
|
||||
HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS,
|
||||
true,
|
||||
HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY,
|
||||
5
|
||||
);
|
||||
var settings = HuggingFaceRerankTaskSettings.fromMap(new HashMap<>(taskMap));
|
||||
assertTrue(settings.getReturnDocuments());
|
||||
assertEquals(5, settings.getTopNDocumentsOnly().intValue());
|
||||
}
|
||||
|
||||
public void testFromMap_WithNullValues_ReturnsSettingsWithNulls() {
|
||||
var settings = HuggingFaceRerankTaskSettings.fromMap(Map.of());
|
||||
assertNull(settings.getReturnDocuments());
|
||||
assertNull(settings.getTopNDocumentsOnly());
|
||||
}
|
||||
|
||||
public void testFromMap_WithInvalidReturnDocuments_ThrowsValidationException() {
|
||||
Map<String, Object> taskMap = Map.of(
|
||||
HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS,
|
||||
"invalid",
|
||||
HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY,
|
||||
5
|
||||
);
|
||||
var thrownException = expectThrows(ValidationException.class, () -> HuggingFaceRerankTaskSettings.fromMap(new HashMap<>(taskMap)));
|
||||
assertThat(thrownException.getMessage(), containsString("field [return_documents] is not of the expected type"));
|
||||
}
|
||||
|
||||
public void testFromMap_WithInvalidTopNDocsOnly_ThrowsValidationException() {
|
||||
Map<String, Object> taskMap = Map.of(
|
||||
HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS,
|
||||
true,
|
||||
HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY,
|
||||
"invalid"
|
||||
);
|
||||
var thrownException = expectThrows(ValidationException.class, () -> HuggingFaceRerankTaskSettings.fromMap(new HashMap<>(taskMap)));
|
||||
assertThat(thrownException.getMessage(), containsString("field [top_n] is not of the expected type"));
|
||||
}
|
||||
|
||||
public void UpdatedTaskSettings_WithEmptyMap_ReturnsSameSettings() {
|
||||
var initialSettings = new HuggingFaceRerankTaskSettings(5, true);
|
||||
HuggingFaceRerankTaskSettings updatedSettings = (HuggingFaceRerankTaskSettings) initialSettings.updatedTaskSettings(Map.of());
|
||||
assertEquals(initialSettings, updatedSettings);
|
||||
}
|
||||
|
||||
public void testUpdatedTaskSettings_WithNewReturnDocuments_ReturnsUpdatedSettings() {
|
||||
var initialSettings = new HuggingFaceRerankTaskSettings(5, true);
|
||||
Map<String, Object> newSettings = Map.of(HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS, false);
|
||||
HuggingFaceRerankTaskSettings updatedSettings = (HuggingFaceRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings);
|
||||
assertFalse(updatedSettings.getReturnDocuments());
|
||||
assertEquals(initialSettings.getTopNDocumentsOnly(), updatedSettings.getTopNDocumentsOnly());
|
||||
}
|
||||
|
||||
public void testUpdatedTaskSettings_WithNewTopNDocsOnly_ReturnsUpdatedSettings() {
|
||||
var initialSettings = new HuggingFaceRerankTaskSettings(5, true);
|
||||
Map<String, Object> newSettings = Map.of(HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY, 7);
|
||||
HuggingFaceRerankTaskSettings updatedSettings = (HuggingFaceRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings);
|
||||
assertEquals(7, updatedSettings.getTopNDocumentsOnly().intValue());
|
||||
assertEquals(initialSettings.getReturnDocuments(), updatedSettings.getReturnDocuments());
|
||||
}
|
||||
|
||||
public void testUpdatedTaskSettings_WithMultipleNewValues_ReturnsUpdatedSettings() {
|
||||
var initialSettings = new HuggingFaceRerankTaskSettings(5, true);
|
||||
Map<String, Object> newSettings = Map.of(
|
||||
HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS,
|
||||
false,
|
||||
HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY,
|
||||
7
|
||||
);
|
||||
HuggingFaceRerankTaskSettings updatedSettings = (HuggingFaceRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings);
|
||||
assertFalse(updatedSettings.getReturnDocuments());
|
||||
assertEquals(7, updatedSettings.getTopNDocumentsOnly().intValue());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<HuggingFaceRerankTaskSettings> instanceReader() {
|
||||
return HuggingFaceRerankTaskSettings::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected HuggingFaceRerankTaskSettings createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected HuggingFaceRerankTaskSettings mutateInstance(HuggingFaceRerankTaskSettings instance) throws IOException {
|
||||
return randomValueOtherThan(instance, HuggingFaceRerankTaskSettingsTests::createRandom);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected HuggingFaceRerankTaskSettings mutateInstanceForVersion(HuggingFaceRerankTaskSettings instance, TransportVersion version) {
|
||||
return instance;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,147 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.huggingface.response;
|
||||
|
||||
import org.apache.http.HttpResponse;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
|
||||
import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.request.rerank.HuggingFaceRerankRequest;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests.buildExpectationRerank;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class HuggingFaceRerankResponseEntityTests extends ESTestCase {
|
||||
private static final String MISSED_FIELD_INDEX = "index";
|
||||
private static final String MISSED_FIELD_SCORE = "score";
|
||||
private static final String RESPONSE_JSON_TWO_DOCS = """
|
||||
[
|
||||
{
|
||||
"index": 4,
|
||||
"score": -0.22222222222222222,
|
||||
"text": "ranked second"
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"score": 1.11111111111111111,
|
||||
"text": "ranked first"
|
||||
}
|
||||
]
|
||||
""";
|
||||
private static final List<RankedDocsResultsTests.RerankExpectation> EXPECTED_TWO_DOCS = List.of(
|
||||
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 1, "relevance_score", 1.11111111111111111F, "text", "ranked first")),
|
||||
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 4, "relevance_score", -0.22222222222222222F, "text", "ranked second"))
|
||||
);
|
||||
|
||||
private static final String RESPONSE_JSON_FIVE_DOCS = """
|
||||
[
|
||||
{
|
||||
"index": 1,
|
||||
"score": 1.11111111111111111,
|
||||
"text": "ranked first"
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"score": -0.33333333333333333,
|
||||
"text": "ranked third"
|
||||
},
|
||||
{
|
||||
"index": 0,
|
||||
"score": -0.55555555555555555,
|
||||
"text": "ranked fifth"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"score": -0.44444444444444444,
|
||||
"text": "ranked fourth"
|
||||
},
|
||||
{
|
||||
"index": 4,
|
||||
"score": -0.22222222222222222,
|
||||
"text": "ranked second"
|
||||
}
|
||||
]
|
||||
""";
|
||||
private static final List<RankedDocsResultsTests.RerankExpectation> EXPECTED_FIVE_DOCS = List.of(
|
||||
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 1, "relevance_score", 1.11111111111111111F, "text", "ranked first")),
|
||||
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 4, "relevance_score", -0.22222222222222222F, "text", "ranked second")),
|
||||
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 3, "relevance_score", -0.33333333333333333F, "text", "ranked third")),
|
||||
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 2, "relevance_score", -0.44444444444444444F, "text", "ranked fourth")),
|
||||
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 0, "relevance_score", -0.55555555555555555F, "text", "ranked fifth"))
|
||||
);
|
||||
|
||||
private static final HuggingFaceRerankRequest REQUEST_MOCK = mock(HuggingFaceRerankRequest.class);
|
||||
|
||||
public void testFromResponse_CreatesRankedDocsResults_TopNNull_FiveDocs_NoLimit() throws IOException {
|
||||
assertTopNLimit(null, RESPONSE_JSON_FIVE_DOCS, EXPECTED_FIVE_DOCS);
|
||||
}
|
||||
|
||||
public void testFromResponse_CreatesRankedDocsResults_TopN5_TwoDocs_NoLimit() throws IOException {
|
||||
assertTopNLimit(5, RESPONSE_JSON_TWO_DOCS, EXPECTED_TWO_DOCS);
|
||||
}
|
||||
|
||||
public void testFromResponse_CreatesRankedDocsResults_TopN2_FiveDocs_Limits() throws IOException {
|
||||
assertTopNLimit(2, RESPONSE_JSON_FIVE_DOCS, EXPECTED_TWO_DOCS);
|
||||
}
|
||||
|
||||
public void testFails_CreateRankedDocsResults_IndexFieldNull() {
|
||||
String responseJson = """
|
||||
[
|
||||
{
|
||||
"score": 1.11111111111111111,
|
||||
"text": "ranked first"
|
||||
}
|
||||
]
|
||||
""";
|
||||
assertMissingFieldThrowsIllegalArgumentException(responseJson, MISSED_FIELD_INDEX);
|
||||
}
|
||||
|
||||
public void testFails_CreateRankedDocsResults_ScoreFieldNull() {
|
||||
String responseJson = """
|
||||
[
|
||||
{
|
||||
"index": 1,
|
||||
"text": "ranked first"
|
||||
}
|
||||
]
|
||||
""";
|
||||
assertMissingFieldThrowsIllegalArgumentException(responseJson, MISSED_FIELD_SCORE);
|
||||
}
|
||||
|
||||
private void assertMissingFieldThrowsIllegalArgumentException(String responseJson, String missingField) {
|
||||
when(REQUEST_MOCK.getTopN()).thenReturn(1);
|
||||
|
||||
var thrownException = expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> HuggingFaceRerankResponseEntity.fromResponse(
|
||||
REQUEST_MOCK,
|
||||
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
||||
)
|
||||
);
|
||||
assertThat(thrownException.getMessage(), is("Required [" + missingField + "]"));
|
||||
}
|
||||
|
||||
private void assertTopNLimit(
|
||||
Integer topN, String responseJson, List<RankedDocsResultsTests.RerankExpectation> expectation) throws IOException {
|
||||
when(REQUEST_MOCK.getTopN()).thenReturn(topN);
|
||||
|
||||
RankedDocsResults parsedResults = HuggingFaceRerankResponseEntity.fromResponse(
|
||||
REQUEST_MOCK,
|
||||
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
|
||||
);
|
||||
assertThat(parsedResults.asMap(), is(buildExpectationRerank(expectation)));
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue