mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 09:28:55 -04:00
[ML] Add mixed cluster tests for inference (#108392)
* mixed cluster tests are executable
* add tests from upgrade tests
* [ML] Add mixed cluster tests for existing services
* clean up
* review improvements
* spotless
* remove blocked AzureOpenAI mixed IT
* improvements from DK review
* temp for testing
* refactoring and documentation
* Revert manual testing configs of "temp for testing"
This reverts parts of commit fca46fd2b6
.
* revert TESTING.asciidoc formatting
* Update TESTING.asciidoc to avoid reformatting
* add minimum version for tests to match minimum version in services
* spotless
This commit is contained in:
parent
74ec90bf1d
commit
c88a6fe481
8 changed files with 896 additions and 5 deletions
|
@ -551,13 +551,19 @@ When running `./gradlew check`, minimal bwc checks are also run against compatib
|
|||
|
||||
==== BWC Testing against a specific remote/branch
|
||||
|
||||
Sometimes a backward compatibility change spans two versions. A common case is a new functionality
|
||||
that needs a BWC bridge in an unreleased versioned of a release branch (for example, 5.x).
|
||||
To test the changes, you can instruct Gradle to build the BWC version from another remote/branch combination instead of
|
||||
pulling the release branch from GitHub. You do so using the `bwc.remote` and `bwc.refspec.BRANCH` system properties:
|
||||
Sometimes a backward compatibility change spans two versions.
|
||||
A common case is a new functionality that needs a BWC bridge in an unreleased versioned of a release branch (for example, 5.x).
|
||||
Another use case, since the introduction of serverless, is to test BWC against main in addition to the other released branches.
|
||||
To do so, specify the `bwc.refspec` remote and branch to use for the BWC build as `origin/main`.
|
||||
To test against main, you will also need to create a new version in link:./server/src/main/java/org/elasticsearch/Version.java[Version.java],
|
||||
increment `elasticsearch` in link:./build-tools-internal/version.properties[version.properties], and hard-code the `project.version` for ml-cpp
|
||||
in link:./x-pack/plugin/ml/build.gradle[ml/build.gradle].
|
||||
|
||||
In general, to test the changes, you can instruct Gradle to build the BWC version from another remote/branch combination instead of pulling the release branch from GitHub.
|
||||
You do so using the `bwc.refspec.{VERSION}` system property:
|
||||
|
||||
-------------------------------------------------
|
||||
./gradlew check -Dbwc.remote=${remote} -Dbwc.refspec.5.x=index_req_bwc_5.x
|
||||
./gradlew check -Dtests.bwc.refspec.8.15=origin/main
|
||||
-------------------------------------------------
|
||||
|
||||
The branch needs to be available on the remote that the BWC makes of the
|
||||
|
|
37
x-pack/plugin/inference/qa/mixed-cluster/build.gradle
Normal file
37
x-pack/plugin/inference/qa/mixed-cluster/build.gradle
Normal file
|
@ -0,0 +1,37 @@
|
|||
import org.elasticsearch.gradle.Version
|
||||
import org.elasticsearch.gradle.VersionProperties
|
||||
import org.elasticsearch.gradle.util.GradleUtils
|
||||
import org.elasticsearch.gradle.internal.info.BuildParams
|
||||
import org.elasticsearch.gradle.testclusters.StandaloneRestIntegTestTask
|
||||
|
||||
apply plugin: 'elasticsearch.internal-java-rest-test'
|
||||
apply plugin: 'elasticsearch.internal-test-artifact-base'
|
||||
apply plugin: 'elasticsearch.bwc-test'
|
||||
|
||||
dependencies {
|
||||
testImplementation project(path: ':x-pack:plugin:inference:qa:inference-service-tests')
|
||||
compileOnly project(':x-pack:plugin:core')
|
||||
javaRestTestImplementation(testArtifact(project(xpackModule('core'))))
|
||||
javaRestTestImplementation project(path: xpackModule('inference'))
|
||||
clusterPlugins project(
|
||||
':x-pack:plugin:inference:qa:test-service-plugin'
|
||||
)
|
||||
}
|
||||
|
||||
// inference is available in 8.11 or later
|
||||
def supportedVersion = bwcVersion -> {
|
||||
return bwcVersion.onOrAfter(Version.fromString("8.11.0"));
|
||||
}
|
||||
|
||||
BuildParams.bwcVersions.withWireCompatible(supportedVersion) { bwcVersion, baseName ->
|
||||
def javaRestTest = tasks.register("v${bwcVersion}#javaRestTest", StandaloneRestIntegTestTask) {
|
||||
usesBwcDistribution(bwcVersion)
|
||||
systemProperty("tests.old_cluster_version", bwcVersion)
|
||||
maxParallelForks = 1
|
||||
}
|
||||
|
||||
tasks.register(bwcTaskName(bwcVersion)) {
|
||||
dependsOn javaRestTest
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,129 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.qa.mixed;
|
||||
|
||||
import org.apache.http.util.EntityUtils;
|
||||
import org.elasticsearch.client.Request;
|
||||
import org.elasticsearch.client.Response;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.settings.SecureString;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.test.http.MockWebServer;
|
||||
import org.elasticsearch.test.rest.ESRestTestCase;
|
||||
import org.hamcrest.Matchers;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public abstract class BaseMixedTestCase extends MixedClusterSpecTestCase {
|
||||
protected static String getUrl(MockWebServer webServer) {
|
||||
return Strings.format("http://%s:%s", webServer.getHostName(), webServer.getPort());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Settings restClientSettings() {
|
||||
String token = ESRestTestCase.basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray()));
|
||||
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
|
||||
}
|
||||
|
||||
protected void delete(String inferenceId, TaskType taskType) throws IOException {
|
||||
var request = new Request("DELETE", Strings.format("_inference/%s/%s", taskType, inferenceId));
|
||||
var response = ESRestTestCase.client().performRequest(request);
|
||||
ESRestTestCase.assertOK(response);
|
||||
}
|
||||
|
||||
protected void delete(String inferenceId) throws IOException {
|
||||
var request = new Request("DELETE", Strings.format("_inference/%s", inferenceId));
|
||||
var response = ESRestTestCase.client().performRequest(request);
|
||||
ESRestTestCase.assertOK(response);
|
||||
}
|
||||
|
||||
protected Map<String, Object> getAll() throws IOException {
|
||||
var request = new Request("GET", "_inference/_all");
|
||||
var response = ESRestTestCase.client().performRequest(request);
|
||||
ESRestTestCase.assertOK(response);
|
||||
return ESRestTestCase.entityAsMap(response);
|
||||
}
|
||||
|
||||
protected Map<String, Object> get(String inferenceId) throws IOException {
|
||||
var endpoint = Strings.format("_inference/%s", inferenceId);
|
||||
var request = new Request("GET", endpoint);
|
||||
var response = ESRestTestCase.client().performRequest(request);
|
||||
ESRestTestCase.assertOK(response);
|
||||
return ESRestTestCase.entityAsMap(response);
|
||||
}
|
||||
|
||||
protected Map<String, Object> get(TaskType taskType, String inferenceId) throws IOException {
|
||||
var endpoint = Strings.format("_inference/%s/%s", taskType, inferenceId);
|
||||
var request = new Request("GET", endpoint);
|
||||
var response = ESRestTestCase.client().performRequest(request);
|
||||
ESRestTestCase.assertOK(response);
|
||||
return ESRestTestCase.entityAsMap(response);
|
||||
}
|
||||
|
||||
protected Map<String, Object> inference(String inferenceId, TaskType taskType, String input) throws IOException {
|
||||
var endpoint = Strings.format("_inference/%s/%s", taskType, inferenceId);
|
||||
var request = new Request("POST", endpoint);
|
||||
request.setJsonEntity("{\"input\": [" + '"' + input + '"' + "]}");
|
||||
|
||||
var response = ESRestTestCase.client().performRequest(request);
|
||||
ESRestTestCase.assertOK(response);
|
||||
return ESRestTestCase.entityAsMap(response);
|
||||
}
|
||||
|
||||
protected Map<String, Object> rerank(String inferenceId, List<String> inputs, String query) throws IOException {
|
||||
var endpoint = Strings.format("_inference/rerank/%s", inferenceId);
|
||||
var request = new Request("POST", endpoint);
|
||||
|
||||
StringBuilder body = new StringBuilder("{").append("\"query\":\"").append(query).append("\",").append("\"input\":[");
|
||||
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
body.append("\"").append(inputs.get(i)).append("\"");
|
||||
if (i < inputs.size() - 1) {
|
||||
body.append(",");
|
||||
}
|
||||
}
|
||||
|
||||
body.append("]}");
|
||||
request.setJsonEntity(body.toString());
|
||||
|
||||
var response = ESRestTestCase.client().performRequest(request);
|
||||
ESRestTestCase.assertOK(response);
|
||||
return ESRestTestCase.entityAsMap(response);
|
||||
}
|
||||
|
||||
protected void put(String inferenceId, String modelConfig, TaskType taskType) throws IOException {
|
||||
String endpoint = Strings.format("_inference/%s/%s?error_trace", taskType, inferenceId);
|
||||
var request = new Request("PUT", endpoint);
|
||||
request.setJsonEntity(modelConfig);
|
||||
var response = ESRestTestCase.client().performRequest(request);
|
||||
logger.warn("PUT response: {}", response.toString());
|
||||
System.out.println("PUT response: " + response.toString());
|
||||
ESRestTestCase.assertOKAndConsume(response);
|
||||
}
|
||||
|
||||
protected static void assertOkOrCreated(Response response) throws IOException {
|
||||
int statusCode = response.getStatusLine().getStatusCode();
|
||||
// Once EntityUtils.toString(entity) is called the entity cannot be reused.
|
||||
// Avoid that call with check here.
|
||||
if (statusCode == 200 || statusCode == 201) {
|
||||
return;
|
||||
}
|
||||
|
||||
String responseStr = EntityUtils.toString(response.getEntity());
|
||||
ESTestCase.assertThat(
|
||||
responseStr,
|
||||
response.getStatusLine().getStatusCode(),
|
||||
Matchers.anyOf(Matchers.equalTo(200), Matchers.equalTo(201))
|
||||
);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,271 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.qa.mixed;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.test.http.MockResponse;
|
||||
import org.elasticsearch.test.http.MockWebServer;
|
||||
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
|
||||
import org.hamcrest.Matchers;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.qa.mixed.MixedClusterSpecTestCase.bwcVersion;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.hamcrest.Matchers.hasEntry;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
import static org.hamcrest.Matchers.oneOf;
|
||||
|
||||
public class CohereServiceMixedIT extends BaseMixedTestCase {
|
||||
|
||||
private static final String COHERE_EMBEDDINGS_ADDED = "8.13.0";
|
||||
private static final String COHERE_RERANK_ADDED = "8.14.0";
|
||||
private static final String BYTE_ALIAS_FOR_INT8_ADDED = "8.14.0";
|
||||
private static final String MINIMUM_SUPPORTED_VERSION = "8.15.0";
|
||||
|
||||
private static MockWebServer cohereEmbeddingsServer;
|
||||
private static MockWebServer cohereRerankServer;
|
||||
|
||||
@BeforeClass
|
||||
public static void startWebServer() throws IOException {
|
||||
cohereEmbeddingsServer = new MockWebServer();
|
||||
cohereEmbeddingsServer.start();
|
||||
|
||||
cohereRerankServer = new MockWebServer();
|
||||
cohereRerankServer.start();
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
public static void shutdown() {
|
||||
cohereEmbeddingsServer.close();
|
||||
cohereRerankServer.close();
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testCohereEmbeddings() throws IOException {
|
||||
var embeddingsSupported = bwcVersion.onOrAfter(Version.fromString(COHERE_EMBEDDINGS_ADDED));
|
||||
assumeTrue("Cohere embedding service added in " + COHERE_EMBEDDINGS_ADDED, embeddingsSupported);
|
||||
assumeTrue(
|
||||
"Cohere service requires at least " + MINIMUM_SUPPORTED_VERSION,
|
||||
bwcVersion.onOrAfter(Version.fromString(MINIMUM_SUPPORTED_VERSION))
|
||||
);
|
||||
|
||||
final String inferenceIdInt8 = "mixed-cluster-cohere-embeddings-int8";
|
||||
final String inferenceIdFloat = "mixed-cluster-cohere-embeddings-float";
|
||||
|
||||
// queue a response as PUT will call the service
|
||||
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
|
||||
put(inferenceIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING);
|
||||
|
||||
// float model
|
||||
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseFloat()));
|
||||
put(inferenceIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING);
|
||||
|
||||
var configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, inferenceIdInt8).get("endpoints");
|
||||
assertEquals("cohere", configs.get(0).get("service"));
|
||||
var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
|
||||
assertThat(serviceSettings, hasEntry("model_id", "embed-english-light-v3.0"));
|
||||
var embeddingType = serviceSettings.get("embedding_type");
|
||||
// An upgraded node will report the embedding type as byte, an old node int8
|
||||
assertThat(embeddingType, Matchers.is(oneOf("int8", "byte")));
|
||||
|
||||
configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, inferenceIdFloat).get("endpoints");
|
||||
serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
|
||||
assertThat(serviceSettings, hasEntry("embedding_type", "float"));
|
||||
|
||||
assertEmbeddingInference(inferenceIdInt8, CohereEmbeddingType.BYTE);
|
||||
assertEmbeddingInference(inferenceIdFloat, CohereEmbeddingType.FLOAT);
|
||||
|
||||
delete(inferenceIdFloat);
|
||||
delete(inferenceIdInt8);
|
||||
|
||||
}
|
||||
|
||||
void assertEmbeddingInference(String inferenceId, CohereEmbeddingType type) throws IOException {
|
||||
switch (type) {
|
||||
case INT8:
|
||||
case BYTE:
|
||||
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
|
||||
break;
|
||||
case FLOAT:
|
||||
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseFloat()));
|
||||
}
|
||||
|
||||
var inferenceMap = inference(inferenceId, TaskType.TEXT_EMBEDDING, "some text");
|
||||
assertThat(inferenceMap.entrySet(), not(empty()));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testRerank() throws IOException {
|
||||
var rerankSupported = bwcVersion.onOrAfter(Version.fromString(COHERE_RERANK_ADDED));
|
||||
assumeTrue("Cohere rerank service added in " + COHERE_RERANK_ADDED, rerankSupported);
|
||||
assumeTrue(
|
||||
"Cohere service requires at least " + MINIMUM_SUPPORTED_VERSION,
|
||||
bwcVersion.onOrAfter(Version.fromString(MINIMUM_SUPPORTED_VERSION))
|
||||
);
|
||||
|
||||
final String inferenceId = "mixed-cluster-rerank";
|
||||
|
||||
put(inferenceId, rerankConfig(getUrl(cohereRerankServer)), TaskType.RERANK);
|
||||
assertRerank(inferenceId);
|
||||
|
||||
var configs = (List<Map<String, Object>>) get(TaskType.RERANK, inferenceId).get("endpoints");
|
||||
assertThat(configs, hasSize(1));
|
||||
assertEquals("cohere", configs.get(0).get("service"));
|
||||
var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
|
||||
assertThat(serviceSettings, hasEntry("model_id", "rerank-english-v3.0"));
|
||||
var taskSettings = (Map<String, Object>) configs.get(0).get("task_settings");
|
||||
assertThat(taskSettings, hasEntry("top_n", 3));
|
||||
|
||||
assertRerank(inferenceId);
|
||||
|
||||
}
|
||||
|
||||
private void assertRerank(String inferenceId) throws IOException {
|
||||
cohereRerankServer.enqueue(new MockResponse().setResponseCode(200).setBody(rerankResponse()));
|
||||
var inferenceMap = rerank(
|
||||
inferenceId,
|
||||
List.of("luke", "like", "leia", "chewy", "r2d2", "star", "wars"),
|
||||
"star wars main character"
|
||||
);
|
||||
assertThat(inferenceMap.entrySet(), not(empty()));
|
||||
}
|
||||
|
||||
private String embeddingConfigByte(String url) {
|
||||
return embeddingConfigTemplate(url, "byte");
|
||||
}
|
||||
|
||||
private String embeddingConfigInt8(String url) {
|
||||
return embeddingConfigTemplate(url, "int8");
|
||||
}
|
||||
|
||||
private String embeddingConfigFloat(String url) {
|
||||
return embeddingConfigTemplate(url, "float");
|
||||
}
|
||||
|
||||
private String embeddingConfigTemplate(String url, String embeddingType) {
|
||||
return Strings.format("""
|
||||
{
|
||||
"service": "cohere",
|
||||
"service_settings": {
|
||||
"url": "%s",
|
||||
"api_key": "XXXX",
|
||||
"model_id": "embed-english-light-v3.0",
|
||||
"embedding_type": "%s"
|
||||
}
|
||||
}
|
||||
""", url, embeddingType);
|
||||
}
|
||||
|
||||
private String embeddingResponseByte() {
|
||||
return """
|
||||
{
|
||||
"id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
|
||||
"texts": [
|
||||
"hello"
|
||||
],
|
||||
"embeddings": [
|
||||
[
|
||||
12,
|
||||
56
|
||||
]
|
||||
],
|
||||
"meta": {
|
||||
"api_version": {
|
||||
"version": "1"
|
||||
},
|
||||
"billed_units": {
|
||||
"input_tokens": 1
|
||||
}
|
||||
},
|
||||
"response_type": "embeddings_bytes"
|
||||
}
|
||||
""";
|
||||
}
|
||||
|
||||
private String embeddingResponseFloat() {
|
||||
return """
|
||||
{
|
||||
"id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
|
||||
"texts": [
|
||||
"hello"
|
||||
],
|
||||
"embeddings": [
|
||||
[
|
||||
-0.0018434525,
|
||||
0.01777649
|
||||
]
|
||||
],
|
||||
"meta": {
|
||||
"api_version": {
|
||||
"version": "1"
|
||||
},
|
||||
"billed_units": {
|
||||
"input_tokens": 1
|
||||
}
|
||||
},
|
||||
"response_type": "embeddings_floats"
|
||||
}
|
||||
""";
|
||||
}
|
||||
|
||||
private String rerankConfig(String url) {
|
||||
return Strings.format("""
|
||||
{
|
||||
"service": "cohere",
|
||||
"service_settings": {
|
||||
"api_key": "XXXX",
|
||||
"model_id": "rerank-english-v3.0",
|
||||
"url": "%s"
|
||||
},
|
||||
"task_settings": {
|
||||
"return_documents": false,
|
||||
"top_n": 3
|
||||
}
|
||||
}
|
||||
""", url);
|
||||
}
|
||||
|
||||
private String rerankResponse() {
|
||||
return """
|
||||
{
|
||||
"index": "d0760819-5a73-4d58-b163-3956d3648b62",
|
||||
"results": [
|
||||
{
|
||||
"index": 2,
|
||||
"relevance_score": 0.98005307
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"relevance_score": 0.27904198
|
||||
},
|
||||
{
|
||||
"index": 0,
|
||||
"relevance_score": 0.10194652
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
"api_version": {
|
||||
"version": "1"
|
||||
},
|
||||
"billed_units": {
|
||||
"search_units": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
""";
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,147 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.qa.mixed;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.test.http.MockResponse;
|
||||
import org.elasticsearch.test.http.MockWebServer;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
|
||||
public class HuggingFaceServiceMixedIT extends BaseMixedTestCase {
|
||||
|
||||
private static final String HF_EMBEDDINGS_ADDED = "8.12.0";
|
||||
private static final String HF_ELSER_ADDED = "8.12.0";
|
||||
private static final String MINIMUM_SUPPORTED_VERSION = "8.15.0";
|
||||
|
||||
private static MockWebServer embeddingsServer;
|
||||
private static MockWebServer elserServer;
|
||||
|
||||
@BeforeClass
|
||||
public static void startWebServer() throws IOException {
|
||||
embeddingsServer = new MockWebServer();
|
||||
embeddingsServer.start();
|
||||
|
||||
elserServer = new MockWebServer();
|
||||
elserServer.start();
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
public static void shutdown() {
|
||||
embeddingsServer.close();
|
||||
elserServer.close();
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testHFEmbeddings() throws IOException {
|
||||
var embeddingsSupported = bwcVersion.onOrAfter(Version.fromString(HF_EMBEDDINGS_ADDED));
|
||||
assumeTrue("Hugging Face embedding service added in " + HF_EMBEDDINGS_ADDED, embeddingsSupported);
|
||||
assumeTrue(
|
||||
"HuggingFace service requires at least " + MINIMUM_SUPPORTED_VERSION,
|
||||
bwcVersion.onOrAfter(Version.fromString(MINIMUM_SUPPORTED_VERSION))
|
||||
);
|
||||
|
||||
final String inferenceId = "mixed-cluster-embeddings";
|
||||
|
||||
embeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse()));
|
||||
put(inferenceId, embeddingConfig(getUrl(embeddingsServer)), TaskType.TEXT_EMBEDDING);
|
||||
var configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, inferenceId).get("endpoints");
|
||||
assertThat(configs, hasSize(1));
|
||||
assertEquals("hugging_face", configs.get(0).get("service"));
|
||||
assertEmbeddingInference(inferenceId);
|
||||
}
|
||||
|
||||
void assertEmbeddingInference(String inferenceId) throws IOException {
|
||||
embeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse()));
|
||||
var inferenceMap = inference(inferenceId, TaskType.TEXT_EMBEDDING, "some text");
|
||||
assertThat(inferenceMap.entrySet(), not(empty()));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testElser() throws IOException {
|
||||
var supported = bwcVersion.onOrAfter(Version.fromString(HF_ELSER_ADDED));
|
||||
assumeTrue("HF elser service added in " + HF_ELSER_ADDED, supported);
|
||||
assumeTrue(
|
||||
"HuggingFace service requires at least " + MINIMUM_SUPPORTED_VERSION,
|
||||
bwcVersion.onOrAfter(Version.fromString(MINIMUM_SUPPORTED_VERSION))
|
||||
);
|
||||
|
||||
final String inferenceId = "mixed-cluster-elser";
|
||||
final String upgradedClusterId = "upgraded-cluster-elser";
|
||||
|
||||
put(inferenceId, elserConfig(getUrl(elserServer)), TaskType.SPARSE_EMBEDDING);
|
||||
|
||||
var configs = (List<Map<String, Object>>) get(TaskType.SPARSE_EMBEDDING, inferenceId).get("endpoints");
|
||||
assertThat(configs, hasSize(1));
|
||||
assertEquals("hugging_face", configs.get(0).get("service"));
|
||||
assertElser(inferenceId);
|
||||
}
|
||||
|
||||
private void assertElser(String inferenceId) throws IOException {
|
||||
elserServer.enqueue(new MockResponse().setResponseCode(200).setBody(elserResponse()));
|
||||
var inferenceMap = inference(inferenceId, TaskType.SPARSE_EMBEDDING, "some text");
|
||||
assertThat(inferenceMap.entrySet(), not(empty()));
|
||||
}
|
||||
|
||||
private String embeddingConfig(String url) {
|
||||
return Strings.format("""
|
||||
{
|
||||
"service": "hugging_face",
|
||||
"service_settings": {
|
||||
"url": "%s",
|
||||
"api_key": "XXXX"
|
||||
}
|
||||
}
|
||||
""", url);
|
||||
}
|
||||
|
||||
private String embeddingResponse() {
|
||||
return """
|
||||
[
|
||||
[
|
||||
0.014539449,
|
||||
-0.015288644
|
||||
]
|
||||
]
|
||||
""";
|
||||
}
|
||||
|
||||
private String elserConfig(String url) {
|
||||
return Strings.format("""
|
||||
{
|
||||
"service": "hugging_face",
|
||||
"service_settings": {
|
||||
"api_key": "XXXX",
|
||||
"url": "%s"
|
||||
}
|
||||
}
|
||||
""", url);
|
||||
}
|
||||
|
||||
private String elserResponse() {
|
||||
return """
|
||||
[
|
||||
{
|
||||
".": 0.133155956864357,
|
||||
"the": 0.6747211217880249
|
||||
}
|
||||
]
|
||||
""";
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.qa.mixed;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.features.NodeFeature;
|
||||
import org.elasticsearch.test.cluster.ElasticsearchCluster;
|
||||
import org.elasticsearch.test.rest.ESRestTestCase;
|
||||
import org.elasticsearch.test.rest.TestFeatureService;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.Before;
|
||||
import org.junit.ClassRule;
|
||||
|
||||
public abstract class MixedClusterSpecTestCase extends ESRestTestCase {
|
||||
@ClassRule
|
||||
public static ElasticsearchCluster cluster = MixedClustersSpec.mixedVersionCluster();
|
||||
|
||||
@Override
|
||||
protected String getTestRestCluster() {
|
||||
return cluster.getHttpAddresses();
|
||||
}
|
||||
|
||||
static final Version bwcVersion = Version.fromString(System.getProperty("tests.old_cluster_version"));
|
||||
|
||||
private static TestFeatureService oldClusterTestFeatureService = null;
|
||||
|
||||
@Before
|
||||
public void extractOldClusterFeatures() {
|
||||
if (oldClusterTestFeatureService == null) {
|
||||
oldClusterTestFeatureService = testFeatureService;
|
||||
}
|
||||
}
|
||||
|
||||
protected static boolean oldClusterHasFeature(String featureId) {
|
||||
assert oldClusterTestFeatureService != null;
|
||||
return oldClusterTestFeatureService.clusterHasFeature(featureId);
|
||||
}
|
||||
|
||||
protected static boolean oldClusterHasFeature(NodeFeature feature) {
|
||||
return oldClusterHasFeature(feature.id());
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
public static void cleanUp() {
|
||||
oldClusterTestFeatureService = null;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.qa.mixed;
|
||||
|
||||
import org.elasticsearch.test.cluster.ElasticsearchCluster;
|
||||
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
|
||||
import org.elasticsearch.test.cluster.util.Version;
|
||||
|
||||
public class MixedClustersSpec {
|
||||
public static ElasticsearchCluster mixedVersionCluster() {
|
||||
Version oldVersion = Version.fromString(System.getProperty("tests.old_cluster_version"));
|
||||
return ElasticsearchCluster.local()
|
||||
.distribution(DistributionType.DEFAULT)
|
||||
.withNode(node -> node.version(oldVersion))
|
||||
.withNode(node -> node.version(Version.CURRENT))
|
||||
.setting("xpack.security.enabled", "false")
|
||||
.setting("xpack.license.self_generated.type", "trial")
|
||||
.build();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,223 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.qa.mixed;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.test.http.MockResponse;
|
||||
import org.elasticsearch.test.http.MockWebServer;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.qa.mixed.MixedClusterSpecTestCase.bwcVersion;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.hamcrest.Matchers.hasEntry;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
|
||||
public class OpenAIServiceMixedIT extends BaseMixedTestCase {
|
||||
|
||||
private static final String OPEN_AI_EMBEDDINGS_ADDED = "8.12.0";
|
||||
private static final String OPEN_AI_EMBEDDINGS_MODEL_SETTING_MOVED = "8.13.0";
|
||||
private static final String OPEN_AI_COMPLETIONS_ADDED = "8.14.0";
|
||||
private static final String MINIMUM_SUPPORTED_VERSION = "8.15.0";
|
||||
|
||||
private static MockWebServer openAiEmbeddingsServer;
|
||||
private static MockWebServer openAiChatCompletionsServer;
|
||||
|
||||
@BeforeClass
|
||||
public static void startWebServer() throws IOException {
|
||||
openAiEmbeddingsServer = new MockWebServer();
|
||||
openAiEmbeddingsServer.start();
|
||||
|
||||
openAiChatCompletionsServer = new MockWebServer();
|
||||
openAiChatCompletionsServer.start();
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
public static void shutdown() {
|
||||
openAiEmbeddingsServer.close();
|
||||
openAiChatCompletionsServer.close();
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testOpenAiEmbeddings() throws IOException {
|
||||
var openAiEmbeddingsSupported = bwcVersion.onOrAfter(Version.fromString(OPEN_AI_EMBEDDINGS_ADDED));
|
||||
assumeTrue("OpenAI embedding service added in " + OPEN_AI_EMBEDDINGS_ADDED, openAiEmbeddingsSupported);
|
||||
assumeTrue(
|
||||
"OpenAI service requires at least " + MINIMUM_SUPPORTED_VERSION,
|
||||
bwcVersion.onOrAfter(Version.fromString(MINIMUM_SUPPORTED_VERSION))
|
||||
);
|
||||
|
||||
final String inferenceId = "mixed-cluster-embeddings";
|
||||
|
||||
String inferenceConfig = oldClusterVersionCompatibleEmbeddingConfig();
|
||||
// queue a response as PUT will call the service
|
||||
openAiEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse()));
|
||||
put(inferenceId, inferenceConfig, TaskType.TEXT_EMBEDDING);
|
||||
|
||||
var configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, inferenceId).get("endpoints");
|
||||
assertThat(configs, hasSize(1));
|
||||
assertEquals("openai", configs.get(0).get("service"));
|
||||
var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
|
||||
var taskSettings = (Map<String, Object>) configs.get(0).get("task_settings");
|
||||
var modelIdFound = serviceSettings.containsKey("model_id") || taskSettings.containsKey("model_id");
|
||||
assertTrue("model_id not found in config: " + configs.toString(), modelIdFound);
|
||||
|
||||
assertEmbeddingInference(inferenceId);
|
||||
}
|
||||
|
||||
void assertEmbeddingInference(String inferenceId) throws IOException {
|
||||
openAiEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse()));
|
||||
var inferenceMap = inference(inferenceId, TaskType.TEXT_EMBEDDING, "some text");
|
||||
assertThat(inferenceMap.entrySet(), not(empty()));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testOpenAiCompletions() throws IOException {
|
||||
var openAiEmbeddingsSupported = bwcVersion.onOrAfter(Version.fromString(OPEN_AI_EMBEDDINGS_ADDED));
|
||||
assumeTrue("OpenAI completions service added in " + OPEN_AI_COMPLETIONS_ADDED, openAiEmbeddingsSupported);
|
||||
assumeTrue(
|
||||
"OpenAI service requires at least " + MINIMUM_SUPPORTED_VERSION,
|
||||
bwcVersion.onOrAfter(Version.fromString(MINIMUM_SUPPORTED_VERSION))
|
||||
);
|
||||
|
||||
final String inferenceId = "mixed-cluster-completions";
|
||||
final String upgradedClusterId = "upgraded-cluster-completions";
|
||||
|
||||
put(inferenceId, chatCompletionsConfig(getUrl(openAiChatCompletionsServer)), TaskType.COMPLETION);
|
||||
|
||||
var configsMap = get(TaskType.COMPLETION, inferenceId);
|
||||
logger.warn("Configs: {}", configsMap);
|
||||
var configs = (List<Map<String, Object>>) configsMap.get("endpoints");
|
||||
assertThat(configs, hasSize(1));
|
||||
assertEquals("openai", configs.get(0).get("service"));
|
||||
var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
|
||||
assertThat(serviceSettings, hasEntry("model_id", "gpt-4"));
|
||||
var taskSettings = (Map<String, Object>) configs.get(0).get("task_settings");
|
||||
assertThat(taskSettings.keySet(), empty());
|
||||
|
||||
assertCompletionInference(inferenceId);
|
||||
}
|
||||
|
||||
void assertCompletionInference(String inferenceId) throws IOException {
|
||||
openAiChatCompletionsServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionsResponse()));
|
||||
var inferenceMap = inference(inferenceId, TaskType.COMPLETION, "some text");
|
||||
assertThat(inferenceMap.entrySet(), not(empty()));
|
||||
}
|
||||
|
||||
private String oldClusterVersionCompatibleEmbeddingConfig() {
|
||||
if (getOldClusterTestVersion().before(OPEN_AI_EMBEDDINGS_MODEL_SETTING_MOVED)) {
|
||||
return embeddingConfigWithModelInTaskSettings(getUrl(openAiEmbeddingsServer));
|
||||
} else {
|
||||
return embeddingConfigWithModelInServiceSettings(getUrl(openAiEmbeddingsServer));
|
||||
}
|
||||
}
|
||||
|
||||
protected static org.elasticsearch.test.cluster.util.Version getOldClusterTestVersion() {
|
||||
return org.elasticsearch.test.cluster.util.Version.fromString(bwcVersion.toString());
|
||||
}
|
||||
|
||||
private String embeddingConfigWithModelInTaskSettings(String url) {
|
||||
return Strings.format("""
|
||||
{
|
||||
"service": "openai",
|
||||
"service_settings": {
|
||||
"api_key": "XXXX",
|
||||
"url": "%s"
|
||||
},
|
||||
"task_settings": {
|
||||
"model": "text-embedding-ada-002"
|
||||
}
|
||||
}
|
||||
""", url);
|
||||
}
|
||||
|
||||
static String embeddingConfigWithModelInServiceSettings(String url) {
|
||||
return Strings.format("""
|
||||
{
|
||||
"service": "openai",
|
||||
"service_settings": {
|
||||
"api_key": "XXXX",
|
||||
"url": "%s",
|
||||
"model_id": "text-embedding-ada-002"
|
||||
}
|
||||
}
|
||||
""", url);
|
||||
}
|
||||
|
||||
private String chatCompletionsConfig(String url) {
|
||||
return Strings.format("""
|
||||
{
|
||||
"service": "openai",
|
||||
"service_settings": {
|
||||
"api_key": "XXXX",
|
||||
"url": "%s",
|
||||
"model_id": "gpt-4"
|
||||
}
|
||||
}
|
||||
""", url);
|
||||
}
|
||||
|
||||
static String embeddingResponse() {
|
||||
return """
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": 0,
|
||||
"embedding": [
|
||||
0.0123,
|
||||
-0.0123
|
||||
]
|
||||
}
|
||||
],
|
||||
"model": "text-embedding-ada-002",
|
||||
"usage": {
|
||||
"prompt_tokens": 8,
|
||||
"total_tokens": 8
|
||||
}
|
||||
}
|
||||
""";
|
||||
}
|
||||
|
||||
private String chatCompletionsResponse() {
|
||||
return """
|
||||
{
|
||||
"id": "some-id",
|
||||
"object": "chat.completion",
|
||||
"created": 1705397787,
|
||||
"model": "gpt-3.5-turbo-0613",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "some content"
|
||||
},
|
||||
"logprobs": null,
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 46,
|
||||
"completion_tokens": 39,
|
||||
"total_tokens": 85
|
||||
},
|
||||
"system_fingerprint": null
|
||||
}
|
||||
""";
|
||||
}
|
||||
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue