mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 17:34:17 -04:00
[ML] Default the HF service to cosine similarity (#109967)
This commit is contained in:
parent
b18ee11b75
commit
d79f18d069
3 changed files with 22 additions and 6 deletions
5
docs/changelog/109967.yaml
Normal file
5
docs/changelog/109967.yaml
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
pr: 109967
|
||||||
|
summary: Default the HF service to cosine similarity
|
||||||
|
area: Machine Learning
|
||||||
|
type: enhancement
|
||||||
|
issues: []
|
|
@ -17,6 +17,7 @@ import org.elasticsearch.inference.ChunkedInferenceServiceResults;
|
||||||
import org.elasticsearch.inference.ChunkingOptions;
|
import org.elasticsearch.inference.ChunkingOptions;
|
||||||
import org.elasticsearch.inference.InputType;
|
import org.elasticsearch.inference.InputType;
|
||||||
import org.elasticsearch.inference.Model;
|
import org.elasticsearch.inference.Model;
|
||||||
|
import org.elasticsearch.inference.SimilarityMeasure;
|
||||||
import org.elasticsearch.inference.TaskType;
|
import org.elasticsearch.inference.TaskType;
|
||||||
import org.elasticsearch.rest.RestStatus;
|
import org.elasticsearch.rest.RestStatus;
|
||||||
import org.elasticsearch.xpack.inference.common.EmbeddingRequestChunker;
|
import org.elasticsearch.xpack.inference.common.EmbeddingRequestChunker;
|
||||||
|
@ -78,9 +79,14 @@ public class HuggingFaceService extends HuggingFaceBaseService {
|
||||||
}
|
}
|
||||||
|
|
||||||
private static HuggingFaceEmbeddingsModel updateModelWithEmbeddingDetails(HuggingFaceEmbeddingsModel model, int embeddingSize) {
|
private static HuggingFaceEmbeddingsModel updateModelWithEmbeddingDetails(HuggingFaceEmbeddingsModel model, int embeddingSize) {
|
||||||
|
// default to cosine similarity
|
||||||
|
var similarity = model.getServiceSettings().similarity() == null
|
||||||
|
? SimilarityMeasure.COSINE
|
||||||
|
: model.getServiceSettings().similarity();
|
||||||
|
|
||||||
var serviceSettings = new HuggingFaceServiceSettings(
|
var serviceSettings = new HuggingFaceServiceSettings(
|
||||||
model.getServiceSettings().uri(),
|
model.getServiceSettings().uri(),
|
||||||
model.getServiceSettings().similarity(), // we don't know the similarity but use whatever the user specified
|
similarity,
|
||||||
embeddingSize,
|
embeddingSize,
|
||||||
model.getTokenLimit(),
|
model.getTokenLimit(),
|
||||||
model.getServiceSettings().rateLimitSettings()
|
model.getServiceSettings().rateLimitSettings()
|
||||||
|
|
|
@ -529,12 +529,15 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
||||||
""";
|
""";
|
||||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||||
|
|
||||||
var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1);
|
var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1, SimilarityMeasure.DOT_PRODUCT);
|
||||||
PlainActionFuture<Model> listener = new PlainActionFuture<>();
|
PlainActionFuture<Model> listener = new PlainActionFuture<>();
|
||||||
service.checkModelConfig(model, listener);
|
service.checkModelConfig(model, listener);
|
||||||
|
|
||||||
var result = listener.actionGet(TIMEOUT);
|
var result = listener.actionGet(TIMEOUT);
|
||||||
assertThat(result, is(HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1)));
|
assertThat(
|
||||||
|
result,
|
||||||
|
is(HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1, SimilarityMeasure.DOT_PRODUCT))
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -566,7 +569,7 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testCheckModelConfig_LeavesSimilarityAsNull_WhenUnspecified() throws IOException {
|
public void testCheckModelConfig_DefaultsSimilarityToCosine() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) {
|
try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) {
|
||||||
|
@ -587,11 +590,13 @@ public class HuggingFaceServiceTests extends ESTestCase {
|
||||||
service.checkModelConfig(model, listener);
|
service.checkModelConfig(model, listener);
|
||||||
|
|
||||||
var result = listener.actionGet(TIMEOUT);
|
var result = listener.actionGet(TIMEOUT);
|
||||||
assertThat(result, is(HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1, null)));
|
assertThat(
|
||||||
|
result,
|
||||||
|
is(HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1, SimilarityMeasure.COSINE))
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO
|
|
||||||
public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() throws IOException {
|
public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() throws IOException {
|
||||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue