mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 17:34:17 -04:00
[ML] Default the HF service to cosine similarity (#109967)
This commit is contained in:
parent
b18ee11b75
commit
d79f18d069
3 changed files with 22 additions and 6 deletions
5
docs/changelog/109967.yaml
Normal file
5
docs/changelog/109967.yaml
Normal file
|
@ -0,0 +1,5 @@
|
|||
pr: 109967
|
||||
summary: Default the HF service to cosine similarity
|
||||
area: Machine Learning
|
||||
type: enhancement
|
||||
issues: []
|
|
@ -17,6 +17,7 @@ import org.elasticsearch.inference.ChunkedInferenceServiceResults;
|
|||
import org.elasticsearch.inference.ChunkingOptions;
|
||||
import org.elasticsearch.inference.InputType;
|
||||
import org.elasticsearch.inference.Model;
|
||||
import org.elasticsearch.inference.SimilarityMeasure;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.xpack.inference.common.EmbeddingRequestChunker;
|
||||
|
@ -78,9 +79,14 @@ public class HuggingFaceService extends HuggingFaceBaseService {
|
|||
}
|
||||
|
||||
private static HuggingFaceEmbeddingsModel updateModelWithEmbeddingDetails(HuggingFaceEmbeddingsModel model, int embeddingSize) {
|
||||
// default to cosine similarity
|
||||
var similarity = model.getServiceSettings().similarity() == null
|
||||
? SimilarityMeasure.COSINE
|
||||
: model.getServiceSettings().similarity();
|
||||
|
||||
var serviceSettings = new HuggingFaceServiceSettings(
|
||||
model.getServiceSettings().uri(),
|
||||
model.getServiceSettings().similarity(), // we don't know the similarity but use whatever the user specified
|
||||
similarity,
|
||||
embeddingSize,
|
||||
model.getTokenLimit(),
|
||||
model.getServiceSettings().rateLimitSettings()
|
||||
|
|
|
@ -529,12 +529,15 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
|||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
|
||||
var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1);
|
||||
var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1, SimilarityMeasure.DOT_PRODUCT);
|
||||
PlainActionFuture<Model> listener = new PlainActionFuture<>();
|
||||
service.checkModelConfig(model, listener);
|
||||
|
||||
var result = listener.actionGet(TIMEOUT);
|
||||
assertThat(result, is(HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1)));
|
||||
assertThat(
|
||||
result,
|
||||
is(HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1, SimilarityMeasure.DOT_PRODUCT))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -566,7 +569,7 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testCheckModelConfig_LeavesSimilarityAsNull_WhenUnspecified() throws IOException {
|
||||
public void testCheckModelConfig_DefaultsSimilarityToCosine() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) {
|
||||
|
@ -587,11 +590,13 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
|||
service.checkModelConfig(model, listener);
|
||||
|
||||
var result = listener.actionGet(TIMEOUT);
|
||||
assertThat(result, is(HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1, null)));
|
||||
assertThat(
|
||||
result,
|
||||
is(HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1, SimilarityMeasure.COSINE))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO
|
||||
public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue