[ML] Add mixed cluster tests for inference (#108392)

* mixed cluster tests are executable

* add tests from upgrade tests

* [ML] Add mixed cluster tests for existing services

* clean up

* review improvements

* spotless

* remove blocked AzureOpenAI mixed IT

* improvements from DK review

* temp for testing

* refactoring and documentation

* Revert manual testing configs of "temp for testing"

This reverts parts of commit fca46fd2b6.

* revert TESTING.asciidoc formatting

* Update TESTING.asciidoc to avoid reformatting

* add minimum version for tests to match minimum version in services

* spotless
This commit is contained in:
Max Hniebergall 2024-05-15 15:13:09 -04:00 committed by GitHub
parent 74ec90bf1d
commit c88a6fe481
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 896 additions and 5 deletions

View file

@ -551,13 +551,19 @@ When running `./gradlew check`, minimal bwc checks are also run against compatib
==== BWC Testing against a specific remote/branch
Sometimes a backward compatibility change spans two versions. A common case is a new functionality
that needs a BWC bridge in an unreleased versioned of a release branch (for example, 5.x).
To test the changes, you can instruct Gradle to build the BWC version from another remote/branch combination instead of
pulling the release branch from GitHub. You do so using the `bwc.remote` and `bwc.refspec.BRANCH` system properties:
Sometimes a backward compatibility change spans two versions.
A common case is a new functionality that needs a BWC bridge in an unreleased versioned of a release branch (for example, 5.x).
Another use case, since the introduction of serverless, is to test BWC against main in addition to the other released branches.
To do so, specify the `bwc.refspec` remote and branch to use for the BWC build as `origin/main`.
To test against main, you will also need to create a new version in link:./server/src/main/java/org/elasticsearch/Version.java[Version.java],
increment `elasticsearch` in link:./build-tools-internal/version.properties[version.properties], and hard-code the `project.version` for ml-cpp
in link:./x-pack/plugin/ml/build.gradle[ml/build.gradle].
In general, to test the changes, you can instruct Gradle to build the BWC version from another remote/branch combination instead of pulling the release branch from GitHub.
You do so using the `bwc.refspec.{VERSION}` system property:
-------------------------------------------------
./gradlew check -Dbwc.remote=${remote} -Dbwc.refspec.5.x=index_req_bwc_5.x
./gradlew check -Dtests.bwc.refspec.8.15=origin/main
-------------------------------------------------
The branch needs to be available on the remote that the BWC makes of the

View file

@ -0,0 +1,37 @@
import org.elasticsearch.gradle.Version
import org.elasticsearch.gradle.VersionProperties
import org.elasticsearch.gradle.util.GradleUtils
import org.elasticsearch.gradle.internal.info.BuildParams
import org.elasticsearch.gradle.testclusters.StandaloneRestIntegTestTask
apply plugin: 'elasticsearch.internal-java-rest-test'
apply plugin: 'elasticsearch.internal-test-artifact-base'
apply plugin: 'elasticsearch.bwc-test'
dependencies {
testImplementation project(path: ':x-pack:plugin:inference:qa:inference-service-tests')
compileOnly project(':x-pack:plugin:core')
javaRestTestImplementation(testArtifact(project(xpackModule('core'))))
javaRestTestImplementation project(path: xpackModule('inference'))
clusterPlugins project(
':x-pack:plugin:inference:qa:test-service-plugin'
)
}
// inference is available in 8.11 or later
def supportedVersion = bwcVersion -> {
return bwcVersion.onOrAfter(Version.fromString("8.11.0"));
}
BuildParams.bwcVersions.withWireCompatible(supportedVersion) { bwcVersion, baseName ->
def javaRestTest = tasks.register("v${bwcVersion}#javaRestTest", StandaloneRestIntegTestTask) {
usesBwcDistribution(bwcVersion)
systemProperty("tests.old_cluster_version", bwcVersion)
maxParallelForks = 1
}
tasks.register(bwcTaskName(bwcVersion)) {
dependsOn javaRestTest
}
}

View file

@ -0,0 +1,129 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.qa.mixed;
import org.apache.http.util.EntityUtils;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.test.rest.ESRestTestCase;
import org.hamcrest.Matchers;
import java.io.IOException;
import java.util.List;
import java.util.Map;
public abstract class BaseMixedTestCase extends MixedClusterSpecTestCase {
protected static String getUrl(MockWebServer webServer) {
return Strings.format("http://%s:%s", webServer.getHostName(), webServer.getPort());
}
@Override
protected Settings restClientSettings() {
String token = ESRestTestCase.basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray()));
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
}
protected void delete(String inferenceId, TaskType taskType) throws IOException {
var request = new Request("DELETE", Strings.format("_inference/%s/%s", taskType, inferenceId));
var response = ESRestTestCase.client().performRequest(request);
ESRestTestCase.assertOK(response);
}
protected void delete(String inferenceId) throws IOException {
var request = new Request("DELETE", Strings.format("_inference/%s", inferenceId));
var response = ESRestTestCase.client().performRequest(request);
ESRestTestCase.assertOK(response);
}
protected Map<String, Object> getAll() throws IOException {
var request = new Request("GET", "_inference/_all");
var response = ESRestTestCase.client().performRequest(request);
ESRestTestCase.assertOK(response);
return ESRestTestCase.entityAsMap(response);
}
protected Map<String, Object> get(String inferenceId) throws IOException {
var endpoint = Strings.format("_inference/%s", inferenceId);
var request = new Request("GET", endpoint);
var response = ESRestTestCase.client().performRequest(request);
ESRestTestCase.assertOK(response);
return ESRestTestCase.entityAsMap(response);
}
protected Map<String, Object> get(TaskType taskType, String inferenceId) throws IOException {
var endpoint = Strings.format("_inference/%s/%s", taskType, inferenceId);
var request = new Request("GET", endpoint);
var response = ESRestTestCase.client().performRequest(request);
ESRestTestCase.assertOK(response);
return ESRestTestCase.entityAsMap(response);
}
protected Map<String, Object> inference(String inferenceId, TaskType taskType, String input) throws IOException {
var endpoint = Strings.format("_inference/%s/%s", taskType, inferenceId);
var request = new Request("POST", endpoint);
request.setJsonEntity("{\"input\": [" + '"' + input + '"' + "]}");
var response = ESRestTestCase.client().performRequest(request);
ESRestTestCase.assertOK(response);
return ESRestTestCase.entityAsMap(response);
}
protected Map<String, Object> rerank(String inferenceId, List<String> inputs, String query) throws IOException {
var endpoint = Strings.format("_inference/rerank/%s", inferenceId);
var request = new Request("POST", endpoint);
StringBuilder body = new StringBuilder("{").append("\"query\":\"").append(query).append("\",").append("\"input\":[");
for (int i = 0; i < inputs.size(); i++) {
body.append("\"").append(inputs.get(i)).append("\"");
if (i < inputs.size() - 1) {
body.append(",");
}
}
body.append("]}");
request.setJsonEntity(body.toString());
var response = ESRestTestCase.client().performRequest(request);
ESRestTestCase.assertOK(response);
return ESRestTestCase.entityAsMap(response);
}
protected void put(String inferenceId, String modelConfig, TaskType taskType) throws IOException {
String endpoint = Strings.format("_inference/%s/%s?error_trace", taskType, inferenceId);
var request = new Request("PUT", endpoint);
request.setJsonEntity(modelConfig);
var response = ESRestTestCase.client().performRequest(request);
logger.warn("PUT response: {}", response.toString());
System.out.println("PUT response: " + response.toString());
ESRestTestCase.assertOKAndConsume(response);
}
protected static void assertOkOrCreated(Response response) throws IOException {
int statusCode = response.getStatusLine().getStatusCode();
// Once EntityUtils.toString(entity) is called the entity cannot be reused.
// Avoid that call with check here.
if (statusCode == 200 || statusCode == 201) {
return;
}
String responseStr = EntityUtils.toString(response.getEntity());
ESTestCase.assertThat(
responseStr,
response.getStatusLine().getStatusCode(),
Matchers.anyOf(Matchers.equalTo(200), Matchers.equalTo(201))
);
}
}

View file

@ -0,0 +1,271 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.qa.mixed;
import org.elasticsearch.Version;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
import org.hamcrest.Matchers;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import static org.elasticsearch.xpack.inference.qa.mixed.MixedClusterSpecTestCase.bwcVersion;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasEntry;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.oneOf;
public class CohereServiceMixedIT extends BaseMixedTestCase {
private static final String COHERE_EMBEDDINGS_ADDED = "8.13.0";
private static final String COHERE_RERANK_ADDED = "8.14.0";
private static final String BYTE_ALIAS_FOR_INT8_ADDED = "8.14.0";
private static final String MINIMUM_SUPPORTED_VERSION = "8.15.0";
private static MockWebServer cohereEmbeddingsServer;
private static MockWebServer cohereRerankServer;
@BeforeClass
public static void startWebServer() throws IOException {
cohereEmbeddingsServer = new MockWebServer();
cohereEmbeddingsServer.start();
cohereRerankServer = new MockWebServer();
cohereRerankServer.start();
}
@AfterClass
public static void shutdown() {
cohereEmbeddingsServer.close();
cohereRerankServer.close();
}
@SuppressWarnings("unchecked")
public void testCohereEmbeddings() throws IOException {
var embeddingsSupported = bwcVersion.onOrAfter(Version.fromString(COHERE_EMBEDDINGS_ADDED));
assumeTrue("Cohere embedding service added in " + COHERE_EMBEDDINGS_ADDED, embeddingsSupported);
assumeTrue(
"Cohere service requires at least " + MINIMUM_SUPPORTED_VERSION,
bwcVersion.onOrAfter(Version.fromString(MINIMUM_SUPPORTED_VERSION))
);
final String inferenceIdInt8 = "mixed-cluster-cohere-embeddings-int8";
final String inferenceIdFloat = "mixed-cluster-cohere-embeddings-float";
// queue a response as PUT will call the service
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
put(inferenceIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING);
// float model
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseFloat()));
put(inferenceIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING);
var configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, inferenceIdInt8).get("endpoints");
assertEquals("cohere", configs.get(0).get("service"));
var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
assertThat(serviceSettings, hasEntry("model_id", "embed-english-light-v3.0"));
var embeddingType = serviceSettings.get("embedding_type");
// An upgraded node will report the embedding type as byte, an old node int8
assertThat(embeddingType, Matchers.is(oneOf("int8", "byte")));
configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, inferenceIdFloat).get("endpoints");
serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
assertThat(serviceSettings, hasEntry("embedding_type", "float"));
assertEmbeddingInference(inferenceIdInt8, CohereEmbeddingType.BYTE);
assertEmbeddingInference(inferenceIdFloat, CohereEmbeddingType.FLOAT);
delete(inferenceIdFloat);
delete(inferenceIdInt8);
}
void assertEmbeddingInference(String inferenceId, CohereEmbeddingType type) throws IOException {
switch (type) {
case INT8:
case BYTE:
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
break;
case FLOAT:
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseFloat()));
}
var inferenceMap = inference(inferenceId, TaskType.TEXT_EMBEDDING, "some text");
assertThat(inferenceMap.entrySet(), not(empty()));
}
@SuppressWarnings("unchecked")
public void testRerank() throws IOException {
var rerankSupported = bwcVersion.onOrAfter(Version.fromString(COHERE_RERANK_ADDED));
assumeTrue("Cohere rerank service added in " + COHERE_RERANK_ADDED, rerankSupported);
assumeTrue(
"Cohere service requires at least " + MINIMUM_SUPPORTED_VERSION,
bwcVersion.onOrAfter(Version.fromString(MINIMUM_SUPPORTED_VERSION))
);
final String inferenceId = "mixed-cluster-rerank";
put(inferenceId, rerankConfig(getUrl(cohereRerankServer)), TaskType.RERANK);
assertRerank(inferenceId);
var configs = (List<Map<String, Object>>) get(TaskType.RERANK, inferenceId).get("endpoints");
assertThat(configs, hasSize(1));
assertEquals("cohere", configs.get(0).get("service"));
var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
assertThat(serviceSettings, hasEntry("model_id", "rerank-english-v3.0"));
var taskSettings = (Map<String, Object>) configs.get(0).get("task_settings");
assertThat(taskSettings, hasEntry("top_n", 3));
assertRerank(inferenceId);
}
private void assertRerank(String inferenceId) throws IOException {
cohereRerankServer.enqueue(new MockResponse().setResponseCode(200).setBody(rerankResponse()));
var inferenceMap = rerank(
inferenceId,
List.of("luke", "like", "leia", "chewy", "r2d2", "star", "wars"),
"star wars main character"
);
assertThat(inferenceMap.entrySet(), not(empty()));
}
private String embeddingConfigByte(String url) {
return embeddingConfigTemplate(url, "byte");
}
private String embeddingConfigInt8(String url) {
return embeddingConfigTemplate(url, "int8");
}
private String embeddingConfigFloat(String url) {
return embeddingConfigTemplate(url, "float");
}
private String embeddingConfigTemplate(String url, String embeddingType) {
return Strings.format("""
{
"service": "cohere",
"service_settings": {
"url": "%s",
"api_key": "XXXX",
"model_id": "embed-english-light-v3.0",
"embedding_type": "%s"
}
}
""", url, embeddingType);
}
private String embeddingResponseByte() {
return """
{
"id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
"texts": [
"hello"
],
"embeddings": [
[
12,
56
]
],
"meta": {
"api_version": {
"version": "1"
},
"billed_units": {
"input_tokens": 1
}
},
"response_type": "embeddings_bytes"
}
""";
}
private String embeddingResponseFloat() {
return """
{
"id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4",
"texts": [
"hello"
],
"embeddings": [
[
-0.0018434525,
0.01777649
]
],
"meta": {
"api_version": {
"version": "1"
},
"billed_units": {
"input_tokens": 1
}
},
"response_type": "embeddings_floats"
}
""";
}
private String rerankConfig(String url) {
return Strings.format("""
{
"service": "cohere",
"service_settings": {
"api_key": "XXXX",
"model_id": "rerank-english-v3.0",
"url": "%s"
},
"task_settings": {
"return_documents": false,
"top_n": 3
}
}
""", url);
}
private String rerankResponse() {
return """
{
"index": "d0760819-5a73-4d58-b163-3956d3648b62",
"results": [
{
"index": 2,
"relevance_score": 0.98005307
},
{
"index": 3,
"relevance_score": 0.27904198
},
{
"index": 0,
"relevance_score": 0.10194652
}
],
"meta": {
"api_version": {
"version": "1"
},
"billed_units": {
"search_units": 1
}
}
}
""";
}
}

