diff --git a/docs/changelog/109967.yaml b/docs/changelog/109967.yaml new file mode 100644 index 000000000000..cfc6b6462954 --- /dev/null +++ b/docs/changelog/109967.yaml @@ -0,0 +1,5 @@ +pr: 109967 +summary: Default the HF service to cosine similarity +area: Machine Learning +type: enhancement +issues: [] diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index 161ab6c47bfe..6e311c39c787 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -17,6 +17,7 @@ import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.common.EmbeddingRequestChunker; @@ -78,9 +79,14 @@ public class HuggingFaceService extends HuggingFaceBaseService { } private static HuggingFaceEmbeddingsModel updateModelWithEmbeddingDetails(HuggingFaceEmbeddingsModel model, int embeddingSize) { + // default to cosine similarity + var similarity = model.getServiceSettings().similarity() == null + ? SimilarityMeasure.COSINE + : model.getServiceSettings().similarity(); + var serviceSettings = new HuggingFaceServiceSettings( model.getServiceSettings().uri(), - model.getServiceSettings().similarity(), // we don't know the similarity but use whatever the user specified + similarity, embeddingSize, model.getTokenLimit(), model.getServiceSettings().rateLimitSettings() diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index de5c7ec83d57..14fe1451ebac 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -529,12 +529,15 @@ public class HuggingFaceServiceTests extends ESTestCase { """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1); + var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1, SimilarityMeasure.DOT_PRODUCT); PlainActionFuture listener = new PlainActionFuture<>(); service.checkModelConfig(model, listener); var result = listener.actionGet(TIMEOUT); - assertThat(result, is(HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1))); + assertThat( + result, + is(HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1, SimilarityMeasure.DOT_PRODUCT)) + ); } } @@ -566,7 +569,7 @@ public class HuggingFaceServiceTests extends ESTestCase { } } - public void testCheckModelConfig_LeavesSimilarityAsNull_WhenUnspecified() throws IOException { + public void testCheckModelConfig_DefaultsSimilarityToCosine() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { @@ -587,11 +590,13 @@ public class HuggingFaceServiceTests extends ESTestCase { service.checkModelConfig(model, listener); var result = listener.actionGet(TIMEOUT); - assertThat(result, is(HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1, null))); + assertThat( + result, + is(HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1, SimilarityMeasure.COSINE)) + ); } } - // TODO public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);