diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/testfixtures/TestFixtureExtension.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/testfixtures/TestFixtureExtension.java deleted file mode 100644 index 2bcfb7c76d5c..000000000000 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/testfixtures/TestFixtureExtension.java +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ -package org.elasticsearch.gradle.internal.testfixtures; - -import org.gradle.api.GradleException; -import org.gradle.api.NamedDomainObjectContainer; -import org.gradle.api.Project; - -import java.util.HashMap; -import java.util.Locale; -import java.util.Map; -import java.util.Optional; - -public class TestFixtureExtension { - - private final Project project; - final NamedDomainObjectContainer fixtures; - final Map serviceToProjectUseMap = new HashMap<>(); - - public TestFixtureExtension(Project project) { - this.project = project; - this.fixtures = project.container(Project.class); - } - - public void useFixture() { - useFixture(this.project.getPath()); - } - - public void useFixture(String path) { - addFixtureProject(path); - serviceToProjectUseMap.put(path, this.project.getPath()); - } - - public void useFixture(String path, String serviceName) { - addFixtureProject(path); - String key = getServiceNameKey(path, serviceName); - serviceToProjectUseMap.put(key, this.project.getPath()); - - Optional otherProject = this.findOtherProjectUsingService(key); - if (otherProject.isPresent()) { - throw new GradleException( - String.format( - Locale.ROOT, - "Projects %s and %s both claim the %s service defined in the docker-compose.yml of " - + "%sThis is not supported because it breaks running in parallel. Configure dedicated " - + "services for each project and use those instead.", - otherProject.get(), - this.project.getPath(), - serviceName, - path - ) - ); - } - } - - private String getServiceNameKey(String fixtureProjectPath, String serviceName) { - return fixtureProjectPath + "::" + serviceName; - } - - private Optional findOtherProjectUsingService(String serviceName) { - return this.project.getRootProject() - .getAllprojects() - .stream() - .filter(p -> p.equals(this.project) == false) - .filter(p -> p.getExtensions().findByType(TestFixtureExtension.class) != null) - .map(project -> project.getExtensions().getByType(TestFixtureExtension.class)) - .flatMap(ext -> ext.serviceToProjectUseMap.entrySet().stream()) - .filter(entry -> entry.getKey().equals(serviceName)) - .map(Map.Entry::getValue) - .findAny(); - } - - private void addFixtureProject(String path) { - Project fixtureProject = this.project.findProject(path); - if (fixtureProject == null) { - throw new IllegalArgumentException("Could not find test fixture " + fixtureProject); - } - if (fixtureProject.file(TestFixturesPlugin.DOCKER_COMPOSE_YML).exists() == false) { - throw new IllegalArgumentException( - "Project " + path + " is not a valid test fixture: missing " + TestFixturesPlugin.DOCKER_COMPOSE_YML - ); - } - fixtures.add(fixtureProject); - // Check for exclusive access - Optional otherProject = this.findOtherProjectUsingService(path); - if (otherProject.isPresent()) { - throw new GradleException( - String.format( - Locale.ROOT, - "Projects %s and %s both claim all services from %s. This is not supported because it" - + " breaks running in parallel. Configure specific services in docker-compose.yml " - + "for each and add the service name to `useFixture`", - otherProject.get(), - this.project.getPath(), - path - ) - ); - } - } - - boolean isServiceRequired(String serviceName, String fixtureProject) { - if (serviceToProjectUseMap.containsKey(fixtureProject)) { - return true; - } - return serviceToProjectUseMap.containsKey(getServiceNameKey(fixtureProject, serviceName)); - } -} diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/testfixtures/TestFixturesPlugin.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/testfixtures/TestFixturesPlugin.java index c50ff97498c3..4c5f2abb9515 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/testfixtures/TestFixturesPlugin.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/testfixtures/TestFixturesPlugin.java @@ -70,7 +70,6 @@ public class TestFixturesPlugin implements Plugin { project.getRootProject().getPluginManager().apply(DockerSupportPlugin.class); TaskContainer tasks = project.getTasks(); - TestFixtureExtension extension = project.getExtensions().create("testFixtures", TestFixtureExtension.class, project); Provider dockerComposeThrottle = project.getGradle() .getSharedServices() .registerIfAbsent(DOCKER_COMPOSE_THROTTLE, DockerComposeThrottle.class, spec -> spec.getMaxParallelUsages().set(1)); @@ -84,73 +83,63 @@ public class TestFixturesPlugin implements Plugin { File testFixturesDir = project.file("testfixtures_shared"); ext.set("testFixturesDir", testFixturesDir); - if (project.file(DOCKER_COMPOSE_YML).exists()) { - project.getPluginManager().apply(BasePlugin.class); - project.getPluginManager().apply(DockerComposePlugin.class); - TaskProvider preProcessFixture = project.getTasks().register("preProcessFixture", TestFixtureTask.class, t -> { - t.getFixturesDir().set(testFixturesDir); - t.doFirst(task -> { - try { - Files.createDirectories(testFixturesDir.toPath()); - } catch (IOException e) { - throw new UncheckedIOException(e); - } - }); - }); - TaskProvider buildFixture = project.getTasks() - .register("buildFixture", t -> t.dependsOn(preProcessFixture, tasks.named("composeUp"))); - - TaskProvider postProcessFixture = project.getTasks() - .register("postProcessFixture", TestFixtureTask.class, task -> { - task.getFixturesDir().set(testFixturesDir); - task.dependsOn(buildFixture); - configureServiceInfoForTask( - task, - project, - false, - (name, port) -> task.getExtensions().getByType(ExtraPropertiesExtension.class).set(name, port) - ); - }); - - maybeSkipTask(dockerSupport, preProcessFixture); - maybeSkipTask(dockerSupport, postProcessFixture); - maybeSkipTask(dockerSupport, buildFixture); - - ComposeExtension composeExtension = project.getExtensions().getByType(ComposeExtension.class); - composeExtension.setProjectName(project.getName()); - composeExtension.getUseComposeFiles().addAll(Collections.singletonList(DOCKER_COMPOSE_YML)); - composeExtension.getRemoveContainers().set(true); - composeExtension.getCaptureContainersOutput() - .set(EnumSet.of(LogLevel.INFO, LogLevel.DEBUG).contains(project.getGradle().getStartParameter().getLogLevel())); - composeExtension.getUseDockerComposeV2().set(false); - composeExtension.getExecutable().set(this.providerFactory.provider(() -> { - String composePath = dockerSupport.get().getDockerAvailability().dockerComposePath(); - LOGGER.debug("Docker Compose path: {}", composePath); - return composePath != null ? composePath : "/usr/bin/docker-compose"; - })); - - tasks.named("composeUp").configure(t -> { - // Avoid running docker-compose tasks in parallel in CI due to some issues on certain Linux distributions - if (BuildParams.isCi()) { - t.usesService(dockerComposeThrottle); - } - t.mustRunAfter(preProcessFixture); - }); - tasks.named("composePull").configure(t -> t.mustRunAfter(preProcessFixture)); - tasks.named("composeDown").configure(t -> t.doLast(t2 -> getFileSystemOperations().delete(d -> d.delete(testFixturesDir)))); - } else { - project.afterEvaluate(spec -> { - if (extension.fixtures.isEmpty()) { - // if only one fixture is used, that's this one, but without a compose file that's not a valid configuration - throw new IllegalStateException( - "No " + DOCKER_COMPOSE_YML + " found for " + project.getPath() + " nor does it use other fixtures." - ); - } - }); + if (project.file(DOCKER_COMPOSE_YML).exists() == false) { + // if only one fixture is used, that's this one, but without a compose file that's not a valid configuration + throw new IllegalStateException("No " + DOCKER_COMPOSE_YML + " found for " + project.getPath() + "."); } + project.getPluginManager().apply(BasePlugin.class); + project.getPluginManager().apply(DockerComposePlugin.class); + TaskProvider preProcessFixture = project.getTasks().register("preProcessFixture", TestFixtureTask.class, t -> { + t.getFixturesDir().set(testFixturesDir); + t.doFirst(task -> { + try { + Files.createDirectories(testFixturesDir.toPath()); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }); + }); + TaskProvider buildFixture = project.getTasks() + .register("buildFixture", t -> t.dependsOn(preProcessFixture, tasks.named("composeUp"))); - extension.fixtures.matching(fixtureProject -> fixtureProject.equals(project) == false) - .all(fixtureProject -> project.evaluationDependsOn(fixtureProject.getPath())); + TaskProvider postProcessFixture = project.getTasks() + .register("postProcessFixture", TestFixtureTask.class, task -> { + task.getFixturesDir().set(testFixturesDir); + task.dependsOn(buildFixture); + configureServiceInfoForTask( + task, + project, + false, + (name, port) -> task.getExtensions().getByType(ExtraPropertiesExtension.class).set(name, port) + ); + }); + + maybeSkipTask(dockerSupport, preProcessFixture); + maybeSkipTask(dockerSupport, postProcessFixture); + maybeSkipTask(dockerSupport, buildFixture); + + ComposeExtension composeExtension = project.getExtensions().getByType(ComposeExtension.class); + composeExtension.setProjectName(project.getName()); + composeExtension.getUseComposeFiles().addAll(Collections.singletonList(DOCKER_COMPOSE_YML)); + composeExtension.getRemoveContainers().set(true); + composeExtension.getCaptureContainersOutput() + .set(EnumSet.of(LogLevel.INFO, LogLevel.DEBUG).contains(project.getGradle().getStartParameter().getLogLevel())); + composeExtension.getUseDockerComposeV2().set(false); + composeExtension.getExecutable().set(this.providerFactory.provider(() -> { + String composePath = dockerSupport.get().getDockerAvailability().dockerComposePath(); + LOGGER.debug("Docker Compose path: {}", composePath); + return composePath != null ? composePath : "/usr/bin/docker-compose"; + })); + + tasks.named("composeUp").configure(t -> { + // Avoid running docker-compose tasks in parallel in CI due to some issues on certain Linux distributions + if (BuildParams.isCi()) { + t.usesService(dockerComposeThrottle); + } + t.mustRunAfter(preProcessFixture); + }); + tasks.named("composePull").configure(t -> t.mustRunAfter(preProcessFixture)); + tasks.named("composeDown").configure(t -> t.doLast(t2 -> getFileSystemOperations().delete(d -> d.delete(testFixturesDir)))); // Skip docker compose tasks if it is unavailable maybeSkipTasks(tasks, dockerSupport, Test.class); @@ -161,17 +150,18 @@ public class TestFixturesPlugin implements Plugin { maybeSkipTasks(tasks, dockerSupport, ComposePull.class); maybeSkipTasks(tasks, dockerSupport, ComposeDown.class); - tasks.withType(Test.class).configureEach(task -> extension.fixtures.all(fixtureProject -> { - task.dependsOn(fixtureProject.getTasks().named("postProcessFixture")); - task.finalizedBy(fixtureProject.getTasks().named("composeDown")); + tasks.withType(Test.class).configureEach(testTask -> { + testTask.dependsOn(postProcessFixture); + testTask.finalizedBy(tasks.named("composeDown")); configureServiceInfoForTask( - task, - fixtureProject, + testTask, + project, true, - (name, host) -> task.getExtensions().getByType(SystemPropertyCommandLineArgumentProvider.class).systemProperty(name, host) + (name, host) -> testTask.getExtensions() + .getByType(SystemPropertyCommandLineArgumentProvider.class) + .systemProperty(name, host) ); - })); - + }); } private void maybeSkipTasks(TaskContainer tasks, Provider dockerSupport, Class taskClass) { @@ -203,28 +193,20 @@ public class TestFixturesPlugin implements Plugin { task.doFirst(new Action() { @Override public void execute(Task theTask) { - TestFixtureExtension extension = theTask.getProject().getExtensions().getByType(TestFixtureExtension.class); - - fixtureProject.getExtensions() - .getByType(ComposeExtension.class) - .getServicesInfos() - .entrySet() - .stream() - .filter(entry -> enableFilter == false || extension.isServiceRequired(entry.getKey(), fixtureProject.getPath())) - .forEach(entry -> { - String service = entry.getKey(); - ServiceInfo infos = entry.getValue(); - infos.getTcpPorts().forEach((container, host) -> { - String name = "test.fixtures." + service + ".tcp." + container; - theTask.getLogger().info("port mapping property: {}={}", name, host); - consumer.accept(name, host); - }); - infos.getUdpPorts().forEach((container, host) -> { - String name = "test.fixtures." + service + ".udp." + container; - theTask.getLogger().info("port mapping property: {}={}", name, host); - consumer.accept(name, host); - }); + fixtureProject.getExtensions().getByType(ComposeExtension.class).getServicesInfos().entrySet().stream().forEach(entry -> { + String service = entry.getKey(); + ServiceInfo infos = entry.getValue(); + infos.getTcpPorts().forEach((container, host) -> { + String name = "test.fixtures." + service + ".tcp." + container; + theTask.getLogger().info("port mapping property: {}={}", name, host); + consumer.accept(name, host); }); + infos.getUdpPorts().forEach((container, host) -> { + String name = "test.fixtures." + service + ".udp." + container; + theTask.getLogger().info("port mapping property: {}={}", name, host); + consumer.accept(name, host); + }); + }); } }); } diff --git a/distribution/docker/build.gradle b/distribution/docker/build.gradle index a3bb202780c7..68ff2028b92a 100644 --- a/distribution/docker/build.gradle +++ b/distribution/docker/build.gradle @@ -72,8 +72,6 @@ if (useDra == false) { } } -testFixtures.useFixture() - configurations { aarch64DockerSource { attributes { diff --git a/docs/changelog/109044.yaml b/docs/changelog/109044.yaml new file mode 100644 index 000000000000..9e50c377606a --- /dev/null +++ b/docs/changelog/109044.yaml @@ -0,0 +1,5 @@ +pr: 109044 +summary: Enable fallback synthetic source for `token_count` +area: Mapping +type: feature +issues: [] diff --git a/docs/reference/mapping/fields/synthetic-source.asciidoc b/docs/reference/mapping/fields/synthetic-source.asciidoc index 1eba9dfba8b5..a0e7aed177a9 100644 --- a/docs/reference/mapping/fields/synthetic-source.asciidoc +++ b/docs/reference/mapping/fields/synthetic-source.asciidoc @@ -64,6 +64,7 @@ types: ** <> ** <> ** <> +** <> ** <> ** <> diff --git a/docs/reference/mapping/types/token-count.asciidoc b/docs/reference/mapping/types/token-count.asciidoc index 23bbc775243a..7d9dffcc8208 100644 --- a/docs/reference/mapping/types/token-count.asciidoc +++ b/docs/reference/mapping/types/token-count.asciidoc @@ -64,10 +64,10 @@ The following parameters are accepted by `token_count` fields: value. Required. For best performance, use an analyzer without token filters. -`enable_position_increments`:: +`enable_position_increments`:: -Indicates if position increments should be counted. -Set to `false` if you don't want to count tokens removed by analyzer filters (like <>). +Indicates if position increments should be counted. +Set to `false` if you don't want to count tokens removed by analyzer filters (like <>). Defaults to `true`. <>:: @@ -91,3 +91,17 @@ Defaults to `true`. Whether the field value should be stored and retrievable separately from the <> field. Accepts `true` or `false` (default). + +[[token-count-synthetic-source]] +===== Synthetic `_source` + +IMPORTANT: Synthetic `_source` is Generally Available only for TSDB indices +(indices that have `index.mode` set to `time_series`). For other indices +synthetic `_source` is in technical preview. Features in technical preview may +be changed or removed in a future release. Elastic will work to fix +any issues, but features in technical preview are not subject to the support SLA +of official GA features. + +`token_count` fields support <> in their +default configuration. Synthetic `_source` cannot be used together with +<>. diff --git a/docs/reference/transform/transforms-at-scale.asciidoc b/docs/reference/transform/transforms-at-scale.asciidoc index f1d47c994324..f052b2e8a528 100644 --- a/docs/reference/transform/transforms-at-scale.asciidoc +++ b/docs/reference/transform/transforms-at-scale.asciidoc @@ -15,7 +15,7 @@ relevant considerations in this guide to improve performance. It also helps to understand how {transforms} work as different considerations apply depending on whether or not your transform is running in continuous mode or in batch. -In this guide, you’ll learn how to: +In this guide, you'll learn how to: * Understand the impact of configuration options on the performance of {transforms}. @@ -111,10 +111,17 @@ group of IPs, in order to calculate the total `bytes_sent`. If this second search matches many shards, then this could be resource intensive. Consider limiting the scope that the source index pattern and query will match. -Use an absolute time value as a date range filter in your source query (for -example, greater than `2020-01-01T00:00:00`) to limit which historical indices -are accessed. If you use a relative time value (for example, `now-30d`) then -this date range is re-evaluated at the point of each checkpoint execution. +To limit which historical indices are accessed, exclude certain tiers (for +example `"must_not": { "terms": { "_tier": [ "data_frozen", "data_cold" ] } }` +and/or use an absolute time value as a date range filter in your source query +(for example, greater than 2024-01-01T00:00:00). If you use a relative time +value (for example, gte now-30d/d) then ensure date rounding is applied to take +advantage of query caching and ensure that the relative time is much larger than +the largest of `frequency` or `time.sync.delay` or the date histogram bucket, +otherwise data may be missed. Do not use date filters which are less than a date +value (for example, `lt`: less than or `lte`: less than or equal to) as this +conflicts with the logic applied at each checkpoint execution and data may be +missed. Consider using <> in your index names to reduce the number of indices to resolve in your queries. Add a date pattern diff --git a/docs/reference/troubleshooting/common-issues/high-jvm-memory-pressure.asciidoc b/docs/reference/troubleshooting/common-issues/high-jvm-memory-pressure.asciidoc index e88927f159f2..267d6594b802 100644 --- a/docs/reference/troubleshooting/common-issues/high-jvm-memory-pressure.asciidoc +++ b/docs/reference/troubleshooting/common-issues/high-jvm-memory-pressure.asciidoc @@ -30,7 +30,8 @@ collection. **Capture a JVM heap dump** To determine the exact reason for the high JVM memory pressure, capture a heap -dump of the JVM while its memory usage is high. +dump of the JVM while its memory usage is high, and also capture the +<> covering the same time period. [discrete] [[reduce-jvm-memory-pressure]] diff --git a/docs/reference/troubleshooting/network-timeouts.asciidoc b/docs/reference/troubleshooting/network-timeouts.asciidoc index 1920dafe6221..ef942ac1d268 100644 --- a/docs/reference/troubleshooting/network-timeouts.asciidoc +++ b/docs/reference/troubleshooting/network-timeouts.asciidoc @@ -4,8 +4,8 @@ usually by the `JvmMonitorService` in the main node logs. Use these logs to confirm whether or not the node is experiencing high heap usage with long GC pauses. If so, <> has some suggestions for further investigation but typically you -will need to capture a heap dump during a time of high heap usage to fully -understand the problem. +will need to capture a heap dump and the <> +during a time of high heap usage to fully understand the problem. * VM pauses also affect other processes on the same host. A VM pause also typically causes a discontinuity in the system clock, which {es} will report in diff --git a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/extras/TokenCountFieldMapper.java b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/extras/TokenCountFieldMapper.java index 831306a8e859..c538c7641a01 100644 --- a/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/extras/TokenCountFieldMapper.java +++ b/modules/mapper-extras/src/main/java/org/elasticsearch/index/mapper/extras/TokenCountFieldMapper.java @@ -215,4 +215,9 @@ public class TokenCountFieldMapper extends FieldMapper { public FieldMapper.Builder getMergeBuilder() { return new Builder(simpleName()).init(this); } + + @Override + protected SyntheticSourceMode syntheticSourceMode() { + return SyntheticSourceMode.FALLBACK; + } } diff --git a/modules/mapper-extras/src/test/java/org/elasticsearch/index/mapper/extras/TokenCountFieldMapperTests.java b/modules/mapper-extras/src/test/java/org/elasticsearch/index/mapper/extras/TokenCountFieldMapperTests.java index 1636def53536..d34d9c3178c7 100644 --- a/modules/mapper-extras/src/test/java/org/elasticsearch/index/mapper/extras/TokenCountFieldMapperTests.java +++ b/modules/mapper-extras/src/test/java/org/elasticsearch/index/mapper/extras/TokenCountFieldMapperTests.java @@ -33,7 +33,11 @@ import java.io.IOException; import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.function.Function; +import java.util.stream.Collectors; import static org.hamcrest.Matchers.equalTo; @@ -196,7 +200,66 @@ public class TokenCountFieldMapperTests extends MapperTestCase { @Override protected SyntheticSourceSupport syntheticSourceSupport(boolean ignoreMalformed) { - throw new AssumptionViolatedException("not supported"); + assertFalse(ignoreMalformed); + + var nullValue = usually() ? null : randomNonNegativeInt(); + return new SyntheticSourceSupport() { + @Override + public boolean preservesExactSource() { + return true; + } + + public SyntheticSourceExample example(int maxValues) { + if (randomBoolean()) { + var value = generateValue(); + return new SyntheticSourceExample(value.text, value.text, value.tokenCount, this::mapping); + } + + var values = randomList(1, 5, this::generateValue); + + var textArray = values.stream().map(Value::text).toList(); + + var blockExpectedList = values.stream().map(Value::tokenCount).filter(Objects::nonNull).toList(); + var blockExpected = blockExpectedList.size() == 1 ? blockExpectedList.get(0) : blockExpectedList; + + return new SyntheticSourceExample(textArray, textArray, blockExpected, this::mapping); + } + + private record Value(String text, Integer tokenCount) {} + + private Value generateValue() { + if (rarely()) { + return new Value(null, null); + } + + var text = randomList(0, 10, () -> randomAlphaOfLengthBetween(0, 10)).stream().collect(Collectors.joining(" ")); + // with keyword analyzer token count is always 1 + return new Value(text, 1); + } + + private void mapping(XContentBuilder b) throws IOException { + b.field("type", "token_count").field("analyzer", "keyword"); + if (rarely()) { + b.field("index", false); + } + if (rarely()) { + b.field("store", true); + } + if (nullValue != null) { + b.field("null_value", nullValue); + } + } + + @Override + public List invalidExample() throws IOException { + return List.of(); + } + }; + } + + protected Function loadBlockExpected() { + // we can get either a number from doc values or null + return v -> v != null ? (Number) v : null; } @Override diff --git a/modules/mapper-extras/src/yamlRestTest/resources/rest-api-spec/test/token_count/10_basic.yml b/modules/mapper-extras/src/yamlRestTest/resources/rest-api-spec/test/token_count/10_basic.yml new file mode 100644 index 000000000000..03b72a262349 --- /dev/null +++ b/modules/mapper-extras/src/yamlRestTest/resources/rest-api-spec/test/token_count/10_basic.yml @@ -0,0 +1,65 @@ +"Test token count": + - requires: + cluster_features: ["gte_v7.10.0"] + reason: "support for token_count was instroduced in 7.10" + - do: + indices.create: + index: test + body: + mappings: + properties: + count: + type: token_count + analyzer: standard + count_without_dv: + type: token_count + analyzer: standard + doc_values: false + + - do: + index: + index: test + id: "1" + refresh: true + body: + count: "some text" + - do: + search: + index: test + body: + fields: [count, count_without_dv] + + - is_true: hits.hits.0._id + - match: { hits.hits.0.fields.count: [2] } + - is_false: hits.hits.0.fields.count_without_dv + +--- +"Synthetic source": + - requires: + cluster_features: ["mapper.track_ignored_source"] + reason: requires tracking ignored source + - do: + indices.create: + index: test + body: + mappings: + _source: + mode: synthetic + properties: + count: + type: token_count + analyzer: standard + + - do: + index: + index: test + id: "1" + refresh: true + body: + count: "quick brown fox jumps over a lazy dog" + - do: + get: + index: test + id: "1" + + - match: { _source.count: "quick brown fox jumps over a lazy dog" } diff --git a/modules/repository-s3/build.gradle b/modules/repository-s3/build.gradle index 8b1f30a1bba6..1732fd39794b 100644 --- a/modules/repository-s3/build.gradle +++ b/modules/repository-s3/build.gradle @@ -164,6 +164,8 @@ tasks.named("processYamlRestTestResources").configure { tasks.named("internalClusterTest").configure { // this is tested explicitly in a separate test task exclude '**/S3RepositoryThirdPartyTests.class' + // TODO: remove once https://github.com/elastic/elasticsearch/issues/101608 is fixed + systemProperty 'es.insecure_network_trace_enabled', 'true' } tasks.named("yamlRestTest").configure { diff --git a/modules/repository-s3/src/internalClusterTest/java/org/elasticsearch/repositories/s3/S3BlobStoreRepositoryTests.java b/modules/repository-s3/src/internalClusterTest/java/org/elasticsearch/repositories/s3/S3BlobStoreRepositoryTests.java index ebc60e8027d8..030f791feee1 100644 --- a/modules/repository-s3/src/internalClusterTest/java/org/elasticsearch/repositories/s3/S3BlobStoreRepositoryTests.java +++ b/modules/repository-s3/src/internalClusterTest/java/org/elasticsearch/repositories/s3/S3BlobStoreRepositoryTests.java @@ -627,6 +627,8 @@ public class S3BlobStoreRepositoryTests extends ESMockAPIBasedRepositoryIntegTes trackRequest("HeadObject"); metricsCount.computeIfAbsent(new S3BlobStore.StatsKey(S3BlobStore.Operation.HEAD_OBJECT, purpose), k -> new AtomicLong()) .incrementAndGet(); + } else { + logger.info("--> rawRequest not tracked [{}] with parsed purpose [{}]", request, purpose.getKey()); } } diff --git a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedContinuationsIT.java b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedContinuationsIT.java index 25195a1176fb..77333677120a 100644 --- a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedContinuationsIT.java +++ b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedContinuationsIT.java @@ -60,7 +60,7 @@ import org.elasticsearch.http.HttpServerTransport; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.rest.BaseRestHandler; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; @@ -370,31 +370,31 @@ public class Netty4ChunkedContinuationsIT extends ESNetty4IntegTestCase { TransportAction.localOnly(); } - public ChunkedRestResponseBody getChunkedBody() { - return getChunkBatch(0); + public ChunkedRestResponseBodyPart getFirstResponseBodyPart() { + return getResponseBodyPart(0); } - private ChunkedRestResponseBody getChunkBatch(int batchIndex) { + private ChunkedRestResponseBodyPart getResponseBodyPart(int batchIndex) { if (batchIndex == failIndex && randomBoolean()) { throw new ElasticsearchException("simulated failure creating next batch"); } - return new ChunkedRestResponseBody() { + return new ChunkedRestResponseBodyPart() { private final Iterator lines = Iterators.forRange(0, 3, i -> "batch-" + batchIndex + "-chunk-" + i + "\n"); @Override - public boolean isDone() { + public boolean isPartComplete() { return lines.hasNext() == false; } @Override - public boolean isEndOfResponse() { + public boolean isLastPart() { return batchIndex == 2; } @Override - public void getContinuation(ActionListener listener) { - executor.execute(ActionRunnable.supply(listener, () -> getChunkBatch(batchIndex + 1))); + public void getNextPart(ActionListener listener) { + executor.execute(ActionRunnable.supply(listener, () -> getResponseBodyPart(batchIndex + 1))); } @Override @@ -486,11 +486,12 @@ public class Netty4ChunkedContinuationsIT extends ESNetty4IntegTestCase { @Override protected void processResponse(Response response) { try { - final var responseBody = response.getChunkedBody(); // might fail, so do this before acquiring ref + final var responseBody = response.getFirstResponseBodyPart(); + // preceding line might fail, so needs to be done before acquiring the sendResponse ref refs.mustIncRef(); channel.sendResponse(RestResponse.chunked(RestStatus.OK, responseBody, refs::decRef)); } finally { - refs.decRef(); + refs.decRef(); // release the ref acquired at the top of accept() } } }); @@ -534,26 +535,26 @@ public class Netty4ChunkedContinuationsIT extends ESNetty4IntegTestCase { TransportAction.localOnly(); } - public ChunkedRestResponseBody getChunkedBody() { - return new ChunkedRestResponseBody() { + public ChunkedRestResponseBodyPart getResponseBodyPart() { + return new ChunkedRestResponseBodyPart() { private final Iterator lines = Iterators.single("infinite response\n"); @Override - public boolean isDone() { + public boolean isPartComplete() { return lines.hasNext() == false; } @Override - public boolean isEndOfResponse() { + public boolean isLastPart() { return false; } @Override - public void getContinuation(ActionListener listener) { + public void getNextPart(ActionListener listener) { computingContinuation = true; executor.execute(ActionRunnable.supply(listener, () -> { computingContinuation = false; - return getChunkedBody(); + return getResponseBodyPart(); })); } @@ -628,7 +629,7 @@ public class Netty4ChunkedContinuationsIT extends ESNetty4IntegTestCase { client.execute(TYPE, new Request(), new RestActionListener<>(channel) { @Override protected void processResponse(Response response) { - channel.sendResponse(RestResponse.chunked(RestStatus.OK, response.getChunkedBody(), () -> { + channel.sendResponse(RestResponse.chunked(RestStatus.OK, response.getResponseBodyPart(), () -> { // cancellation notification only happens while processing a continuation, not while computing // the next one; prompt cancellation requires use of something like RestCancellableNodeClient assertFalse(response.computingContinuation); diff --git a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedEncodingIT.java b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedEncodingIT.java index b2a54e202730..e3f60ea7a48e 100644 --- a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedEncodingIT.java +++ b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4ChunkedEncodingIT.java @@ -37,7 +37,7 @@ import org.elasticsearch.features.NodeFeature; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.rest.BaseRestHandler; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; @@ -245,19 +245,19 @@ public class Netty4ChunkedEncodingIT extends ESNetty4IntegTestCase { private static void sendChunksResponse(RestChannel channel, Iterator chunkIterator) { final var localRefs = refs; // single volatile read if (localRefs != null && localRefs.tryIncRef()) { - channel.sendResponse(RestResponse.chunked(RestStatus.OK, new ChunkedRestResponseBody() { + channel.sendResponse(RestResponse.chunked(RestStatus.OK, new ChunkedRestResponseBodyPart() { @Override - public boolean isDone() { + public boolean isPartComplete() { return chunkIterator.hasNext() == false; } @Override - public boolean isEndOfResponse() { + public boolean isLastPart() { return true; } @Override - public void getContinuation(ActionListener listener) { + public void getNextPart(ActionListener listener) { assert false : "no continuations"; } diff --git a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4PipeliningIT.java b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4PipeliningIT.java index ce8da0c08af5..89a76dd26e28 100644 --- a/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4PipeliningIT.java +++ b/modules/transport-netty4/src/internalClusterTest/java/org/elasticsearch/http/netty4/Netty4PipeliningIT.java @@ -34,7 +34,7 @@ import org.elasticsearch.http.HttpServerTransport; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.rest.BaseRestHandler; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; import org.elasticsearch.rest.RestRequest; @@ -243,21 +243,21 @@ public class Netty4PipeliningIT extends ESNetty4IntegTestCase { throw new IllegalArgumentException("[" + FAIL_AFTER_BYTES_PARAM + "] must be present and non-negative"); } return channel -> randomExecutor(client.threadPool()).execute( - () -> channel.sendResponse(RestResponse.chunked(RestStatus.OK, new ChunkedRestResponseBody() { + () -> channel.sendResponse(RestResponse.chunked(RestStatus.OK, new ChunkedRestResponseBodyPart() { int bytesRemaining = failAfterBytes; @Override - public boolean isDone() { + public boolean isPartComplete() { return false; } @Override - public boolean isEndOfResponse() { + public boolean isLastPart() { return true; } @Override - public void getContinuation(ActionListener listener) { + public void getNextPart(ActionListener listener) { fail("no continuations here"); } diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4ChunkedHttpContinuation.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4ChunkedHttpContinuation.java index 156f1c27aa67..cde024921698 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4ChunkedHttpContinuation.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4ChunkedHttpContinuation.java @@ -10,16 +10,16 @@ package org.elasticsearch.http.netty4; import io.netty.util.concurrent.PromiseCombiner; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; final class Netty4ChunkedHttpContinuation implements Netty4HttpResponse { private final int sequence; - private final ChunkedRestResponseBody body; + private final ChunkedRestResponseBodyPart bodyPart; private final PromiseCombiner combiner; - Netty4ChunkedHttpContinuation(int sequence, ChunkedRestResponseBody body, PromiseCombiner combiner) { + Netty4ChunkedHttpContinuation(int sequence, ChunkedRestResponseBodyPart bodyPart, PromiseCombiner combiner) { this.sequence = sequence; - this.body = body; + this.bodyPart = bodyPart; this.combiner = combiner; } @@ -28,8 +28,8 @@ final class Netty4ChunkedHttpContinuation implements Netty4HttpResponse { return sequence; } - public ChunkedRestResponseBody body() { - return body; + public ChunkedRestResponseBodyPart bodyPart() { + return bodyPart; } public PromiseCombiner combiner() { diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4ChunkedHttpResponse.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4ChunkedHttpResponse.java index 783c02da0bbc..3abab9fa2526 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4ChunkedHttpResponse.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4ChunkedHttpResponse.java @@ -13,7 +13,7 @@ import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpVersion; import org.elasticsearch.http.HttpResponse; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestStatus; /** @@ -23,16 +23,16 @@ final class Netty4ChunkedHttpResponse extends DefaultHttpResponse implements Net private final int sequence; - private final ChunkedRestResponseBody body; + private final ChunkedRestResponseBodyPart firstBodyPart; - Netty4ChunkedHttpResponse(int sequence, HttpVersion version, RestStatus status, ChunkedRestResponseBody body) { + Netty4ChunkedHttpResponse(int sequence, HttpVersion version, RestStatus status, ChunkedRestResponseBodyPart firstBodyPart) { super(version, HttpResponseStatus.valueOf(status.getStatus())); this.sequence = sequence; - this.body = body; + this.firstBodyPart = firstBodyPart; } - public ChunkedRestResponseBody body() { - return body; + public ChunkedRestResponseBodyPart firstBodyPart() { + return firstBodyPart; } @Override diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java index 8280c438613a..9cf210c2a8aa 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java @@ -34,7 +34,7 @@ import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.core.Booleans; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Tuple; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.transport.Transports; import org.elasticsearch.transport.netty4.Netty4Utils; import org.elasticsearch.transport.netty4.Netty4WriteThrottlingHandler; @@ -58,7 +58,7 @@ public class Netty4HttpPipeliningHandler extends ChannelDuplexHandler { private final int maxEventsHeld; private final PriorityQueue> outboundHoldingQueue; - private record ChunkedWrite(PromiseCombiner combiner, ChannelPromise onDone, ChunkedRestResponseBody responseBody) {} + private record ChunkedWrite(PromiseCombiner combiner, ChannelPromise onDone, ChunkedRestResponseBodyPart responseBodyPart) {} /** * The current {@link ChunkedWrite} if a chunked write is executed at the moment. @@ -214,9 +214,9 @@ public class Netty4HttpPipeliningHandler extends ChannelDuplexHandler { final PromiseCombiner combiner = new PromiseCombiner(ctx.executor()); final ChannelPromise first = ctx.newPromise(); combiner.add((Future) first); - final var responseBody = readyResponse.body(); + final var firstBodyPart = readyResponse.firstBodyPart(); assert currentChunkedWrite == null; - currentChunkedWrite = new ChunkedWrite(combiner, promise, responseBody); + currentChunkedWrite = new ChunkedWrite(combiner, promise, firstBodyPart); if (enqueueWrite(ctx, readyResponse, first)) { // We were able to write out the first chunk directly, try writing out subsequent chunks until the channel becomes unwritable. // NB "writable" means there's space in the downstream ChannelOutboundBuffer, we aren't trying to saturate the physical channel. @@ -232,9 +232,10 @@ public class Netty4HttpPipeliningHandler extends ChannelDuplexHandler { private void doWriteChunkedContinuation(ChannelHandlerContext ctx, Netty4ChunkedHttpContinuation continuation, ChannelPromise promise) { final PromiseCombiner combiner = continuation.combiner(); assert currentChunkedWrite == null; - final var responseBody = continuation.body(); - assert responseBody.isDone() == false : "response with continuations must have at least one (possibly-empty) chunk in each part"; - currentChunkedWrite = new ChunkedWrite(combiner, promise, responseBody); + final var bodyPart = continuation.bodyPart(); + assert bodyPart.isPartComplete() == false + : "response with continuations must have at least one (possibly-empty) chunk in each part"; + currentChunkedWrite = new ChunkedWrite(combiner, promise, bodyPart); // NB "writable" means there's space in the downstream ChannelOutboundBuffer, we aren't trying to saturate the physical channel. while (ctx.channel().isWritable()) { if (writeChunk(ctx, currentChunkedWrite)) { @@ -251,9 +252,9 @@ public class Netty4HttpPipeliningHandler extends ChannelDuplexHandler { } final var finishingWrite = currentChunkedWrite; currentChunkedWrite = null; - final var finishingWriteBody = finishingWrite.responseBody(); - assert finishingWriteBody.isDone(); - final var endOfResponse = finishingWriteBody.isEndOfResponse(); + final var finishingWriteBodyPart = finishingWrite.responseBodyPart(); + assert finishingWriteBodyPart.isPartComplete(); + final var endOfResponse = finishingWriteBodyPart.isLastPart(); if (endOfResponse) { writeSequence++; finishingWrite.combiner().finish(finishingWrite.onDone()); @@ -261,7 +262,7 @@ public class Netty4HttpPipeliningHandler extends ChannelDuplexHandler { final var channel = finishingWrite.onDone().channel(); ActionListener.run(ActionListener.assertOnce(new ActionListener<>() { @Override - public void onResponse(ChunkedRestResponseBody continuation) { + public void onResponse(ChunkedRestResponseBodyPart continuation) { channel.writeAndFlush( new Netty4ChunkedHttpContinuation(writeSequence, continuation, finishingWrite.combiner()), finishingWrite.onDone() // pass the terminal listener/promise along the line @@ -296,7 +297,7 @@ public class Netty4HttpPipeliningHandler extends ChannelDuplexHandler { } } - }), finishingWriteBody::getContinuation); + }), finishingWriteBodyPart::getNextPart); } } @@ -374,22 +375,22 @@ public class Netty4HttpPipeliningHandler extends ChannelDuplexHandler { } private boolean writeChunk(ChannelHandlerContext ctx, ChunkedWrite chunkedWrite) { - final var body = chunkedWrite.responseBody(); + final var bodyPart = chunkedWrite.responseBodyPart(); final var combiner = chunkedWrite.combiner(); - assert body.isDone() == false : "should not continue to try and serialize once done"; + assert bodyPart.isPartComplete() == false : "should not continue to try and serialize once done"; final ReleasableBytesReference bytes; try { - bytes = body.encodeChunk(Netty4WriteThrottlingHandler.MAX_BYTES_PER_WRITE, serverTransport.recycler()); + bytes = bodyPart.encodeChunk(Netty4WriteThrottlingHandler.MAX_BYTES_PER_WRITE, serverTransport.recycler()); } catch (Exception e) { return handleChunkingFailure(ctx, chunkedWrite, e); } final ByteBuf content = Netty4Utils.toByteBuf(bytes); - final boolean done = body.isDone(); - final boolean lastChunk = done && body.isEndOfResponse(); - final ChannelFuture f = ctx.write(lastChunk ? new DefaultLastHttpContent(content) : new DefaultHttpContent(content)); + final boolean isPartComplete = bodyPart.isPartComplete(); + final boolean isBodyComplete = isPartComplete && bodyPart.isLastPart(); + final ChannelFuture f = ctx.write(isBodyComplete ? new DefaultLastHttpContent(content) : new DefaultHttpContent(content)); f.addListener(ignored -> bytes.close()); combiner.add(f); - return done; + return isPartComplete; } private boolean handleChunkingFailure(ChannelHandlerContext ctx, ChunkedWrite chunkedWrite, Exception e) { diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java index 0e1bb527fed9..1e35f084c87e 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java @@ -22,7 +22,7 @@ import io.netty.handler.codec.http.cookie.ServerCookieEncoder; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.http.HttpRequest; import org.elasticsearch.http.HttpResponse; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.transport.netty4.Netty4Utils; @@ -176,8 +176,8 @@ public class Netty4HttpRequest implements HttpRequest { } @Override - public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBody content) { - return new Netty4ChunkedHttpResponse(sequence, request.protocolVersion(), status, content); + public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBodyPart firstBodyPart) { + return new Netty4ChunkedHttpResponse(sequence, request.protocolVersion(), status, firstBodyPart); } @Override diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandlerTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandlerTests.java index bb4a0939c98f..4dca3d17bf07 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandlerTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandlerTests.java @@ -36,7 +36,7 @@ import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.bytes.ZeroBytesReference; import org.elasticsearch.common.recycler.Recycler; import org.elasticsearch.http.HttpResponse; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.transport.netty4.Netty4Utils; @@ -502,23 +502,23 @@ public class Netty4HttpPipeliningHandlerTests extends ESTestCase { }; } - private static ChunkedRestResponseBody getRepeatedChunkResponseBody(int chunkCount, BytesReference chunk) { - return new ChunkedRestResponseBody() { + private static ChunkedRestResponseBodyPart getRepeatedChunkResponseBody(int chunkCount, BytesReference chunk) { + return new ChunkedRestResponseBodyPart() { private int remaining = chunkCount; @Override - public boolean isDone() { + public boolean isPartComplete() { return remaining == 0; } @Override - public boolean isEndOfResponse() { + public boolean isLastPart() { return true; } @Override - public void getContinuation(ActionListener listener) { + public void getNextPart(ActionListener listener) { fail("no continuations here"); } diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java index d2be4212cf41..bc6e5fef834e 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java @@ -71,7 +71,7 @@ import org.elasticsearch.http.HttpTransportSettings; import org.elasticsearch.http.NullDispatcher; import org.elasticsearch.http.netty4.internal.HttpHeadersAuthenticatorUtils; import org.elasticsearch.http.netty4.internal.HttpValidator; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestResponse; @@ -692,7 +692,7 @@ public class Netty4HttpServerTransportTests extends AbstractHttpServerTransportT public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) { try { channel.sendResponse( - RestResponse.chunked(OK, ChunkedRestResponseBody.fromXContent(ignored -> Iterators.single((builder, params) -> { + RestResponse.chunked(OK, ChunkedRestResponseBodyPart.fromXContent(ignored -> Iterators.single((builder, params) -> { throw new AssertionError("should not be called for HEAD REQUEST"); }), ToXContent.EMPTY_PARAMS, channel), null) ); @@ -1048,7 +1048,7 @@ public class Netty4HttpServerTransportTests extends AbstractHttpServerTransportT assertEquals(request.uri(), url); final var response = RestResponse.chunked( OK, - ChunkedRestResponseBody.fromTextChunks(RestResponse.TEXT_CONTENT_TYPE, Collections.emptyIterator()), + ChunkedRestResponseBodyPart.fromTextChunks(RestResponse.TEXT_CONTENT_TYPE, Collections.emptyIterator()), responseReleasedLatch::countDown ); transportClosedFuture.addListener(ActionListener.running(() -> channel.sendResponse(response))); diff --git a/muted-tests.yml b/muted-tests.yml index 8a2c56d94c5b..4f6449246637 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -35,12 +35,6 @@ tests: - class: "org.elasticsearch.upgrades.MlTrainedModelsUpgradeIT" issue: "https://github.com/elastic/elasticsearch/issues/108993" method: "testTrainedModelInference" -- class: "org.elasticsearch.xpack.security.authc.esnative.ReservedRealmElasticAutoconfigIntegTests" - issue: "https://github.com/elastic/elasticsearch/issues/109058" - method: "testAutoconfigSucceedsAfterPromotionFailure" -- class: "org.elasticsearch.xpack.security.authc.esnative.ReservedRealmElasticAutoconfigIntegTests" - issue: "https://github.com/elastic/elasticsearch/issues/109059" - method: "testAutoconfigFailedPasswordPromotion" # Examples: # # Mute a single test case in a YAML test suite: diff --git a/qa/apm/build.gradle b/qa/apm/build.gradle index b26efdf1f9a6..ff22334462fd 100644 --- a/qa/apm/build.gradle +++ b/qa/apm/build.gradle @@ -16,8 +16,6 @@ apply plugin: 'elasticsearch.standalone-rest-test' apply plugin: 'elasticsearch.test.fixtures' apply plugin: 'elasticsearch.internal-distribution-download' -testFixtures.useFixture() - dockerCompose { environment.put 'STACK_VERSION', BuildParams.snapshotBuild ? VersionProperties.elasticsearch : VersionProperties.elasticsearch + "-SNAPSHOT" } diff --git a/qa/remote-clusters/build.gradle b/qa/remote-clusters/build.gradle index 0475b7e0eeb8..67f62c0fee04 100644 --- a/qa/remote-clusters/build.gradle +++ b/qa/remote-clusters/build.gradle @@ -15,8 +15,6 @@ apply plugin: 'elasticsearch.standalone-rest-test' apply plugin: 'elasticsearch.test.fixtures' apply plugin: 'elasticsearch.internal-distribution-download' -testFixtures.useFixture() - tasks.register("copyNodeKeyMaterial", Sync) { from project(':x-pack:plugin:core') .files( diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/330_fetch_fields.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/330_fetch_fields.yml index 52b55098ec4d..703f2a0352fb 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/330_fetch_fields.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/330_fetch_fields.yml @@ -262,41 +262,6 @@ - match: { hits.hits.0.fields.date.0: "1990/12/29" } --- -"Test token count": - - requires: - cluster_features: ["gte_v7.10.0"] - reason: "support for token_count was instroduced in 7.10" - - do: - indices.create: - index: test - body: - mappings: - properties: - count: - type: token_count - analyzer: standard - count_without_dv: - type: token_count - analyzer: standard - doc_values: false - - - do: - index: - index: test - id: "1" - refresh: true - body: - count: "some text" - - do: - search: - index: test - body: - fields: [count, count_without_dv] - - - is_true: hits.hits.0._id - - match: { hits.hits.0.fields.count: [2] } - - is_false: hits.hits.0.fields.count_without_dv ---- Test unmapped field: - requires: cluster_features: "gte_v7.11.0" diff --git a/server/src/internalClusterTest/java/org/elasticsearch/discovery/ClusterDisruptionIT.java b/server/src/internalClusterTest/java/org/elasticsearch/discovery/ClusterDisruptionIT.java index 7f94809e64fa..cd9adea500db 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/discovery/ClusterDisruptionIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/discovery/ClusterDisruptionIT.java @@ -16,6 +16,7 @@ import org.elasticsearch.action.NoShardAvailableActionException; import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.UnsafePlainActionFuture; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.action.shard.ShardStateAction; @@ -26,6 +27,7 @@ import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.routing.Murmur3HashFunction; import org.elasticsearch.cluster.routing.ShardRouting; import org.elasticsearch.cluster.routing.ShardRoutingState; +import org.elasticsearch.cluster.service.ClusterApplierService; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; @@ -542,7 +544,7 @@ public class ClusterDisruptionIT extends AbstractDisruptionTestCase { }); final ClusterService dataClusterService = internalCluster().getInstance(ClusterService.class, dataNode); - final PlainActionFuture failedLeader = new PlainActionFuture<>() { + final PlainActionFuture failedLeader = new UnsafePlainActionFuture<>(ClusterApplierService.CLUSTER_UPDATE_THREAD_NAME) { @Override protected boolean blockingAllowed() { // we're deliberately blocking the cluster applier on the master until the data node starts to rejoin diff --git a/server/src/internalClusterTest/java/org/elasticsearch/rest/RestControllerIT.java b/server/src/internalClusterTest/java/org/elasticsearch/rest/RestControllerIT.java index 809ecbc85870..b76bec065273 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/rest/RestControllerIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/rest/RestControllerIT.java @@ -82,7 +82,7 @@ public class RestControllerIT extends ESIntegTestCase { return channel -> { final var response = RestResponse.chunked( RestStatus.OK, - ChunkedRestResponseBody.fromXContent( + ChunkedRestResponseBodyPart.fromXContent( params -> Iterators.single((b, p) -> b.startObject().endObject()), request, channel diff --git a/server/src/internalClusterTest/java/org/elasticsearch/snapshots/CorruptedBlobStoreRepositoryIT.java b/server/src/internalClusterTest/java/org/elasticsearch/snapshots/CorruptedBlobStoreRepositoryIT.java index f507e27c6073..9eb9041aa51f 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/snapshots/CorruptedBlobStoreRepositoryIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/snapshots/CorruptedBlobStoreRepositoryIT.java @@ -299,7 +299,8 @@ public class CorruptedBlobStoreRepositoryIT extends AbstractSnapshotIntegTestCas final ThreadPool threadPool = internalCluster().getCurrentMasterNodeInstance(ThreadPool.class); assertThat( PlainActionFuture.get( - f -> threadPool.generic() + // any other executor than generic and management + f -> threadPool.executor(ThreadPool.Names.SNAPSHOT) .execute( ActionRunnable.supply( f, diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index d372f4ee023b..22460775300f 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -178,6 +178,7 @@ public class TransportVersions { public static final TransportVersion GET_SHUTDOWN_STATUS_TIMEOUT = def(8_669_00_0); public static final TransportVersion FAILURE_STORE_TELEMETRY = def(8_670_00_0); public static final TransportVersion ADD_METADATA_FLATTENED_TO_ROLES = def(8_671_00_0); + public static final TransportVersion ML_INFERENCE_GOOGLE_AI_STUDIO_COMPLETION_ADDED = def(8_672_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/action/support/PlainActionFuture.java b/server/src/main/java/org/elasticsearch/action/support/PlainActionFuture.java index e2b8fcbf2825..938fe4c84480 100644 --- a/server/src/main/java/org/elasticsearch/action/support/PlainActionFuture.java +++ b/server/src/main/java/org/elasticsearch/action/support/PlainActionFuture.java @@ -9,10 +9,12 @@ package org.elasticsearch.action.support; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionFuture; import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.service.ClusterApplierService; import org.elasticsearch.cluster.service.MasterService; +import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.FutureUtils; import org.elasticsearch.common.util.concurrent.UncategorizedExecutionException; import org.elasticsearch.core.CheckedConsumer; @@ -37,6 +39,7 @@ public class PlainActionFuture implements ActionFuture, ActionListener @Override public void onFailure(Exception e) { + assert assertCompleteAllowed(); if (sync.setException(Objects.requireNonNull(e))) { done(false); } @@ -113,6 +116,7 @@ public class PlainActionFuture implements ActionFuture, ActionListener @Override public boolean cancel(boolean mayInterruptIfRunning) { + assert assertCompleteAllowed(); if (sync.cancel() == false) { return false; } @@ -130,6 +134,7 @@ public class PlainActionFuture implements ActionFuture, ActionListener * @return true if the state was successfully changed. */ protected final boolean set(@Nullable T value) { + assert assertCompleteAllowed(); boolean result = sync.set(value); if (result) { done(true); @@ -399,4 +404,27 @@ public class PlainActionFuture implements ActionFuture, ActionListener e.accept(fut); return fut.actionGet(timeout, unit); } + + private boolean assertCompleteAllowed() { + Thread waiter = sync.getFirstQueuedThread(); + // todo: reenable assertion once downstream code is updated + assert true || waiter == null || allowedExecutors(waiter, Thread.currentThread()) + : "cannot complete future on thread " + + Thread.currentThread() + + " with waiter on thread " + + waiter + + ", could deadlock if pool was full\n" + + ExceptionsHelper.formatStackTrace(waiter.getStackTrace()); + return true; + } + + // only used in assertions + boolean allowedExecutors(Thread thread1, Thread thread2) { + // this should only be used to validate thread interactions, like not waiting for a future completed on the same + // executor, hence calling it with the same thread indicates a bug in the assertion using this. + assert thread1 != thread2 : "only call this for different threads"; + String thread1Name = EsExecutors.executorName(thread1); + String thread2Name = EsExecutors.executorName(thread2); + return thread1Name == null || thread2Name == null || thread1Name.equals(thread2Name) == false; + } } diff --git a/server/src/main/java/org/elasticsearch/action/support/UnsafePlainActionFuture.java b/server/src/main/java/org/elasticsearch/action/support/UnsafePlainActionFuture.java new file mode 100644 index 000000000000..2d9585bd26b5 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/action/support/UnsafePlainActionFuture.java @@ -0,0 +1,52 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.action.support; + +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.core.CheckedConsumer; + +import java.util.Objects; + +/** + * An unsafe future. You should not need to use this for new code, rather you should be able to convert that code to be async + * or use a clear hierarchy of thread pool executors around the future. + * + * This future is unsafe, since it allows notifying the future on the same thread pool executor that it is being waited on. This + * is a common deadlock scenario, since all threads may be waiting and thus no thread may be able to complete the future. + */ +@Deprecated(forRemoval = true) +public class UnsafePlainActionFuture extends PlainActionFuture { + + private final String unsafeExecutor; + private final String unsafeExecutor2; + + public UnsafePlainActionFuture(String unsafeExecutor) { + this(unsafeExecutor, null); + } + + public UnsafePlainActionFuture(String unsafeExecutor, String unsafeExecutor2) { + Objects.requireNonNull(unsafeExecutor); + this.unsafeExecutor = unsafeExecutor; + this.unsafeExecutor2 = unsafeExecutor2; + } + + @Override + boolean allowedExecutors(Thread thread1, Thread thread2) { + return super.allowedExecutors(thread1, thread2) + || unsafeExecutor.equals(EsExecutors.executorName(thread1)) + || unsafeExecutor2 == null + || unsafeExecutor2.equals(EsExecutors.executorName(thread1)); + } + + public static T get(CheckedConsumer, E> e, String allowedExecutor) throws E { + PlainActionFuture fut = new UnsafePlainActionFuture<>(allowedExecutor); + e.accept(fut); + return fut.actionGet(); + } +} diff --git a/server/src/main/java/org/elasticsearch/client/internal/support/AbstractClient.java b/server/src/main/java/org/elasticsearch/client/internal/support/AbstractClient.java index 966299408a67..f4e86c8a4eca 100644 --- a/server/src/main/java/org/elasticsearch/client/internal/support/AbstractClient.java +++ b/server/src/main/java/org/elasticsearch/client/internal/support/AbstractClient.java @@ -59,6 +59,7 @@ import org.elasticsearch.action.search.TransportMultiSearchAction; import org.elasticsearch.action.search.TransportSearchAction; import org.elasticsearch.action.search.TransportSearchScrollAction; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.UnsafePlainActionFuture; import org.elasticsearch.action.termvectors.MultiTermVectorsAction; import org.elasticsearch.action.termvectors.MultiTermVectorsRequest; import org.elasticsearch.action.termvectors.MultiTermVectorsRequestBuilder; @@ -410,7 +411,13 @@ public abstract class AbstractClient implements Client { * on the result before it goes out of scope. * @param reference counted result type */ - private static class RefCountedFuture extends PlainActionFuture { + // todo: the use of UnsafePlainActionFuture here is quite broad, we should find a better way to be more specific + // (unless making all usages safe is easy). + private static class RefCountedFuture extends UnsafePlainActionFuture { + + private RefCountedFuture() { + super(ThreadPool.Names.GENERIC); + } @Override public final void onResponse(R result) { diff --git a/server/src/main/java/org/elasticsearch/common/time/DateFormatters.java b/server/src/main/java/org/elasticsearch/common/time/DateFormatters.java index 1133eac3f8f7..55c421b87196 100644 --- a/server/src/main/java/org/elasticsearch/common/time/DateFormatters.java +++ b/server/src/main/java/org/elasticsearch/common/time/DateFormatters.java @@ -53,6 +53,9 @@ public class DateFormatters { * If a string cannot be parsed by the ISO parser, it then tries the java.time one. * If there's lots of these strings, trying the ISO parser, then the java.time parser, might cause a performance drop. * So provide a JVM option so that users can just use the java.time parsers, if they really need to. + *

+ * Note that this property is sometimes set by {@code ESTestCase.setTestSysProps} to flip between implementations in tests, + * to ensure both are fully tested */ @UpdateForV9 // evaluate if we need to deprecate/remove this private static final boolean JAVA_TIME_PARSERS_ONLY = Booleans.parseBoolean(System.getProperty("es.datetime.java_time_parsers"), false); diff --git a/server/src/main/java/org/elasticsearch/common/util/concurrent/EsExecutors.java b/server/src/main/java/org/elasticsearch/common/util/concurrent/EsExecutors.java index 015d3899ab90..9bf381e6f471 100644 --- a/server/src/main/java/org/elasticsearch/common/util/concurrent/EsExecutors.java +++ b/server/src/main/java/org/elasticsearch/common/util/concurrent/EsExecutors.java @@ -276,6 +276,10 @@ public class EsExecutors { return threadName.substring(executorNameStart + 1, executorNameEnd); } + public static String executorName(Thread thread) { + return executorName(thread.getName()); + } + public static ThreadFactory daemonThreadFactory(Settings settings, String namePrefix) { return daemonThreadFactory(threadName(settings, namePrefix)); } diff --git a/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java b/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java index 9719716c57ce..f04b8f13bfe7 100644 --- a/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java +++ b/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java @@ -21,8 +21,8 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.rest.AbstractRestChannel; -import org.elasticsearch.rest.ChunkedRestResponseBody; -import org.elasticsearch.rest.LoggingChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; +import org.elasticsearch.rest.LoggingChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestResponse; import org.elasticsearch.rest.RestStatus; @@ -113,7 +113,7 @@ public class DefaultRestChannel extends AbstractRestChannel { try { final HttpResponse httpResponse; if (isHeadRequest == false && restResponse.isChunked()) { - ChunkedRestResponseBody chunkedContent = restResponse.chunkedContent(); + ChunkedRestResponseBodyPart chunkedContent = restResponse.chunkedContent(); if (httpLogger != null && httpLogger.isBodyTracerEnabled()) { final var loggerStream = httpLogger.openResponseBodyLoggingStream(request.getRequestId()); toClose.add(() -> { @@ -123,7 +123,7 @@ public class DefaultRestChannel extends AbstractRestChannel { assert false : e; // nothing much to go wrong here } }); - chunkedContent = new LoggingChunkedRestResponseBody(chunkedContent, loggerStream); + chunkedContent = new LoggingChunkedRestResponseBodyPart(chunkedContent, loggerStream); } httpResponse = httpRequest.createResponse(restResponse.status(), chunkedContent); diff --git a/server/src/main/java/org/elasticsearch/http/HttpRequest.java b/server/src/main/java/org/elasticsearch/http/HttpRequest.java index b82947e42308..2757fa15ce47 100644 --- a/server/src/main/java/org/elasticsearch/http/HttpRequest.java +++ b/server/src/main/java/org/elasticsearch/http/HttpRequest.java @@ -10,7 +10,7 @@ package org.elasticsearch.http; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.core.Nullable; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestStatus; import java.util.List; @@ -40,7 +40,7 @@ public interface HttpRequest extends HttpPreRequest { */ HttpResponse createResponse(RestStatus status, BytesReference content); - HttpResponse createResponse(RestStatus status, ChunkedRestResponseBody content); + HttpResponse createResponse(RestStatus status, ChunkedRestResponseBodyPart firstBodyPart); @Nullable Exception getInboundException(); diff --git a/server/src/main/java/org/elasticsearch/index/engine/CompletionStatsCache.java b/server/src/main/java/org/elasticsearch/index/engine/CompletionStatsCache.java index f66b85647189..91eea9f6b1b1 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/CompletionStatsCache.java +++ b/server/src/main/java/org/elasticsearch/index/engine/CompletionStatsCache.java @@ -15,10 +15,12 @@ import org.apache.lucene.search.ReferenceManager; import org.apache.lucene.search.suggest.document.CompletionTerms; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.UnsafePlainActionFuture; import org.elasticsearch.common.FieldMemoryStats; import org.elasticsearch.common.regex.Regex; import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.search.suggest.completion.CompletionStats; +import org.elasticsearch.threadpool.ThreadPool; import java.util.HashMap; import java.util.Map; @@ -42,7 +44,7 @@ public class CompletionStatsCache implements ReferenceManager.RefreshListener { } public CompletionStats get(String... fieldNamePatterns) { - final PlainActionFuture newFuture = new PlainActionFuture<>(); + final PlainActionFuture newFuture = new UnsafePlainActionFuture<>(ThreadPool.Names.MANAGEMENT); final PlainActionFuture oldFuture = completionStatsFutureRef.compareAndExchange(null, newFuture); if (oldFuture != null) { diff --git a/server/src/main/java/org/elasticsearch/index/engine/Engine.java b/server/src/main/java/org/elasticsearch/index/engine/Engine.java index 65f47dd3994a..c219e16659c9 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/Engine.java +++ b/server/src/main/java/org/elasticsearch/index/engine/Engine.java @@ -36,6 +36,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.action.support.UnsafePlainActionFuture; import org.elasticsearch.cluster.service.ClusterApplierService; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.logging.Loggers; @@ -75,6 +76,7 @@ import org.elasticsearch.index.store.Store; import org.elasticsearch.index.translog.Translog; import org.elasticsearch.index.translog.TranslogStats; import org.elasticsearch.search.suggest.completion.CompletionStats; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.Transports; import java.io.Closeable; @@ -1956,7 +1958,7 @@ public abstract class Engine implements Closeable { logger.debug("drainForClose(): draining ops"); releaseEnsureOpenRef.close(); - final var future = new PlainActionFuture() { + final var future = new UnsafePlainActionFuture(ThreadPool.Names.GENERIC) { @Override protected boolean blockingAllowed() { // TODO remove this blocking, or at least do it elsewhere, see https://github.com/elastic/elasticsearch/issues/89821 diff --git a/server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBody.java b/server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBodyPart.java similarity index 79% rename from server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBody.java rename to server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBodyPart.java index 2f7fc458ca02..4888b59f1956 100644 --- a/server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBody.java +++ b/server/src/main/java/org/elasticsearch/rest/ChunkedRestResponseBodyPart.java @@ -39,32 +39,32 @@ import java.util.Iterator; * materializing the full response into on-heap buffers up front, instead serializing only as much of the response as can be flushed to the * network right away.

* - *

Each {@link ChunkedRestResponseBody} represents a sequence of chunks that are ready for immediate transmission: if {@link #isDone} - * returns {@code false} then {@link #encodeChunk} can be called at any time and must synchronously return the next chunk to be sent. - * Many HTTP responses will be a single such sequence. However, if an implementation's {@link #isEndOfResponse} returns {@code false} at the - * end of the sequence then the transmission is paused and {@link #getContinuation} is called to compute the next sequence of chunks + *

Each {@link ChunkedRestResponseBodyPart} represents a sequence of chunks that are ready for immediate transmission: if + * {@link #isPartComplete} returns {@code false} then {@link #encodeChunk} can be called at any time and must synchronously return the next + * chunk to be sent. Many HTTP responses will be a single part, but if an implementation's {@link #isLastPart} returns {@code false} at the + * end of the part then the transmission is paused and {@link #getNextPart} is called to compute the next sequence of chunks * asynchronously.

*/ -public interface ChunkedRestResponseBody { +public interface ChunkedRestResponseBodyPart { - Logger logger = LogManager.getLogger(ChunkedRestResponseBody.class); + Logger logger = LogManager.getLogger(ChunkedRestResponseBodyPart.class); /** - * @return {@code true} if this body contains no more chunks and the REST layer should check for a possible continuation by calling - * {@link #isEndOfResponse}, or {@code false} if the REST layer should request another chunk from this body using {@link #encodeChunk}. + * @return {@code true} if this body part contains no more chunks and the REST layer should check for a possible continuation by calling + * {@link #isLastPart}, or {@code false} if the REST layer should request another chunk from this body using {@link #encodeChunk}. */ - boolean isDone(); + boolean isPartComplete(); /** - * @return {@code true} if this is the last chunked body in the response, or {@code false} if the REST layer should request further - * chunked bodies by calling {@link #getContinuation}. + * @return {@code true} if this is the last chunked body part in the response, or {@code false} if the REST layer should request further + * chunked bodies by calling {@link #getNextPart}. */ - boolean isEndOfResponse(); + boolean isLastPart(); /** - *

Asynchronously retrieves the next part of the body. Called if {@link #isEndOfResponse} returns {@code false}.

+ *

Asynchronously retrieves the next part of the response body. Called if {@link #isLastPart} returns {@code false}.

* - *

Note that this is called on a transport thread, so implementations must take care to dispatch any nontrivial work elsewhere.

+ *

Note that this is called on a transport thread: implementations must take care to dispatch any nontrivial work elsewhere.

*

Note that the {@link Task} corresponding to any invocation of {@link Client#execute} completes as soon as the client action * returns its response, so it no longer exists when this method is called and cannot be used to receive cancellation notifications. @@ -78,7 +78,7 @@ public interface ChunkedRestResponseBody { * the body of the response, so there's no good ways to handle an exception here. Completing the listener exceptionally * will log an error, abort sending the response, and close the HTTP connection. */ - void getContinuation(ActionListener listener); + void getNextPart(ActionListener listener); /** * Serializes approximately as many bytes of the response as request by {@code sizeHint} to a {@link ReleasableBytesReference} that @@ -97,17 +97,17 @@ public interface ChunkedRestResponseBody { String getResponseContentTypeString(); /** - * Create a chunked response body to be written to a specific {@link RestChannel} from a {@link ChunkedToXContent}. + * Create a one-part chunked response body to be written to a specific {@link RestChannel} from a {@link ChunkedToXContent}. * * @param chunkedToXContent chunked x-content instance to serialize * @param params parameters to use for serialization * @param channel channel the response will be written to * @return chunked rest response body */ - static ChunkedRestResponseBody fromXContent(ChunkedToXContent chunkedToXContent, ToXContent.Params params, RestChannel channel) + static ChunkedRestResponseBodyPart fromXContent(ChunkedToXContent chunkedToXContent, ToXContent.Params params, RestChannel channel) throws IOException { - return new ChunkedRestResponseBody() { + return new ChunkedRestResponseBodyPart() { private final OutputStream out = new OutputStream() { @Override @@ -135,17 +135,17 @@ public interface ChunkedRestResponseBody { private BytesStream target; @Override - public boolean isDone() { + public boolean isPartComplete() { return serialization.hasNext() == false; } @Override - public boolean isEndOfResponse() { + public boolean isLastPart() { return true; } @Override - public void getContinuation(ActionListener listener) { + public void getNextPart(ActionListener listener) { assert false : "no continuations"; listener.onFailure(new IllegalStateException("no continuations available")); } @@ -191,11 +191,11 @@ public interface ChunkedRestResponseBody { } /** - * Create a chunked response body to be written to a specific {@link RestChannel} from a stream of text chunks, each represented as a - * consumer of a {@link Writer}. + * Create a one-part chunked response body to be written to a specific {@link RestChannel} from a stream of UTF-8-encoded text chunks, + * each represented as a consumer of a {@link Writer}. */ - static ChunkedRestResponseBody fromTextChunks(String contentType, Iterator> chunkIterator) { - return new ChunkedRestResponseBody() { + static ChunkedRestResponseBodyPart fromTextChunks(String contentType, Iterator> chunkIterator) { + return new ChunkedRestResponseBodyPart() { private RecyclerBytesStreamOutput currentOutput; private final Writer writer = new OutputStreamWriter(new OutputStream() { @Override @@ -224,17 +224,17 @@ public interface ChunkedRestResponseBody { }, StandardCharsets.UTF_8); @Override - public boolean isDone() { + public boolean isPartComplete() { return chunkIterator.hasNext() == false; } @Override - public boolean isEndOfResponse() { + public boolean isLastPart() { return true; } @Override - public void getContinuation(ActionListener listener) { + public void getNextPart(ActionListener listener) { assert false : "no continuations"; listener.onFailure(new IllegalStateException("no continuations available")); } diff --git a/server/src/main/java/org/elasticsearch/rest/LoggingChunkedRestResponseBody.java b/server/src/main/java/org/elasticsearch/rest/LoggingChunkedRestResponseBodyPart.java similarity index 68% rename from server/src/main/java/org/elasticsearch/rest/LoggingChunkedRestResponseBody.java rename to server/src/main/java/org/elasticsearch/rest/LoggingChunkedRestResponseBodyPart.java index 865f433e25aa..f7a018eaacf7 100644 --- a/server/src/main/java/org/elasticsearch/rest/LoggingChunkedRestResponseBody.java +++ b/server/src/main/java/org/elasticsearch/rest/LoggingChunkedRestResponseBodyPart.java @@ -16,29 +16,29 @@ import org.elasticsearch.common.recycler.Recycler; import java.io.IOException; import java.io.OutputStream; -public class LoggingChunkedRestResponseBody implements ChunkedRestResponseBody { +public class LoggingChunkedRestResponseBodyPart implements ChunkedRestResponseBodyPart { - private final ChunkedRestResponseBody inner; + private final ChunkedRestResponseBodyPart inner; private final OutputStream loggerStream; - public LoggingChunkedRestResponseBody(ChunkedRestResponseBody inner, OutputStream loggerStream) { + public LoggingChunkedRestResponseBodyPart(ChunkedRestResponseBodyPart inner, OutputStream loggerStream) { this.inner = inner; this.loggerStream = loggerStream; } @Override - public boolean isDone() { - return inner.isDone(); + public boolean isPartComplete() { + return inner.isPartComplete(); } @Override - public boolean isEndOfResponse() { - return inner.isEndOfResponse(); + public boolean isLastPart() { + return inner.isLastPart(); } @Override - public void getContinuation(ActionListener listener) { - inner.getContinuation(listener.map(continuation -> new LoggingChunkedRestResponseBody(continuation, loggerStream))); + public void getNextPart(ActionListener listener) { + inner.getNextPart(listener.map(continuation -> new LoggingChunkedRestResponseBodyPart(continuation, loggerStream))); } @Override diff --git a/server/src/main/java/org/elasticsearch/rest/RestController.java b/server/src/main/java/org/elasticsearch/rest/RestController.java index 0c08520a5dd0..b08f6ed81017 100644 --- a/server/src/main/java/org/elasticsearch/rest/RestController.java +++ b/server/src/main/java/org/elasticsearch/rest/RestController.java @@ -857,7 +857,7 @@ public class RestController implements HttpServerTransport.Dispatcher { final var headers = response.getHeaders(); response = RestResponse.chunked( response.status(), - new EncodedLengthTrackingChunkedRestResponseBody(response.chunkedContent(), responseLengthRecorder), + new EncodedLengthTrackingChunkedRestResponseBodyPart(response.chunkedContent(), responseLengthRecorder), Releasables.wrap(responseLengthRecorder, response) ); for (final var header : headers.entrySet()) { @@ -916,13 +916,13 @@ public class RestController implements HttpServerTransport.Dispatcher { } } - private static class EncodedLengthTrackingChunkedRestResponseBody implements ChunkedRestResponseBody { + private static class EncodedLengthTrackingChunkedRestResponseBodyPart implements ChunkedRestResponseBodyPart { - private final ChunkedRestResponseBody delegate; + private final ChunkedRestResponseBodyPart delegate; private final ResponseLengthRecorder responseLengthRecorder; - private EncodedLengthTrackingChunkedRestResponseBody( - ChunkedRestResponseBody delegate, + private EncodedLengthTrackingChunkedRestResponseBodyPart( + ChunkedRestResponseBodyPart delegate, ResponseLengthRecorder responseLengthRecorder ) { this.delegate = delegate; @@ -930,19 +930,19 @@ public class RestController implements HttpServerTransport.Dispatcher { } @Override - public boolean isDone() { - return delegate.isDone(); + public boolean isPartComplete() { + return delegate.isPartComplete(); } @Override - public boolean isEndOfResponse() { - return delegate.isEndOfResponse(); + public boolean isLastPart() { + return delegate.isLastPart(); } @Override - public void getContinuation(ActionListener listener) { - delegate.getContinuation( - listener.map(continuation -> new EncodedLengthTrackingChunkedRestResponseBody(continuation, responseLengthRecorder)) + public void getNextPart(ActionListener listener) { + delegate.getNextPart( + listener.map(continuation -> new EncodedLengthTrackingChunkedRestResponseBodyPart(continuation, responseLengthRecorder)) ); } @@ -950,7 +950,7 @@ public class RestController implements HttpServerTransport.Dispatcher { public ReleasableBytesReference encodeChunk(int sizeHint, Recycler recycler) throws IOException { final ReleasableBytesReference bytesReference = delegate.encodeChunk(sizeHint, recycler); responseLengthRecorder.addChunkLength(bytesReference.length()); - if (isDone() && isEndOfResponse()) { + if (isPartComplete() && isLastPart()) { responseLengthRecorder.close(); } return bytesReference; diff --git a/server/src/main/java/org/elasticsearch/rest/RestResponse.java b/server/src/main/java/org/elasticsearch/rest/RestResponse.java index 9862ab31bd53..8cc0e35a6480 100644 --- a/server/src/main/java/org/elasticsearch/rest/RestResponse.java +++ b/server/src/main/java/org/elasticsearch/rest/RestResponse.java @@ -48,7 +48,7 @@ public final class RestResponse implements Releasable { private final BytesReference content; @Nullable - private final ChunkedRestResponseBody chunkedResponseBody; + private final ChunkedRestResponseBodyPart chunkedResponseBody; private final String responseMediaType; private Map> customHeaders; @@ -84,9 +84,9 @@ public final class RestResponse implements Releasable { this(status, responseMediaType, content, null, releasable); } - public static RestResponse chunked(RestStatus restStatus, ChunkedRestResponseBody content, @Nullable Releasable releasable) { - if (content.isDone()) { - assert content.isEndOfResponse() : "response with continuations must have at least one (possibly-empty) chunk in each part"; + public static RestResponse chunked(RestStatus restStatus, ChunkedRestResponseBodyPart content, @Nullable Releasable releasable) { + if (content.isPartComplete()) { + assert content.isLastPart() : "response with continuations must have at least one (possibly-empty) chunk in each part"; return new RestResponse(restStatus, content.getResponseContentTypeString(), BytesArray.EMPTY, releasable); } else { return new RestResponse(restStatus, content.getResponseContentTypeString(), null, content, releasable); @@ -100,7 +100,7 @@ public final class RestResponse implements Releasable { RestStatus status, String responseMediaType, @Nullable BytesReference content, - @Nullable ChunkedRestResponseBody chunkedResponseBody, + @Nullable ChunkedRestResponseBodyPart chunkedResponseBody, @Nullable Releasable releasable ) { this.status = status; @@ -162,7 +162,7 @@ public final class RestResponse implements Releasable { } @Nullable - public ChunkedRestResponseBody chunkedContent() { + public ChunkedRestResponseBodyPart chunkedContent() { return chunkedResponseBody; } diff --git a/server/src/main/java/org/elasticsearch/rest/action/RestChunkedToXContentListener.java b/server/src/main/java/org/elasticsearch/rest/action/RestChunkedToXContentListener.java index 3798f2b6b6fb..ef2aa8418eef 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/RestChunkedToXContentListener.java +++ b/server/src/main/java/org/elasticsearch/rest/action/RestChunkedToXContentListener.java @@ -10,7 +10,7 @@ package org.elasticsearch.rest.action; import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.core.Releasable; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestResponse; import org.elasticsearch.rest.RestStatus; @@ -40,7 +40,7 @@ public class RestChunkedToXContentListener e channel.sendResponse( RestResponse.chunked( getRestStatus(response), - ChunkedRestResponseBody.fromXContent(response, params, channel), + ChunkedRestResponseBodyPart.fromXContent(response, params, channel), releasableFromResponse(response) ) ); diff --git a/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestNodesHotThreadsAction.java b/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestNodesHotThreadsAction.java index bcf0d9932559..9cf2d6a2ed39 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestNodesHotThreadsAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestNodesHotThreadsAction.java @@ -27,7 +27,7 @@ import java.io.IOException; import java.util.List; import java.util.Locale; -import static org.elasticsearch.rest.ChunkedRestResponseBody.fromTextChunks; +import static org.elasticsearch.rest.ChunkedRestResponseBodyPart.fromTextChunks; import static org.elasticsearch.rest.RestRequest.Method.GET; import static org.elasticsearch.rest.RestResponse.TEXT_CONTENT_TYPE; import static org.elasticsearch.rest.RestUtils.getTimeout; diff --git a/server/src/main/java/org/elasticsearch/rest/action/cat/RestTable.java b/server/src/main/java/org/elasticsearch/rest/action/cat/RestTable.java index 5999d1b81da4..2f94e3ab90cb 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/cat/RestTable.java +++ b/server/src/main/java/org/elasticsearch/rest/action/cat/RestTable.java @@ -17,7 +17,7 @@ import org.elasticsearch.common.unit.SizeValue; import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.core.Booleans; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestResponse; @@ -63,7 +63,7 @@ public class RestTable { return RestResponse.chunked( RestStatus.OK, - ChunkedRestResponseBody.fromXContent( + ChunkedRestResponseBodyPart.fromXContent( ignored -> Iterators.concat( Iterators.single((builder, params) -> builder.startArray()), Iterators.map(rowOrder.iterator(), row -> (builder, params) -> { @@ -94,7 +94,7 @@ public class RestTable { return RestResponse.chunked( RestStatus.OK, - ChunkedRestResponseBody.fromTextChunks( + ChunkedRestResponseBodyPart.fromTextChunks( RestResponse.TEXT_CONTENT_TYPE, Iterators.concat( // optional header diff --git a/server/src/main/java/org/elasticsearch/rest/action/info/RestClusterInfoAction.java b/server/src/main/java/org/elasticsearch/rest/action/info/RestClusterInfoAction.java index 8be023bb4a18..0a38d59d2972 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/info/RestClusterInfoAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/info/RestClusterInfoAction.java @@ -19,7 +19,7 @@ import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; import org.elasticsearch.http.HttpStats; import org.elasticsearch.ingest.IngestStats; import org.elasticsearch.rest.BaseRestHandler; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestResponse; import org.elasticsearch.rest.RestStatus; @@ -122,7 +122,7 @@ public class RestClusterInfoAction extends BaseRestHandler { return RestResponse.chunked( RestStatus.OK, - ChunkedRestResponseBody.fromXContent( + ChunkedRestResponseBodyPart.fromXContent( outerParams -> Iterators.concat( ChunkedToXContentHelper.startObject(), Iterators.single((builder, params) -> builder.field("cluster_name", response.getClusterName().value())), diff --git a/server/src/main/java/org/elasticsearch/threadpool/ThreadPool.java b/server/src/main/java/org/elasticsearch/threadpool/ThreadPool.java index a887a2be558e..88c507404e76 100644 --- a/server/src/main/java/org/elasticsearch/threadpool/ThreadPool.java +++ b/server/src/main/java/org/elasticsearch/threadpool/ThreadPool.java @@ -152,12 +152,14 @@ public class ThreadPool implements ReportingService, Scheduler { public static final Map THREAD_POOL_TYPES = Map.ofEntries( entry(Names.GENERIC, ThreadPoolType.SCALING), + entry(Names.CLUSTER_COORDINATION, ThreadPoolType.FIXED), entry(Names.GET, ThreadPoolType.FIXED), entry(Names.ANALYZE, ThreadPoolType.FIXED), entry(Names.WRITE, ThreadPoolType.FIXED), entry(Names.SEARCH, ThreadPoolType.FIXED), entry(Names.SEARCH_WORKER, ThreadPoolType.FIXED), entry(Names.SEARCH_COORDINATION, ThreadPoolType.FIXED), + entry(Names.AUTO_COMPLETE, ThreadPoolType.FIXED), entry(Names.MANAGEMENT, ThreadPoolType.SCALING), entry(Names.FLUSH, ThreadPoolType.SCALING), entry(Names.REFRESH, ThreadPoolType.SCALING), diff --git a/server/src/test/java/org/elasticsearch/common/time/DateFormattersTests.java b/server/src/test/java/org/elasticsearch/common/time/DateFormattersTests.java index fa333ddf6b0c..3b0935e8f7b5 100644 --- a/server/src/test/java/org/elasticsearch/common/time/DateFormattersTests.java +++ b/server/src/test/java/org/elasticsearch/common/time/DateFormattersTests.java @@ -12,8 +12,10 @@ import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.common.util.LocaleUtils; import org.elasticsearch.index.mapper.DateFieldMapper; import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matcher; import java.time.Clock; +import java.time.DateTimeException; import java.time.Instant; import java.time.LocalDateTime; import java.time.ZoneId; @@ -39,12 +41,25 @@ import static org.hamcrest.Matchers.sameInstance; public class DateFormattersTests extends ESTestCase { - private IllegalArgumentException assertParseException(String input, String format) { + private void assertParseException(String input, String format) { DateFormatter javaTimeFormatter = DateFormatter.forPattern(format); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> javaTimeFormatter.parse(input)); assertThat(e.getMessage(), containsString(input)); assertThat(e.getMessage(), containsString(format)); - return e; + assertThat(e.getCause(), instanceOf(DateTimeException.class)); + } + + private void assertParseException(String input, String format, int errorIndex) { + assertParseException(input, format, equalTo(errorIndex)); + } + + private void assertParseException(String input, String format, Matcher indexMatcher) { + DateFormatter javaTimeFormatter = DateFormatter.forPattern(format); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> javaTimeFormatter.parse(input)); + assertThat(e.getMessage(), containsString(input)); + assertThat(e.getMessage(), containsString(format)); + assertThat(e.getCause(), instanceOf(DateTimeParseException.class)); + assertThat(((DateTimeParseException) e.getCause()).getErrorIndex(), indexMatcher); } private void assertParses(String input, String format) { @@ -698,7 +713,7 @@ public class DateFormattersTests extends ESTestCase { ES java.time implementation does not suffer from this, but we intentionally not allow parsing timezone without a time part as it is not allowed in iso8601 */ - assertParseException("2016-11-30T+01", "strict_date_optional_time"); + assertParseException("2016-11-30T+01", "strict_date_optional_time", 11); assertParses("2016-11-30T12+01", "strict_date_optional_time"); assertParses("2016-11-30T12:00+01", "strict_date_optional_time"); @@ -792,8 +807,8 @@ public class DateFormattersTests extends ESTestCase { assertParses("2001-01-01T00:00:00.123Z", javaFormatter); assertParses("2001-01-01T00:00:00,123Z", javaFormatter); - assertParseException("2001-01-01T00:00:00.123,456Z", "strict_date_optional_time"); - assertParseException("2001-01-01T00:00:00.123,456Z", "date_optional_time"); + assertParseException("2001-01-01T00:00:00.123,456Z", "strict_date_optional_time", 23); + assertParseException("2001-01-01T00:00:00.123,456Z", "date_optional_time", 23); // This should fail, but java is ok with this because the field has the same value // assertJavaTimeParseException("2001-01-01T00:00:00.123,123Z", "strict_date_optional_time_nanos"); } @@ -911,7 +926,7 @@ public class DateFormattersTests extends ESTestCase { assertParses("2018-12-31T12:12:12.123456789", "date_hour_minute_second_fraction"); assertParses("2018-12-31T12:12:12.1", "date_hour_minute_second_millis"); assertParses("2018-12-31T12:12:12.123", "date_hour_minute_second_millis"); - assertParseException("2018-12-31T12:12:12.123456789", "date_hour_minute_second_millis"); + assertParseException("2018-12-31T12:12:12.123456789", "date_hour_minute_second_millis", 23); assertParses("2018-12-31T12:12:12.1", "date_hour_minute_second_millis"); assertParses("2018-12-31T12:12:12.1", "date_hour_minute_second_fraction"); @@ -981,11 +996,11 @@ public class DateFormattersTests extends ESTestCase { assertParses("12:12:12.123", "hour_minute_second_fraction"); assertParses("12:12:12.123456789", "hour_minute_second_fraction"); assertParses("12:12:12.1", "hour_minute_second_fraction"); - assertParseException("12:12:12", "hour_minute_second_fraction"); + assertParseException("12:12:12", "hour_minute_second_fraction", 8); assertParses("12:12:12.123", "hour_minute_second_millis"); - assertParseException("12:12:12.123456789", "hour_minute_second_millis"); + assertParseException("12:12:12.123456789", "hour_minute_second_millis", 12); assertParses("12:12:12.1", "hour_minute_second_millis"); - assertParseException("12:12:12", "hour_minute_second_millis"); + assertParseException("12:12:12", "hour_minute_second_millis", 8); assertParses("2018-128", "ordinal_date"); assertParses("2018-1", "ordinal_date"); @@ -1025,8 +1040,8 @@ public class DateFormattersTests extends ESTestCase { assertParses("10:15:3.123Z", "time"); assertParses("10:15:3.123+0100", "time"); assertParses("10:15:3.123+01:00", "time"); - assertParseException("10:15:3.1", "time"); - assertParseException("10:15:3Z", "time"); + assertParseException("10:15:3.1", "time", 9); + assertParseException("10:15:3Z", "time", 7); assertParses("10:15:30Z", "time_no_millis"); assertParses("10:15:30+0100", "time_no_millis"); @@ -1043,7 +1058,7 @@ public class DateFormattersTests extends ESTestCase { assertParses("10:15:3Z", "time_no_millis"); assertParses("10:15:3+0100", "time_no_millis"); assertParses("10:15:3+01:00", "time_no_millis"); - assertParseException("10:15:3", "time_no_millis"); + assertParseException("10:15:3", "time_no_millis", 7); assertParses("T10:15:30.1Z", "t_time"); assertParses("T10:15:30.123Z", "t_time"); @@ -1061,8 +1076,8 @@ public class DateFormattersTests extends ESTestCase { assertParses("T10:15:3.123Z", "t_time"); assertParses("T10:15:3.123+0100", "t_time"); assertParses("T10:15:3.123+01:00", "t_time"); - assertParseException("T10:15:3.1", "t_time"); - assertParseException("T10:15:3Z", "t_time"); + assertParseException("T10:15:3.1", "t_time", 10); + assertParseException("T10:15:3Z", "t_time", 8); assertParses("T10:15:30Z", "t_time_no_millis"); assertParses("T10:15:30+0100", "t_time_no_millis"); @@ -1076,12 +1091,12 @@ public class DateFormattersTests extends ESTestCase { assertParses("T10:15:3Z", "t_time_no_millis"); assertParses("T10:15:3+0100", "t_time_no_millis"); assertParses("T10:15:3+01:00", "t_time_no_millis"); - assertParseException("T10:15:3", "t_time_no_millis"); + assertParseException("T10:15:3", "t_time_no_millis", 8); assertParses("2012-W48-6", "week_date"); assertParses("2012-W01-6", "week_date"); assertParses("2012-W1-6", "week_date"); - assertParseException("2012-W1-8", "week_date"); + assertParseException("2012-W1-8", "week_date", 0); assertParses("2012-W48-6T10:15:30.1Z", "week_date_time"); assertParses("2012-W48-6T10:15:30.123Z", "week_date_time"); @@ -1135,17 +1150,12 @@ public class DateFormattersTests extends ESTestCase { } public void testExceptionWhenCompositeParsingFails() { - assertParseException("2014-06-06T12:01:02.123", "yyyy-MM-dd'T'HH:mm:ss||yyyy-MM-dd'T'HH:mm:ss.SS"); - } - - public void testExceptionErrorIndex() { - Exception e = assertParseException("2024-01-01j", "iso8601||strict_date_optional_time"); - assertThat(((DateTimeParseException) e.getCause()).getErrorIndex(), equalTo(10)); + assertParseException("2014-06-06T12:01:02.123", "yyyy-MM-dd'T'HH:mm:ss||yyyy-MM-dd'T'HH:mm:ss.SS", 19); } public void testStrictParsing() { assertParses("2018W313", "strict_basic_week_date"); - assertParseException("18W313", "strict_basic_week_date"); + assertParseException("18W313", "strict_basic_week_date", 0); assertParses("2018W313T121212.1Z", "strict_basic_week_date_time"); assertParses("2018W313T121212.123Z", "strict_basic_week_date_time"); assertParses("2018W313T121212.123456789Z", "strict_basic_week_date_time"); @@ -1153,52 +1163,52 @@ public class DateFormattersTests extends ESTestCase { assertParses("2018W313T121212.123+0100", "strict_basic_week_date_time"); assertParses("2018W313T121212.1+01:00", "strict_basic_week_date_time"); assertParses("2018W313T121212.123+01:00", "strict_basic_week_date_time"); - assertParseException("2018W313T12128.123Z", "strict_basic_week_date_time"); - assertParseException("2018W313T12128.123456789Z", "strict_basic_week_date_time"); - assertParseException("2018W313T81212.123Z", "strict_basic_week_date_time"); - assertParseException("2018W313T12812.123Z", "strict_basic_week_date_time"); - assertParseException("2018W313T12812.1Z", "strict_basic_week_date_time"); + assertParseException("2018W313T12128.123Z", "strict_basic_week_date_time", 13); + assertParseException("2018W313T12128.123456789Z", "strict_basic_week_date_time", 13); + assertParseException("2018W313T81212.123Z", "strict_basic_week_date_time", 13); + assertParseException("2018W313T12812.123Z", "strict_basic_week_date_time", 13); + assertParseException("2018W313T12812.1Z", "strict_basic_week_date_time", 13); assertParses("2018W313T121212Z", "strict_basic_week_date_time_no_millis"); assertParses("2018W313T121212+0100", "strict_basic_week_date_time_no_millis"); assertParses("2018W313T121212+01:00", "strict_basic_week_date_time_no_millis"); - assertParseException("2018W313T12128Z", "strict_basic_week_date_time_no_millis"); - assertParseException("2018W313T12128+0100", "strict_basic_week_date_time_no_millis"); - assertParseException("2018W313T12128+01:00", "strict_basic_week_date_time_no_millis"); - assertParseException("2018W313T81212Z", "strict_basic_week_date_time_no_millis"); - assertParseException("2018W313T81212+0100", "strict_basic_week_date_time_no_millis"); - assertParseException("2018W313T81212+01:00", "strict_basic_week_date_time_no_millis"); - assertParseException("2018W313T12812Z", "strict_basic_week_date_time_no_millis"); - assertParseException("2018W313T12812+0100", "strict_basic_week_date_time_no_millis"); - assertParseException("2018W313T12812+01:00", "strict_basic_week_date_time_no_millis"); + assertParseException("2018W313T12128Z", "strict_basic_week_date_time_no_millis", 13); + assertParseException("2018W313T12128+0100", "strict_basic_week_date_time_no_millis", 13); + assertParseException("2018W313T12128+01:00", "strict_basic_week_date_time_no_millis", 13); + assertParseException("2018W313T81212Z", "strict_basic_week_date_time_no_millis", 13); + assertParseException("2018W313T81212+0100", "strict_basic_week_date_time_no_millis", 13); + assertParseException("2018W313T81212+01:00", "strict_basic_week_date_time_no_millis", 13); + assertParseException("2018W313T12812Z", "strict_basic_week_date_time_no_millis", 13); + assertParseException("2018W313T12812+0100", "strict_basic_week_date_time_no_millis", 13); + assertParseException("2018W313T12812+01:00", "strict_basic_week_date_time_no_millis", 13); assertParses("2018-12-31", "strict_date"); - assertParseException("10000-12-31", "strict_date"); - assertParseException("2018-8-31", "strict_date"); + assertParseException("10000-12-31", "strict_date", 0); + assertParseException("2018-8-31", "strict_date", 5); assertParses("2018-12-31T12", "strict_date_hour"); - assertParseException("2018-12-31T8", "strict_date_hour"); + assertParseException("2018-12-31T8", "strict_date_hour", 11); assertParses("2018-12-31T12:12", "strict_date_hour_minute"); - assertParseException("2018-12-31T8:3", "strict_date_hour_minute"); + assertParseException("2018-12-31T8:3", "strict_date_hour_minute", 11); assertParses("2018-12-31T12:12:12", "strict_date_hour_minute_second"); - assertParseException("2018-12-31T12:12:1", "strict_date_hour_minute_second"); + assertParseException("2018-12-31T12:12:1", "strict_date_hour_minute_second", 17); assertParses("2018-12-31T12:12:12.1", "strict_date_hour_minute_second_fraction"); assertParses("2018-12-31T12:12:12.123", "strict_date_hour_minute_second_fraction"); assertParses("2018-12-31T12:12:12.123456789", "strict_date_hour_minute_second_fraction"); assertParses("2018-12-31T12:12:12.123", "strict_date_hour_minute_second_millis"); assertParses("2018-12-31T12:12:12.1", "strict_date_hour_minute_second_millis"); assertParses("2018-12-31T12:12:12.1", "strict_date_hour_minute_second_fraction"); - assertParseException("2018-12-31T12:12:12", "strict_date_hour_minute_second_millis"); - assertParseException("2018-12-31T12:12:12", "strict_date_hour_minute_second_fraction"); + assertParseException("2018-12-31T12:12:12", "strict_date_hour_minute_second_millis", 19); + assertParseException("2018-12-31T12:12:12", "strict_date_hour_minute_second_fraction", 19); assertParses("2018-12-31", "strict_date_optional_time"); - assertParseException("2018-12-1", "strict_date_optional_time"); - assertParseException("2018-1-31", "strict_date_optional_time"); - assertParseException("10000-01-31", "strict_date_optional_time"); + assertParseException("2018-12-1", "strict_date_optional_time", 7); + assertParseException("2018-1-31", "strict_date_optional_time", 4); + assertParseException("10000-01-31", "strict_date_optional_time", 4); assertParses("2010-01-05T02:00", "strict_date_optional_time"); assertParses("2018-12-31T10:15:30", "strict_date_optional_time"); assertParses("2018-12-31T10:15:30Z", "strict_date_optional_time"); assertParses("2018-12-31T10:15:30+0100", "strict_date_optional_time"); assertParses("2018-12-31T10:15:30+01:00", "strict_date_optional_time"); - assertParseException("2018-12-31T10:15:3", "strict_date_optional_time"); - assertParseException("2018-12-31T10:5:30", "strict_date_optional_time"); - assertParseException("2018-12-31T9:15:30", "strict_date_optional_time"); + assertParseException("2018-12-31T10:15:3", "strict_date_optional_time", 16); + assertParseException("2018-12-31T10:5:30", "strict_date_optional_time", 13); + assertParseException("2018-12-31T9:15:30", "strict_date_optional_time", 11); assertParses("2015-01-04T00:00Z", "strict_date_optional_time"); assertParses("2018-12-31T10:15:30.1Z", "strict_date_time"); assertParses("2018-12-31T10:15:30.123Z", "strict_date_time"); @@ -1210,33 +1220,33 @@ public class DateFormattersTests extends ESTestCase { assertParses("2018-12-31T10:15:30.11Z", "strict_date_time"); assertParses("2018-12-31T10:15:30.11+0100", "strict_date_time"); assertParses("2018-12-31T10:15:30.11+01:00", "strict_date_time"); - assertParseException("2018-12-31T10:15:3.123Z", "strict_date_time"); - assertParseException("2018-12-31T10:5:30.123Z", "strict_date_time"); - assertParseException("2018-12-31T1:15:30.123Z", "strict_date_time"); + assertParseException("2018-12-31T10:15:3.123Z", "strict_date_time", 17); + assertParseException("2018-12-31T10:5:30.123Z", "strict_date_time", 14); + assertParseException("2018-12-31T1:15:30.123Z", "strict_date_time", 11); assertParses("2018-12-31T10:15:30Z", "strict_date_time_no_millis"); assertParses("2018-12-31T10:15:30+0100", "strict_date_time_no_millis"); assertParses("2018-12-31T10:15:30+01:00", "strict_date_time_no_millis"); - assertParseException("2018-12-31T10:5:30Z", "strict_date_time_no_millis"); - assertParseException("2018-12-31T10:15:3Z", "strict_date_time_no_millis"); - assertParseException("2018-12-31T1:15:30Z", "strict_date_time_no_millis"); + assertParseException("2018-12-31T10:5:30Z", "strict_date_time_no_millis", 14); + assertParseException("2018-12-31T10:15:3Z", "strict_date_time_no_millis", 17); + assertParseException("2018-12-31T1:15:30Z", "strict_date_time_no_millis", 11); assertParses("12", "strict_hour"); assertParses("01", "strict_hour"); - assertParseException("1", "strict_hour"); + assertParseException("1", "strict_hour", 0); assertParses("12:12", "strict_hour_minute"); assertParses("12:01", "strict_hour_minute"); - assertParseException("12:1", "strict_hour_minute"); + assertParseException("12:1", "strict_hour_minute", 3); assertParses("12:12:12", "strict_hour_minute_second"); assertParses("12:12:01", "strict_hour_minute_second"); - assertParseException("12:12:1", "strict_hour_minute_second"); + assertParseException("12:12:1", "strict_hour_minute_second", 6); assertParses("12:12:12.123", "strict_hour_minute_second_fraction"); assertParses("12:12:12.123456789", "strict_hour_minute_second_fraction"); assertParses("12:12:12.1", "strict_hour_minute_second_fraction"); - assertParseException("12:12:12", "strict_hour_minute_second_fraction"); + assertParseException("12:12:12", "strict_hour_minute_second_fraction", 8); assertParses("12:12:12.123", "strict_hour_minute_second_millis"); assertParses("12:12:12.1", "strict_hour_minute_second_millis"); - assertParseException("12:12:12", "strict_hour_minute_second_millis"); + assertParseException("12:12:12", "strict_hour_minute_second_millis", 8); assertParses("2018-128", "strict_ordinal_date"); - assertParseException("2018-1", "strict_ordinal_date"); + assertParseException("2018-1", "strict_ordinal_date", 5); assertParses("2018-128T10:15:30.1Z", "strict_ordinal_date_time"); assertParses("2018-128T10:15:30.123Z", "strict_ordinal_date_time"); @@ -1245,23 +1255,23 @@ public class DateFormattersTests extends ESTestCase { assertParses("2018-128T10:15:30.123+0100", "strict_ordinal_date_time"); assertParses("2018-128T10:15:30.1+01:00", "strict_ordinal_date_time"); assertParses("2018-128T10:15:30.123+01:00", "strict_ordinal_date_time"); - assertParseException("2018-1T10:15:30.123Z", "strict_ordinal_date_time"); + assertParseException("2018-1T10:15:30.123Z", "strict_ordinal_date_time", 5); assertParses("2018-128T10:15:30Z", "strict_ordinal_date_time_no_millis"); assertParses("2018-128T10:15:30+0100", "strict_ordinal_date_time_no_millis"); assertParses("2018-128T10:15:30+01:00", "strict_ordinal_date_time_no_millis"); - assertParseException("2018-1T10:15:30Z", "strict_ordinal_date_time_no_millis"); + assertParseException("2018-1T10:15:30Z", "strict_ordinal_date_time_no_millis", 5); assertParses("10:15:30.1Z", "strict_time"); assertParses("10:15:30.123Z", "strict_time"); assertParses("10:15:30.123456789Z", "strict_time"); assertParses("10:15:30.123+0100", "strict_time"); assertParses("10:15:30.123+01:00", "strict_time"); - assertParseException("1:15:30.123Z", "strict_time"); - assertParseException("10:1:30.123Z", "strict_time"); - assertParseException("10:15:3.123Z", "strict_time"); - assertParseException("10:15:3.1", "strict_time"); - assertParseException("10:15:3Z", "strict_time"); + assertParseException("1:15:30.123Z", "strict_time", 0); + assertParseException("10:1:30.123Z", "strict_time", 3); + assertParseException("10:15:3.123Z", "strict_time", 6); + assertParseException("10:15:3.1", "strict_time", 6); + assertParseException("10:15:3Z", "strict_time", 6); assertParses("10:15:30Z", "strict_time_no_millis"); assertParses("10:15:30+0100", "strict_time_no_millis"); @@ -1269,10 +1279,10 @@ public class DateFormattersTests extends ESTestCase { assertParses("01:15:30Z", "strict_time_no_millis"); assertParses("01:15:30+0100", "strict_time_no_millis"); assertParses("01:15:30+01:00", "strict_time_no_millis"); - assertParseException("1:15:30Z", "strict_time_no_millis"); - assertParseException("10:5:30Z", "strict_time_no_millis"); - assertParseException("10:15:3Z", "strict_time_no_millis"); - assertParseException("10:15:3", "strict_time_no_millis"); + assertParseException("1:15:30Z", "strict_time_no_millis", 0); + assertParseException("10:5:30Z", "strict_time_no_millis", 3); + assertParseException("10:15:3Z", "strict_time_no_millis", 6); + assertParseException("10:15:3", "strict_time_no_millis", 6); assertParses("T10:15:30.1Z", "strict_t_time"); assertParses("T10:15:30.123Z", "strict_t_time"); @@ -1281,28 +1291,28 @@ public class DateFormattersTests extends ESTestCase { assertParses("T10:15:30.123+0100", "strict_t_time"); assertParses("T10:15:30.1+01:00", "strict_t_time"); assertParses("T10:15:30.123+01:00", "strict_t_time"); - assertParseException("T1:15:30.123Z", "strict_t_time"); - assertParseException("T10:1:30.123Z", "strict_t_time"); - assertParseException("T10:15:3.123Z", "strict_t_time"); - assertParseException("T10:15:3.1", "strict_t_time"); - assertParseException("T10:15:3Z", "strict_t_time"); + assertParseException("T1:15:30.123Z", "strict_t_time", 1); + assertParseException("T10:1:30.123Z", "strict_t_time", 4); + assertParseException("T10:15:3.123Z", "strict_t_time", 7); + assertParseException("T10:15:3.1", "strict_t_time", 7); + assertParseException("T10:15:3Z", "strict_t_time", 7); assertParses("T10:15:30Z", "strict_t_time_no_millis"); assertParses("T10:15:30+0100", "strict_t_time_no_millis"); assertParses("T10:15:30+01:00", "strict_t_time_no_millis"); - assertParseException("T1:15:30Z", "strict_t_time_no_millis"); - assertParseException("T10:1:30Z", "strict_t_time_no_millis"); - assertParseException("T10:15:3Z", "strict_t_time_no_millis"); - assertParseException("T10:15:3", "strict_t_time_no_millis"); + assertParseException("T1:15:30Z", "strict_t_time_no_millis", 1); + assertParseException("T10:1:30Z", "strict_t_time_no_millis", 4); + assertParseException("T10:15:3Z", "strict_t_time_no_millis", 7); + assertParseException("T10:15:3", "strict_t_time_no_millis", 7); assertParses("2012-W48-6", "strict_week_date"); assertParses("2012-W01-6", "strict_week_date"); - assertParseException("2012-W1-6", "strict_week_date"); - assertParseException("2012-W1-8", "strict_week_date"); + assertParseException("2012-W1-6", "strict_week_date", 6); + assertParseException("2012-W1-8", "strict_week_date", 6); assertParses("2012-W48-6", "strict_week_date"); assertParses("2012-W01-6", "strict_week_date"); - assertParseException("2012-W1-6", "strict_week_date"); + assertParseException("2012-W1-6", "strict_week_date", 6); assertParseException("2012-W01-8", "strict_week_date"); assertParses("2012-W48-6T10:15:30.1Z", "strict_week_date_time"); @@ -1312,38 +1322,38 @@ public class DateFormattersTests extends ESTestCase { assertParses("2012-W48-6T10:15:30.123+0100", "strict_week_date_time"); assertParses("2012-W48-6T10:15:30.1+01:00", "strict_week_date_time"); assertParses("2012-W48-6T10:15:30.123+01:00", "strict_week_date_time"); - assertParseException("2012-W1-6T10:15:30.123Z", "strict_week_date_time"); + assertParseException("2012-W1-6T10:15:30.123Z", "strict_week_date_time", 6); assertParses("2012-W48-6T10:15:30Z", "strict_week_date_time_no_millis"); assertParses("2012-W48-6T10:15:30+0100", "strict_week_date_time_no_millis"); assertParses("2012-W48-6T10:15:30+01:00", "strict_week_date_time_no_millis"); - assertParseException("2012-W1-6T10:15:30Z", "strict_week_date_time_no_millis"); + assertParseException("2012-W1-6T10:15:30Z", "strict_week_date_time_no_millis", 6); assertParses("2012", "strict_year"); - assertParseException("1", "strict_year"); + assertParseException("1", "strict_year", 0); assertParses("-2000", "strict_year"); assertParses("2012-12", "strict_year_month"); - assertParseException("1-1", "strict_year_month"); + assertParseException("1-1", "strict_year_month", 0); assertParses("2012-12-31", "strict_year_month_day"); - assertParseException("1-12-31", "strict_year_month_day"); - assertParseException("2012-1-31", "strict_year_month_day"); - assertParseException("2012-12-1", "strict_year_month_day"); + assertParseException("1-12-31", "strict_year_month_day", 0); + assertParseException("2012-1-31", "strict_year_month_day", 4); + assertParseException("2012-12-1", "strict_year_month_day", 7); assertParses("2018", "strict_weekyear"); - assertParseException("1", "strict_weekyear"); + assertParseException("1", "strict_weekyear", 0); assertParses("2018", "strict_weekyear"); assertParses("2017", "strict_weekyear"); - assertParseException("1", "strict_weekyear"); + assertParseException("1", "strict_weekyear", 0); assertParses("2018-W29", "strict_weekyear_week"); assertParses("2018-W01", "strict_weekyear_week"); - assertParseException("2018-W1", "strict_weekyear_week"); + assertParseException("2018-W1", "strict_weekyear_week", 6); assertParses("2012-W31-5", "strict_weekyear_week_day"); - assertParseException("2012-W1-1", "strict_weekyear_week_day"); + assertParseException("2012-W1-1", "strict_weekyear_week_day", 6); } public void testDateFormatterWithLocale() { diff --git a/server/src/test/java/org/elasticsearch/common/util/concurrent/EsExecutorsTests.java b/server/src/test/java/org/elasticsearch/common/util/concurrent/EsExecutorsTests.java index 037df07d1e07..df7b02c2309a 100644 --- a/server/src/test/java/org/elasticsearch/common/util/concurrent/EsExecutorsTests.java +++ b/server/src/test/java/org/elasticsearch/common/util/concurrent/EsExecutorsTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.Processors; +import org.elasticsearch.node.Node; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; import org.hamcrest.Matcher; @@ -22,6 +23,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.CyclicBarrier; import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; +import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -628,7 +630,19 @@ public class EsExecutorsTests extends ESTestCase { public void testParseExecutorName() throws InterruptedException { final var executorName = randomAlphaOfLength(10); - final var threadFactory = EsExecutors.daemonThreadFactory(rarely() ? null : randomAlphaOfLength(10), executorName); + final String nodeName = rarely() ? null : randomIdentifier(); + final ThreadFactory threadFactory; + if (nodeName == null) { + threadFactory = EsExecutors.daemonThreadFactory(Settings.EMPTY, executorName); + } else if (randomBoolean()) { + threadFactory = EsExecutors.daemonThreadFactory( + Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), nodeName).build(), + executorName + ); + } else { + threadFactory = EsExecutors.daemonThreadFactory(nodeName, executorName); + } + final var thread = threadFactory.newThread(() -> {}); try { assertThat(EsExecutors.executorName(thread.getName()), equalTo(executorName)); diff --git a/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java b/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java index f12d8ea5c631..d49347a0dd3f 100644 --- a/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java +++ b/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java @@ -31,7 +31,7 @@ import org.elasticsearch.common.util.MockPageCacheRecycler; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.Releasable; import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestResponse; @@ -530,20 +530,20 @@ public class DefaultRestChannelTests extends ESTestCase { { // chunked response final var isClosed = new AtomicBoolean(); - channel.sendResponse(RestResponse.chunked(RestStatus.OK, new ChunkedRestResponseBody() { + channel.sendResponse(RestResponse.chunked(RestStatus.OK, new ChunkedRestResponseBodyPart() { @Override - public boolean isDone() { + public boolean isPartComplete() { return false; } @Override - public boolean isEndOfResponse() { + public boolean isLastPart() { throw new AssertionError("should not check for end-of-response for HEAD request"); } @Override - public void getContinuation(ActionListener listener) { + public void getNextPart(ActionListener listener) { throw new AssertionError("should not get any continuations for HEAD request"); } @@ -688,25 +688,25 @@ public class DefaultRestChannelTests extends ESTestCase { HttpRequest httpRequest = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/") { @Override - public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBody content) { + public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBodyPart firstBodyPart) { try (var bso = new BytesStreamOutput()) { - writeContent(bso, content); + writeContent(bso, firstBodyPart); return new TestHttpResponse(status, bso.bytes()); } catch (IOException e) { return fail(e); } } - private static void writeContent(OutputStream bso, ChunkedRestResponseBody content) throws IOException { - while (content.isDone() == false) { + private static void writeContent(OutputStream bso, ChunkedRestResponseBodyPart content) throws IOException { + while (content.isPartComplete() == false) { try (var bytes = content.encodeChunk(1 << 14, BytesRefRecycler.NON_RECYCLING_INSTANCE)) { bytes.writeTo(bso); } } - if (content.isEndOfResponse()) { + if (content.isLastPart()) { return; } - writeContent(bso, PlainActionFuture.get(content::getContinuation)); + writeContent(bso, PlainActionFuture.get(content::getNextPart)); } }; @@ -735,14 +735,14 @@ public class DefaultRestChannelTests extends ESTestCase { ) ); - final var parts = new ArrayList(); - class TestBody implements ChunkedRestResponseBody { + final var parts = new ArrayList(); + class TestBodyPart implements ChunkedRestResponseBodyPart { boolean isDone; final BytesReference thisChunk; final BytesReference remainingChunks; final int remainingContinuations; - TestBody(BytesReference content, int remainingContinuations) { + TestBodyPart(BytesReference content, int remainingContinuations) { if (remainingContinuations == 0) { thisChunk = content; remainingChunks = BytesArray.EMPTY; @@ -755,18 +755,18 @@ public class DefaultRestChannelTests extends ESTestCase { } @Override - public boolean isDone() { + public boolean isPartComplete() { return isDone; } @Override - public boolean isEndOfResponse() { + public boolean isLastPart() { return remainingContinuations == 0; } @Override - public void getContinuation(ActionListener listener) { - final var continuation = new TestBody(remainingChunks, remainingContinuations - 1); + public void getNextPart(ActionListener listener) { + final var continuation = new TestBodyPart(remainingChunks, remainingContinuations - 1); parts.add(continuation); listener.onResponse(continuation); } @@ -785,7 +785,7 @@ public class DefaultRestChannelTests extends ESTestCase { } final var isClosed = new AtomicBoolean(); - final var firstPart = new TestBody(responseBody, between(0, 3)); + final var firstPart = new TestBodyPart(responseBody, between(0, 3)); parts.add(firstPart); assertEquals( responseBody, @@ -797,8 +797,8 @@ public class DefaultRestChannelTests extends ESTestCase { () -> channel.sendResponse(RestResponse.chunked(RestStatus.OK, firstPart, () -> { assertTrue(isClosed.compareAndSet(false, true)); for (int i = 0; i < parts.size(); i++) { - assertTrue("isDone " + i, parts.get(i).isDone()); - assertEquals("isEndOfResponse " + i, i == parts.size() - 1, parts.get(i).isEndOfResponse()); + assertTrue("isPartComplete " + i, parts.get(i).isPartComplete()); + assertEquals("isLastPart " + i, i == parts.size() - 1, parts.get(i).isLastPart()); } })) ) diff --git a/server/src/test/java/org/elasticsearch/http/TestHttpRequest.java b/server/src/test/java/org/elasticsearch/http/TestHttpRequest.java index 4e30dde5e5e7..e7b0232afa24 100644 --- a/server/src/test/java/org/elasticsearch/http/TestHttpRequest.java +++ b/server/src/test/java/org/elasticsearch/http/TestHttpRequest.java @@ -10,7 +10,7 @@ package org.elasticsearch.http; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestStatus; @@ -78,7 +78,7 @@ class TestHttpRequest implements HttpRequest { } @Override - public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBody content) { + public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBodyPart firstBodyPart) { throw new UnsupportedOperationException("chunked responses not supported"); } diff --git a/server/src/test/java/org/elasticsearch/rest/ChunkedRestResponseBodyTests.java b/server/src/test/java/org/elasticsearch/rest/ChunkedRestResponseBodyPartTests.java similarity index 81% rename from server/src/test/java/org/elasticsearch/rest/ChunkedRestResponseBodyTests.java rename to server/src/test/java/org/elasticsearch/rest/ChunkedRestResponseBodyPartTests.java index cce2a8db25c8..9c703d83e7d0 100644 --- a/server/src/test/java/org/elasticsearch/rest/ChunkedRestResponseBodyTests.java +++ b/server/src/test/java/org/elasticsearch/rest/ChunkedRestResponseBodyPartTests.java @@ -30,7 +30,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -public class ChunkedRestResponseBodyTests extends ESTestCase { +public class ChunkedRestResponseBodyPartTests extends ESTestCase { public void testEncodesChunkedXContentCorrectly() throws IOException { final ChunkedToXContent chunkedToXContent = (ToXContent.Params outerParams) -> Iterators.forArray( @@ -50,7 +50,7 @@ public class ChunkedRestResponseBodyTests extends ESTestCase { } final var bytesDirect = BytesReference.bytes(builderDirect); - var chunkedResponse = ChunkedRestResponseBody.fromXContent( + var firstBodyPart = ChunkedRestResponseBodyPart.fromXContent( chunkedToXContent, ToXContent.EMPTY_PARAMS, new FakeRestChannel( @@ -61,20 +61,25 @@ public class ChunkedRestResponseBodyTests extends ESTestCase { ); final List refsGenerated = new ArrayList<>(); - while (chunkedResponse.isDone() == false) { - refsGenerated.add(chunkedResponse.encodeChunk(randomIntBetween(2, 10), BytesRefRecycler.NON_RECYCLING_INSTANCE)); + while (firstBodyPart.isPartComplete() == false) { + refsGenerated.add(firstBodyPart.encodeChunk(randomIntBetween(2, 10), BytesRefRecycler.NON_RECYCLING_INSTANCE)); } + assertTrue(firstBodyPart.isLastPart()); assertEquals(bytesDirect, CompositeBytesReference.of(refsGenerated.toArray(new BytesReference[0]))); } public void testFromTextChunks() throws IOException { final var chunks = randomList(1000, () -> randomUnicodeOfLengthBetween(1, 100)); - var body = ChunkedRestResponseBody.fromTextChunks("text/plain", Iterators.map(chunks.iterator(), s -> w -> w.write(s))); + var firstBodyPart = ChunkedRestResponseBodyPart.fromTextChunks( + "text/plain", + Iterators.map(chunks.iterator(), s -> w -> w.write(s)) + ); final List refsGenerated = new ArrayList<>(); - while (body.isDone() == false) { - refsGenerated.add(body.encodeChunk(randomIntBetween(2, 10), BytesRefRecycler.NON_RECYCLING_INSTANCE)); + while (firstBodyPart.isPartComplete() == false) { + refsGenerated.add(firstBodyPart.encodeChunk(randomIntBetween(2, 10), BytesRefRecycler.NON_RECYCLING_INSTANCE)); } + assertTrue(firstBodyPart.isLastPart()); final BytesReference chunkedBytes = CompositeBytesReference.of(refsGenerated.toArray(new BytesReference[0])); try (var outputStream = new ByteArrayOutputStream(); var writer = new OutputStreamWriter(outputStream, StandardCharsets.UTF_8)) { diff --git a/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java b/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java index 37300f1c19b1..10ea83e59c0a 100644 --- a/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java +++ b/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java @@ -733,7 +733,7 @@ public class RestControllerTests extends ESTestCase { } @Override - public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBody content) { + public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBodyPart firstBodyPart) { throw new AssertionError("should not be called"); } diff --git a/server/src/test/java/org/elasticsearch/rest/RestResponseTests.java b/server/src/test/java/org/elasticsearch/rest/RestResponseTests.java index 41a54ac580a5..eaef60e15822 100644 --- a/server/src/test/java/org/elasticsearch/rest/RestResponseTests.java +++ b/server/src/test/java/org/elasticsearch/rest/RestResponseTests.java @@ -97,7 +97,7 @@ public class RestResponseTests extends ESTestCase { public void testEmptyChunkedBody() { RestResponse response = RestResponse.chunked( RestStatus.OK, - ChunkedRestResponseBody.fromTextChunks(RestResponse.TEXT_CONTENT_TYPE, Collections.emptyIterator()), + ChunkedRestResponseBodyPart.fromTextChunks(RestResponse.TEXT_CONTENT_TYPE, Collections.emptyIterator()), null ); assertFalse(response.isChunked()); diff --git a/server/src/test/java/org/elasticsearch/rest/action/cat/RestTableTests.java b/server/src/test/java/org/elasticsearch/rest/action/cat/RestTableTests.java index dff6b52e470d..cb98eaddb77c 100644 --- a/server/src/test/java/org/elasticsearch/rest/action/cat/RestTableTests.java +++ b/server/src/test/java/org/elasticsearch/rest/action/cat/RestTableTests.java @@ -432,14 +432,15 @@ public class RestTableTests extends ESTestCase { }; final var bodyChunks = new ArrayList(); - final var chunkedRestResponseBody = response.chunkedContent(); + final var firstBodyPart = response.chunkedContent(); - while (chunkedRestResponseBody.isDone() == false) { - try (var chunk = chunkedRestResponseBody.encodeChunk(pageSize, recycler)) { + while (firstBodyPart.isPartComplete() == false) { + try (var chunk = firstBodyPart.encodeChunk(pageSize, recycler)) { assertThat(chunk.length(), greaterThan(0)); bodyChunks.add(chunk.utf8ToString()); } } + assertTrue(firstBodyPart.isLastPart()); assertEquals(0, openPages.get()); return bodyChunks; } diff --git a/test/framework/src/main/java/org/elasticsearch/index/shard/IndexShardTestCase.java b/test/framework/src/main/java/org/elasticsearch/index/shard/IndexShardTestCase.java index a389020cdcde..442a8c3b82dc 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/shard/IndexShardTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/index/shard/IndexShardTestCase.java @@ -13,6 +13,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.admin.indices.flush.FlushRequest; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.UnsafePlainActionFuture; import org.elasticsearch.action.support.replication.TransportReplicationAction; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.MappingMetadata; @@ -869,7 +870,7 @@ public abstract class IndexShardTestCase extends ESTestCase { routingTable ); try { - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture future = new UnsafePlainActionFuture<>(ThreadPool.Names.GENERIC); recovery.recoverToTarget(future); future.actionGet(); recoveryTarget.markAsDone(); diff --git a/test/framework/src/main/java/org/elasticsearch/rest/RestResponseUtils.java b/test/framework/src/main/java/org/elasticsearch/rest/RestResponseUtils.java index 1b1331fe25bb..fe2df39b2159 100644 --- a/test/framework/src/main/java/org/elasticsearch/rest/RestResponseUtils.java +++ b/test/framework/src/main/java/org/elasticsearch/rest/RestResponseUtils.java @@ -28,8 +28,8 @@ public class RestResponseUtils { return restResponse.content(); } - final var chunkedRestResponseBody = restResponse.chunkedContent(); - assert chunkedRestResponseBody.isDone() == false; + final var firstResponseBodyPart = restResponse.chunkedContent(); + assert firstResponseBodyPart.isPartComplete() == false; final int pageSize; try (var page = NON_RECYCLING_INSTANCE.obtain()) { @@ -37,12 +37,12 @@ public class RestResponseUtils { } try (var out = new BytesStreamOutput()) { - while (chunkedRestResponseBody.isDone() == false) { - try (var chunk = chunkedRestResponseBody.encodeChunk(pageSize, NON_RECYCLING_INSTANCE)) { + while (firstResponseBodyPart.isPartComplete() == false) { + try (var chunk = firstResponseBodyPart.encodeChunk(pageSize, NON_RECYCLING_INSTANCE)) { chunk.writeTo(out); } } - assert chunkedRestResponseBody.isEndOfResponse() : "RestResponseUtils#getBodyContent does not support continuations (yet)"; + assert firstResponseBodyPart.isLastPart() : "RestResponseUtils#getBodyContent does not support continuations (yet)"; out.flush(); return out.bytes(); diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java index eeb94beff04d..14269a8835f5 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java @@ -374,6 +374,15 @@ public abstract class ESTestCase extends LuceneTestCase { // We have to disable setting the number of available processors as tests in the same JVM randomize processors and will step on each // other if we allow them to set the number of available processors as it's set-once in Netty. System.setProperty("es.set.netty.runtime.available.processors", "false"); + + // sometimes use the java.time date formatters + // we can't use randomBoolean here, the random context isn't set properly + // so read it directly from the test seed in an unfortunately hacky way + String testSeed = System.getProperty("tests.seed", "0"); + boolean firstBit = (Integer.parseInt(testSeed.substring(testSeed.length() - 1), 16) & 1) == 1; + if (firstBit) { + System.setProperty("es.datetime.java_time_parsers", "true"); + } } protected final Logger logger = LogManager.getLogger(getClass()); diff --git a/test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java b/test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java index 726d2ec0d963..3a9c4b371c9d 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java +++ b/test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java @@ -16,7 +16,7 @@ import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.http.HttpChannel; import org.elasticsearch.http.HttpRequest; import org.elasticsearch.http.HttpResponse; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.NamedXContentRegistry; @@ -129,7 +129,7 @@ public class FakeRestRequest extends RestRequest { } @Override - public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBody content) { + public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBodyPart firstBodyPart) { return createResponse(status, BytesArray.EMPTY); } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java index 3dc7201535e0..d966a21a56b5 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java @@ -20,6 +20,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListenerResponseHandler; import org.elasticsearch.action.support.ChannelActionListener; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.UnsafePlainActionFuture; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; import org.elasticsearch.cluster.node.VersionInformation; @@ -960,7 +961,7 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { protected void doRun() throws Exception { safeAwait(go); for (int iter = 0; iter < 10; iter++) { - PlainActionFuture listener = new PlainActionFuture<>(); + PlainActionFuture listener = new UnsafePlainActionFuture<>(ThreadPool.Names.GENERIC); final String info = sender + "_B_" + iter; serviceB.sendRequest( nodeA, @@ -996,7 +997,7 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { protected void doRun() throws Exception { go.await(); for (int iter = 0; iter < 10; iter++) { - PlainActionFuture listener = new PlainActionFuture<>(); + PlainActionFuture listener = new UnsafePlainActionFuture<>(ThreadPool.Names.GENERIC); final String info = sender + "_" + iter; final DiscoveryNode node = nodeB; // capture now try { @@ -3464,7 +3465,7 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase { * @param connectionProfile the connection profile to use when connecting to this node */ public static void connectToNode(TransportService service, DiscoveryNode node, ConnectionProfile connectionProfile) { - PlainActionFuture.get(fut -> service.connectToNode(node, connectionProfile, fut.map(x -> null))); + UnsafePlainActionFuture.get(fut -> service.connectToNode(node, connectionProfile, fut.map(x -> null)), ThreadPool.Names.GENERIC); } /** diff --git a/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java b/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java index 481f39d67341..c5ef1d7c2bf1 100644 --- a/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java +++ b/x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java @@ -14,6 +14,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.RefCountingListener; +import org.elasticsearch.action.support.UnsafePlainActionFuture; import org.elasticsearch.blobcache.BlobCacheMetrics; import org.elasticsearch.blobcache.BlobCacheUtils; import org.elasticsearch.blobcache.common.ByteRange; @@ -36,6 +37,7 @@ import org.elasticsearch.env.Environment; import org.elasticsearch.env.NodeEnvironment; import org.elasticsearch.monitor.fs.FsProbe; import org.elasticsearch.node.NodeRoleSettings; +import org.elasticsearch.repositories.blobstore.BlobStoreRepository; import org.elasticsearch.threadpool.ThreadPool; import java.io.IOException; @@ -1136,7 +1138,9 @@ public class SharedBlobCacheService implements Releasable { int startRegion, int endRegion ) throws InterruptedException, ExecutionException { - final PlainActionFuture readsComplete = new PlainActionFuture<>(); + final PlainActionFuture readsComplete = new UnsafePlainActionFuture<>( + BlobStoreRepository.STATELESS_SHARD_PREWARMING_THREAD_NAME + ); final AtomicInteger bytesRead = new AtomicInteger(); try (var listeners = new RefCountingListener(1, readsComplete)) { for (int region = startRegion; region <= endRegion; region++) { diff --git a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/repository/CcrRepository.java b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/repository/CcrRepository.java index baf1509c7388..67c4c769d21d 100644 --- a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/repository/CcrRepository.java +++ b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/repository/CcrRepository.java @@ -26,6 +26,7 @@ import org.elasticsearch.action.admin.indices.stats.ShardStats; import org.elasticsearch.action.support.ListenerTimeouts; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.ThreadedActionListener; +import org.elasticsearch.action.support.UnsafePlainActionFuture; import org.elasticsearch.client.internal.Client; import org.elasticsearch.client.internal.RemoteClusterClient; import org.elasticsearch.cluster.ClusterName; @@ -599,7 +600,11 @@ public class CcrRepository extends AbstractLifecycleComponent implements Reposit Client followerClient, Index followerIndex ) { - final PlainActionFuture indexMetadataFuture = new PlainActionFuture<>(); + // todo: this could manifest in production and seems we could make this async easily. + final PlainActionFuture indexMetadataFuture = new UnsafePlainActionFuture<>( + Ccr.CCR_THREAD_POOL_NAME, + ThreadPool.Names.GENERIC + ); final long startTimeInNanos = System.nanoTime(); final Supplier timeout = () -> { final long elapsedInNanos = System.nanoTime() - startTimeInNanos; diff --git a/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowTaskReplicationTests.java b/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowTaskReplicationTests.java index 3a16f368d322..04a97ad9e7f9 100644 --- a/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowTaskReplicationTests.java +++ b/x-pack/plugin/ccr/src/test/java/org/elasticsearch/xpack/ccr/action/ShardFollowTaskReplicationTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.delete.DeleteRequest; import org.elasticsearch.action.support.ActionTestUtils; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.UnsafePlainActionFuture; import org.elasticsearch.action.support.replication.PostWriteRefresh; import org.elasticsearch.action.support.replication.ReplicationResponse; import org.elasticsearch.action.support.replication.TransportWriteAction; @@ -802,7 +803,7 @@ public class ShardFollowTaskReplicationTests extends ESIndexLevelReplicationTest @Override protected void performOnPrimary(IndexShard primary, BulkShardOperationsRequest request, ActionListener listener) { - final PlainActionFuture permitFuture = new PlainActionFuture<>(); + final PlainActionFuture permitFuture = new UnsafePlainActionFuture<>(ThreadPool.Names.GENERIC); primary.acquirePrimaryOperationPermit(permitFuture, EsExecutors.DIRECT_EXECUTOR_SERVICE); final TransportWriteAction.WritePrimaryResult ccrResult; final var threadpool = mock(ThreadPool.class); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlResponseListener.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlResponseListener.java index 34f2906d003a..0ed77b624f5b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlResponseListener.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlResponseListener.java @@ -12,7 +12,7 @@ import org.elasticsearch.core.Releasable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; -import org.elasticsearch.rest.ChunkedRestResponseBody; +import org.elasticsearch.rest.ChunkedRestResponseBodyPart; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestResponse; @@ -132,13 +132,13 @@ public final class EsqlResponseListener extends RestRefCountedChunkedToXContentL if (mediaType instanceof TextFormat format) { restResponse = RestResponse.chunked( RestStatus.OK, - ChunkedRestResponseBody.fromTextChunks(format.contentType(restRequest), format.format(restRequest, esqlResponse)), + ChunkedRestResponseBodyPart.fromTextChunks(format.contentType(restRequest), format.format(restRequest, esqlResponse)), releasable ); } else { restResponse = RestResponse.chunked( RestStatus.OK, - ChunkedRestResponseBody.fromXContent(esqlResponse, channel.request(), channel), + ChunkedRestResponseBodyPart.fromXContent(esqlResponse, channel.request(), channel), releasable ); } diff --git a/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/RestMoveToStepAction.java b/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/RestMoveToStepAction.java index 9256a61addd8..64ce857a0198 100644 --- a/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/RestMoveToStepAction.java +++ b/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/RestMoveToStepAction.java @@ -36,13 +36,22 @@ public class RestMoveToStepAction extends BaseRestHandler { @Override protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { - String index = restRequest.param("name"); - TransportMoveToStepAction.Request request; + final var masterNodeTimeout = getMasterNodeTimeout(restRequest); + final var ackTimeout = getAckTimeout(restRequest); + final var index = restRequest.param("name"); + final TransportMoveToStepAction.Request request; try (XContentParser parser = restRequest.contentParser()) { - request = TransportMoveToStepAction.Request.parseRequest(index, parser); + request = TransportMoveToStepAction.Request.parseRequest( + (currentStepKey, nextStepKey) -> new TransportMoveToStepAction.Request( + masterNodeTimeout, + ackTimeout, + index, + currentStepKey, + nextStepKey + ), + parser + ); } - request.ackTimeout(getAckTimeout(restRequest)); - request.masterNodeTimeout(getMasterNodeTimeout(restRequest)); return channel -> client.execute(ILMActions.MOVE_TO_STEP, request, new RestToXContentListener<>(channel)); } } diff --git a/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/RestRetryAction.java b/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/RestRetryAction.java index 10a3fa38df67..1000bd1e6824 100644 --- a/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/RestRetryAction.java +++ b/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/RestRetryAction.java @@ -36,10 +36,8 @@ public class RestRetryAction extends BaseRestHandler { @Override protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) { - String[] indices = Strings.splitStringByCommaToArray(restRequest.param("index")); - TransportRetryAction.Request request = new TransportRetryAction.Request(indices); - request.ackTimeout(getAckTimeout(restRequest)); - request.masterNodeTimeout(getMasterNodeTimeout(restRequest)); + final var indices = Strings.splitStringByCommaToArray(restRequest.param("index")); + final var request = new TransportRetryAction.Request(getMasterNodeTimeout(restRequest), getAckTimeout(restRequest), indices); request.indices(indices); request.indicesOptions(IndicesOptions.fromRequest(restRequest, IndicesOptions.strictExpandOpen())); return channel -> client.execute(ILMActions.RETRY, request, new RestToXContentListener<>(channel)); diff --git a/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/TransportMoveToStepAction.java b/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/TransportMoveToStepAction.java index 87c93a919821..ec905c0e9eb4 100644 --- a/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/TransportMoveToStepAction.java +++ b/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/ilm/action/TransportMoveToStepAction.java @@ -32,6 +32,7 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.SuppressForbidden; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; @@ -188,15 +189,20 @@ public class TransportMoveToStepAction extends TransportMasterNodeAction implements ToXContentObject { + + public interface Factory { + Request create(Step.StepKey currentStepKey, PartialStepKey nextStepKey); + } + static final ParseField CURRENT_KEY_FIELD = new ParseField("current_step"); static final ParseField NEXT_KEY_FIELD = new ParseField("next_step"); - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "move_to_step_request", false, - (a, index) -> { + (a, factory) -> { Step.StepKey currentStepKey = (Step.StepKey) a[0]; PartialStepKey nextStepKey = (PartialStepKey) a[1]; - return new Request(index, currentStepKey, nextStepKey); + return factory.create(currentStepKey, nextStepKey); } ); @@ -207,12 +213,18 @@ public class TransportMoveToStepAction extends TransportMasterNodeAction PartialStepKey.parse(p), NEXT_KEY_FIELD); } - private String index; - private Step.StepKey currentStepKey; - private PartialStepKey nextStepKey; + private final String index; + private final Step.StepKey currentStepKey; + private final PartialStepKey nextStepKey; - public Request(String index, Step.StepKey currentStepKey, PartialStepKey nextStepKey) { - super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT); + public Request( + TimeValue masterNodeTimeout, + TimeValue ackTimeout, + String index, + Step.StepKey currentStepKey, + PartialStepKey nextStepKey + ) { + super(masterNodeTimeout, ackTimeout); this.index = index; this.currentStepKey = currentStepKey; this.nextStepKey = nextStepKey; @@ -225,10 +237,6 @@ public class TransportMoveToStepAction extends TransportMasterNodeAction implements IndicesRequest.Replaceable { - private String[] indices = Strings.EMPTY_ARRAY; + private String[] indices; private IndicesOptions indicesOptions = IndicesOptions.strictExpandOpen(); - public Request(String... indices) { - super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT); + public Request(TimeValue masterNodeTimeout, TimeValue ackTimeout, String... indices) { + super(masterNodeTimeout, ackTimeout); this.indices = indices; } @@ -128,10 +128,6 @@ public class TransportRetryAction extends TransportMasterNodeAction new TransportMoveToStepAction.Request( + TEST_REQUEST_TIMEOUT, + TEST_REQUEST_TIMEOUT, + index, + currentStepKey, + nextStepKey + ), + parser + ); } @Override @@ -52,7 +67,7 @@ public class MoveToStepRequestTests extends AbstractXContentSerializingTestCase< default -> throw new AssertionError("Illegal randomisation branch"); } - return new TransportMoveToStepAction.Request(indexName, currentStepKey, nextStepKey); + return new TransportMoveToStepAction.Request(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT, indexName, currentStepKey, nextStepKey); } private static TransportMoveToStepAction.Request.PartialStepKey randomStepSpecification() { diff --git a/x-pack/plugin/ilm/src/test/java/org/elasticsearch/xpack/ilm/action/RetryRequestTests.java b/x-pack/plugin/ilm/src/test/java/org/elasticsearch/xpack/ilm/action/RetryRequestTests.java index e4f3c58fe6e6..4f053ddc2caa 100644 --- a/x-pack/plugin/ilm/src/test/java/org/elasticsearch/xpack/ilm/action/RetryRequestTests.java +++ b/x-pack/plugin/ilm/src/test/java/org/elasticsearch/xpack/ilm/action/RetryRequestTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ilm.action; import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; @@ -17,10 +18,11 @@ public class RetryRequestTests extends AbstractWireSerializingTestCase throw new AssertionError("Illegal randomisation branch"); } - TransportRetryAction.Request newRequest = new TransportRetryAction.Request(); - newRequest.indices(indices); + final var newRequest = new TransportRetryAction.Request(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT, indices); newRequest.indicesOptions(indicesOptions); return newRequest; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 4931b4da6f72..edea0104ded1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -44,6 +44,7 @@ import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInt import org.elasticsearch.xpack.inference.services.elasticsearch.MultilingualE5SmallInternalServiceSettings; import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; +import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserSecretSettings; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings; @@ -106,6 +107,7 @@ public class InferenceNamedWriteablesProvider { addCohereNamedWriteables(namedWriteables); addAzureOpenAiNamedWriteables(namedWriteables); addAzureAiStudioNamedWriteables(namedWriteables); + addGoogleAiStudioNamedWritables(namedWriteables); return namedWriteables; } @@ -254,6 +256,16 @@ public class InferenceNamedWriteablesProvider { ); } + private static void addGoogleAiStudioNamedWritables(List namedWriteables) { + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + GoogleAiStudioCompletionServiceSettings.NAME, + GoogleAiStudioCompletionServiceSettings::new + ) + ); + } + private static void addInternalElserNamedWriteables(List namedWriteables) { namedWriteables.add( new NamedWriteableRegistry.Entry(ServiceSettings.class, ElserInternalServiceSettings.NAME, ElserInternalServiceSettings::new) @@ -318,4 +330,5 @@ public class InferenceNamedWriteablesProvider { new NamedWriteableRegistry.Entry(InferenceServiceResults.class, RankedDocsResults.NAME, RankedDocsResults::new) ); } + } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index aa3bfb6c224f..83fc7323eab4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -68,6 +68,7 @@ import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService import org.elasticsearch.xpack.inference.services.cohere.CohereService; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; import org.elasticsearch.xpack.inference.services.elser.ElserInternalService; +import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService; import org.elasticsearch.xpack.inference.services.openai.OpenAiService; @@ -194,6 +195,7 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP context -> new CohereService(httpFactory.get(), serviceComponents.get()), context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get()), context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get()), + context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get()), ElasticsearchInternalService::new ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioActionCreator.java new file mode 100644 index 000000000000..51a8cc7a0bd5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioActionCreator.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.googleaistudio; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel; + +import java.util.Map; +import java.util.Objects; + +public class GoogleAiStudioActionCreator implements GoogleAiStudioActionVisitor { + + private final Sender sender; + + private final ServiceComponents serviceComponents; + + public GoogleAiStudioActionCreator(Sender sender, ServiceComponents serviceComponents) { + this.sender = Objects.requireNonNull(sender); + this.serviceComponents = Objects.requireNonNull(serviceComponents); + } + + @Override + public ExecutableAction create(GoogleAiStudioCompletionModel model, Map taskSettings) { + // no overridden model as task settings are always empty for Google AI Studio completion model + return new GoogleAiStudioCompletionAction(sender, model, serviceComponents.threadPool()); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioActionVisitor.java new file mode 100644 index 000000000000..090d3f9a6971 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioActionVisitor.java @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.googleaistudio; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel; + +import java.util.Map; + +public interface GoogleAiStudioActionVisitor { + + ExecutableAction create(GoogleAiStudioCompletionModel model, Map taskSettings); + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionAction.java new file mode 100644 index 000000000000..7f918ae9a7db --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionAction.java @@ -0,0 +1,73 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.googleaistudio; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.GoogleAiStudioCompletionRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel; + +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException; + +public class GoogleAiStudioCompletionAction implements ExecutableAction { + + private final String failedToSendRequestErrorMessage; + + private final GoogleAiStudioCompletionRequestManager requestManager; + + private final Sender sender; + + public GoogleAiStudioCompletionAction(Sender sender, GoogleAiStudioCompletionModel model, ThreadPool threadPool) { + Objects.requireNonNull(threadPool); + Objects.requireNonNull(model); + this.sender = Objects.requireNonNull(sender); + this.requestManager = new GoogleAiStudioCompletionRequestManager(model, threadPool); + this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(model.uri(), "Google AI Studio completion"); + } + + @Override + public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener listener) { + if (inferenceInputs instanceof DocumentsOnlyInput == false) { + listener.onFailure(new ElasticsearchStatusException("Invalid inference input type", RestStatus.INTERNAL_SERVER_ERROR)); + return; + } + + var docsOnlyInput = (DocumentsOnlyInput) inferenceInputs; + if (docsOnlyInput.getInputs().size() > 1) { + listener.onFailure( + new ElasticsearchStatusException("Google AI Studio completion only accepts 1 input", RestStatus.BAD_REQUEST) + ); + return; + } + + try { + ActionListener wrappedListener = wrapFailuresInElasticsearchException( + failedToSendRequestErrorMessage, + listener + ); + sender.send(requestManager, inferenceInputs, timeout, wrappedListener); + } catch (ElasticsearchException e) { + listener.onFailure(e); + } catch (Exception e) { + listener.onFailure(createInternalServerError(e, failedToSendRequestErrorMessage)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioResponseHandler.java new file mode 100644 index 000000000000..1138cfcb7cdc --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioResponseHandler.java @@ -0,0 +1,75 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.googleaistudio; + +import org.apache.logging.log4j.Logger; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.googleaistudio.GoogleAiStudioErrorResponseEntity; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody; + +public class GoogleAiStudioResponseHandler extends BaseResponseHandler { + + static final String GOOGLE_AI_STUDIO_UNAVAILABLE = "The Google AI Studio service may be temporarily overloaded or down"; + + public GoogleAiStudioResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, GoogleAiStudioErrorResponseEntity::fromResponse); + } + + @Override + public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) + throws RetryException { + checkForFailureStatusCode(request, result); + checkForEmptyBody(throttlerManager, logger, request, result); + } + + /** + * Validates the status code and throws a RetryException if not in the range [200, 300). + * + * The Google AI Studio error codes are documented here. + * @param request The originating request + * @param result The http response and body + * @throws RetryException Throws if status code is {@code >= 300 or < 200 } + */ + void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException { + int statusCode = result.response().getStatusLine().getStatusCode(); + if (statusCode >= 200 && statusCode < 300) { + return; + } + + // handle error codes + if (statusCode == 500) { + throw new RetryException(true, buildError(SERVER_ERROR, request, result)); + } else if (statusCode == 503) { + throw new RetryException(true, buildError(GOOGLE_AI_STUDIO_UNAVAILABLE, request, result)); + } else if (statusCode > 500) { + throw new RetryException(false, buildError(SERVER_ERROR, request, result)); + } else if (statusCode == 429) { + throw new RetryException(true, buildError(RATE_LIMIT, request, result)); + } else if (statusCode == 404) { + throw new RetryException(false, buildError(resourceNotFoundError(request), request, result)); + } else if (statusCode == 403) { + throw new RetryException(false, buildError(PERMISSION_DENIED, request, result)); + } else if (statusCode >= 300 && statusCode < 400) { + throw new RetryException(false, buildError(REDIRECTION, request, result)); + } else { + throw new RetryException(false, buildError(UNSUCCESSFUL, request, result)); + } + } + + private static String resourceNotFoundError(Request request) { + return format("Resource not found at [%s]", request.getURI()); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java index b703cf2f14b7..f793cb358692 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java @@ -23,6 +23,7 @@ public abstract class BaseResponseHandler implements ResponseHandler { public static final String SERVER_ERROR = "Received a server error status code"; public static final String RATE_LIMIT = "Received a rate limit status code"; public static final String AUTHENTICATION = "Received an authentication error status code"; + public static final String PERMISSION_DENIED = "Received a permission denied error status code"; public static final String REDIRECTION = "Unhandled redirection"; public static final String CONTENT_TOO_LARGE = "Received a content too large status code"; public static final String UNSUCCESSFUL = "Received an unsuccessful status code"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java new file mode 100644 index 000000000000..eb9baa680446 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.http.client.protocol.HttpClientContext; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.googleaistudio.GoogleAiStudioResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioCompletionRequest; +import org.elasticsearch.xpack.inference.external.response.googleaistudio.GoogleAiStudioCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel; + +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +public class GoogleAiStudioCompletionRequestManager extends GoogleAiStudioRequestManager { + + private static final Logger logger = LogManager.getLogger(GoogleAiStudioCompletionRequestManager.class); + + private static final ResponseHandler HANDLER = createCompletionHandler(); + + private final GoogleAiStudioCompletionModel model; + + private static ResponseHandler createCompletionHandler() { + return new GoogleAiStudioResponseHandler("google ai studio completion", GoogleAiStudioCompletionResponseEntity::fromResponse); + } + + public GoogleAiStudioCompletionRequestManager(GoogleAiStudioCompletionModel model, ThreadPool threadPool) { + super(threadPool, model); + this.model = Objects.requireNonNull(model); + } + + @Override + public Runnable create( + String query, + List input, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + HttpClientContext context, + ActionListener listener + ) { + GoogleAiStudioCompletionRequest request = new GoogleAiStudioCompletionRequest(input, model); + return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioRequestManager.java new file mode 100644 index 000000000000..670c00f9a280 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioRequestManager.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioModel; + +import java.util.Objects; + +public abstract class GoogleAiStudioRequestManager extends BaseRequestManager { + GoogleAiStudioRequestManager(ThreadPool threadPool, GoogleAiStudioModel model) { + super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings()); + } + + record RateLimitGrouping(int modelIdHash) { + public static RateLimitGrouping of(GoogleAiStudioModel model) { + Objects.requireNonNull(model); + + return new RateLimitGrouping(model.rateLimitServiceSettings().modelId().hashCode()); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequest.java new file mode 100644 index 000000000000..f52fe623e791 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequest.java @@ -0,0 +1,72 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.googleaistudio; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; + +public class GoogleAiStudioCompletionRequest implements GoogleAiStudioRequest { + + private final List input; + + private final URI uri; + + private final GoogleAiStudioCompletionModel model; + + public GoogleAiStudioCompletionRequest(List input, GoogleAiStudioCompletionModel model) { + this.input = input; + this.model = Objects.requireNonNull(model); + this.uri = model.uri(); + } + + @Override + public HttpRequest createHttpRequest() { + var httpPost = new HttpPost(uri); + var requestEntity = Strings.toString(new GoogleAiStudioCompletionRequestEntity(input)); + + ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8)); + httpPost.setEntity(byteEntity); + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + GoogleAiStudioRequest.decorateWithApiKeyParameter(httpPost, model.getSecretSettings()); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public URI getURI() { + return this.uri; + } + + @Override + public Request truncate() { + // No truncation for Google AI Studio completion + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // No truncation for Google AI Studio completion + return null; + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequestEntity.java new file mode 100644 index 000000000000..85e4d616c16e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequestEntity.java @@ -0,0 +1,79 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.googleaistudio; + +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public record GoogleAiStudioCompletionRequestEntity(List input) implements ToXContentObject { + + private static final String CONTENTS_FIELD = "contents"; + + private static final String PARTS_FIELD = "parts"; + + private static final String TEXT_FIELD = "text"; + + private static final String GENERATION_CONFIG_FIELD = "generationConfig"; + + private static final String CANDIDATE_COUNT_FIELD = "candidateCount"; + + private static final String ROLE_FIELD = "role"; + + private static final String ROLE_USER = "user"; + + public GoogleAiStudioCompletionRequestEntity { + Objects.requireNonNull(input); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.startArray(CONTENTS_FIELD); + + { + for (String content : input) { + builder.startObject(); + + { + builder.startArray(PARTS_FIELD); + builder.startObject(); + + { + builder.field(TEXT_FIELD, content); + } + + builder.endObject(); + builder.endArray(); + } + + builder.field(ROLE_FIELD, ROLE_USER); + + builder.endObject(); + } + } + + builder.endArray(); + + builder.startObject(GENERATION_CONFIG_FIELD); + + { + // default is already 1, but we want to guard ourselves against API changes so setting it explicitly + builder.field(CANDIDATE_COUNT_FIELD, 1); + } + + builder.endObject(); + + builder.endObject(); + + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioRequest.java new file mode 100644 index 000000000000..ede9c6193aa2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioRequest.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.googleaistudio; + +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.utils.URIBuilder; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioSecretSettings; + +public interface GoogleAiStudioRequest extends Request { + + String API_KEY_PARAMETER = "key"; + + static void decorateWithApiKeyParameter(HttpPost httpPost, GoogleAiStudioSecretSettings secretSettings) { + try { + var uri = httpPost.getURI(); + var uriWithApiKey = new URIBuilder().setScheme(uri.getScheme()) + .setHost(uri.getHost()) + .setPort(uri.getPort()) + .setPath(uri.getPath()) + .addParameter(API_KEY_PARAMETER, secretSettings.apiKey().toString()) + .build(); + + httpPost.setURI(uriWithApiKey); + } catch (Exception e) { + ValidationException validationException = new ValidationException(e); + validationException.addValidationError(e.getMessage()); + throw validationException; + } + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioUtils.java new file mode 100644 index 000000000000..d63a0bbe2af9 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioUtils.java @@ -0,0 +1,22 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.googleaistudio; + +public class GoogleAiStudioUtils { + + public static final String HOST_SUFFIX = "generativelanguage.googleapis.com"; + + public static final String V1 = "v1"; + + public static final String MODELS = "models"; + + public static final String GENERATE_CONTENT_ACTION = "generateContent"; + + private GoogleAiStudioUtils() {} + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioCompletionResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioCompletionResponseEntity.java new file mode 100644 index 000000000000..852f25705d6f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioCompletionResponseEntity.java @@ -0,0 +1,109 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.googleaistudio; + +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; + +public class GoogleAiStudioCompletionResponseEntity { + + private static final String FAILED_TO_FIND_FIELD_TEMPLATE = + "Failed to find required field [%s] in Google AI Studio completion response"; + + /** + * Parses the Google AI Studio completion response. + * + * For a request like: + * + *

+     *     
+     *         {
+     *           "contents": [
+     *                          {
+     *                              "parts": [{
+     *                                  "text": "input"
+     *                              }]
+     *                          }
+     *                      ]
+     *          }
+     *     
+     * 
+ * + * The response would look like: + * + *
+     *     
+     *         {
+     *     "candidates": [
+     *         {
+     *             "content": {
+     *                 "parts": [
+     *                     {
+     *                         "text": "response"
+     *                     }
+     *                 ],
+     *                 "role": "model"
+     *             },
+     *             "finishReason": "STOP",
+     *             "index": 0,
+     *             "safetyRatings": [...]
+     *         }
+     *     ],
+     *     "usageMetadata": { ... }
+     * }
+     *     
+     * 
+ * + */ + + public static ChatCompletionResults fromResponse(Request request, HttpResult response) throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + moveToFirstToken(jsonParser); + + XContentParser.Token token = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + + positionParserAtTokenAfterField(jsonParser, "candidates", FAILED_TO_FIND_FIELD_TEMPLATE); + + jsonParser.nextToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, jsonParser.currentToken(), jsonParser); + + positionParserAtTokenAfterField(jsonParser, "content", FAILED_TO_FIND_FIELD_TEMPLATE); + + token = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + + positionParserAtTokenAfterField(jsonParser, "parts", FAILED_TO_FIND_FIELD_TEMPLATE); + + jsonParser.nextToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + + positionParserAtTokenAfterField(jsonParser, "text", FAILED_TO_FIND_FIELD_TEMPLATE); + + XContentParser.Token contentToken = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.VALUE_STRING, contentToken, jsonParser); + String content = jsonParser.text(); + + return new ChatCompletionResults(List.of(new ChatCompletionResults.Result(content))); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioErrorResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioErrorResponseEntity.java new file mode 100644 index 000000000000..f57f672e10b1 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioErrorResponseEntity.java @@ -0,0 +1,78 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.googleaistudio; + +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorMessage; + +import java.util.Map; + +public class GoogleAiStudioErrorResponseEntity implements ErrorMessage { + + private final String errorMessage; + + private GoogleAiStudioErrorResponseEntity(String errorMessage) { + this.errorMessage = errorMessage; + } + + @Override + public String getErrorMessage() { + return errorMessage; + } + + /** + * An example error response for invalid auth would look like + * + * { + * "error": { + * "code": 400, + * "message": "API key not valid. Please pass a valid API key.", + * "status": "INVALID_ARGUMENT", + * "details": [ + * { + * "@type": "type.googleapis.com/google.rpc.ErrorInfo", + * "reason": "API_KEY_INVALID", + * "domain": "googleapis.com", + * "metadata": { + * "service": "generativelanguage.googleapis.com" + * } + * } + * ] + * } + * } + * + * @param response The error response + * @return An error entity if the response is JSON with the above structure + * or null if the response does not contain the `error.message` field + */ + + @SuppressWarnings("unchecked") + public static GoogleAiStudioErrorResponseEntity fromResponse(HttpResult response) { + try ( + XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON) + .createParser(XContentParserConfiguration.EMPTY, response.body()) + ) { + var responseMap = jsonParser.map(); + var error = (Map) responseMap.get("error"); + if (error != null) { + var message = (String) error.get("message"); + if (message != null) { + return new GoogleAiStudioErrorResponseEntity(message); + } + } + } catch (Exception e) { + // swallow the error + } + + return null; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 25e8afbe1d16..4b5ec48f99b7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -196,8 +196,8 @@ public class ServiceUtils { ); } - public static String invalidUrlErrorMsg(String url, String settingName, String settingScope) { - return Strings.format("[%s] Invalid url [%s] received for field [%s]", settingScope, url, settingName); + public static String invalidUrlErrorMsg(String url, String settingName, String settingScope, String error) { + return Strings.format("[%s] Invalid url [%s] received for field [%s]. Error: %s", settingScope, url, settingName, error); } public static String mustBeNonEmptyString(String settingName, String scope) { @@ -231,7 +231,6 @@ public class ServiceUtils { return Strings.format("[%s] does not allow the setting [%s]", scope, settingName); } - // TODO improve URI validation logic public static URI convertToUri(@Nullable String url, String settingName, String settingScope, ValidationException validationException) { try { if (url == null) { @@ -239,8 +238,8 @@ public class ServiceUtils { } return createUri(url); - } catch (IllegalArgumentException ignored) { - validationException.addValidationError(ServiceUtils.invalidUrlErrorMsg(url, settingName, settingScope)); + } catch (IllegalArgumentException cause) { + validationException.addValidationError(ServiceUtils.invalidUrlErrorMsg(url, settingName, settingScope, cause.getMessage())); return null; } } @@ -251,7 +250,7 @@ public class ServiceUtils { try { return new URI(url); } catch (URISyntaxException e) { - throw new IllegalArgumentException(format("unable to parse url [%s]", url), e); + throw new IllegalArgumentException(format("unable to parse url [%s]. Reason: %s", url, e.getReason()), e); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettings.java index a82ffbba3d68..0b586af5005f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettings.java @@ -66,7 +66,7 @@ public class CustomElandRerankTaskSettings implements TaskSettings { } /** - * Return either the request or orignal settings by preferring non-null fields + * Return either the request or original settings by preferring non-null fields * from the request settings over the original settings. * * @param originalSettings the settings stored as part of the inference entity configuration diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioModel.java new file mode 100644 index 000000000000..4ddffd0bae61 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioModel.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googleaistudio; + +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionVisitor; + +import java.util.Map; +import java.util.Objects; + +public abstract class GoogleAiStudioModel extends Model { + + private final GoogleAiStudioRateLimitServiceSettings rateLimitServiceSettings; + + public GoogleAiStudioModel( + ModelConfigurations configurations, + ModelSecrets secrets, + GoogleAiStudioRateLimitServiceSettings rateLimitServiceSettings + ) { + super(configurations, secrets); + + this.rateLimitServiceSettings = Objects.requireNonNull(rateLimitServiceSettings); + } + + public abstract ExecutableAction accept(GoogleAiStudioActionVisitor creator, Map taskSettings, InputType inputType); + + public GoogleAiStudioRateLimitServiceSettings rateLimitServiceSettings() { + return rateLimitServiceSettings; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioRateLimitServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioRateLimitServiceSettings.java new file mode 100644 index 000000000000..2e443263c7f5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioRateLimitServiceSettings.java @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googleaistudio; + +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +public interface GoogleAiStudioRateLimitServiceSettings { + + String modelId(); + + RateLimitSettings rateLimitSettings(); + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioSecretSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioSecretSettings.java new file mode 100644 index 000000000000..bf702d010e2a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioSecretSettings.java @@ -0,0 +1,106 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googleaistudio; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalSecureString; + +public class GoogleAiStudioSecretSettings implements SecretSettings { + + public static final String NAME = "google_ai_studio_secret_settings"; + public static final String API_KEY = "api_key"; + + private final SecureString apiKey; + + public static GoogleAiStudioSecretSettings fromMap(@Nullable Map map) { + if (map == null) { + return null; + } + + ValidationException validationException = new ValidationException(); + SecureString secureApiKey = extractOptionalSecureString(map, API_KEY, ModelSecrets.SECRET_SETTINGS, validationException); + + if (secureApiKey == null) { + validationException.addValidationError(format("[secret_settings] must have [%s] set", API_KEY)); + } + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new GoogleAiStudioSecretSettings(secureApiKey); + } + + public GoogleAiStudioSecretSettings(SecureString apiKey) { + Objects.requireNonNull(apiKey); + this.apiKey = apiKey; + } + + public GoogleAiStudioSecretSettings(StreamInput in) throws IOException { + this(in.readOptionalSecureString()); + } + + public SecureString apiKey() { + return apiKey; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + if (apiKey != null) { + builder.field(API_KEY, apiKey.toString()); + } + + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_GOOGLE_AI_STUDIO_COMPLETION_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalSecureString(apiKey); + } + + @Override + public boolean equals(Object object) { + if (this == object) return true; + if (object == null || getClass() != object.getClass()) return false; + GoogleAiStudioSecretSettings that = (GoogleAiStudioSecretSettings) object; + return Objects.equals(apiKey, that.apiKey); + } + + @Override + public int hashCode() { + return Objects.hash(apiKey); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java new file mode 100644 index 000000000000..f990923cee92 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -0,0 +1,218 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googleaistudio; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionCreator; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; + +public class GoogleAiStudioService extends SenderService { + + public static final String NAME = "googleaistudio"; + + public GoogleAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { + super(factory, serviceComponents); + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String inferenceEntityId, + TaskType taskType, + Map config, + Set platfromArchitectures, + ActionListener parsedModelListener + ) { + try { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + GoogleAiStudioModel model = createModel( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + serviceSettingsMap, + TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), + ConfigurationParseContext.REQUEST + ); + + throwIfNotEmptyMap(config, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + + parsedModelListener.onResponse(model); + } catch (Exception e) { + parsedModelListener.onFailure(e); + } + + } + + private static GoogleAiStudioModel createModel( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secretSettings, + String failureMessage, + ConfigurationParseContext context + ) { + return switch (taskType) { + case COMPLETION -> new GoogleAiStudioCompletionModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings + ); + default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + }; + } + + @Override + public GoogleAiStudioModel parsePersistedConfigWithSecrets( + String inferenceEntityId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + + return createModelFromPersistent( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + secretSettingsMap, + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + ); + } + + private static GoogleAiStudioModel createModelFromPersistent( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + Map secretSettings, + String failureMessage + ) { + return createModel( + inferenceEntityId, + taskType, + serviceSettings, + taskSettings, + secretSettings, + failureMessage, + ConfigurationParseContext.PERSISTENT + ); + } + + @Override + public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + + return createModelFromPersistent( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + null, + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + ); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_GOOGLE_AI_STUDIO_COMPLETION_ADDED; + } + + @Override + protected void doInfer( + Model model, + List input, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof GoogleAiStudioModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + GoogleAiStudioModel googleAiStudioModel = (GoogleAiStudioModel) model; + var actionCreator = new GoogleAiStudioActionCreator(getSender(), getServiceComponents()); + + var action = googleAiStudioModel.accept(actionCreator, taskSettings, inputType); + action.execute(new DocumentsOnlyInput(input), timeout, listener); + } + + @Override + protected void doInfer( + Model model, + String query, + List input, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + throw new UnsupportedOperationException("Query input not supported for Google AI Studio"); + } + + @Override + protected void doChunkedInfer( + Model model, + String query, + List input, + Map taskSettings, + InputType inputType, + ChunkingOptions chunkingOptions, + TimeValue timeout, + ActionListener> listener + ) { + throw new UnsupportedOperationException("Chunked inference not supported yet for Google AI Studio"); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModel.java new file mode 100644 index 000000000000..6a11f678158b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModel.java @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googleaistudio.completion; + +import org.apache.http.client.utils.URIBuilder; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionVisitor; +import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioUtils; +import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioModel; +import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioSecretSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; + +import static org.elasticsearch.core.Strings.format; + +public class GoogleAiStudioCompletionModel extends GoogleAiStudioModel { + + private URI uri; + + public GoogleAiStudioCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + Map secrets + ) { + this( + inferenceEntityId, + taskType, + service, + GoogleAiStudioCompletionServiceSettings.fromMap(serviceSettings), + EmptyTaskSettings.INSTANCE, + GoogleAiStudioSecretSettings.fromMap(secrets) + ); + } + + // Should only be used directly for testing + GoogleAiStudioCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + GoogleAiStudioCompletionServiceSettings serviceSettings, + TaskSettings taskSettings, + @Nullable GoogleAiStudioSecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secrets), + serviceSettings + ); + try { + this.uri = buildUri(serviceSettings.modelId()); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + // Should only be used directly for testing + GoogleAiStudioCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + String url, + GoogleAiStudioCompletionServiceSettings serviceSettings, + TaskSettings taskSettings, + @Nullable GoogleAiStudioSecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secrets), + serviceSettings + ); + try { + this.uri = new URI(url); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public URI uri() { + return uri; + } + + @Override + public GoogleAiStudioCompletionServiceSettings getServiceSettings() { + return (GoogleAiStudioCompletionServiceSettings) super.getServiceSettings(); + } + + @Override + public GoogleAiStudioSecretSettings getSecretSettings() { + return (GoogleAiStudioSecretSettings) super.getSecretSettings(); + } + + public static URI buildUri(String model) throws URISyntaxException { + return new URIBuilder().setScheme("https") + .setHost(GoogleAiStudioUtils.HOST_SUFFIX) + .setPathSegments( + GoogleAiStudioUtils.V1, + GoogleAiStudioUtils.MODELS, + format("%s:%s", model, GoogleAiStudioUtils.GENERATE_CONTENT_ACTION) + ) + .build(); + } + + @Override + public ExecutableAction accept(GoogleAiStudioActionVisitor visitor, Map taskSettings, InputType inputType) { + return visitor.create(this, taskSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettings.java new file mode 100644 index 000000000000..f8f343be8eb4 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettings.java @@ -0,0 +1,126 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googleaistudio.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; + +public class GoogleAiStudioCompletionServiceSettings extends FilteredXContentObject + implements + ServiceSettings, + GoogleAiStudioRateLimitServiceSettings { + + public static final String NAME = "google_ai_studio_completion_service_settings"; + + /** + * Rate limits are defined at Google Gemini API Pricing. + * For pay-as-you-go you've 360 requests per minute. + */ + private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(360); + + public static GoogleAiStudioCompletionServiceSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + String model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new GoogleAiStudioCompletionServiceSettings(model, rateLimitSettings); + } + + private final String modelId; + + private final RateLimitSettings rateLimitSettings; + + public GoogleAiStudioCompletionServiceSettings(String modelId, @Nullable RateLimitSettings rateLimitSettings) { + this.modelId = modelId; + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + public GoogleAiStudioCompletionServiceSettings(StreamInput in) throws IOException { + modelId = in.readString(); + rateLimitSettings = new RateLimitSettings(in); + } + + @Override + public String modelId() { + return modelId; + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + toXContentFragmentOfExposedFields(builder, params); + rateLimitSettings.toXContent(builder, params); + + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_GOOGLE_AI_STUDIO_COMPLETION_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + rateLimitSettings.writeTo(out); + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_ID, modelId); + + return builder; + } + + @Override + public boolean equals(Object object) { + if (this == object) return true; + if (object == null || getClass() != object.getClass()) return false; + GoogleAiStudioCompletionServiceSettings that = (GoogleAiStudioCompletionServiceSettings) object; + return Objects.equals(modelId, that.modelId) && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, rateLimitSettings); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/MatchersUtils.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/MatchersUtils.java new file mode 100644 index 000000000000..6397e83fc246 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/MatchersUtils.java @@ -0,0 +1,83 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.TypeSafeMatcher; + +import java.util.regex.Pattern; + +/** + * Utility class containing custom hamcrest {@link Matcher} implementations or other utility functionality related to hamcrest. + */ +public class MatchersUtils { + + /** + * Custom matcher implementing a matcher operating on json strings ignoring whitespaces, which are not inside a key or a value. + * + * Example: + * { + * "key": "value" + * } + * + * will match + * + * {"key":"value"} + * + * as both json strings are equal ignoring the whitespace, which does not reside in a key or a value. + * + */ + protected static class IsEqualIgnoreWhitespaceInJsonString extends TypeSafeMatcher { + + protected static final Pattern WHITESPACE_IN_JSON_EXCEPT_KEYS_AND_VALUES_PATTERN = createPattern(); + + private static Pattern createPattern() { + String regex = "(?<=[:,\\[{])\\s+|\\s+(?=[\\]}:,])|^\\s+|\\s+$"; + return Pattern.compile(regex); + } + + private final String string; + + IsEqualIgnoreWhitespaceInJsonString(String string) { + if (string == null) { + throw new IllegalArgumentException("Non-null value required"); + } + this.string = string; + } + + @Override + protected boolean matchesSafely(String item) { + java.util.regex.Matcher itemMatcher = WHITESPACE_IN_JSON_EXCEPT_KEYS_AND_VALUES_PATTERN.matcher(item); + java.util.regex.Matcher stringMatcher = WHITESPACE_IN_JSON_EXCEPT_KEYS_AND_VALUES_PATTERN.matcher(string); + + String itemReplacedWhitespaces = itemMatcher.replaceAll(""); + String stringReplacedWhitespaces = stringMatcher.replaceAll(""); + + return itemReplacedWhitespaces.equals(stringReplacedWhitespaces); + } + + @Override + public void describeTo(Description description) { + java.util.regex.Matcher stringMatcher = WHITESPACE_IN_JSON_EXCEPT_KEYS_AND_VALUES_PATTERN.matcher(string); + String stringReplacedWhitespaces = stringMatcher.replaceAll(""); + + description.appendText("a string equal to (when all whitespaces are ignored expect in keys and values): ") + .appendValue(stringReplacedWhitespaces); + } + + public static Matcher equalToIgnoringWhitespaceInJsonString(String expectedString) { + return new IsEqualIgnoreWhitespaceInJsonString(expectedString); + } + } + + public static Matcher equalToIgnoringWhitespaceInJsonString(String expectedString) { + return IsEqualIgnoreWhitespaceInJsonString.equalToIgnoringWhitespaceInJsonString(expectedString); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/MatchersUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/MatchersUtilsTests.java new file mode 100644 index 000000000000..6f30d23a45ae --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/MatchersUtilsTests.java @@ -0,0 +1,186 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Description; +import org.hamcrest.SelfDescribing; + +import java.util.Objects; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static org.hamcrest.Matchers.is; + +public class MatchersUtilsTests extends ESTestCase { + + public void testIsEqualIgnoreWhitespaceInJsonString_Pattern() { + var json = """ + + { + "key": "value" + } + + """; + + Pattern pattern = MatchersUtils.IsEqualIgnoreWhitespaceInJsonString.WHITESPACE_IN_JSON_EXCEPT_KEYS_AND_VALUES_PATTERN; + Matcher matcher = pattern.matcher(json); + String jsonWithRemovedWhitespaces = matcher.replaceAll(""); + + assertThat(jsonWithRemovedWhitespaces, is(""" + {"key":"value"}""")); + } + + public void testIsEqualIgnoreWhitespaceInJsonString_Pattern_DoesNotRemoveWhitespaceInKeysAndValues() { + var json = """ + + { + "key 1": "value 1" + } + + """; + + Pattern pattern = MatchersUtils.IsEqualIgnoreWhitespaceInJsonString.WHITESPACE_IN_JSON_EXCEPT_KEYS_AND_VALUES_PATTERN; + Matcher matcher = pattern.matcher(json); + String jsonWithRemovedWhitespaces = matcher.replaceAll(""); + + assertThat(jsonWithRemovedWhitespaces, is(""" + {"key 1":"value 1"}""")); + } + + public void testIsEqualIgnoreWhitespaceInJsonString_MatchesSafely_DoesMatch() { + var json = """ + + { + "key 1": "value 1", + "key 2: { + "key 3: "value 3" + }, + "key 4": [ + "value 4", "value 5" + ] + } + + """; + + var jsonWithDifferentSpacing = """ + {"key 1": "value 1", + "key 2: { + "key 3: "value 3" + }, + "key 4": [ + "value 4", "value 5" + ] + } + + """; + + var typeSafeMatcher = new MatchersUtils.IsEqualIgnoreWhitespaceInJsonString(json); + boolean matches = typeSafeMatcher.matchesSafely(jsonWithDifferentSpacing); + + assertTrue(matches); + } + + public void testIsEqualIgnoreWhitespaceInJsonString_MatchesSafely_DoesNotMatch() { + var json = """ + + { + "key 1": "value 1", + "key 2: { + "key 3: "value 3" + }, + "key 4": [ + "value 4", "value 5" + ] + } + + """; + + // one value missing in array + var jsonWithDifferentSpacing = """ + {"key 1": "value 1", + "key 2: { + "key 3: "value 3" + }, + "key 4": [ + "value 4" + ] + } + + """; + + var typeSafeMatcher = new MatchersUtils.IsEqualIgnoreWhitespaceInJsonString(json); + boolean matches = typeSafeMatcher.matchesSafely(jsonWithDifferentSpacing); + + assertFalse(matches); + } + + public void testIsEqualIgnoreWhitespaceInJsonString_DescribeTo() { + var jsonOne = """ + { + "key": "value" + } + """; + + var typeSafeMatcher = new MatchersUtils.IsEqualIgnoreWhitespaceInJsonString(jsonOne); + var description = new TestDescription(""); + + typeSafeMatcher.describeTo(description); + + assertThat(description.toString(), is(""" + a string equal to (when all whitespaces are ignored expect in keys and values): {"key":"value"}""")); + } + + private static class TestDescription implements Description { + + private String descriptionContent; + + TestDescription(String descriptionContent) { + Objects.requireNonNull(descriptionContent); + this.descriptionContent = descriptionContent; + } + + @Override + public Description appendText(String text) { + descriptionContent += text; + return this; + } + + @Override + public Description appendDescriptionOf(SelfDescribing value) { + throw new UnsupportedOperationException(); + } + + @Override + public Description appendValue(Object value) { + descriptionContent += value; + return this; + } + + @SafeVarargs + @Override + public final Description appendValueList(String start, String separator, String end, T... values) { + throw new UnsupportedOperationException(); + } + + @Override + public Description appendValueList(String start, String separator, String end, Iterable values) { + throw new UnsupportedOperationException(); + } + + @Override + public Description appendList(String start, String separator, String end, Iterable values) { + throw new UnsupportedOperationException(); + } + + @Override + public String toString() { + return descriptionContent; + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java index c4878c495a94..88d408d309a7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java @@ -20,7 +20,6 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.common.TruncatorTests; -import org.elasticsearch.xpack.inference.external.action.openai.OpenAiChatCompletionActionTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -46,6 +45,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.API_KEY_HEADER; +import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -163,10 +163,7 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase { var result = listener.actionGet(TIMEOUT); - assertThat( - result.asMap(), - is(OpenAiChatCompletionActionTests.buildExpectedChatCompletionResultMap(List.of("test input string"))) - ); + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("test input string")))); assertThat(webServer.requests(), hasSize(1)); MockRequest request = webServer.requests().get(0); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java index 6ca6985c9e8f..9b0371ad51f8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java @@ -40,9 +40,9 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; -import static org.elasticsearch.xpack.inference.external.action.cohere.CohereCompletionActionTests.buildExpectedChatCompletionResultMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.Matchers.equalTo; @@ -200,7 +200,7 @@ public class CohereActionCreatorTests extends ESTestCase { var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectedChatCompletionResultMap(List.of("result")))); + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result")))); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), is(XContentType.JSON.mediaType())); @@ -260,7 +260,7 @@ public class CohereActionCreatorTests extends ESTestCase { var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectedChatCompletionResultMap(List.of("result")))); + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result")))); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), is(XContentType.JSON.mediaType())); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java index 195f2bab1d6b..12c3d132d124 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java @@ -23,7 +23,6 @@ import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -44,6 +43,8 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; @@ -119,7 +120,7 @@ public class CohereCompletionActionTests extends ESTestCase { var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectedChatCompletionResultMap(List.of("result")))); + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result")))); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); @@ -180,7 +181,7 @@ public class CohereCompletionActionTests extends ESTestCase { var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectedChatCompletionResultMap(List.of("result")))); + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result")))); assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); @@ -198,7 +199,7 @@ public class CohereCompletionActionTests extends ESTestCase { public void testExecute_ThrowsURISyntaxException_ForInvalidUrl() throws IOException { try (var sender = mock(Sender.class)) { var thrownException = expectThrows(IllegalArgumentException.class, () -> createAction("a^b", "api key", "model", sender)); - assertThat(thrownException.getMessage(), is("unable to parse url [a^b]")); + assertThat(thrownException.getMessage(), containsString("unable to parse url [a^b]")); } } @@ -338,13 +339,6 @@ public class CohereCompletionActionTests extends ESTestCase { } } - public static Map buildExpectedChatCompletionResultMap(List results) { - return Map.of( - ChatCompletionResults.COMPLETION, - results.stream().map(result -> Map.of(ChatCompletionResults.Result.RESULT, result)).toList() - ); - } - private CohereCompletionAction createAction(String url, String apiKey, @Nullable String modelName, Sender sender) { var model = CohereCompletionModelTests.createModel(url, apiKey, modelName); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java index 6ca4cb305ab3..dbc97fa2e13d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereEmbeddingsActionTests.java @@ -49,6 +49,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationByte; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; @@ -245,7 +246,7 @@ public class CohereEmbeddingsActionTests extends ESTestCase { IllegalArgumentException.class, () -> createAction("^^", "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender) ); - MatcherAssert.assertThat(thrownException.getMessage(), is("unable to parse url [^^]")); + MatcherAssert.assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java new file mode 100644 index 000000000000..09ef5351eb1f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java @@ -0,0 +1,274 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.googleaistudio; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModelTests; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioServiceTests.buildExpectationCompletions; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; + +public class GoogleAiStudioCompletionActionTests extends ESTestCase { + + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testExecute_ReturnsSuccessfulResponse() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = HttpRequestSenderTests.createSenderWithSingleRequestManager(senderFactory, "test_service")) { + sender.start(); + + String responseJson = """ + { + "candidates": [ + { + "content": { + "parts": [ + { + "text": "result" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 4, + "candidatesTokenCount": 566, + "totalTokenCount": 570 + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var action = createAction(getUrl(webServer), "secret", "model", sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("input")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationCompletions(List.of("result")))); + assertThat(webServer.requests(), hasSize(1)); + assertThat(webServer.requests().get(0).getUri().getQuery(), is("key=secret")); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat( + requestMap, + is( + Map.of( + "contents", + List.of(Map.of("role", "user", "parts", List.of(Map.of("text", "input")))), + "generationConfig", + Map.of("candidateCount", 1) + ) + ) + ); + } + } + + public void testExecute_ThrowsElasticsearchException() { + var sender = mock(Sender.class); + doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), "secret", "model", sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is("failed")); + } + + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() { + var sender = mock(Sender.class); + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new IllegalStateException("failed")); + + return Void.TYPE; + }).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), "secret", "model", sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat( + thrownException.getMessage(), + is(format("Failed to send Google AI Studio completion request to [%s]", getUrl(webServer))) + ); + } + + public void testExecute_ThrowsException() { + var sender = mock(Sender.class); + doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), "secret", "model", sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat( + thrownException.getMessage(), + is(format("Failed to send Google AI Studio completion request to [%s]", getUrl(webServer))) + ); + } + + public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = senderFactory.createSender("test_service")) { + sender.start(); + + String responseJson = """ + { + "candidates": [ + { + "content": { + "parts": [ + { + "text": "result" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 4, + "candidatesTokenCount": 566, + "totalTokenCount": 570 + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var action = createAction(getUrl(webServer), "secret", "model", sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is("Google AI Studio completion only accepts 1 input")); + assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST)); + } + } + + private GoogleAiStudioCompletionAction createAction(String url, String apiKey, String modelName, Sender sender) { + var model = GoogleAiStudioCompletionModelTests.createModel(modelName, url, apiKey); + + return new GoogleAiStudioCompletionAction(sender, model, threadPool); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java index 099ac166dda7..fceea8810f6c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/huggingface/HuggingFaceActionCreatorTests.java @@ -99,7 +99,7 @@ public class HuggingFaceActionCreatorTests extends ESTestCase { assertThat( result.asMap(), is( - SparseEmbeddingResultsTests.buildExpectation( + SparseEmbeddingResultsTests.buildExpectationSparseEmbeddings( List.of(new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of(".", 0.13315596f), false)) ) ) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java index b62e8fc9865e..496238eaad0e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java @@ -35,11 +35,11 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; -import static org.elasticsearch.xpack.inference.external.action.openai.OpenAiChatCompletionActionTests.buildExpectedChatCompletionResultMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; +import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createChatCompletionModel; @@ -333,7 +333,7 @@ public class OpenAiActionCreatorTests extends ESTestCase { var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectedChatCompletionResultMap(List.of("Hello there, how may I assist you today?")))); + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?")))); assertThat(webServer.requests(), hasSize(1)); var request = webServer.requests().get(0); @@ -396,7 +396,7 @@ public class OpenAiActionCreatorTests extends ESTestCase { var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectedChatCompletionResultMap(List.of("Hello there, how may I assist you today?")))); + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?")))); assertThat(webServer.requests(), hasSize(1)); var request = webServer.requests().get(0); @@ -458,7 +458,7 @@ public class OpenAiActionCreatorTests extends ESTestCase { var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectedChatCompletionResultMap(List.of("Hello there, how may I assist you today?")))); + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?")))); assertThat(webServer.requests(), hasSize(1)); var request = webServer.requests().get(0); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java index e28c3e817b35..914ff12db259 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java @@ -24,7 +24,6 @@ import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; @@ -45,8 +44,10 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; +import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createChatCompletionModel; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; @@ -117,7 +118,7 @@ public class OpenAiChatCompletionActionTests extends ESTestCase { var result = listener.actionGet(TIMEOUT); - assertThat(result.asMap(), is(buildExpectedChatCompletionResultMap(List.of("result content")))); + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result content")))); assertThat(webServer.requests(), hasSize(1)); MockRequest request = webServer.requests().get(0); @@ -142,7 +143,7 @@ public class OpenAiChatCompletionActionTests extends ESTestCase { IllegalArgumentException.class, () -> createAction("^^", "org", "secret", "model", "user", sender) ); - assertThat(thrownException.getMessage(), is("unable to parse url [^^]")); + assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]")); } } @@ -276,13 +277,6 @@ public class OpenAiChatCompletionActionTests extends ESTestCase { } } - public static Map buildExpectedChatCompletionResultMap(List results) { - return Map.of( - ChatCompletionResults.COMPLETION, - results.stream().map(result -> Map.of(ChatCompletionResults.Result.RESULT, result)).toList() - ); - } - private OpenAiChatCompletionAction createAction( String url, String org, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java index 260e352fd26c..15b7417912ef 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiEmbeddingsActionTests.java @@ -43,6 +43,7 @@ import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUt import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModelTests.createModel; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; @@ -131,7 +132,7 @@ public class OpenAiEmbeddingsActionTests extends ESTestCase { IllegalArgumentException.class, () -> createAction("^^", "org", "secret", "model", "user", sender) ); - assertThat(thrownException.getMessage(), is("unable to parse url [^^]")); + assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioResponseHandlerTests.java new file mode 100644 index 000000000000..ba20799978d4 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/googleaistudio/GoogleAiStudioResponseHandlerTests.java @@ -0,0 +1,133 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.googleaistudio; + +import org.apache.http.Header; +import org.apache.http.HeaderElement; +import org.apache.http.HttpResponse; +import org.apache.http.StatusLine; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class GoogleAiStudioResponseHandlerTests extends ESTestCase { + + public void testCheckForFailureStatusCode_DoesNotThrowFor200() { + callCheckForFailureStatusCode(200, "id"); + } + + public void testCheckForFailureStatusCode_ThrowsFor500_ShouldRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(500, "id")); + assertTrue(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString("Received a server error status code for request from inference entity id [id] status [500]") + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor503_ShouldRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(503, "id")); + assertTrue(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString( + "The Google AI Studio service may be temporarily overloaded or down for request from inference entity id [id] status [503]" + ) + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor505_ShouldNotRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(505, "id")); + assertFalse(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString("Received a server error status code for request from inference entity id [id] status [505]") + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + public void testCheckForFailureStatusCode_ThrowsFor429_ShouldRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(429, "id")); + assertTrue(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString("Received a rate limit status code for request from inference entity id [id] status [429]") + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.TOO_MANY_REQUESTS)); + } + + public void testCheckForFailureStatusCode_ThrowsFor404_ShouldNotRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(404, "id")); + assertFalse(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString("Resource not found at [null] for request from inference entity id [id] status [404]") + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.NOT_FOUND)); + } + + public void testCheckForFailureStatusCode_ThrowsFor403_ShouldNotRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(403, "id")); + assertFalse(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString("Received a permission denied error status code for request from inference entity id [id] status [403]") + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.FORBIDDEN)); + } + + public void testCheckForFailureStatusCode_ThrowsFor300_ShouldNotRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(300, "id")); + assertFalse(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString("Unhandled redirection for request from inference entity id [id] status [300]") + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.MULTIPLE_CHOICES)); + } + + public void testCheckForFailureStatusCode_ThrowsFor425_ShouldNotRetry() { + var exception = expectThrows(RetryException.class, () -> callCheckForFailureStatusCode(425, "id")); + assertFalse(exception.shouldRetry()); + assertThat( + exception.getCause().getMessage(), + containsString("Received an unsuccessful status code for request from inference entity id [id] status [425]") + ); + assertThat(((ElasticsearchStatusException) exception.getCause()).status(), is(RestStatus.BAD_REQUEST)); + } + + private static void callCheckForFailureStatusCode(int statusCode, String modelId) { + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(statusCode); + + var httpResponse = mock(HttpResponse.class); + when(httpResponse.getStatusLine()).thenReturn(statusLine); + var header = mock(Header.class); + when(header.getElements()).thenReturn(new HeaderElement[] {}); + when(httpResponse.getFirstHeader(anyString())).thenReturn(header); + + var mockRequest = mock(Request.class); + when(mockRequest.getInferenceEntityId()).thenReturn(modelId); + var httpResult = new HttpResult(httpResponse, new byte[] {}); + var handler = new GoogleAiStudioResponseHandler("", (request, result) -> null); + + handler.checkForFailureStatusCode(mockRequest, httpResult); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioRequestTests.java new file mode 100644 index 000000000000..d77c88dacd06 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioRequestTests.java @@ -0,0 +1,55 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.googleaistudio; + +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioSecretSettings; + +import java.net.URI; +import java.net.URISyntaxException; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class GoogleAiStudioRequestTests extends ESTestCase { + + public void testDecorateWithApiKeyParameter() throws URISyntaxException { + var uriString = "https://localhost:3000"; + var secureApiKey = new SecureString("api_key".toCharArray()); + var httpPost = new HttpPost(uriString); + var secretSettings = new GoogleAiStudioSecretSettings(secureApiKey); + + GoogleAiStudioRequest.decorateWithApiKeyParameter(httpPost, secretSettings); + + assertThat(httpPost.getURI(), is(new URI(Strings.format("%s?key=%s", uriString, secureApiKey)))); + } + + public void testDecorateWithApiKeyParameter_ThrowsValidationException_WhenAnyExceptionIsThrown() { + var errorMessage = "something went wrong"; + var cause = new RuntimeException(errorMessage); + var httpPost = mock(HttpPost.class); + when(httpPost.getURI()).thenThrow(cause); + + ValidationException validationException = expectThrows( + ValidationException.class, + () -> GoogleAiStudioRequest.decorateWithApiKeyParameter( + httpPost, + new GoogleAiStudioSecretSettings(new SecureString("abc".toCharArray())) + ) + ); + assertThat(validationException.getCause(), is(cause)); + assertThat(validationException.getMessage(), containsString(errorMessage)); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestEntityTests.java new file mode 100644 index 000000000000..0b8ded1a4f11 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestEntityTests.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.googleaistudio.completion; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioCompletionRequestEntity; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; + +public class GoogleAiStudioCompletionRequestEntityTests extends ESTestCase { + + public void testToXContent_WritesSingleMessage() throws IOException { + var entity = new GoogleAiStudioCompletionRequestEntity(List.of("input")); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "contents": [ + { + "parts": [ + { + "text":"input" + } + ], + "role": "user" + } + ], + "generationConfig": { + "candidateCount": 1 + } + }""")); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestTests.java new file mode 100644 index 000000000000..7d7ee1dcba6c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestTests.java @@ -0,0 +1,73 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.googleaistudio.completion; + +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioCompletionRequest; +import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModelTests; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.endsWith; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; + +public class GoogleAiStudioCompletionRequestTests extends ESTestCase { + + public void testCreateRequest() throws IOException { + var apiKey = "api_key"; + var input = "input"; + + var request = new GoogleAiStudioCompletionRequest(List.of(input), GoogleAiStudioCompletionModelTests.createModel("model", apiKey)); + + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getURI().toString(), endsWith(Strings.format("%s=%s", "key", apiKey))); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat( + requestMap, + is( + Map.of( + "contents", + List.of(Map.of("role", "user", "parts", List.of(Map.of("text", input)))), + "generationConfig", + Map.of("candidateCount", 1) + ) + ) + ); + } + + public void testTruncate_ReturnsSameInstance() { + var request = new GoogleAiStudioCompletionRequest( + List.of("input"), + GoogleAiStudioCompletionModelTests.createModel("model", "api key") + ); + var truncatedRequest = request.truncate(); + + assertThat(truncatedRequest, sameInstance(request)); + } + + public void testTruncationInfo_ReturnsNull() { + var request = new GoogleAiStudioCompletionRequest( + List.of("input"), + GoogleAiStudioCompletionModelTests.createModel("model", "api key") + ); + + assertNull(request.getTruncationInfo()); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioCompletionResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioCompletionResponseEntityTests.java new file mode 100644 index 000000000000..ea4dd6ce47e2 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioCompletionResponseEntityTests.java @@ -0,0 +1,189 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.googleaistudio; + +import org.apache.http.HttpResponse; +import org.elasticsearch.common.ParsingException; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class GoogleAiStudioCompletionResponseEntityTests extends ESTestCase { + + public void testFromResponse_CreatesResultsForASingleItem() throws IOException { + String responseJson = """ + { + "candidates": [ + { + "content": { + "parts": [ + { + "text": "result" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 4, + "candidatesTokenCount": 312, + "totalTokenCount": 316 + } + } + """; + + ChatCompletionResults chatCompletionResults = GoogleAiStudioCompletionResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(chatCompletionResults.getResults().size(), is(1)); + assertThat(chatCompletionResults.getResults().get(0).content(), is("result")); + } + + public void testFromResponse_FailsWhenCandidatesFieldIsNotPresent() { + String responseJson = """ + { + "not_candidates": [ + { + "content": { + "parts": [ + { + "text": "result" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 4, + "candidatesTokenCount": 312, + "totalTokenCount": 316 + } + } + """; + + var thrownException = expectThrows( + IllegalStateException.class, + () -> GoogleAiStudioCompletionResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + assertThat(thrownException.getMessage(), is("Failed to find required field [candidates] in Google AI Studio completion response")); + } + + public void testFromResponse_FailsWhenTextFieldIsNotAString() { + String responseJson = """ + { + "candidates": [ + { + "content": { + "parts": [ + { + "text": { + "key": "value" + } + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 4, + "candidatesTokenCount": 312, + "totalTokenCount": 316 + } + } + """; + + var thrownException = expectThrows( + ParsingException.class, + () -> GoogleAiStudioCompletionResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + assertThat( + thrownException.getMessage(), + is("Failed to parse object: expecting token of type [VALUE_STRING] but found [START_OBJECT]") + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioErrorResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioErrorResponseEntityTests.java new file mode 100644 index 000000000000..61448f2e35bd --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioErrorResponseEntityTests.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.googleaistudio; + +import org.apache.http.HttpResponse; +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class GoogleAiStudioErrorResponseEntityTests extends ESTestCase { + + private static HttpResult getMockResult(String jsonString) { + var response = mock(HttpResponse.class); + return new HttpResult(response, Strings.toUTF8Bytes(jsonString)); + } + + public void testErrorResponse_ExtractsError() { + var result = getMockResult(""" + { + "error": { + "code": 400, + "message": "error message", + "status": "INVALID_ARGUMENT", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.BadRequest", + "fieldViolations": [ + { + "description": "Invalid JSON payload received. Unknown name \\"abc\\": Cannot find field." + } + ] + } + ] + } + } + """); + + var error = GoogleAiStudioErrorResponseEntity.fromResponse(result); + assertNotNull(error); + assertThat(error.getErrorMessage(), is("error message")); + } + + public void testErrorResponse_ReturnsNullIfNoError() { + var result = getMockResult(""" + { + "foo": "bar" + } + """); + + var error = GoogleAiStudioErrorResponseEntity.fromResponse(result); + assertNull(error); + } + + public void testErrorResponse_ReturnsNullIfNotJson() { + var result = getMockResult("error message"); + + var error = GoogleAiStudioErrorResponseEntity.fromResponse(result); + assertNull(error); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntityTests.java index bdb8e38fa822..c3c416d8fe65 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntityTests.java @@ -22,7 +22,7 @@ import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; -import static org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests.buildExpectation; +import static org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests.buildExpectationSparseEmbeddings; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; @@ -46,7 +46,7 @@ public class HuggingFaceElserResponseEntityTests extends ESTestCase { assertThat( parsedResults.asMap(), is( - buildExpectation( + buildExpectationSparseEmbeddings( List.of(new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of(".", 0.13315596f, "the", 0.67472112f), false)) ) ) @@ -73,7 +73,7 @@ public class HuggingFaceElserResponseEntityTests extends ESTestCase { assertThat( parsedResults.asMap(), is( - buildExpectation( + buildExpectationSparseEmbeddings( List.of(new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of(".", 0.13315596f, "the", 0.67472112f), true)) ) ) @@ -101,7 +101,7 @@ public class HuggingFaceElserResponseEntityTests extends ESTestCase { assertThat( parsedResults.asMap(), is( - buildExpectation( + buildExpectationSparseEmbeddings( List.of( new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of(".", 0.13315596f, "the", 0.67472112f), false), new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of("hi", 0.13315596f, "super", 0.67472112f), false) @@ -135,7 +135,7 @@ public class HuggingFaceElserResponseEntityTests extends ESTestCase { assertThat( parsedResults.asMap(), is( - buildExpectation( + buildExpectationSparseEmbeddings( List.of( new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of(".", 0.13315596f, "the", 0.67472112f), true), new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of("hi", 0.13315596f, "super", 0.67472112f), false) @@ -169,7 +169,7 @@ public class HuggingFaceElserResponseEntityTests extends ESTestCase { assertThat( parsedResults.asMap(), is( - buildExpectation( + buildExpectationSparseEmbeddings( List.of( new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of(".", 0.13315596f, "the", 0.67472112f), false), new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of("hi", 0.13315596f, "super", 0.67472112f), false) @@ -239,7 +239,11 @@ public class HuggingFaceElserResponseEntityTests extends ESTestCase { assertThat( parsedResults.asMap(), - is(buildExpectation(List.of(new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of("field", 1.0f), false)))) + is( + buildExpectationSparseEmbeddings( + List.of(new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of("field", 1.0f), false)) + ) + ) ); } @@ -259,7 +263,11 @@ public class HuggingFaceElserResponseEntityTests extends ESTestCase { assertThat( parsedResults.asMap(), - is(buildExpectation(List.of(new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of("field", 4.0294965E10F), false)))) + is( + buildExpectationSparseEmbeddings( + List.of(new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of("field", 4.0294965E10F), false)) + ) + ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChatCompletionResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChatCompletionResultsTests.java index 6bbe6eea5394..1b9b2db660bf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChatCompletionResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChatCompletionResultsTests.java @@ -125,6 +125,13 @@ public class ChatCompletionResultsTests extends AbstractWireSerializingTestCase< return new ChatCompletionResults(chatCompletionResults); } + public static Map buildExpectationCompletion(List results) { + return Map.of( + ChatCompletionResults.COMPLETION, + results.stream().map(result -> Map.of(ChatCompletionResults.Result.RESULT, result)).toList() + ); + } + private static ChatCompletionResults.Result createRandomChatCompletionResult() { return new ChatCompletionResults.Result(randomAlphaOfLengthBetween(10, 300)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java index 727df98d27bb..acc0ef6eed26 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java @@ -87,7 +87,7 @@ public class SparseEmbeddingResultsTests extends AbstractWireSerializingTestCase public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException { var entity = createSparseResult(List.of(createEmbedding(List.of(new SparseEmbedding.WeightedToken("token", 0.1F)), false))); - assertThat(entity.asMap(), is(buildExpectation(List.of(new EmbeddingExpectation(Map.of("token", 0.1F), false))))); + assertThat(entity.asMap(), is(buildExpectationSparseEmbeddings(List.of(new EmbeddingExpectation(Map.of("token", 0.1F), false))))); String xContentResult = Strings.toString(entity, true, true); assertThat(xContentResult, is(""" { @@ -118,7 +118,7 @@ public class SparseEmbeddingResultsTests extends AbstractWireSerializingTestCase assertThat( entity.asMap(), is( - buildExpectation( + buildExpectationSparseEmbeddings( List.of( new EmbeddingExpectation(Map.of("token", 0.1F, "token2", 0.2F), false), new EmbeddingExpectation(Map.of("token3", 0.3F, "token4", 0.4F), false) @@ -170,7 +170,7 @@ public class SparseEmbeddingResultsTests extends AbstractWireSerializingTestCase public record EmbeddingExpectation(Map tokens, boolean isTruncated) {} - public static Map buildExpectation(List embeddings) { + public static Map buildExpectationSparseEmbeddings(List embeddings) { return Map.of( SparseEmbeddingResults.SPARSE_EMBEDDING, embeddings.stream() diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java index 48784b9bd865..a07f75ec2c53 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingByteResultsTests.java @@ -131,7 +131,7 @@ public class TextEmbeddingByteResultsTests extends AbstractWireSerializingTestCa } } - public static Map buildExpectation(List> embeddings) { + public static Map buildExpectationByte(List> embeddings) { return Map.of( TextEmbeddingByteResults.TEXT_EMBEDDING_BYTES, embeddings.stream().map(embedding -> Map.of(ByteEmbedding.EMBEDDING, embedding)).toList() diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java index 0a34de7b342e..edd9637d92dd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java @@ -307,7 +307,19 @@ public class ServiceUtilsTests extends ESTestCase { assertNull(uri); assertThat(validation.validationErrors().size(), is(1)); - assertThat(validation.validationErrors().get(0), is("[scope] Invalid url [^^] received for field [name]")); + assertThat(validation.validationErrors().get(0), containsString("[scope] Invalid url [^^] received for field [name]")); + } + + public void testConvertToUri_AddsValidationError_WhenUrlIsInvalid_PreservesReason() { + var validation = new ValidationException(); + var uri = convertToUri("^^", "name", "scope", validation); + + assertNull(uri); + assertThat(validation.validationErrors().size(), is(1)); + assertThat( + validation.validationErrors().get(0), + is("[scope] Invalid url [^^] received for field [name]. Error: unable to parse url [^^]. Reason: Illegal character in path") + ); } public void testCreateUri_CreatesUri() { @@ -320,7 +332,7 @@ public class ServiceUtilsTests extends ESTestCase { public void testCreateUri_ThrowsException_WithInvalidUrl() { var exception = expectThrows(IllegalArgumentException.class, () -> createUri("^^")); - assertThat(exception.getMessage(), is("unable to parse url [^^]")); + assertThat(exception.getMessage(), containsString("unable to parse url [^^]")); } public void testCreateUri_ThrowsException_WithNullUrl() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java index 303ed1cab2c5..f4dad7546c8a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java @@ -229,7 +229,9 @@ public class CohereServiceSettingsTests extends AbstractWireSerializingTestCase< MatcherAssert.assertThat( thrownException.getMessage(), - is(Strings.format("Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s];", url, ServiceFields.URL)) + containsString( + Strings.format("Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s]", url, ServiceFields.URL) + ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettingsTests.java new file mode 100644 index 000000000000..05515bf9e386 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettingsTests.java @@ -0,0 +1,115 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elasticsearch; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; + +import java.io.IOException; +import java.util.HashMap; + +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; + +public class CustomElandRerankTaskSettingsTests extends AbstractWireSerializingTestCase { + + public void testDefaultsFromMap_MapIsNull_ReturnsDefaultSettings() { + var customElandRerankTaskSettings = CustomElandRerankTaskSettings.defaultsFromMap(null); + + assertThat(customElandRerankTaskSettings, sameInstance(CustomElandRerankTaskSettings.DEFAULT_SETTINGS)); + } + + public void testDefaultsFromMap_MapIsEmpty_ReturnsDefaultSettings() { + var customElandRerankTaskSettings = CustomElandRerankTaskSettings.defaultsFromMap(new HashMap<>()); + + assertThat(customElandRerankTaskSettings, sameInstance(CustomElandRerankTaskSettings.DEFAULT_SETTINGS)); + } + + public void testDefaultsFromMap_ExtractedReturnDocumentsNull_SetsReturnDocumentToTrue() { + var customElandRerankTaskSettings = CustomElandRerankTaskSettings.defaultsFromMap(new HashMap<>()); + + assertThat(customElandRerankTaskSettings.returnDocuments(), is(Boolean.TRUE)); + } + + public void testFromMap_MapIsNull_ReturnsDefaultSettings() { + var customElandRerankTaskSettings = CustomElandRerankTaskSettings.fromMap(null); + + assertThat(customElandRerankTaskSettings, sameInstance(CustomElandRerankTaskSettings.DEFAULT_SETTINGS)); + } + + public void testFromMap_MapIsEmpty_ReturnsDefaultSettings() { + var customElandRerankTaskSettings = CustomElandRerankTaskSettings.fromMap(new HashMap<>()); + + assertThat(customElandRerankTaskSettings, sameInstance(CustomElandRerankTaskSettings.DEFAULT_SETTINGS)); + } + + public void testToXContent_WritesAllValues() throws IOException { + var serviceSettings = new CustomElandRerankTaskSettings(Boolean.TRUE); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"return_documents":true}""")); + } + + public void testToXContent_DoesNotWriteReturnDocuments_IfNull() throws IOException { + Boolean bool = null; + var serviceSettings = new CustomElandRerankTaskSettings(bool); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {}""")); + } + + public void testOf_PrefersNonNullRequestTaskSettings() { + var originalSettings = new CustomElandRerankTaskSettings(Boolean.FALSE); + var requestTaskSettings = new CustomElandRerankTaskSettings(Boolean.TRUE); + + var taskSettings = CustomElandRerankTaskSettings.of(originalSettings, requestTaskSettings); + + assertThat(taskSettings, sameInstance(requestTaskSettings)); + } + + public void testOf_UseOriginalSettings_IfRequestSettingsValuesAreNull() { + Boolean bool = null; + var originalSettings = new CustomElandRerankTaskSettings(Boolean.TRUE); + var requestTaskSettings = new CustomElandRerankTaskSettings(bool); + + var taskSettings = CustomElandRerankTaskSettings.of(originalSettings, requestTaskSettings); + + assertThat(taskSettings, sameInstance(originalSettings)); + } + + private static CustomElandRerankTaskSettings createRandom() { + return new CustomElandRerankTaskSettings(randomOptionalBoolean()); + } + + @Override + protected Writeable.Reader instanceReader() { + return CustomElandRerankTaskSettings::new; + } + + @Override + protected CustomElandRerankTaskSettings createTestInstance() { + return createRandom(); + } + + @Override + protected CustomElandRerankTaskSettings mutateInstance(CustomElandRerankTaskSettings instance) throws IOException { + return randomValueOtherThan(instance, CustomElandRerankTaskSettingsTests::createRandom); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioSecretSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioSecretSettingsTests.java new file mode 100644 index 000000000000..a0339934783d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioSecretSettingsTests.java @@ -0,0 +1,71 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googleaistudio; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class GoogleAiStudioSecretSettingsTests extends AbstractWireSerializingTestCase { + + public static GoogleAiStudioSecretSettings createRandom() { + return new GoogleAiStudioSecretSettings(randomSecureStringOfLength(15)); + } + + public void testFromMap() { + var apiKey = "abc"; + var secretSettings = GoogleAiStudioSecretSettings.fromMap(new HashMap<>(Map.of(GoogleAiStudioSecretSettings.API_KEY, apiKey))); + + assertThat(new GoogleAiStudioSecretSettings(new SecureString(apiKey.toCharArray())), is(secretSettings)); + } + + public void testFromMap_ReturnsNull_WhenMapIsNull() { + assertNull(GoogleAiStudioSecretSettings.fromMap(null)); + } + + public void testFromMap_ThrowsError_WhenApiKeyIsNull() { + var throwException = expectThrows(ValidationException.class, () -> GoogleAiStudioSecretSettings.fromMap(new HashMap<>())); + + assertThat(throwException.getMessage(), containsString("[secret_settings] must have [api_key] set")); + } + + public void testFromMap_ThrowsError_WhenApiKeyIsEmpty() { + var thrownException = expectThrows( + ValidationException.class, + () -> GoogleAiStudioSecretSettings.fromMap(new HashMap<>(Map.of(GoogleAiStudioSecretSettings.API_KEY, ""))) + ); + + assertThat( + thrownException.getMessage(), + containsString("[secret_settings] Invalid value empty string. [api_key] must be a non-empty string") + ); + } + + @Override + protected Writeable.Reader instanceReader() { + return GoogleAiStudioSecretSettings::new; + } + + @Override + protected GoogleAiStudioSecretSettings createTestInstance() { + return createRandom(); + } + + @Override + protected GoogleAiStudioSecretSettings mutateInstance(GoogleAiStudioSecretSettings instance) throws IOException { + return randomValueOtherThan(instance, GoogleAiStudioSecretSettingsTests::createRandom); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java new file mode 100644 index 000000000000..f157622ea729 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -0,0 +1,630 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googleaistudio; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel; +import org.hamcrest.MatcherAssert; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty; +import static org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModelTests.createModel; +import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasSize; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class GoogleAiStudioServiceTests extends ESTestCase { + + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testParseRequestConfig_CreatesAGoogleAiStudioCompletionModel() throws IOException { + var apiKey = "apiKey"; + var modelId = "model"; + + try (var service = createGoogleAiStudioService()) { + ActionListener modelListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); + + var completionModel = (GoogleAiStudioCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getSecretSettings().apiKey().toString(), is(apiKey)); + }, e -> fail("Model parsing should have succeeded, but failed: " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.COMPLETION, + getRequestConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + new HashMap<>(Map.of()), + getSecretSettingsMap(apiKey) + ), + Set.of(), + modelListener + ); + } + } + + public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException { + try (var service = createGoogleAiStudioService()) { + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "The [googleaistudio] service does not support task type [sparse_embedding]" + ); + + service.parseRequestConfig( + "id", + TaskType.SPARSE_EMBEDDING, + getRequestConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, "model")), + new HashMap<>(Map.of()), + getSecretSettingsMap("secret") + ), + Set.of(), + failureListener + ); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createGoogleAiStudioService()) { + var config = getRequestConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, "model")), + getTaskSettingsMapEmpty(), + getSecretSettingsMap("secret") + ); + config.put("extra_key", "value"); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [googleaistudio] service" + ); + service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), failureListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { + try (var service = createGoogleAiStudioService()) { + Map serviceSettings = new HashMap<>(Map.of(ServiceFields.MODEL_ID, "model")); + serviceSettings.put("extra_key", "value"); + + var config = getRequestConfigMap(serviceSettings, getTaskSettingsMapEmpty(), getSecretSettingsMap("api_key")); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [googleaistudio] service" + ); + service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), failureListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { + try (var service = createGoogleAiStudioService()) { + Map taskSettingsMap = new HashMap<>(); + taskSettingsMap.put("extra_key", "value"); + + var config = getRequestConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, "model")), + taskSettingsMap, + getSecretSettingsMap("secret") + ); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [googleaistudio] service" + ); + service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), failureListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException { + try (var service = createGoogleAiStudioService()) { + Map secretSettings = getSecretSettingsMap("secret"); + secretSettings.put("extra_key", "value"); + + var config = getRequestConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, "model")), + getTaskSettingsMapEmpty(), + secretSettings + ); + + var failureListener = getModelListenerForException( + ElasticsearchStatusException.class, + "Model configuration contains settings [{extra_key=value}] unknown to the [googleaistudio] service" + ); + service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), failureListener); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAGoogleAiStudioCompletionModel() throws IOException { + var modelId = "model"; + var apiKey = "apiKey"; + + try (var service = createGoogleAiStudioService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + getTaskSettingsMapEmpty(), + getSecretSettingsMap(apiKey) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); + + var completionModel = (GoogleAiStudioCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertThat(completionModel.getSecretSettings().apiKey().toString(), is(apiKey)); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + var modelId = "model"; + var apiKey = "apiKey"; + + try (var service = createGoogleAiStudioService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + getTaskSettingsMapEmpty(), + getSecretSettingsMap(apiKey) + ); + persistedConfig.config().put("extra_key", "value"); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); + + var completionModel = (GoogleAiStudioCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertThat(completionModel.getSecretSettings().apiKey(), is(apiKey)); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { + var modelId = "model"; + var apiKey = "apiKey"; + + try (var service = createGoogleAiStudioService()) { + var secretSettingsMap = getSecretSettingsMap(apiKey); + secretSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + getTaskSettingsMapEmpty(), + secretSettingsMap + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); + + var completionModel = (GoogleAiStudioCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertThat(completionModel.getSecretSettings().apiKey().toString(), is(apiKey)); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + var modelId = "model"; + var apiKey = "apiKey"; + + try (var service = createGoogleAiStudioService()) { + Map serviceSettingsMap = new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)); + serviceSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getTaskSettingsMapEmpty(), getSecretSettingsMap(apiKey)); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); + + var completionModel = (GoogleAiStudioCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertThat(completionModel.getSecretSettings().apiKey().toString(), is(apiKey)); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + var modelId = "model"; + var apiKey = "apiKey"; + + try (var service = createGoogleAiStudioService()) { + Map taskSettings = getTaskSettingsMapEmpty(); + taskSettings.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + taskSettings, + getSecretSettingsMap(apiKey) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); + + var completionModel = (GoogleAiStudioCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertThat(completionModel.getSecretSettings().apiKey().toString(), is(apiKey)); + } + } + + public void testParsePersistedConfig_CreatesAGoogleAiStudioCompletionModel() throws IOException { + var modelId = "model"; + + try (var service = createGoogleAiStudioService()) { + var persistedConfig = getPersistedConfigMap(new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), getTaskSettingsMapEmpty()); + + var model = service.parsePersistedConfig("id", TaskType.COMPLETION, persistedConfig.config()); + + assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); + + var completionModel = (GoogleAiStudioCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertNull(completionModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + var modelId = "model"; + + try (var service = createGoogleAiStudioService()) { + var persistedConfig = getPersistedConfigMap(new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), getTaskSettingsMapEmpty()); + persistedConfig.config().put("extra_key", "value"); + + var model = service.parsePersistedConfig("id", TaskType.COMPLETION, persistedConfig.config()); + + assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); + + var completionModel = (GoogleAiStudioCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertNull(completionModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + var modelId = "model"; + + try (var service = createGoogleAiStudioService()) { + Map serviceSettingsMap = new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)); + serviceSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getTaskSettingsMapEmpty()); + + var model = service.parsePersistedConfig("id", TaskType.COMPLETION, persistedConfig.config()); + + assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); + + var completionModel = (GoogleAiStudioCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertNull(completionModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + var modelId = "model"; + + try (var service = createGoogleAiStudioService()) { + Map taskSettings = getTaskSettingsMapEmpty(); + taskSettings.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap(new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), taskSettings); + + var model = service.parsePersistedConfig("id", TaskType.COMPLETION, persistedConfig.config()); + + assertThat(model, instanceOf(GoogleAiStudioCompletionModel.class)); + + var completionModel = (GoogleAiStudioCompletionModel) model; + assertThat(completionModel.getServiceSettings().modelId(), is(modelId)); + assertThat(completionModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertNull(completionModel.getSecretSettings()); + } + } + + public void testInfer_ThrowsErrorWhenModelIsNotGoogleAiStudioModel() throws IOException { + var sender = mock(Sender.class); + + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender(anyString())).thenReturn(sender); + + var mockModel = getInvalidModel("model_id", "service_name"); + + try (var service = new GoogleAiStudioService(factory, createWithEmptySettings(threadPool))) { + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + mockModel, + null, + List.of(""), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + MatcherAssert.assertThat( + thrownException.getMessage(), + is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") + ); + + verify(factory, times(1)).createSender(anyString()); + verify(sender, times(1)).start(); + } + + verify(sender, times(1)).close(); + verifyNoMoreInteractions(factory); + verifyNoMoreInteractions(sender); + } + + public void testInfer_SendsRequest() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + String responseJson = """ + { + "candidates": [ + { + "content": { + "parts": [ + { + "text": "result" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 4, + "candidatesTokenCount": 215, + "totalTokenCount": 219 + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createModel("model", getUrl(webServer), "secret"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("input"), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationCompletions(List.of("result")))); + assertThat(webServer.requests(), hasSize(1)); + assertThat(webServer.requests().get(0).getUri().getQuery(), is("key=secret")); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat( + requestMap, + is( + Map.of( + "contents", + List.of(Map.of("role", "user", "parts", List.of(Map.of("text", "input")))), + "generationConfig", + Map.of("candidateCount", 1) + ) + ) + ); + } + } + + public void testInfer_ResourceNotFound() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "error": { + "message": "error" + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); + + var model = createModel("model", getUrl(webServer), "secret"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(error.getMessage(), containsString("Resource not found at ")); + assertThat(error.getMessage(), containsString("Error message: [error]")); + assertThat(webServer.requests(), hasSize(1)); + } + } + + public static Map buildExpectationCompletions(List completions) { + return Map.of( + ChatCompletionResults.COMPLETION, + completions.stream().map(completion -> Map.of(ChatCompletionResults.Result.RESULT, completion)).collect(Collectors.toList()) + ); + } + + private static ActionListener getModelListenerForException(Class exceptionClass, String expectedMessage) { + return ActionListener.wrap((model) -> fail("Model parsing should have failed"), e -> { + assertThat(e, Matchers.instanceOf(exceptionClass)); + assertThat(e.getMessage(), is(expectedMessage)); + }); + } + + private Map getRequestConfigMap( + Map serviceSettings, + Map taskSettings, + Map secretSettings + ) { + var builtServiceSettings = new HashMap<>(); + builtServiceSettings.putAll(serviceSettings); + builtServiceSettings.putAll(secretSettings); + + return new HashMap<>( + Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings) + ); + } + + private GoogleAiStudioService createGoogleAiStudioService() { + return new GoogleAiStudioService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + } + + private PersistedConfig getPersistedConfigMap( + Map serviceSettings, + Map taskSettings, + Map secretSettings + ) { + + return new PersistedConfig( + new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, serviceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)), + new HashMap<>(Map.of(ModelSecrets.SECRET_SETTINGS, secretSettings)) + ); + } + + private PersistedConfig getPersistedConfigMap(Map serviceSettings, Map taskSettings) { + return new PersistedConfig( + new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, serviceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)), + null + ); + } + + private record PersistedConfig(Map config, Map secrets) {} + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModelTests.java new file mode 100644 index 000000000000..1f8233f7eb10 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModelTests.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googleaistudio.completion; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioSecretSettings; + +import java.net.URISyntaxException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class GoogleAiStudioCompletionModelTests extends ESTestCase { + + public void testCreateModel_AlwaysWithEmptyTaskSettings() { + var model = new GoogleAiStudioCompletionModel( + "inference entity id", + TaskType.COMPLETION, + "service", + new HashMap<>(Map.of("model_id", "model")), + new HashMap<>(Map.of()), + null + ); + + assertThat(model.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + } + + public void testBuildUri() throws URISyntaxException { + assertThat( + GoogleAiStudioCompletionModel.buildUri("model").toString(), + is("https://generativelanguage.googleapis.com/v1/models/model:generateContent") + ); + } + + public static GoogleAiStudioCompletionModel createModel(String model, String apiKey) { + return new GoogleAiStudioCompletionModel( + "id", + TaskType.COMPLETION, + "service", + new GoogleAiStudioCompletionServiceSettings(model, null), + EmptyTaskSettings.INSTANCE, + new GoogleAiStudioSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static GoogleAiStudioCompletionModel createModel(String model, String url, String apiKey) { + return new GoogleAiStudioCompletionModel( + "id", + TaskType.COMPLETION, + "service", + url, + new GoogleAiStudioCompletionServiceSettings(model, null), + EmptyTaskSettings.INSTANCE, + new GoogleAiStudioSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettingsTests.java new file mode 100644 index 000000000000..46e6e60af493 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionServiceSettingsTests.java @@ -0,0 +1,76 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googleaistudio.completion; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class GoogleAiStudioCompletionServiceSettingsTests extends AbstractWireSerializingTestCase { + + public static GoogleAiStudioCompletionServiceSettings createRandom() { + return new GoogleAiStudioCompletionServiceSettings(randomAlphaOfLength(8), randomFrom(RateLimitSettingsTests.createRandom(), null)); + } + + public void testFromMap_Request_CreatesSettingsCorrectly() { + var model = "some model"; + + var serviceSettings = GoogleAiStudioCompletionServiceSettings.fromMap(new HashMap<>(Map.of(ServiceFields.MODEL_ID, model))); + + assertThat(serviceSettings, is(new GoogleAiStudioCompletionServiceSettings(model, null))); + } + + public void testToXContent_WritesAllValues() throws IOException { + var entity = new GoogleAiStudioCompletionServiceSettings("model", null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"model_id":"model","rate_limit":{"requests_per_minute":360}}""")); + } + + public void testToFilteredXContent_WritesAllValues_Except_RateLimit() throws IOException { + var entity = new GoogleAiStudioCompletionServiceSettings("model", null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + var filteredXContent = entity.getFilteredXContentObject(); + filteredXContent.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"model_id":"model"}""")); + } + + @Override + protected Writeable.Reader instanceReader() { + return GoogleAiStudioCompletionServiceSettings::new; + } + + @Override + protected GoogleAiStudioCompletionServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected GoogleAiStudioCompletionServiceSettings mutateInstance(GoogleAiStudioCompletionServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, GoogleAiStudioCompletionServiceSettingsTests::createRandom); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettingsTests.java index 9d92f756dd31..91b91593adee 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceSettingsTests.java @@ -141,7 +141,9 @@ public class HuggingFaceServiceSettingsTests extends AbstractWireSerializingTest assertThat( thrownException.getMessage(), - is(Strings.format("Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s];", url, ServiceFields.URL)) + containsString( + Strings.format("Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s]", url, ServiceFields.URL) + ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index 54eef58fb2f7..914775bf9fa6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -494,7 +494,7 @@ public class HuggingFaceServiceTests extends ESTestCase { assertThat( result.asMap(), Matchers.is( - SparseEmbeddingResultsTests.buildExpectation( + SparseEmbeddingResultsTests.buildExpectationSparseEmbeddings( List.of(new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of(".", 0.13315596f), false)) ) ) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserModelTests.java index 33dbee2a32b9..2ad2c12b4a97 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserModelTests.java @@ -11,13 +11,13 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; -import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.containsString; public class HuggingFaceElserModelTests extends ESTestCase { public void testThrowsURISyntaxException_ForInvalidUrl() { var thrownException = expectThrows(IllegalArgumentException.class, () -> createModel("^^", "secret")); - assertThat(thrownException.getMessage(), is("unable to parse url [^^]")); + assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]")); } public static HuggingFaceElserModel createModel(String url, String apiKey) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettingsTests.java index bd6a5007b72e..57f9c59b65e1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserServiceSettingsTests.java @@ -77,9 +77,9 @@ public class HuggingFaceElserServiceSettingsTests extends AbstractWireSerializin assertThat( thrownException.getMessage(), - is( + containsString( Strings.format( - "Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s];", + "Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s]", url, HuggingFaceElserServiceSettings.URL ) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModelTests.java index d579da2d9fbc..baf5467d8fe0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/embeddings/HuggingFaceEmbeddingsModelTests.java @@ -16,13 +16,13 @@ import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; -import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.containsString; public class HuggingFaceEmbeddingsModelTests extends ESTestCase { public void testThrowsURISyntaxException_ForInvalidUrl() { var thrownException = expectThrows(IllegalArgumentException.class, () -> createModel("^^", "secret")); - assertThat(thrownException.getMessage(), is("unable to parse url [^^]")); + assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]")); } public static HuggingFaceEmbeddingsModel createModel(String url, String apiKey) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettingsTests.java index 75ea63eba8a3..186ca8942641 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionServiceSettingsTests.java @@ -170,7 +170,9 @@ public class OpenAiChatCompletionServiceSettingsTests extends AbstractWireSerial assertThat( thrownException.getMessage(), - is(Strings.format("Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s];", url, ServiceFields.URL)) + containsString( + Strings.format("Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s]", url, ServiceFields.URL) + ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettingsTests.java index 1be70ee58683..438f895fe48a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettingsTests.java @@ -335,7 +335,9 @@ public class OpenAiEmbeddingsServiceSettingsTests extends AbstractWireSerializin assertThat( thrownException.getMessage(), - is(Strings.format("Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s];", url, ServiceFields.URL)) + containsString( + Strings.format("Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s]", url, ServiceFields.URL) + ) ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunner.java index 637b37853363..06075363997c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunner.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunner.java @@ -17,6 +17,7 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.UnsafePlainActionFuture; import org.elasticsearch.client.internal.Client; import org.elasticsearch.client.internal.OriginSettingClient; import org.elasticsearch.common.settings.Settings; @@ -31,6 +32,7 @@ import org.elasticsearch.tasks.TaskId; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.dataframe.DestinationIndex; import org.elasticsearch.xpack.ml.dataframe.stats.DataCountsTracker; import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker; @@ -100,7 +102,9 @@ public class InferenceRunner { LOGGER.info("[{}] Started inference on test data against model [{}]", config.getId(), modelId); try { - PlainActionFuture localModelPlainActionFuture = new PlainActionFuture<>(); + PlainActionFuture localModelPlainActionFuture = new UnsafePlainActionFuture<>( + MachineLearning.UTILITY_THREAD_POOL_NAME + ); modelLoadingService.getModelForInternalInference(modelId, localModelPlainActionFuture); InferenceState inferenceState = restoreInferenceState(); dataCountsTracker.setTestDocsCount(inferenceState.processedTestDocsCount); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java index e181e1fc8668..7052e6f147b3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java @@ -13,6 +13,7 @@ import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.search.SearchPhaseExecutionException; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.UnsafePlainActionFuture; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterState; @@ -205,7 +206,9 @@ public class TrainedModelAssignmentNodeService implements ClusterStateListener { if (stopped) { return; } - final PlainActionFuture listener = new PlainActionFuture<>(); + final PlainActionFuture listener = new UnsafePlainActionFuture<>( + MachineLearning.UTILITY_THREAD_POOL_NAME + ); try { deploymentManager.startDeployment(loadingTask, listener); // This needs to be synchronous here in the utility thread to keep queueing order diff --git a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/cache/full/CacheService.java b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/cache/full/CacheService.java index 6e480a21d507..636d138c8a3e 100644 --- a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/cache/full/CacheService.java +++ b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/cache/full/CacheService.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.store.AlreadyClosedException; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.UnsafePlainActionFuture; import org.elasticsearch.blobcache.common.ByteRange; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.UUIDs; @@ -347,7 +348,7 @@ public class CacheService extends AbstractLifecycleComponent { if (allowShardsEvictions) { final ShardEviction shardEviction = new ShardEviction(snapshotUUID, snapshotIndexName, shardId); pendingShardsEvictions.computeIfAbsent(shardEviction, shard -> { - final PlainActionFuture future = new PlainActionFuture<>(); + final PlainActionFuture future = new UnsafePlainActionFuture<>(ThreadPool.Names.GENERIC); threadPool.generic().execute(new AbstractRunnable() { @Override protected void doRun() { diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/test/SecuritySingleNodeTestCase.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/test/SecuritySingleNodeTestCase.java index 16a3ea53eeea..2eb45021a5bf 100644 --- a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/test/SecuritySingleNodeTestCase.java +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/test/SecuritySingleNodeTestCase.java @@ -93,7 +93,7 @@ public abstract class SecuritySingleNodeTestCase extends ESSingleNodeTestCase { return getTaskWithId(state, TASK_NAME) == null; } - protected void awaitSecurityMigration() { + private void awaitSecurityMigration() { final var latch = new CountDownLatch(1); ClusterService clusterService = getInstanceFromNode(ClusterService.class); clusterService.addListener((event) -> { diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/esnative/ReservedRealmElasticAutoconfigIntegTests.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/esnative/ReservedRealmElasticAutoconfigIntegTests.java index c04630d45795..ae48d7563494 100644 --- a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/esnative/ReservedRealmElasticAutoconfigIntegTests.java +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/esnative/ReservedRealmElasticAutoconfigIntegTests.java @@ -15,7 +15,10 @@ import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.client.Request; import org.elasticsearch.client.RequestOptions; import org.elasticsearch.client.ResponseException; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.MockSecureSettings; import org.elasticsearch.common.settings.SecureString; @@ -29,7 +32,10 @@ import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken import org.elasticsearch.xpack.core.security.test.TestRestrictedIndices; import org.junit.BeforeClass; +import java.util.concurrent.CountDownLatch; + import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.xpack.core.security.action.UpdateIndexMigrationVersionAction.MIGRATION_VERSION_CUSTOM_KEY; import static org.elasticsearch.xpack.security.support.SecuritySystemIndices.SECURITY_MAIN_ALIAS; import static org.hamcrest.Matchers.is; @@ -64,6 +70,25 @@ public class ReservedRealmElasticAutoconfigIntegTests extends SecuritySingleNode return null; // no bootstrap password for this test } + private boolean isMigrationComplete(ClusterState state) { + IndexMetadata indexMetadata = state.metadata().getIndices().get(TestRestrictedIndices.INTERNAL_SECURITY_MAIN_INDEX_7); + return indexMetadata.getCustomData(MIGRATION_VERSION_CUSTOM_KEY) != null; + } + + private void awaitSecurityMigrationRanOnce() { + final var latch = new CountDownLatch(1); + ClusterService clusterService = getInstanceFromNode(ClusterService.class); + clusterService.addListener((event) -> { + if (isMigrationComplete(event.state())) { + latch.countDown(); + } + }); + if (isMigrationComplete(clusterService.state())) { + latch.countDown(); + } + safeAwait(latch); + } + public void testAutoconfigFailedPasswordPromotion() { try { // prevents the .security index from being created automatically (after elastic user authentication) @@ -80,7 +105,7 @@ public class ReservedRealmElasticAutoconfigIntegTests extends SecuritySingleNode assertThat(getIndexResponse.getIndices().length, is(1)); assertThat(getIndexResponse.getIndices()[0], is(TestRestrictedIndices.INTERNAL_SECURITY_MAIN_INDEX_7)); // Security migration needs to finish before deleting the index - awaitSecurityMigration(); + awaitSecurityMigrationRanOnce(); DeleteIndexRequest deleteIndexRequest = new DeleteIndexRequest(getIndexResponse.getIndices()); assertAcked(client().admin().indices().delete(deleteIndexRequest).actionGet()); } @@ -140,7 +165,7 @@ public class ReservedRealmElasticAutoconfigIntegTests extends SecuritySingleNode putUserRequest.roles(Strings.EMPTY_ARRAY); client().execute(PutUserAction.INSTANCE, putUserRequest).get(); // Security migration needs to finish before making the cluster read only - awaitSecurityMigration(); + awaitSecurityMigrationRanOnce(); // but then make the cluster read-only ClusterUpdateSettingsRequest updateSettingsRequest = new ClusterUpdateSettingsRequest(); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java index ecfce2f85842..84fa92bb7d2d 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java @@ -25,6 +25,7 @@ import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.support.ActionFilter; import org.elasticsearch.action.support.DestructiveOperations; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.UnsafePlainActionFuture; import org.elasticsearch.bootstrap.BootstrapCheck; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; @@ -2130,7 +2131,7 @@ public class Security extends Plugin return; } - final PlainActionFuture future = new PlainActionFuture<>(); + final PlainActionFuture future = new UnsafePlainActionFuture<>(ThreadPool.Names.GENERIC); getClient().execute( ActionTypes.RELOAD_REMOTE_CLUSTER_CREDENTIALS_ACTION, new TransportReloadRemoteClusterCredentialsAction.Request(settingsWithKeystore),