View file

@ -0,0 +1,147 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.qa.mixed;
import org.elasticsearch.Version;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.not;
public class HuggingFaceServiceMixedIT extends BaseMixedTestCase {
private static final String HF_EMBEDDINGS_ADDED = "8.12.0";
private static final String HF_ELSER_ADDED = "8.12.0";
private static final String MINIMUM_SUPPORTED_VERSION = "8.15.0";
private static MockWebServer embeddingsServer;
private static MockWebServer elserServer;
@BeforeClass
public static void startWebServer() throws IOException {
embeddingsServer = new MockWebServer();
embeddingsServer.start();
elserServer = new MockWebServer();
elserServer.start();
}
@AfterClass
public static void shutdown() {
embeddingsServer.close();
elserServer.close();
}
@SuppressWarnings("unchecked")
public void testHFEmbeddings() throws IOException {
var embeddingsSupported = bwcVersion.onOrAfter(Version.fromString(HF_EMBEDDINGS_ADDED));
assumeTrue("Hugging Face embedding service added in " + HF_EMBEDDINGS_ADDED, embeddingsSupported);
assumeTrue(
"HuggingFace service requires at least " + MINIMUM_SUPPORTED_VERSION,
bwcVersion.onOrAfter(Version.fromString(MINIMUM_SUPPORTED_VERSION))
);
final String inferenceId = "mixed-cluster-embeddings";
embeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse()));
put(inferenceId, embeddingConfig(getUrl(embeddingsServer)), TaskType.TEXT_EMBEDDING);
var configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, inferenceId).get("endpoints");
assertThat(configs, hasSize(1));
assertEquals("hugging_face", configs.get(0).get("service"));
assertEmbeddingInference(inferenceId);
}
void assertEmbeddingInference(String inferenceId) throws IOException {
embeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse()));
var inferenceMap = inference(inferenceId, TaskType.TEXT_EMBEDDING, "some text");
assertThat(inferenceMap.entrySet(), not(empty()));
}
@SuppressWarnings("unchecked")
public void testElser() throws IOException {
var supported = bwcVersion.onOrAfter(Version.fromString(HF_ELSER_ADDED));
assumeTrue("HF elser service added in " + HF_ELSER_ADDED, supported);
assumeTrue(
"HuggingFace service requires at least " + MINIMUM_SUPPORTED_VERSION,
bwcVersion.onOrAfter(Version.fromString(MINIMUM_SUPPORTED_VERSION))
);
final String inferenceId = "mixed-cluster-elser";
final String upgradedClusterId = "upgraded-cluster-elser";
put(inferenceId, elserConfig(getUrl(elserServer)), TaskType.SPARSE_EMBEDDING);
var configs = (List<Map<String, Object>>) get(TaskType.SPARSE_EMBEDDING, inferenceId).get("endpoints");
assertThat(configs, hasSize(1));
assertEquals("hugging_face", configs.get(0).get("service"));
assertElser(inferenceId);
}
private void assertElser(String inferenceId) throws IOException {
elserServer.enqueue(new MockResponse().setResponseCode(200).setBody(elserResponse()));
var inferenceMap = inference(inferenceId, TaskType.SPARSE_EMBEDDING, "some text");
assertThat(inferenceMap.entrySet(), not(empty()));
}
private String embeddingConfig(String url) {
return Strings.format("""
{
"service": "hugging_face",
"service_settings": {
"url": "%s",
"api_key": "XXXX"
}
}
""", url);
}
private String embeddingResponse() {
return """
[
[
0.014539449,
-0.015288644
]
]
""";
}
private String elserConfig(String url) {
return Strings.format("""
{
"service": "hugging_face",
"service_settings": {
"api_key": "XXXX",
"url": "%s"
}
}
""", url);
}
private String elserResponse() {
return """
[
{
".": 0.133155956864357,
"the": 0.6747211217880249
}
]
""";
}
}

