[ML] Default the HF service to cosine similarity (#109967)

This commit is contained in:
David Kyle 2024-06-21 10:39:20 +01:00 committed by GitHub
parent b18ee11b75
commit d79f18d069
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 22 additions and 6 deletions

View file

@ -0,0 +1,5 @@
pr: 109967
summary: Default the HF service to cosine similarity
area: Machine Learning
type: enhancement
issues: []

View file

@ -17,6 +17,7 @@ import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.ChunkingOptions; import org.elasticsearch.inference.ChunkingOptions;
import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model; import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.common.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.common.EmbeddingRequestChunker;
@ -78,9 +79,14 @@ public class HuggingFaceService extends HuggingFaceBaseService {
} }
private static HuggingFaceEmbeddingsModel updateModelWithEmbeddingDetails(HuggingFaceEmbeddingsModel model, int embeddingSize) { private static HuggingFaceEmbeddingsModel updateModelWithEmbeddingDetails(HuggingFaceEmbeddingsModel model, int embeddingSize) {
// default to cosine similarity
var similarity = model.getServiceSettings().similarity() == null
? SimilarityMeasure.COSINE
: model.getServiceSettings().similarity();
var serviceSettings = new HuggingFaceServiceSettings( var serviceSettings = new HuggingFaceServiceSettings(
model.getServiceSettings().uri(), model.getServiceSettings().uri(),
model.getServiceSettings().similarity(), // we don't know the similarity but use whatever the user specified similarity,
embeddingSize, embeddingSize,
model.getTokenLimit(), model.getTokenLimit(),
model.getServiceSettings().rateLimitSettings() model.getServiceSettings().rateLimitSettings()

View file

@ -529,12 +529,15 @@ public class HuggingFaceServiceTests extends ESTestCase {
"""; """;
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1); var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1, SimilarityMeasure.DOT_PRODUCT);
PlainActionFuture<Model> listener = new PlainActionFuture<>(); PlainActionFuture<Model> listener = new PlainActionFuture<>();
service.checkModelConfig(model, listener); service.checkModelConfig(model, listener);
var result = listener.actionGet(TIMEOUT); var result = listener.actionGet(TIMEOUT);
assertThat(result, is(HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1))); assertThat(
result,
is(HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1, SimilarityMeasure.DOT_PRODUCT))
);
} }
} }
@ -566,7 +569,7 @@ public class HuggingFaceServiceTests extends ESTestCase {
} }
} }
public void testCheckModelConfig_LeavesSimilarityAsNull_WhenUnspecified() throws IOException { public void testCheckModelConfig_DefaultsSimilarityToCosine() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) {
@ -587,11 +590,13 @@ public class HuggingFaceServiceTests extends ESTestCase {
service.checkModelConfig(model, listener); service.checkModelConfig(model, listener);
var result = listener.actionGet(TIMEOUT); var result = listener.actionGet(TIMEOUT);
assertThat(result, is(HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1, null))); assertThat(
result,
is(HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1, SimilarityMeasure.COSINE))
);
} }
} }
// TODO
public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() throws IOException { public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);