View file

@ -0,0 +1,53 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.qa.mixed;
import org.elasticsearch.Version;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.test.cluster.ElasticsearchCluster;
import org.elasticsearch.test.rest.ESRestTestCase;
import org.elasticsearch.test.rest.TestFeatureService;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.ClassRule;
public abstract class MixedClusterSpecTestCase extends ESRestTestCase {
@ClassRule
public static ElasticsearchCluster cluster = MixedClustersSpec.mixedVersionCluster();
@Override
protected String getTestRestCluster() {
return cluster.getHttpAddresses();
}
static final Version bwcVersion = Version.fromString(System.getProperty("tests.old_cluster_version"));
private static TestFeatureService oldClusterTestFeatureService = null;
@Before
public void extractOldClusterFeatures() {
if (oldClusterTestFeatureService == null) {
oldClusterTestFeatureService = testFeatureService;
}
}
protected static boolean oldClusterHasFeature(String featureId) {
assert oldClusterTestFeatureService != null;
return oldClusterTestFeatureService.clusterHasFeature(featureId);
}
protected static boolean oldClusterHasFeature(NodeFeature feature) {
return oldClusterHasFeature(feature.id());
}
@AfterClass
public static void cleanUp() {
oldClusterTestFeatureService = null;
}
}

View file

@ -0,0 +1,25 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.qa.mixed;
import org.elasticsearch.test.cluster.ElasticsearchCluster;
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
import org.elasticsearch.test.cluster.util.Version;
public class MixedClustersSpec {
public static ElasticsearchCluster mixedVersionCluster() {
Version oldVersion = Version.fromString(System.getProperty("tests.old_cluster_version"));
return ElasticsearchCluster.local()
.distribution(DistributionType.DEFAULT)
.withNode(node -> node.version(oldVersion))
.withNode(node -> node.version(Version.CURRENT))
.setting("xpack.security.enabled", "false")
.setting("xpack.license.self_generated.type", "trial")
.build();
}
}

View file

@ -0,0 +1,223 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.qa.mixed;
import org.elasticsearch.Version;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import static org.elasticsearch.xpack.inference.qa.mixed.MixedClusterSpecTestCase.bwcVersion;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasEntry;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.not;
public class OpenAIServiceMixedIT extends BaseMixedTestCase {
private static final String OPEN_AI_EMBEDDINGS_ADDED = "8.12.0";
private static final String OPEN_AI_EMBEDDINGS_MODEL_SETTING_MOVED = "8.13.0";
private static final String OPEN_AI_COMPLETIONS_ADDED = "8.14.0";
private static final String MINIMUM_SUPPORTED_VERSION = "8.15.0";
private static MockWebServer openAiEmbeddingsServer;
private static MockWebServer openAiChatCompletionsServer;
@BeforeClass
public static void startWebServer() throws IOException {
openAiEmbeddingsServer = new MockWebServer();
openAiEmbeddingsServer.start();
openAiChatCompletionsServer = new MockWebServer();
openAiChatCompletionsServer.start();
}
@AfterClass
public static void shutdown() {
openAiEmbeddingsServer.close();
openAiChatCompletionsServer.close();
}
@SuppressWarnings("unchecked")
public void testOpenAiEmbeddings() throws IOException {
var openAiEmbeddingsSupported = bwcVersion.onOrAfter(Version.fromString(OPEN_AI_EMBEDDINGS_ADDED));
assumeTrue("OpenAI embedding service added in " + OPEN_AI_EMBEDDINGS_ADDED, openAiEmbeddingsSupported);
assumeTrue(
"OpenAI service requires at least " + MINIMUM_SUPPORTED_VERSION,
bwcVersion.onOrAfter(Version.fromString(MINIMUM_SUPPORTED_VERSION))
);
final String inferenceId = "mixed-cluster-embeddings";
String inferenceConfig = oldClusterVersionCompatibleEmbeddingConfig();
// queue a response as PUT will call the service
openAiEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse()));
put(inferenceId, inferenceConfig, TaskType.TEXT_EMBEDDING);
var configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, inferenceId).get("endpoints");
assertThat(configs, hasSize(1));
assertEquals("openai", configs.get(0).get("service"));
var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
var taskSettings = (Map<String, Object>) configs.get(0).get("task_settings");
var modelIdFound = serviceSettings.containsKey("model_id") || taskSettings.containsKey("model_id");
assertTrue("model_id not found in config: " + configs.toString(), modelIdFound);
assertEmbeddingInference(inferenceId);
}
void assertEmbeddingInference(String inferenceId) throws IOException {
openAiEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse()));
var inferenceMap = inference(inferenceId, TaskType.TEXT_EMBEDDING, "some text");
assertThat(inferenceMap.entrySet(), not(empty()));
}
@SuppressWarnings("unchecked")
public void testOpenAiCompletions() throws IOException {
var openAiEmbeddingsSupported = bwcVersion.onOrAfter(Version.fromString(OPEN_AI_EMBEDDINGS_ADDED));
assumeTrue("OpenAI completions service added in " + OPEN_AI_COMPLETIONS_ADDED, openAiEmbeddingsSupported);
assumeTrue(
"OpenAI service requires at least " + MINIMUM_SUPPORTED_VERSION,
bwcVersion.onOrAfter(Version.fromString(MINIMUM_SUPPORTED_VERSION))
);
final String inferenceId = "mixed-cluster-completions";
final String upgradedClusterId = "upgraded-cluster-completions";
put(inferenceId, chatCompletionsConfig(getUrl(openAiChatCompletionsServer)), TaskType.COMPLETION);
var configsMap = get(TaskType.COMPLETION, inferenceId);
logger.warn("Configs: {}", configsMap);
var configs = (List<Map<String, Object>>) configsMap.get("endpoints");
assertThat(configs, hasSize(1));
assertEquals("openai", configs.get(0).get("service"));
var serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
assertThat(serviceSettings, hasEntry("model_id", "gpt-4"));
var taskSettings = (Map<String, Object>) configs.get(0).get("task_settings");
assertThat(taskSettings.keySet(), empty());
assertCompletionInference(inferenceId);
}
void assertCompletionInference(String inferenceId) throws IOException {
openAiChatCompletionsServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionsResponse()));
var inferenceMap = inference(inferenceId, TaskType.COMPLETION, "some text");
assertThat(inferenceMap.entrySet(), not(empty()));
}
private String oldClusterVersionCompatibleEmbeddingConfig() {
if (getOldClusterTestVersion().before(OPEN_AI_EMBEDDINGS_MODEL_SETTING_MOVED)) {
return embeddingConfigWithModelInTaskSettings(getUrl(openAiEmbeddingsServer));
} else {
return embeddingConfigWithModelInServiceSettings(getUrl(openAiEmbeddingsServer));
}
}
protected static org.elasticsearch.test.cluster.util.Version getOldClusterTestVersion() {
return org.elasticsearch.test.cluster.util.Version.fromString(bwcVersion.toString());
}
private String embeddingConfigWithModelInTaskSettings(String url) {
return Strings.format("""
{
"service": "openai",
"service_settings": {
"api_key": "XXXX",
"url": "%s"
},
"task_settings": {
"model": "text-embedding-ada-002"
}
}
""", url);
}
static String embeddingConfigWithModelInServiceSettings(String url) {
return Strings.format("""
{
"service": "openai",
"service_settings": {
"api_key": "XXXX",
"url": "%s",
"model_id": "text-embedding-ada-002"
}
}
""", url);
}
private String chatCompletionsConfig(String url) {
return Strings.format("""
{
"service": "openai",
"service_settings": {
"api_key": "XXXX",
"url": "%s",
"model_id": "gpt-4"
}
}
""", url);
}
static String embeddingResponse() {
return """
{
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": [
0.0123,
-0.0123
]
}
],
"model": "text-embedding-ada-002",
"usage": {
"prompt_tokens": 8,
"total_tokens": 8
}
}
""";
}
private String chatCompletionsResponse() {
return """
{
"id": "some-id",
"object": "chat.completion",
"created": 1705397787,
"model": "gpt-3.5-turbo-0613",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "some content"
},
"logprobs": null,
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 46,
"completion_tokens": 39,
"total_tokens": 85
},
"system_fingerprint": null
}
""";
}
}