Merge remote-tracking branch 'origin/main' into lucene_snapshot

This commit is contained in:
elasticsearchmachine 2024-05-28 10:01:58 +00:00
commit 83f51b477e
136 changed files with 4301 additions and 677 deletions

View file

@ -1,112 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.gradle.internal.testfixtures;
import org.gradle.api.GradleException;
import org.gradle.api.NamedDomainObjectContainer;
import org.gradle.api.Project;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
public class TestFixtureExtension {
private final Project project;
final NamedDomainObjectContainer<Project> fixtures;
final Map<String, String> serviceToProjectUseMap = new HashMap<>();
public TestFixtureExtension(Project project) {
this.project = project;
this.fixtures = project.container(Project.class);
}
public void useFixture() {
useFixture(this.project.getPath());
}
public void useFixture(String path) {
addFixtureProject(path);
serviceToProjectUseMap.put(path, this.project.getPath());
}
public void useFixture(String path, String serviceName) {
addFixtureProject(path);
String key = getServiceNameKey(path, serviceName);
serviceToProjectUseMap.put(key, this.project.getPath());
Optional<String> otherProject = this.findOtherProjectUsingService(key);
if (otherProject.isPresent()) {
throw new GradleException(
String.format(
Locale.ROOT,
"Projects %s and %s both claim the %s service defined in the docker-compose.yml of "
+ "%sThis is not supported because it breaks running in parallel. Configure dedicated "
+ "services for each project and use those instead.",
otherProject.get(),
this.project.getPath(),
serviceName,
path
)
);
}
}
private String getServiceNameKey(String fixtureProjectPath, String serviceName) {
return fixtureProjectPath + "::" + serviceName;
}
private Optional<String> findOtherProjectUsingService(String serviceName) {
return this.project.getRootProject()
.getAllprojects()
.stream()
.filter(p -> p.equals(this.project) == false)
.filter(p -> p.getExtensions().findByType(TestFixtureExtension.class) != null)
.map(project -> project.getExtensions().getByType(TestFixtureExtension.class))
.flatMap(ext -> ext.serviceToProjectUseMap.entrySet().stream())
.filter(entry -> entry.getKey().equals(serviceName))
.map(Map.Entry::getValue)
.findAny();
}
private void addFixtureProject(String path) {
Project fixtureProject = this.project.findProject(path);
if (fixtureProject == null) {
throw new IllegalArgumentException("Could not find test fixture " + fixtureProject);
}
if (fixtureProject.file(TestFixturesPlugin.DOCKER_COMPOSE_YML).exists() == false) {
throw new IllegalArgumentException(
"Project " + path + " is not a valid test fixture: missing " + TestFixturesPlugin.DOCKER_COMPOSE_YML
);
}
fixtures.add(fixtureProject);
// Check for exclusive access
Optional<String> otherProject = this.findOtherProjectUsingService(path);
if (otherProject.isPresent()) {
throw new GradleException(
String.format(
Locale.ROOT,
"Projects %s and %s both claim all services from %s. This is not supported because it"
+ " breaks running in parallel. Configure specific services in docker-compose.yml "
+ "for each and add the service name to `useFixture`",
otherProject.get(),
this.project.getPath(),
path
)
);
}
}
boolean isServiceRequired(String serviceName, String fixtureProject) {
if (serviceToProjectUseMap.containsKey(fixtureProject)) {
return true;
}
return serviceToProjectUseMap.containsKey(getServiceNameKey(fixtureProject, serviceName));
}
}

View file

@ -70,7 +70,6 @@ public class TestFixturesPlugin implements Plugin<Project> {
project.getRootProject().getPluginManager().apply(DockerSupportPlugin.class);
TaskContainer tasks = project.getTasks();
TestFixtureExtension extension = project.getExtensions().create("testFixtures", TestFixtureExtension.class, project);
Provider<DockerComposeThrottle> dockerComposeThrottle = project.getGradle()
.getSharedServices()
.registerIfAbsent(DOCKER_COMPOSE_THROTTLE, DockerComposeThrottle.class, spec -> spec.getMaxParallelUsages().set(1));
@ -84,73 +83,63 @@ public class TestFixturesPlugin implements Plugin<Project> {
File testFixturesDir = project.file("testfixtures_shared");
ext.set("testFixturesDir", testFixturesDir);
if (project.file(DOCKER_COMPOSE_YML).exists()) {
project.getPluginManager().apply(BasePlugin.class);
project.getPluginManager().apply(DockerComposePlugin.class);
TaskProvider<TestFixtureTask> preProcessFixture = project.getTasks().register("preProcessFixture", TestFixtureTask.class, t -> {
t.getFixturesDir().set(testFixturesDir);
t.doFirst(task -> {
try {
Files.createDirectories(testFixturesDir.toPath());
} catch (IOException e) {
throw new UncheckedIOException(e);
}
});
});
TaskProvider<Task> buildFixture = project.getTasks()
.register("buildFixture", t -> t.dependsOn(preProcessFixture, tasks.named("composeUp")));
TaskProvider<TestFixtureTask> postProcessFixture = project.getTasks()
.register("postProcessFixture", TestFixtureTask.class, task -> {
task.getFixturesDir().set(testFixturesDir);
task.dependsOn(buildFixture);
configureServiceInfoForTask(
task,
project,
false,
(name, port) -> task.getExtensions().getByType(ExtraPropertiesExtension.class).set(name, port)
);
});
maybeSkipTask(dockerSupport, preProcessFixture);
maybeSkipTask(dockerSupport, postProcessFixture);
maybeSkipTask(dockerSupport, buildFixture);
ComposeExtension composeExtension = project.getExtensions().getByType(ComposeExtension.class);
composeExtension.setProjectName(project.getName());
composeExtension.getUseComposeFiles().addAll(Collections.singletonList(DOCKER_COMPOSE_YML));
composeExtension.getRemoveContainers().set(true);
composeExtension.getCaptureContainersOutput()
.set(EnumSet.of(LogLevel.INFO, LogLevel.DEBUG).contains(project.getGradle().getStartParameter().getLogLevel()));
composeExtension.getUseDockerComposeV2().set(false);
composeExtension.getExecutable().set(this.providerFactory.provider(() -> {
String composePath = dockerSupport.get().getDockerAvailability().dockerComposePath();
LOGGER.debug("Docker Compose path: {}", composePath);
return composePath != null ? composePath : "/usr/bin/docker-compose";
}));
tasks.named("composeUp").configure(t -> {
// Avoid running docker-compose tasks in parallel in CI due to some issues on certain Linux distributions
if (BuildParams.isCi()) {
t.usesService(dockerComposeThrottle);
}
t.mustRunAfter(preProcessFixture);
});
tasks.named("composePull").configure(t -> t.mustRunAfter(preProcessFixture));
tasks.named("composeDown").configure(t -> t.doLast(t2 -> getFileSystemOperations().delete(d -> d.delete(testFixturesDir))));
} else {
project.afterEvaluate(spec -> {
if (extension.fixtures.isEmpty()) {
// if only one fixture is used, that's this one, but without a compose file that's not a valid configuration
throw new IllegalStateException(
"No " + DOCKER_COMPOSE_YML + " found for " + project.getPath() + " nor does it use other fixtures."
);
}
});
if (project.file(DOCKER_COMPOSE_YML).exists() == false) {
// if only one fixture is used, that's this one, but without a compose file that's not a valid configuration
throw new IllegalStateException("No " + DOCKER_COMPOSE_YML + " found for " + project.getPath() + ".");
}
project.getPluginManager().apply(BasePlugin.class);
project.getPluginManager().apply(DockerComposePlugin.class);
TaskProvider<TestFixtureTask> preProcessFixture = project.getTasks().register("preProcessFixture", TestFixtureTask.class, t -> {
t.getFixturesDir().set(testFixturesDir);
t.doFirst(task -> {
try {
Files.createDirectories(testFixturesDir.toPath());
} catch (IOException e) {
throw new UncheckedIOException(e);
}
});
});
TaskProvider<Task> buildFixture = project.getTasks()
.register("buildFixture", t -> t.dependsOn(preProcessFixture, tasks.named("composeUp")));
extension.fixtures.matching(fixtureProject -> fixtureProject.equals(project) == false)
.all(fixtureProject -> project.evaluationDependsOn(fixtureProject.getPath()));
TaskProvider<TestFixtureTask> postProcessFixture = project.getTasks()
.register("postProcessFixture", TestFixtureTask.class, task -> {
task.getFixturesDir().set(testFixturesDir);
task.dependsOn(buildFixture);
configureServiceInfoForTask(
task,
project,
false,
(name, port) -> task.getExtensions().getByType(ExtraPropertiesExtension.class).set(name, port)
);
});
maybeSkipTask(dockerSupport, preProcessFixture);
maybeSkipTask(dockerSupport, postProcessFixture);
maybeSkipTask(dockerSupport, buildFixture);
ComposeExtension composeExtension = project.getExtensions().getByType(ComposeExtension.class);
composeExtension.setProjectName(project.getName());
composeExtension.getUseComposeFiles().addAll(Collections.singletonList(DOCKER_COMPOSE_YML));
composeExtension.getRemoveContainers().set(true);
composeExtension.getCaptureContainersOutput()
.set(EnumSet.of(LogLevel.INFO, LogLevel.DEBUG).contains(project.getGradle().getStartParameter().getLogLevel()));
composeExtension.getUseDockerComposeV2().set(false);
composeExtension.getExecutable().set(this.providerFactory.provider(() -> {
String composePath = dockerSupport.get().getDockerAvailability().dockerComposePath();
LOGGER.debug("Docker Compose path: {}", composePath);
return composePath != null ? composePath : "/usr/bin/docker-compose";
}));
tasks.named("composeUp").configure(t -> {
// Avoid running docker-compose tasks in parallel in CI due to some issues on certain Linux distributions
if (BuildParams.isCi()) {
t.usesService(dockerComposeThrottle);
}
t.mustRunAfter(preProcessFixture);
});
tasks.named("composePull").configure(t -> t.mustRunAfter(preProcessFixture));
tasks.named("composeDown").configure(t -> t.doLast(t2 -> getFileSystemOperations().delete(d -> d.delete(testFixturesDir))));
// Skip docker compose tasks if it is unavailable
maybeSkipTasks(tasks, dockerSupport, Test.class);
@ -161,17 +150,18 @@ public class TestFixturesPlugin implements Plugin<Project> {
maybeSkipTasks(tasks, dockerSupport, ComposePull.class);
maybeSkipTasks(tasks, dockerSupport, ComposeDown.class);
tasks.withType(Test.class).configureEach(task -> extension.fixtures.all(fixtureProject -> {
task.dependsOn(fixtureProject.getTasks().named("postProcessFixture"));
task.finalizedBy(fixtureProject.getTasks().named("composeDown"));
tasks.withType(Test.class).configureEach(testTask -> {
testTask.dependsOn(postProcessFixture);
testTask.finalizedBy(tasks.named("composeDown"));
configureServiceInfoForTask(
task,
fixtureProject,
testTask,
project,
true,
(name, host) -> task.getExtensions().getByType(SystemPropertyCommandLineArgumentProvider.class).systemProperty(name, host)
(name, host) -> testTask.getExtensions()
.getByType(SystemPropertyCommandLineArgumentProvider.class)
.systemProperty(name, host)
);
}));
});
}
private void maybeSkipTasks(TaskContainer tasks, Provider<DockerSupportService> dockerSupport, Class<? extends DefaultTask> taskClass) {
@ -203,28 +193,20 @@ public class TestFixturesPlugin implements Plugin<Project> {
task.doFirst(new Action<Task>() {
@Override
public void execute(Task theTask) {
TestFixtureExtension extension = theTask.getProject().getExtensions().getByType(TestFixtureExtension.class);
fixtureProject.getExtensions()
.getByType(ComposeExtension.class)
.getServicesInfos()
.entrySet()
.stream()
.filter(entry -> enableFilter == false || extension.isServiceRequired(entry.getKey(), fixtureProject.getPath()))
.forEach(entry -> {
String service = entry.getKey();
ServiceInfo infos = entry.getValue();
infos.getTcpPorts().forEach((container, host) -> {
String name = "test.fixtures." + service + ".tcp." + container;
theTask.getLogger().info("port mapping property: {}={}", name, host);
consumer.accept(name, host);
});
infos.getUdpPorts().forEach((container, host) -> {
String name = "test.fixtures." + service + ".udp." + container;
theTask.getLogger().info("port mapping property: {}={}", name, host);
consumer.accept(name, host);
});
fixtureProject.getExtensions().getByType(ComposeExtension.class).getServicesInfos().entrySet().stream().forEach(entry -> {
String service = entry.getKey();
ServiceInfo infos = entry.getValue();
infos.getTcpPorts().forEach((container, host) -> {
String name = "test.fixtures." + service + ".tcp." + container;
theTask.getLogger().info("port mapping property: {}={}", name, host);
consumer.accept(name, host);
});
infos.getUdpPorts().forEach((container, host) -> {
String name = "test.fixtures." + service + ".udp." + container;
theTask.getLogger().info("port mapping property: {}={}", name, host);
consumer.accept(name, host);
});
});
}
});
}

View file

@ -72,8 +72,6 @@ if (useDra == false) {
}
}
testFixtures.useFixture()
configurations {
aarch64DockerSource {
attributes {

View file

@ -0,0 +1,5 @@
pr: 109044
summary: Enable fallback synthetic source for `token_count`
area: Mapping
type: feature
issues: []

View file

@ -64,6 +64,7 @@ types:
** <<search-as-you-type-synthetic-source,`search_as_you_type`>>
** <<numeric-synthetic-source,`short`>>
** <<text-synthetic-source,`text`>>
** <<token-count-synthetic-source,`token_count`>>
** <<version-synthetic-source,`version`>>
** <<wildcard-synthetic-source,`wildcard`>>

View file

@ -64,10 +64,10 @@ The following parameters are accepted by `token_count` fields:
value. Required. For best performance, use an analyzer without token
filters.
`enable_position_increments`::
`enable_position_increments`::
Indicates if position increments should be counted.
Set to `false` if you don't want to count tokens removed by analyzer filters (like <<analysis-stop-tokenfilter,`stop`>>).
Indicates if position increments should be counted.
Set to `false` if you don't want to count tokens removed by analyzer filters (like <<analysis-stop-tokenfilter,`stop`>>).
Defaults to `true`.
<<doc-values,`doc_values`>>::
@ -91,3 +91,17 @@ Defaults to `true`.
Whether the field value should be stored and retrievable separately from
the <<mapping-source-field,`_source`>> field. Accepts `true` or `false`
(default).
[[token-count-synthetic-source]]
===== Synthetic `_source`
IMPORTANT: Synthetic `_source` is Generally Available only for TSDB indices
(indices that have `index.mode` set to `time_series`). For other indices
synthetic `_source` is in technical preview. Features in technical preview may
be changed or removed in a future release. Elastic will work to fix
any issues, but features in technical preview are not subject to the support SLA
of official GA features.
`token_count` fields support <<synthetic-source,synthetic `_source`>> in their
default configuration. Synthetic `_source` cannot be used together with
<<copy-to,`copy_to`>>.

View file

@ -15,7 +15,7 @@ relevant considerations in this guide to improve performance. It also helps to
understand how {transforms} work as different considerations apply depending on
whether or not your transform is running in continuous mode or in batch.
In this guide, youll learn how to:
In this guide, you'll learn how to:
* Understand the impact of configuration options on the performance of
{transforms}.
@ -111,10 +111,17 @@ group of IPs, in order to calculate the total `bytes_sent`. If this second
search matches many shards, then this could be resource intensive. Consider
limiting the scope that the source index pattern and query will match.
Use an absolute time value as a date range filter in your source query (for
example, greater than `2020-01-01T00:00:00`) to limit which historical indices
are accessed. If you use a relative time value (for example, `now-30d`) then
this date range is re-evaluated at the point of each checkpoint execution.
To limit which historical indices are accessed, exclude certain tiers (for
example `"must_not": { "terms": { "_tier": [ "data_frozen", "data_cold" ] } }`
and/or use an absolute time value as a date range filter in your source query
(for example, greater than 2024-01-01T00:00:00). If you use a relative time
value (for example, gte now-30d/d) then ensure date rounding is applied to take
advantage of query caching and ensure that the relative time is much larger than
the largest of `frequency` or `time.sync.delay` or the date histogram bucket,
otherwise data may be missed. Do not use date filters which are less than a date
value (for example, `lt`: less than or `lte`: less than or equal to) as this
conflicts with the logic applied at each checkpoint execution and data may be
missed.
Consider using <<api-date-math-index-names,date math>> in your index names to
reduce the number of indices to resolve in your queries. Add a date pattern

View file

@ -30,7 +30,8 @@ collection.
**Capture a JVM heap dump**
To determine the exact reason for the high JVM memory pressure, capture a heap
dump of the JVM while its memory usage is high.
dump of the JVM while its memory usage is high, and also capture the
<<gc-logging,garbage collector logs>> covering the same time period.
[discrete]
[[reduce-jvm-memory-pressure]]

View file

@ -4,8 +4,8 @@ usually by the `JvmMonitorService` in the main node logs. Use these logs to
confirm whether or not the node is experiencing high heap usage with long GC
pauses. If so, <<high-jvm-memory-pressure,the troubleshooting guide for high
heap usage>> has some suggestions for further investigation but typically you
will need to capture a heap dump during a time of high heap usage to fully
understand the problem.
will need to capture a heap dump and the <<gc-logging,garbage collector logs>>
during a time of high heap usage to fully understand the problem.
* VM pauses also affect other processes on the same host. A VM pause also
typically causes a discontinuity in the system clock, which {es} will report in

View file

@ -215,4 +215,9 @@ public class TokenCountFieldMapper extends FieldMapper {
public FieldMapper.Builder getMergeBuilder() {
return new Builder(simpleName()).init(this);
}
@Override
protected SyntheticSourceMode syntheticSourceMode() {
return SyntheticSourceMode.FALLBACK;
}
}

View file

@ -33,7 +33,11 @@ import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import static org.hamcrest.Matchers.equalTo;
@ -196,7 +200,66 @@ public class TokenCountFieldMapperTests extends MapperTestCase {
@Override
protected SyntheticSourceSupport syntheticSourceSupport(boolean ignoreMalformed) {
throw new AssumptionViolatedException("not supported");
assertFalse(ignoreMalformed);
var nullValue = usually() ? null : randomNonNegativeInt();
return new SyntheticSourceSupport() {
@Override
public boolean preservesExactSource() {
return true;
}
public SyntheticSourceExample example(int maxValues) {
if (randomBoolean()) {
var value = generateValue();
return new SyntheticSourceExample(value.text, value.text, value.tokenCount, this::mapping);
}
var values = randomList(1, 5, this::generateValue);
var textArray = values.stream().map(Value::text).toList();
var blockExpectedList = values.stream().map(Value::tokenCount).filter(Objects::nonNull).toList();
var blockExpected = blockExpectedList.size() == 1 ? blockExpectedList.get(0) : blockExpectedList;
return new SyntheticSourceExample(textArray, textArray, blockExpected, this::mapping);
}
private record Value(String text, Integer tokenCount) {}
private Value generateValue() {
if (rarely()) {
return new Value(null, null);
}
var text = randomList(0, 10, () -> randomAlphaOfLengthBetween(0, 10)).stream().collect(Collectors.joining(" "));
// with keyword analyzer token count is always 1
return new Value(text, 1);
}
private void mapping(XContentBuilder b) throws IOException {
b.field("type", "token_count").field("analyzer", "keyword");
if (rarely()) {
b.field("index", false);
}
if (rarely()) {
b.field("store", true);
}
if (nullValue != null) {
b.field("null_value", nullValue);
}
}
@Override
public List<SyntheticSourceInvalidExample> invalidExample() throws IOException {
return List.of();
}
};
}
protected Function<Object, Object> loadBlockExpected() {
// we can get either a number from doc values or null
return v -> v != null ? (Number) v : null;
}
@Override

View file

@ -0,0 +1,65 @@
"Test token count":
- requires:
cluster_features: ["gte_v7.10.0"]
reason: "support for token_count was instroduced in 7.10"
- do:
indices.create:
index: test
body:
mappings:
properties:
count:
type: token_count
analyzer: standard
count_without_dv:
type: token_count
analyzer: standard
doc_values: false
- do:
index:
index: test
id: "1"
refresh: true
body:
count: "some text"
- do:
search:
index: test
body:
fields: [count, count_without_dv]
- is_true: hits.hits.0._id
- match: { hits.hits.0.fields.count: [2] }
- is_false: hits.hits.0.fields.count_without_dv
---
"Synthetic source":
- requires:
cluster_features: ["mapper.track_ignored_source"]
reason: requires tracking ignored source
- do:
indices.create:
index: test
body:
mappings:
_source:
mode: synthetic
properties:
count:
type: token_count
analyzer: standard
- do:
index:
index: test
id: "1"
refresh: true
body:
count: "quick brown fox jumps over a lazy dog"
- do:
get:
index: test
id: "1"
- match: { _source.count: "quick brown fox jumps over a lazy dog" }

View file

@ -164,6 +164,8 @@ tasks.named("processYamlRestTestResources").configure {
tasks.named("internalClusterTest").configure {
// this is tested explicitly in a separate test task
exclude '**/S3RepositoryThirdPartyTests.class'
// TODO: remove once https://github.com/elastic/elasticsearch/issues/101608 is fixed
systemProperty 'es.insecure_network_trace_enabled', 'true'
}
tasks.named("yamlRestTest").configure {

View file

@ -627,6 +627,8 @@ public class S3BlobStoreRepositoryTests extends ESMockAPIBasedRepositoryIntegTes
trackRequest("HeadObject");
metricsCount.computeIfAbsent(new S3BlobStore.StatsKey(S3BlobStore.Operation.HEAD_OBJECT, purpose), k -> new AtomicLong())
.incrementAndGet();
} else {
logger.info("--> rawRequest not tracked [{}] with parsed purpose [{}]", request, purpose.getKey());
}
}

View file

@ -60,7 +60,7 @@ import org.elasticsearch.http.HttpServerTransport;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.ChunkedRestResponseBody;
import org.elasticsearch.rest.ChunkedRestResponseBodyPart;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestHandler;
@ -370,31 +370,31 @@ public class Netty4ChunkedContinuationsIT extends ESNetty4IntegTestCase {
TransportAction.localOnly();
}
public ChunkedRestResponseBody getChunkedBody() {
return getChunkBatch(0);
public ChunkedRestResponseBodyPart getFirstResponseBodyPart() {
return getResponseBodyPart(0);
}
private ChunkedRestResponseBody getChunkBatch(int batchIndex) {
private ChunkedRestResponseBodyPart getResponseBodyPart(int batchIndex) {
if (batchIndex == failIndex && randomBoolean()) {
throw new ElasticsearchException("simulated failure creating next batch");
}
return new ChunkedRestResponseBody() {
return new ChunkedRestResponseBodyPart() {
private final Iterator<String> lines = Iterators.forRange(0, 3, i -> "batch-" + batchIndex + "-chunk-" + i + "\n");
@Override
public boolean isDone() {
public boolean isPartComplete() {
return lines.hasNext() == false;
}
@Override
public boolean isEndOfResponse() {
public boolean isLastPart() {
return batchIndex == 2;
}
@Override
public void getContinuation(ActionListener<ChunkedRestResponseBody> listener) {
executor.execute(ActionRunnable.supply(listener, () -> getChunkBatch(batchIndex + 1)));
public void getNextPart(ActionListener<ChunkedRestResponseBodyPart> listener) {
executor.execute(ActionRunnable.supply(listener, () -> getResponseBodyPart(batchIndex + 1)));
}
@Override
@ -486,11 +486,12 @@ public class Netty4ChunkedContinuationsIT extends ESNetty4IntegTestCase {
@Override
protected void processResponse(Response response) {
try {
final var responseBody = response.getChunkedBody(); // might fail, so do this before acquiring ref
final var responseBody = response.getFirstResponseBodyPart();
// preceding line might fail, so needs to be done before acquiring the sendResponse ref
refs.mustIncRef();
channel.sendResponse(RestResponse.chunked(RestStatus.OK, responseBody, refs::decRef));
} finally {
refs.decRef();
refs.decRef(); // release the ref acquired at the top of accept()
}
}
});
@ -534,26 +535,26 @@ public class Netty4ChunkedContinuationsIT extends ESNetty4IntegTestCase {
TransportAction.localOnly();
}
public ChunkedRestResponseBody getChunkedBody() {
return new ChunkedRestResponseBody() {
public ChunkedRestResponseBodyPart getResponseBodyPart() {
return new ChunkedRestResponseBodyPart() {
private final Iterator<String> lines = Iterators.single("infinite response\n");
@Override
public boolean isDone() {
public boolean isPartComplete() {
return lines.hasNext() == false;
}
@Override
public boolean isEndOfResponse() {
public boolean isLastPart() {
return false;
}
@Override
public void getContinuation(ActionListener<ChunkedRestResponseBody> listener) {
public void getNextPart(ActionListener<ChunkedRestResponseBodyPart> listener) {
computingContinuation = true;
executor.execute(ActionRunnable.supply(listener, () -> {
computingContinuation = false;
return getChunkedBody();
return getResponseBodyPart();
}));
}
@ -628,7 +629,7 @@ public class Netty4ChunkedContinuationsIT extends ESNetty4IntegTestCase {
client.execute(TYPE, new Request(), new RestActionListener<>(channel) {
@Override
protected void processResponse(Response response) {
channel.sendResponse(RestResponse.chunked(RestStatus.OK, response.getChunkedBody(), () -> {
channel.sendResponse(RestResponse.chunked(RestStatus.OK, response.getResponseBodyPart(), () -> {
// cancellation notification only happens while processing a continuation, not while computing
// the next one; prompt cancellation requires use of something like RestCancellableNodeClient
assertFalse(response.computingContinuation);

View file

@ -37,7 +37,7 @@ import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.ChunkedRestResponseBody;
import org.elasticsearch.rest.ChunkedRestResponseBodyPart;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestHandler;
@ -245,19 +245,19 @@ public class Netty4ChunkedEncodingIT extends ESNetty4IntegTestCase {
private static void sendChunksResponse(RestChannel channel, Iterator<BytesReference> chunkIterator) {
final var localRefs = refs; // single volatile read
if (localRefs != null && localRefs.tryIncRef()) {
channel.sendResponse(RestResponse.chunked(RestStatus.OK, new ChunkedRestResponseBody() {
channel.sendResponse(RestResponse.chunked(RestStatus.OK, new ChunkedRestResponseBodyPart() {
@Override
public boolean isDone() {
public boolean isPartComplete() {
return chunkIterator.hasNext() == false;
}
@Override
public boolean isEndOfResponse() {
public boolean isLastPart() {
return true;
}
@Override
public void getContinuation(ActionListener<ChunkedRestResponseBody> listener) {
public void getNextPart(ActionListener<ChunkedRestResponseBodyPart> listener) {
assert false : "no continuations";
}

View file

@ -34,7 +34,7 @@ import org.elasticsearch.http.HttpServerTransport;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.ChunkedRestResponseBody;
import org.elasticsearch.rest.ChunkedRestResponseBodyPart;
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestHandler;
import org.elasticsearch.rest.RestRequest;
@ -243,21 +243,21 @@ public class Netty4PipeliningIT extends ESNetty4IntegTestCase {
throw new IllegalArgumentException("[" + FAIL_AFTER_BYTES_PARAM + "] must be present and non-negative");
}
return channel -> randomExecutor(client.threadPool()).execute(
() -> channel.sendResponse(RestResponse.chunked(RestStatus.OK, new ChunkedRestResponseBody() {
() -> channel.sendResponse(RestResponse.chunked(RestStatus.OK, new ChunkedRestResponseBodyPart() {
int bytesRemaining = failAfterBytes;
@Override
public boolean isDone() {
public boolean isPartComplete() {
return false;
}
@Override
public boolean isEndOfResponse() {
public boolean isLastPart() {
return true;
}
@Override
public void getContinuation(ActionListener<ChunkedRestResponseBody> listener) {
public void getNextPart(ActionListener<ChunkedRestResponseBodyPart> listener) {
fail("no continuations here");
}

View file

@ -10,16 +10,16 @@ package org.elasticsearch.http.netty4;
import io.netty.util.concurrent.PromiseCombiner;
import org.elasticsearch.rest.ChunkedRestResponseBody;
import org.elasticsearch.rest.ChunkedRestResponseBodyPart;
final class Netty4ChunkedHttpContinuation implements Netty4HttpResponse {
private final int sequence;
private final ChunkedRestResponseBody body;
private final ChunkedRestResponseBodyPart bodyPart;
private final PromiseCombiner combiner;
Netty4ChunkedHttpContinuation(int sequence, ChunkedRestResponseBody body, PromiseCombiner combiner) {
Netty4ChunkedHttpContinuation(int sequence, ChunkedRestResponseBodyPart bodyPart, PromiseCombiner combiner) {
this.sequence = sequence;
this.body = body;
this.bodyPart = bodyPart;
this.combiner = combiner;
}
@ -28,8 +28,8 @@ final class Netty4ChunkedHttpContinuation implements Netty4HttpResponse {
return sequence;
}
public ChunkedRestResponseBody body() {
return body;
public ChunkedRestResponseBodyPart bodyPart() {
return bodyPart;
}
public PromiseCombiner combiner() {

View file

@ -13,7 +13,7 @@ import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import org.elasticsearch.http.HttpResponse;
import org.elasticsearch.rest.ChunkedRestResponseBody;
import org.elasticsearch.rest.ChunkedRestResponseBodyPart;
import org.elasticsearch.rest.RestStatus;
/**
@ -23,16 +23,16 @@ final class Netty4ChunkedHttpResponse extends DefaultHttpResponse implements Net
private final int sequence;
private final ChunkedRestResponseBody body;
private final ChunkedRestResponseBodyPart firstBodyPart;
Netty4ChunkedHttpResponse(int sequence, HttpVersion version, RestStatus status, ChunkedRestResponseBody body) {
Netty4ChunkedHttpResponse(int sequence, HttpVersion version, RestStatus status, ChunkedRestResponseBodyPart firstBodyPart) {
super(version, HttpResponseStatus.valueOf(status.getStatus()));
this.sequence = sequence;
this.body = body;
this.firstBodyPart = firstBodyPart;
}
public ChunkedRestResponseBody body() {
return body;
public ChunkedRestResponseBodyPart firstBodyPart() {
return firstBodyPart;
}
@Override

View file

@ -34,7 +34,7 @@ import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.core.Booleans;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.rest.ChunkedRestResponseBody;
import org.elasticsearch.rest.ChunkedRestResponseBodyPart;
import org.elasticsearch.transport.Transports;
import org.elasticsearch.transport.netty4.Netty4Utils;
import org.elasticsearch.transport.netty4.Netty4WriteThrottlingHandler;
@ -58,7 +58,7 @@ public class Netty4HttpPipeliningHandler extends ChannelDuplexHandler {
private final int maxEventsHeld;
private final PriorityQueue<Tuple<? extends Netty4HttpResponse, ChannelPromise>> outboundHoldingQueue;
private record ChunkedWrite(PromiseCombiner combiner, ChannelPromise onDone, ChunkedRestResponseBody responseBody) {}
private record ChunkedWrite(PromiseCombiner combiner, ChannelPromise onDone, ChunkedRestResponseBodyPart responseBodyPart) {}
/**
* The current {@link ChunkedWrite} if a chunked write is executed at the moment.
@ -214,9 +214,9 @@ public class Netty4HttpPipeliningHandler extends ChannelDuplexHandler {
final PromiseCombiner combiner = new PromiseCombiner(ctx.executor());
final ChannelPromise first = ctx.newPromise();
combiner.add((Future<Void>) first);
final var responseBody = readyResponse.body();
final var firstBodyPart = readyResponse.firstBodyPart();
assert currentChunkedWrite == null;
currentChunkedWrite = new ChunkedWrite(combiner, promise, responseBody);
currentChunkedWrite = new ChunkedWrite(combiner, promise, firstBodyPart);
if (enqueueWrite(ctx, readyResponse, first)) {
// We were able to write out the first chunk directly, try writing out subsequent chunks until the channel becomes unwritable.
// NB "writable" means there's space in the downstream ChannelOutboundBuffer, we aren't trying to saturate the physical channel.
@ -232,9 +232,10 @@ public class Netty4HttpPipeliningHandler extends ChannelDuplexHandler {
private void doWriteChunkedContinuation(ChannelHandlerContext ctx, Netty4ChunkedHttpContinuation continuation, ChannelPromise promise) {
final PromiseCombiner combiner = continuation.combiner();
assert currentChunkedWrite == null;
final var responseBody = continuation.body();
assert responseBody.isDone() == false : "response with continuations must have at least one (possibly-empty) chunk in each part";
currentChunkedWrite = new ChunkedWrite(combiner, promise, responseBody);
final var bodyPart = continuation.bodyPart();
assert bodyPart.isPartComplete() == false
: "response with continuations must have at least one (possibly-empty) chunk in each part";
currentChunkedWrite = new ChunkedWrite(combiner, promise, bodyPart);
// NB "writable" means there's space in the downstream ChannelOutboundBuffer, we aren't trying to saturate the physical channel.
while (ctx.channel().isWritable()) {
if (writeChunk(ctx, currentChunkedWrite)) {
@ -251,9 +252,9 @@ public class Netty4HttpPipeliningHandler extends ChannelDuplexHandler {
}
final var finishingWrite = currentChunkedWrite;
currentChunkedWrite = null;
final var finishingWriteBody = finishingWrite.responseBody();
assert finishingWriteBody.isDone();
final var endOfResponse = finishingWriteBody.isEndOfResponse();
final var finishingWriteBodyPart = finishingWrite.responseBodyPart();
assert finishingWriteBodyPart.isPartComplete();
final var endOfResponse = finishingWriteBodyPart.isLastPart();
if (endOfResponse) {
writeSequence++;
finishingWrite.combiner().finish(finishingWrite.onDone());
@ -261,7 +262,7 @@ public class Netty4HttpPipeliningHandler extends ChannelDuplexHandler {
final var channel = finishingWrite.onDone().channel();
ActionListener.run(ActionListener.assertOnce(new ActionListener<>() {
@Override
public void onResponse(ChunkedRestResponseBody continuation) {
public void onResponse(ChunkedRestResponseBodyPart continuation) {
channel.writeAndFlush(
new Netty4ChunkedHttpContinuation(writeSequence, continuation, finishingWrite.combiner()),
finishingWrite.onDone() // pass the terminal listener/promise along the line
@ -296,7 +297,7 @@ public class Netty4HttpPipeliningHandler extends ChannelDuplexHandler {
}
}
}), finishingWriteBody::getContinuation);
}), finishingWriteBodyPart::getNextPart);
}
}
@ -374,22 +375,22 @@ public class Netty4HttpPipeliningHandler extends ChannelDuplexHandler {
}
private boolean writeChunk(ChannelHandlerContext ctx, ChunkedWrite chunkedWrite) {
final var body = chunkedWrite.responseBody();
final var bodyPart = chunkedWrite.responseBodyPart();
final var combiner = chunkedWrite.combiner();
assert body.isDone() == false : "should not continue to try and serialize once done";
assert bodyPart.isPartComplete() == false : "should not continue to try and serialize once done";
final ReleasableBytesReference bytes;
try {
bytes = body.encodeChunk(Netty4WriteThrottlingHandler.MAX_BYTES_PER_WRITE, serverTransport.recycler());
bytes = bodyPart.encodeChunk(Netty4WriteThrottlingHandler.MAX_BYTES_PER_WRITE, serverTransport.recycler());
} catch (Exception e) {
return handleChunkingFailure(ctx, chunkedWrite, e);
}
final ByteBuf content = Netty4Utils.toByteBuf(bytes);
final boolean done = body.isDone();
final boolean lastChunk = done && body.isEndOfResponse();
final ChannelFuture f = ctx.write(lastChunk ? new DefaultLastHttpContent(content) : new DefaultHttpContent(content));
final boolean isPartComplete = bodyPart.isPartComplete();
final boolean isBodyComplete = isPartComplete && bodyPart.isLastPart();
final ChannelFuture f = ctx.write(isBodyComplete ? new DefaultLastHttpContent(content) : new DefaultHttpContent(content));
f.addListener(ignored -> bytes.close());
combiner.add(f);
return done;
return isPartComplete;
}
private boolean handleChunkingFailure(ChannelHandlerContext ctx, ChunkedWrite chunkedWrite, Exception e) {

View file

@ -22,7 +22,7 @@ import io.netty.handler.codec.http.cookie.ServerCookieEncoder;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.http.HttpRequest;
import org.elasticsearch.http.HttpResponse;
import org.elasticsearch.rest.ChunkedRestResponseBody;
import org.elasticsearch.rest.ChunkedRestResponseBodyPart;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.transport.netty4.Netty4Utils;
@ -176,8 +176,8 @@ public class Netty4HttpRequest implements HttpRequest {
}
@Override
public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBody content) {
return new Netty4ChunkedHttpResponse(sequence, request.protocolVersion(), status, content);
public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBodyPart firstBodyPart) {
return new Netty4ChunkedHttpResponse(sequence, request.protocolVersion(), status, firstBodyPart);
}
@Override

View file

@ -36,7 +36,7 @@ import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.bytes.ZeroBytesReference;
import org.elasticsearch.common.recycler.Recycler;
import org.elasticsearch.http.HttpResponse;
import org.elasticsearch.rest.ChunkedRestResponseBody;
import org.elasticsearch.rest.ChunkedRestResponseBodyPart;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.transport.netty4.Netty4Utils;
@ -502,23 +502,23 @@ public class Netty4HttpPipeliningHandlerTests extends ESTestCase {
};
}
private static ChunkedRestResponseBody getRepeatedChunkResponseBody(int chunkCount, BytesReference chunk) {
return new ChunkedRestResponseBody() {
private static ChunkedRestResponseBodyPart getRepeatedChunkResponseBody(int chunkCount, BytesReference chunk) {
return new ChunkedRestResponseBodyPart() {
private int remaining = chunkCount;
@Override
public boolean isDone() {
public boolean isPartComplete() {
return remaining == 0;
}
@Override
public boolean isEndOfResponse() {
public boolean isLastPart() {
return true;
}
@Override
public void getContinuation(ActionListener<ChunkedRestResponseBody> listener) {
public void getNextPart(ActionListener<ChunkedRestResponseBodyPart> listener) {
fail("no continuations here");
}

View file

@ -71,7 +71,7 @@ import org.elasticsearch.http.HttpTransportSettings;
import org.elasticsearch.http.NullDispatcher;
import org.elasticsearch.http.netty4.internal.HttpHeadersAuthenticatorUtils;
import org.elasticsearch.http.netty4.internal.HttpValidator;
import org.elasticsearch.rest.ChunkedRestResponseBody;
import org.elasticsearch.rest.ChunkedRestResponseBodyPart;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestResponse;
@ -692,7 +692,7 @@ public class Netty4HttpServerTransportTests extends AbstractHttpServerTransportT
public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) {
try {
channel.sendResponse(
RestResponse.chunked(OK, ChunkedRestResponseBody.fromXContent(ignored -> Iterators.single((builder, params) -> {
RestResponse.chunked(OK, ChunkedRestResponseBodyPart.fromXContent(ignored -> Iterators.single((builder, params) -> {
throw new AssertionError("should not be called for HEAD REQUEST");
}), ToXContent.EMPTY_PARAMS, channel), null)
);
@ -1048,7 +1048,7 @@ public class Netty4HttpServerTransportTests extends AbstractHttpServerTransportT
assertEquals(request.uri(), url);
final var response = RestResponse.chunked(
OK,
ChunkedRestResponseBody.fromTextChunks(RestResponse.TEXT_CONTENT_TYPE, Collections.emptyIterator()),
ChunkedRestResponseBodyPart.fromTextChunks(RestResponse.TEXT_CONTENT_TYPE, Collections.emptyIterator()),
responseReleasedLatch::countDown
);
transportClosedFuture.addListener(ActionListener.running(() -> channel.sendResponse(response)));

View file

@ -35,12 +35,6 @@ tests:
- class: "org.elasticsearch.upgrades.MlTrainedModelsUpgradeIT"
issue: "https://github.com/elastic/elasticsearch/issues/108993"
method: "testTrainedModelInference"
- class: "org.elasticsearch.xpack.security.authc.esnative.ReservedRealmElasticAutoconfigIntegTests"
issue: "https://github.com/elastic/elasticsearch/issues/109058"
method: "testAutoconfigSucceedsAfterPromotionFailure"
- class: "org.elasticsearch.xpack.security.authc.esnative.ReservedRealmElasticAutoconfigIntegTests"
issue: "https://github.com/elastic/elasticsearch/issues/109059"
method: "testAutoconfigFailedPasswordPromotion"
# Examples:
#
# Mute a single test case in a YAML test suite:

View file

@ -16,8 +16,6 @@ apply plugin: 'elasticsearch.standalone-rest-test'
apply plugin: 'elasticsearch.test.fixtures'
apply plugin: 'elasticsearch.internal-distribution-download'
testFixtures.useFixture()
dockerCompose {
environment.put 'STACK_VERSION', BuildParams.snapshotBuild ? VersionProperties.elasticsearch : VersionProperties.elasticsearch + "-SNAPSHOT"
}

View file

@ -15,8 +15,6 @@ apply plugin: 'elasticsearch.standalone-rest-test'
apply plugin: 'elasticsearch.test.fixtures'
apply plugin: 'elasticsearch.internal-distribution-download'
testFixtures.useFixture()
tasks.register("copyNodeKeyMaterial", Sync) {
from project(':x-pack:plugin:core')
.files(

View file

@ -262,41 +262,6 @@
- match: { hits.hits.0.fields.date.0: "1990/12/29" }
---
"Test token count":
- requires:
cluster_features: ["gte_v7.10.0"]
reason: "support for token_count was instroduced in 7.10"
- do:
indices.create:
index: test
body:
mappings:
properties:
count:
type: token_count
analyzer: standard
count_without_dv:
type: token_count
analyzer: standard
doc_values: false
- do:
index:
index: test
id: "1"
refresh: true
body:
count: "some text"
- do:
search:
index: test
body:
fields: [count, count_without_dv]
- is_true: hits.hits.0._id
- match: { hits.hits.0.fields.count: [2] }
- is_false: hits.hits.0.fields.count_without_dv
---
Test unmapped field:
- requires:
cluster_features: "gte_v7.11.0"

View file

@ -16,6 +16,7 @@ import org.elasticsearch.action.NoShardAvailableActionException;
import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.UnsafePlainActionFuture;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.action.shard.ShardStateAction;
@ -26,6 +27,7 @@ import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.routing.Murmur3HashFunction;
import org.elasticsearch.cluster.routing.ShardRouting;
import org.elasticsearch.cluster.routing.ShardRoutingState;
import org.elasticsearch.cluster.service.ClusterApplierService;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
@ -542,7 +544,7 @@ public class ClusterDisruptionIT extends AbstractDisruptionTestCase {
});
final ClusterService dataClusterService = internalCluster().getInstance(ClusterService.class, dataNode);
final PlainActionFuture<Void> failedLeader = new PlainActionFuture<>() {
final PlainActionFuture<Void> failedLeader = new UnsafePlainActionFuture<>(ClusterApplierService.CLUSTER_UPDATE_THREAD_NAME) {
@Override
protected boolean blockingAllowed() {
// we're deliberately blocking the cluster applier on the master until the data node starts to rejoin

View file

@ -82,7 +82,7 @@ public class RestControllerIT extends ESIntegTestCase {
return channel -> {
final var response = RestResponse.chunked(
RestStatus.OK,
ChunkedRestResponseBody.fromXContent(
ChunkedRestResponseBodyPart.fromXContent(
params -> Iterators.single((b, p) -> b.startObject().endObject()),
request,
channel

View file

@ -299,7 +299,8 @@ public class CorruptedBlobStoreRepositoryIT extends AbstractSnapshotIntegTestCas
final ThreadPool threadPool = internalCluster().getCurrentMasterNodeInstance(ThreadPool.class);
assertThat(
PlainActionFuture.get(
f -> threadPool.generic()
// any other executor than generic and management
f -> threadPool.executor(ThreadPool.Names.SNAPSHOT)
.execute(
ActionRunnable.supply(
f,

View file

@ -178,6 +178,7 @@ public class TransportVersions {
public static final TransportVersion GET_SHUTDOWN_STATUS_TIMEOUT = def(8_669_00_0);
public static final TransportVersion FAILURE_STORE_TELEMETRY = def(8_670_00_0);
public static final TransportVersion ADD_METADATA_FLATTENED_TO_ROLES = def(8_671_00_0);
public static final TransportVersion ML_INFERENCE_GOOGLE_AI_STUDIO_COMPLETION_ADDED = def(8_672_00_0);
/*
* STOP! READ THIS FIRST! No, really,

View file

@ -9,10 +9,12 @@
package org.elasticsearch.action.support;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.service.ClusterApplierService;
import org.elasticsearch.cluster.service.MasterService;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.FutureUtils;
import org.elasticsearch.common.util.concurrent.UncategorizedExecutionException;
import org.elasticsearch.core.CheckedConsumer;
@ -37,6 +39,7 @@ public class PlainActionFuture<T> implements ActionFuture<T>, ActionListener<T>
@Override
public void onFailure(Exception e) {
assert assertCompleteAllowed();
if (sync.setException(Objects.requireNonNull(e))) {
done(false);
}
@ -113,6 +116,7 @@ public class PlainActionFuture<T> implements ActionFuture<T>, ActionListener<T>
@Override
public boolean cancel(boolean mayInterruptIfRunning) {
assert assertCompleteAllowed();
if (sync.cancel() == false) {
return false;
}
@ -130,6 +134,7 @@ public class PlainActionFuture<T> implements ActionFuture<T>, ActionListener<T>
* @return true if the state was successfully changed.
*/
protected final boolean set(@Nullable T value) {
assert assertCompleteAllowed();
boolean result = sync.set(value);
if (result) {
done(true);
@ -399,4 +404,27 @@ public class PlainActionFuture<T> implements ActionFuture<T>, ActionListener<T>
e.accept(fut);
return fut.actionGet(timeout, unit);
}
private boolean assertCompleteAllowed() {
Thread waiter = sync.getFirstQueuedThread();
// todo: reenable assertion once downstream code is updated
assert true || waiter == null || allowedExecutors(waiter, Thread.currentThread())
: "cannot complete future on thread "
+ Thread.currentThread()
+ " with waiter on thread "
+ waiter
+ ", could deadlock if pool was full\n"
+ ExceptionsHelper.formatStackTrace(waiter.getStackTrace());
return true;
}
// only used in assertions
boolean allowedExecutors(Thread thread1, Thread thread2) {
// this should only be used to validate thread interactions, like not waiting for a future completed on the same
// executor, hence calling it with the same thread indicates a bug in the assertion using this.
assert thread1 != thread2 : "only call this for different threads";
String thread1Name = EsExecutors.executorName(thread1);
String thread2Name = EsExecutors.executorName(thread2);
return thread1Name == null || thread2Name == null || thread1Name.equals(thread2Name) == false;
}
}

View file

@ -0,0 +1,52 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.action.support;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.CheckedConsumer;
import java.util.Objects;
/**
* An unsafe future. You should not need to use this for new code, rather you should be able to convert that code to be async
* or use a clear hierarchy of thread pool executors around the future.
*
* This future is unsafe, since it allows notifying the future on the same thread pool executor that it is being waited on. This
* is a common deadlock scenario, since all threads may be waiting and thus no thread may be able to complete the future.
*/
@Deprecated(forRemoval = true)
public class UnsafePlainActionFuture<T> extends PlainActionFuture<T> {
private final String unsafeExecutor;
private final String unsafeExecutor2;
public UnsafePlainActionFuture(String unsafeExecutor) {
this(unsafeExecutor, null);
}
public UnsafePlainActionFuture(String unsafeExecutor, String unsafeExecutor2) {
Objects.requireNonNull(unsafeExecutor);
this.unsafeExecutor = unsafeExecutor;
this.unsafeExecutor2 = unsafeExecutor2;
}
@Override
boolean allowedExecutors(Thread thread1, Thread thread2) {
return super.allowedExecutors(thread1, thread2)
|| unsafeExecutor.equals(EsExecutors.executorName(thread1))
|| unsafeExecutor2 == null
|| unsafeExecutor2.equals(EsExecutors.executorName(thread1));
}
public static <T, E extends Exception> T get(CheckedConsumer<PlainActionFuture<T>, E> e, String allowedExecutor) throws E {
PlainActionFuture<T> fut = new UnsafePlainActionFuture<>(allowedExecutor);
e.accept(fut);
return fut.actionGet();
}
}

View file

@ -59,6 +59,7 @@ import org.elasticsearch.action.search.TransportMultiSearchAction;
import org.elasticsearch.action.search.TransportSearchAction;
import org.elasticsearch.action.search.TransportSearchScrollAction;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.UnsafePlainActionFuture;
import org.elasticsearch.action.termvectors.MultiTermVectorsAction;
import org.elasticsearch.action.termvectors.MultiTermVectorsRequest;
import org.elasticsearch.action.termvectors.MultiTermVectorsRequestBuilder;
@ -410,7 +411,13 @@ public abstract class AbstractClient implements Client {
* on the result before it goes out of scope.
* @param <R> reference counted result type
*/
private static class RefCountedFuture<R extends RefCounted> extends PlainActionFuture<R> {
// todo: the use of UnsafePlainActionFuture here is quite broad, we should find a better way to be more specific
// (unless making all usages safe is easy).
private static class RefCountedFuture<R extends RefCounted> extends UnsafePlainActionFuture<R> {
private RefCountedFuture() {
super(ThreadPool.Names.GENERIC);
}
@Override
public final void onResponse(R result) {

View file

@ -53,6 +53,9 @@ public class DateFormatters {
* If a string cannot be parsed by the ISO parser, it then tries the java.time one.
* If there's lots of these strings, trying the ISO parser, then the java.time parser, might cause a performance drop.
* So provide a JVM option so that users can just use the java.time parsers, if they really need to.
* <p>
* Note that this property is sometimes set by {@code ESTestCase.setTestSysProps} to flip between implementations in tests,
* to ensure both are fully tested
*/
@UpdateForV9 // evaluate if we need to deprecate/remove this
private static final boolean JAVA_TIME_PARSERS_ONLY = Booleans.parseBoolean(System.getProperty("es.datetime.java_time_parsers"), false);

View file

@ -276,6 +276,10 @@ public class EsExecutors {
return threadName.substring(executorNameStart + 1, executorNameEnd);
}
public static String executorName(Thread thread) {
return executorName(thread.getName());
}
public static ThreadFactory daemonThreadFactory(Settings settings, String namePrefix) {
return daemonThreadFactory(threadName(settings, namePrefix));
}

View file

@ -21,8 +21,8 @@ import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.rest.AbstractRestChannel;
import org.elasticsearch.rest.ChunkedRestResponseBody;
import org.elasticsearch.rest.LoggingChunkedRestResponseBody;
import org.elasticsearch.rest.ChunkedRestResponseBodyPart;
import org.elasticsearch.rest.LoggingChunkedRestResponseBodyPart;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestResponse;
import org.elasticsearch.rest.RestStatus;
@ -113,7 +113,7 @@ public class DefaultRestChannel extends AbstractRestChannel {
try {
final HttpResponse httpResponse;
if (isHeadRequest == false && restResponse.isChunked()) {
ChunkedRestResponseBody chunkedContent = restResponse.chunkedContent();
ChunkedRestResponseBodyPart chunkedContent = restResponse.chunkedContent();
if (httpLogger != null && httpLogger.isBodyTracerEnabled()) {
final var loggerStream = httpLogger.openResponseBodyLoggingStream(request.getRequestId());
toClose.add(() -> {
@ -123,7 +123,7 @@ public class DefaultRestChannel extends AbstractRestChannel {
assert false : e; // nothing much to go wrong here
}
});
chunkedContent = new LoggingChunkedRestResponseBody(chunkedContent, loggerStream);
chunkedContent = new LoggingChunkedRestResponseBodyPart(chunkedContent, loggerStream);
}
httpResponse = httpRequest.createResponse(restResponse.status(), chunkedContent);

View file

@ -10,7 +10,7 @@ package org.elasticsearch.http;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.rest.ChunkedRestResponseBody;
import org.elasticsearch.rest.ChunkedRestResponseBodyPart;
import org.elasticsearch.rest.RestStatus;
import java.util.List;
@ -40,7 +40,7 @@ public interface HttpRequest extends HttpPreRequest {
*/
HttpResponse createResponse(RestStatus status, BytesReference content);
HttpResponse createResponse(RestStatus status, ChunkedRestResponseBody content);
HttpResponse createResponse(RestStatus status, ChunkedRestResponseBodyPart firstBodyPart);
@Nullable
Exception getInboundException();

View file

@ -15,10 +15,12 @@ import org.apache.lucene.search.ReferenceManager;
import org.apache.lucene.search.suggest.document.CompletionTerms;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.UnsafePlainActionFuture;
import org.elasticsearch.common.FieldMemoryStats;
import org.elasticsearch.common.regex.Regex;
import org.elasticsearch.common.util.CollectionUtils;
import org.elasticsearch.search.suggest.completion.CompletionStats;
import org.elasticsearch.threadpool.ThreadPool;
import java.util.HashMap;
import java.util.Map;
@ -42,7 +44,7 @@ public class CompletionStatsCache implements ReferenceManager.RefreshListener {
}
public CompletionStats get(String... fieldNamePatterns) {
final PlainActionFuture<CompletionStats> newFuture = new PlainActionFuture<>();
final PlainActionFuture<CompletionStats> newFuture = new UnsafePlainActionFuture<>(ThreadPool.Names.MANAGEMENT);
final PlainActionFuture<CompletionStats> oldFuture = completionStatsFutureRef.compareAndExchange(null, newFuture);
if (oldFuture != null) {

View file

@ -36,6 +36,7 @@ import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.action.support.UnsafePlainActionFuture;
import org.elasticsearch.cluster.service.ClusterApplierService;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.logging.Loggers;
@ -75,6 +76,7 @@ import org.elasticsearch.index.store.Store;
import org.elasticsearch.index.translog.Translog;
import org.elasticsearch.index.translog.TranslogStats;
import org.elasticsearch.search.suggest.completion.CompletionStats;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.Transports;
import java.io.Closeable;
@ -1956,7 +1958,7 @@ public abstract class Engine implements Closeable {
logger.debug("drainForClose(): draining ops");
releaseEnsureOpenRef.close();
final var future = new PlainActionFuture<Void>() {
final var future = new UnsafePlainActionFuture<Void>(ThreadPool.Names.GENERIC) {
@Override
protected boolean blockingAllowed() {
// TODO remove this blocking, or at least do it elsewhere, see https://github.com/elastic/elasticsearch/issues/89821

View file

@ -39,32 +39,32 @@ import java.util.Iterator;
* materializing the full response into on-heap buffers up front, instead serializing only as much of the response as can be flushed to the
* network right away.</p>
*
* <p>Each {@link ChunkedRestResponseBody} represents a sequence of chunks that are ready for immediate transmission: if {@link #isDone}
* returns {@code false} then {@link #encodeChunk} can be called at any time and must synchronously return the next chunk to be sent.
* Many HTTP responses will be a single such sequence. However, if an implementation's {@link #isEndOfResponse} returns {@code false} at the
* end of the sequence then the transmission is paused and {@link #getContinuation} is called to compute the next sequence of chunks
* <p>Each {@link ChunkedRestResponseBodyPart} represents a sequence of chunks that are ready for <i>immediate</i> transmission: if
* {@link #isPartComplete} returns {@code false} then {@link #encodeChunk} can be called at any time and must synchronously return the next
* chunk to be sent. Many HTTP responses will be a single part, but if an implementation's {@link #isLastPart} returns {@code false} at the
* end of the part then the transmission is paused and {@link #getNextPart} is called to compute the next sequence of chunks
* asynchronously.</p>
*/
public interface ChunkedRestResponseBody {
public interface ChunkedRestResponseBodyPart {
Logger logger = LogManager.getLogger(ChunkedRestResponseBody.class);
Logger logger = LogManager.getLogger(ChunkedRestResponseBodyPart.class);
/**
* @return {@code true} if this body contains no more chunks and the REST layer should check for a possible continuation by calling
* {@link #isEndOfResponse}, or {@code false} if the REST layer should request another chunk from this body using {@link #encodeChunk}.
* @return {@code true} if this body part contains no more chunks and the REST layer should check for a possible continuation by calling
* {@link #isLastPart}, or {@code false} if the REST layer should request another chunk from this body using {@link #encodeChunk}.
*/
boolean isDone();
boolean isPartComplete();
/**
* @return {@code true} if this is the last chunked body in the response, or {@code false} if the REST layer should request further
* chunked bodies by calling {@link #getContinuation}.
* @return {@code true} if this is the last chunked body part in the response, or {@code false} if the REST layer should request further
* chunked bodies by calling {@link #getNextPart}.
*/
boolean isEndOfResponse();
boolean isLastPart();
/**
* <p>Asynchronously retrieves the next part of the body. Called if {@link #isEndOfResponse} returns {@code false}.</p>
* <p>Asynchronously retrieves the next part of the response body. Called if {@link #isLastPart} returns {@code false}.</p>
*
* <p>Note that this is called on a transport thread, so implementations must take care to dispatch any nontrivial work elsewhere.</p>
* <p>Note that this is called on a transport thread: implementations must take care to dispatch any nontrivial work elsewhere.</p>
* <p>Note that the {@link Task} corresponding to any invocation of {@link Client#execute} completes as soon as the client action
* returns its response, so it no longer exists when this method is called and cannot be used to receive cancellation notifications.
@ -78,7 +78,7 @@ public interface ChunkedRestResponseBody {
* the body of the response, so there's no good ways to handle an exception here. Completing the listener exceptionally
* will log an error, abort sending the response, and close the HTTP connection.
*/
void getContinuation(ActionListener<ChunkedRestResponseBody> listener);
void getNextPart(ActionListener<ChunkedRestResponseBodyPart> listener);
/**
* Serializes approximately as many bytes of the response as request by {@code sizeHint} to a {@link ReleasableBytesReference} that
@ -97,17 +97,17 @@ public interface ChunkedRestResponseBody {
String getResponseContentTypeString();
/**
* Create a chunked response body to be written to a specific {@link RestChannel} from a {@link ChunkedToXContent}.
* Create a one-part chunked response body to be written to a specific {@link RestChannel} from a {@link ChunkedToXContent}.
*
* @param chunkedToXContent chunked x-content instance to serialize
* @param params parameters to use for serialization
* @param channel channel the response will be written to
* @return chunked rest response body
*/
static ChunkedRestResponseBody fromXContent(ChunkedToXContent chunkedToXContent, ToXContent.Params params, RestChannel channel)
static ChunkedRestResponseBodyPart fromXContent(ChunkedToXContent chunkedToXContent, ToXContent.Params params, RestChannel channel)
throws IOException {
return new ChunkedRestResponseBody() {
return new ChunkedRestResponseBodyPart() {
private final OutputStream out = new OutputStream() {
@Override
@ -135,17 +135,17 @@ public interface ChunkedRestResponseBody {
private BytesStream target;
@Override
public boolean isDone() {
public boolean isPartComplete() {
return serialization.hasNext() == false;
}
@Override
public boolean isEndOfResponse() {
public boolean isLastPart() {
return true;
}
@Override
public void getContinuation(ActionListener<ChunkedRestResponseBody> listener) {
public void getNextPart(ActionListener<ChunkedRestResponseBodyPart> listener) {
assert false : "no continuations";
listener.onFailure(new IllegalStateException("no continuations available"));
}
@ -191,11 +191,11 @@ public interface ChunkedRestResponseBody {
}
/**
* Create a chunked response body to be written to a specific {@link RestChannel} from a stream of text chunks, each represented as a
* consumer of a {@link Writer}.
* Create a one-part chunked response body to be written to a specific {@link RestChannel} from a stream of UTF-8-encoded text chunks,
* each represented as a consumer of a {@link Writer}.
*/
static ChunkedRestResponseBody fromTextChunks(String contentType, Iterator<CheckedConsumer<Writer, IOException>> chunkIterator) {
return new ChunkedRestResponseBody() {
static ChunkedRestResponseBodyPart fromTextChunks(String contentType, Iterator<CheckedConsumer<Writer, IOException>> chunkIterator) {
return new ChunkedRestResponseBodyPart() {
private RecyclerBytesStreamOutput currentOutput;
private final Writer writer = new OutputStreamWriter(new OutputStream() {
@Override
@ -224,17 +224,17 @@ public interface ChunkedRestResponseBody {
}, StandardCharsets.UTF_8);
@Override
public boolean isDone() {
public boolean isPartComplete() {
return chunkIterator.hasNext() == false;
}
@Override
public boolean isEndOfResponse() {
public boolean isLastPart() {
return true;
}
@Override
public void getContinuation(ActionListener<ChunkedRestResponseBody> listener) {
public void getNextPart(ActionListener<ChunkedRestResponseBodyPart> listener) {
assert false : "no continuations";
listener.onFailure(new IllegalStateException("no continuations available"));
}

View file

@ -16,29 +16,29 @@ import org.elasticsearch.common.recycler.Recycler;
import java.io.IOException;
import java.io.OutputStream;
public class LoggingChunkedRestResponseBody implements ChunkedRestResponseBody {
public class LoggingChunkedRestResponseBodyPart implements ChunkedRestResponseBodyPart {
private final ChunkedRestResponseBody inner;
private final ChunkedRestResponseBodyPart inner;
private final OutputStream loggerStream;
public LoggingChunkedRestResponseBody(ChunkedRestResponseBody inner, OutputStream loggerStream) {
public LoggingChunkedRestResponseBodyPart(ChunkedRestResponseBodyPart inner, OutputStream loggerStream) {
this.inner = inner;
this.loggerStream = loggerStream;
}
@Override
public boolean isDone() {
return inner.isDone();
public boolean isPartComplete() {
return inner.isPartComplete();
}
@Override
public boolean isEndOfResponse() {
return inner.isEndOfResponse();
public boolean isLastPart() {
return inner.isLastPart();
}
@Override
public void getContinuation(ActionListener<ChunkedRestResponseBody> listener) {
inner.getContinuation(listener.map(continuation -> new LoggingChunkedRestResponseBody(continuation, loggerStream)));
public void getNextPart(ActionListener<ChunkedRestResponseBodyPart> listener) {
inner.getNextPart(listener.map(continuation -> new LoggingChunkedRestResponseBodyPart(continuation, loggerStream)));
}
@Override

View file

@ -857,7 +857,7 @@ public class RestController implements HttpServerTransport.Dispatcher {
final var headers = response.getHeaders();
response = RestResponse.chunked(
response.status(),
new EncodedLengthTrackingChunkedRestResponseBody(response.chunkedContent(), responseLengthRecorder),
new EncodedLengthTrackingChunkedRestResponseBodyPart(response.chunkedContent(), responseLengthRecorder),
Releasables.wrap(responseLengthRecorder, response)
);
for (final var header : headers.entrySet()) {
@ -916,13 +916,13 @@ public class RestController implements HttpServerTransport.Dispatcher {
}
}
private static class EncodedLengthTrackingChunkedRestResponseBody implements ChunkedRestResponseBody {
private static class EncodedLengthTrackingChunkedRestResponseBodyPart implements ChunkedRestResponseBodyPart {
private final ChunkedRestResponseBody delegate;
private final ChunkedRestResponseBodyPart delegate;
private final ResponseLengthRecorder responseLengthRecorder;
private EncodedLengthTrackingChunkedRestResponseBody(
ChunkedRestResponseBody delegate,
private EncodedLengthTrackingChunkedRestResponseBodyPart(
ChunkedRestResponseBodyPart delegate,
ResponseLengthRecorder responseLengthRecorder
) {
this.delegate = delegate;
@ -930,19 +930,19 @@ public class RestController implements HttpServerTransport.Dispatcher {
}
@Override
public boolean isDone() {
return delegate.isDone();
public boolean isPartComplete() {
return delegate.isPartComplete();
}
@Override
public boolean isEndOfResponse() {
return delegate.isEndOfResponse();
public boolean isLastPart() {
return delegate.isLastPart();
}
@Override
public void getContinuation(ActionListener<ChunkedRestResponseBody> listener) {
delegate.getContinuation(
listener.map(continuation -> new EncodedLengthTrackingChunkedRestResponseBody(continuation, responseLengthRecorder))
public void getNextPart(ActionListener<ChunkedRestResponseBodyPart> listener) {
delegate.getNextPart(
listener.map(continuation -> new EncodedLengthTrackingChunkedRestResponseBodyPart(continuation, responseLengthRecorder))
);
}
@ -950,7 +950,7 @@ public class RestController implements HttpServerTransport.Dispatcher {
public ReleasableBytesReference encodeChunk(int sizeHint, Recycler<BytesRef> recycler) throws IOException {
final ReleasableBytesReference bytesReference = delegate.encodeChunk(sizeHint, recycler);
responseLengthRecorder.addChunkLength(bytesReference.length());
if (isDone() && isEndOfResponse()) {
if (isPartComplete() && isLastPart()) {
responseLengthRecorder.close();
}
return bytesReference;

View file

@ -48,7 +48,7 @@ public final class RestResponse implements Releasable {
private final BytesReference content;
@Nullable
private final ChunkedRestResponseBody chunkedResponseBody;
private final ChunkedRestResponseBodyPart chunkedResponseBody;
private final String responseMediaType;
private Map<String, List<String>> customHeaders;
@ -84,9 +84,9 @@ public final class RestResponse implements Releasable {
this(status, responseMediaType, content, null, releasable);
}
public static RestResponse chunked(RestStatus restStatus, ChunkedRestResponseBody content, @Nullable Releasable releasable) {
if (content.isDone()) {
assert content.isEndOfResponse() : "response with continuations must have at least one (possibly-empty) chunk in each part";
public static RestResponse chunked(RestStatus restStatus, ChunkedRestResponseBodyPart content, @Nullable Releasable releasable) {
if (content.isPartComplete()) {
assert content.isLastPart() : "response with continuations must have at least one (possibly-empty) chunk in each part";
return new RestResponse(restStatus, content.getResponseContentTypeString(), BytesArray.EMPTY, releasable);
} else {
return new RestResponse(restStatus, content.getResponseContentTypeString(), null, content, releasable);
@ -100,7 +100,7 @@ public final class RestResponse implements Releasable {
RestStatus status,
String responseMediaType,
@Nullable BytesReference content,
@Nullable ChunkedRestResponseBody chunkedResponseBody,
@Nullable ChunkedRestResponseBodyPart chunkedResponseBody,
@Nullable Releasable releasable
) {
this.status = status;
@ -162,7 +162,7 @@ public final class RestResponse implements Releasable {
}
@Nullable
public ChunkedRestResponseBody chunkedContent() {
public ChunkedRestResponseBodyPart chunkedContent() {
return chunkedResponseBody;
}

View file

@ -10,7 +10,7 @@ package org.elasticsearch.rest.action;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.rest.ChunkedRestResponseBody;
import org.elasticsearch.rest.ChunkedRestResponseBodyPart;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestResponse;
import org.elasticsearch.rest.RestStatus;
@ -40,7 +40,7 @@ public class RestChunkedToXContentListener<Response extends ChunkedToXContent> e
channel.sendResponse(
RestResponse.chunked(
getRestStatus(response),
ChunkedRestResponseBody.fromXContent(response, params, channel),
ChunkedRestResponseBodyPart.fromXContent(response, params, channel),
releasableFromResponse(response)
)
);

View file

@ -27,7 +27,7 @@ import java.io.IOException;
import java.util.List;
import java.util.Locale;
import static org.elasticsearch.rest.ChunkedRestResponseBody.fromTextChunks;
import static org.elasticsearch.rest.ChunkedRestResponseBodyPart.fromTextChunks;
import static org.elasticsearch.rest.RestRequest.Method.GET;
import static org.elasticsearch.rest.RestResponse.TEXT_CONTENT_TYPE;
import static org.elasticsearch.rest.RestUtils.getTimeout;

View file

@ -17,7 +17,7 @@ import org.elasticsearch.common.unit.SizeValue;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.Booleans;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.rest.ChunkedRestResponseBody;
import org.elasticsearch.rest.ChunkedRestResponseBodyPart;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestResponse;
@ -63,7 +63,7 @@ public class RestTable {
return RestResponse.chunked(
RestStatus.OK,
ChunkedRestResponseBody.fromXContent(
ChunkedRestResponseBodyPart.fromXContent(
ignored -> Iterators.concat(
Iterators.single((builder, params) -> builder.startArray()),
Iterators.map(rowOrder.iterator(), row -> (builder, params) -> {
@ -94,7 +94,7 @@ public class RestTable {
return RestResponse.chunked(
RestStatus.OK,
ChunkedRestResponseBody.fromTextChunks(
ChunkedRestResponseBodyPart.fromTextChunks(
RestResponse.TEXT_CONTENT_TYPE,
Iterators.concat(
// optional header

View file

@ -19,7 +19,7 @@ import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.http.HttpStats;
import org.elasticsearch.ingest.IngestStats;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.ChunkedRestResponseBody;
import org.elasticsearch.rest.ChunkedRestResponseBodyPart;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestResponse;
import org.elasticsearch.rest.RestStatus;
@ -122,7 +122,7 @@ public class RestClusterInfoAction extends BaseRestHandler {
return RestResponse.chunked(
RestStatus.OK,
ChunkedRestResponseBody.fromXContent(
ChunkedRestResponseBodyPart.fromXContent(
outerParams -> Iterators.concat(
ChunkedToXContentHelper.startObject(),
Iterators.single((builder, params) -> builder.field("cluster_name", response.getClusterName().value())),

View file

@ -152,12 +152,14 @@ public class ThreadPool implements ReportingService<ThreadPoolInfo>, Scheduler {
public static final Map<String, ThreadPoolType> THREAD_POOL_TYPES = Map.ofEntries(
entry(Names.GENERIC, ThreadPoolType.SCALING),
entry(Names.CLUSTER_COORDINATION, ThreadPoolType.FIXED),
entry(Names.GET, ThreadPoolType.FIXED),
entry(Names.ANALYZE, ThreadPoolType.FIXED),
entry(Names.WRITE, ThreadPoolType.FIXED),
entry(Names.SEARCH, ThreadPoolType.FIXED),
entry(Names.SEARCH_WORKER, ThreadPoolType.FIXED),
entry(Names.SEARCH_COORDINATION, ThreadPoolType.FIXED),
entry(Names.AUTO_COMPLETE, ThreadPoolType.FIXED),
entry(Names.MANAGEMENT, ThreadPoolType.SCALING),
entry(Names.FLUSH, ThreadPoolType.SCALING),
entry(Names.REFRESH, ThreadPoolType.SCALING),

View file

@ -12,8 +12,10 @@ import org.elasticsearch.ElasticsearchParseException;
import org.elasticsearch.common.util.LocaleUtils;
import org.elasticsearch.index.mapper.DateFieldMapper;
import org.elasticsearch.test.ESTestCase;
import org.hamcrest.Matcher;
import java.time.Clock;
import java.time.DateTimeException;
import java.time.Instant;
import java.time.LocalDateTime;
import java.time.ZoneId;
@ -39,12 +41,25 @@ import static org.hamcrest.Matchers.sameInstance;
public class DateFormattersTests extends ESTestCase {
private IllegalArgumentException assertParseException(String input, String format) {
private void assertParseException(String input, String format) {
DateFormatter javaTimeFormatter = DateFormatter.forPattern(format);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> javaTimeFormatter.parse(input));
assertThat(e.getMessage(), containsString(input));
assertThat(e.getMessage(), containsString(format));
return e;
assertThat(e.getCause(), instanceOf(DateTimeException.class));
}
private void assertParseException(String input, String format, int errorIndex) {
assertParseException(input, format, equalTo(errorIndex));
}
private void assertParseException(String input, String format, Matcher<Integer> indexMatcher) {
DateFormatter javaTimeFormatter = DateFormatter.forPattern(format);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> javaTimeFormatter.parse(input));
assertThat(e.getMessage(), containsString(input));
assertThat(e.getMessage(), containsString(format));
assertThat(e.getCause(), instanceOf(DateTimeParseException.class));
assertThat(((DateTimeParseException) e.getCause()).getErrorIndex(), indexMatcher);
}
private void assertParses(String input, String format) {
@ -698,7 +713,7 @@ public class DateFormattersTests extends ESTestCase {
ES java.time implementation does not suffer from this,
but we intentionally not allow parsing timezone without a time part as it is not allowed in iso8601
*/
assertParseException("2016-11-30T+01", "strict_date_optional_time");
assertParseException("2016-11-30T+01", "strict_date_optional_time", 11);
assertParses("2016-11-30T12+01", "strict_date_optional_time");
assertParses("2016-11-30T12:00+01", "strict_date_optional_time");
@ -792,8 +807,8 @@ public class DateFormattersTests extends ESTestCase {
assertParses("2001-01-01T00:00:00.123Z", javaFormatter);
assertParses("2001-01-01T00:00:00,123Z", javaFormatter);
assertParseException("2001-01-01T00:00:00.123,456Z", "strict_date_optional_time");
assertParseException("2001-01-01T00:00:00.123,456Z", "date_optional_time");
assertParseException("2001-01-01T00:00:00.123,456Z", "strict_date_optional_time", 23);
assertParseException("2001-01-01T00:00:00.123,456Z", "date_optional_time", 23);
// This should fail, but java is ok with this because the field has the same value
// assertJavaTimeParseException("2001-01-01T00:00:00.123,123Z", "strict_date_optional_time_nanos");
}
@ -911,7 +926,7 @@ public class DateFormattersTests extends ESTestCase {
assertParses("2018-12-31T12:12:12.123456789", "date_hour_minute_second_fraction");
assertParses("2018-12-31T12:12:12.1", "date_hour_minute_second_millis");
assertParses("2018-12-31T12:12:12.123", "date_hour_minute_second_millis");
assertParseException("2018-12-31T12:12:12.123456789", "date_hour_minute_second_millis");
assertParseException("2018-12-31T12:12:12.123456789", "date_hour_minute_second_millis", 23);
assertParses("2018-12-31T12:12:12.1", "date_hour_minute_second_millis");
assertParses("2018-12-31T12:12:12.1", "date_hour_minute_second_fraction");
@ -981,11 +996,11 @@ public class DateFormattersTests extends ESTestCase {
assertParses("12:12:12.123", "hour_minute_second_fraction");
assertParses("12:12:12.123456789", "hour_minute_second_fraction");
assertParses("12:12:12.1", "hour_minute_second_fraction");
assertParseException("12:12:12", "hour_minute_second_fraction");
assertParseException("12:12:12", "hour_minute_second_fraction", 8);
assertParses("12:12:12.123", "hour_minute_second_millis");
assertParseException("12:12:12.123456789", "hour_minute_second_millis");
assertParseException("12:12:12.123456789", "hour_minute_second_millis", 12);
assertParses("12:12:12.1", "hour_minute_second_millis");
assertParseException("12:12:12", "hour_minute_second_millis");
assertParseException("12:12:12", "hour_minute_second_millis", 8);
assertParses("2018-128", "ordinal_date");
assertParses("2018-1", "ordinal_date");
@ -1025,8 +1040,8 @@ public class DateFormattersTests extends ESTestCase {
assertParses("10:15:3.123Z", "time");
assertParses("10:15:3.123+0100", "time");
assertParses("10:15:3.123+01:00", "time");
assertParseException("10:15:3.1", "time");
assertParseException("10:15:3Z", "time");
assertParseException("10:15:3.1", "time", 9);
assertParseException("10:15:3Z", "time", 7);
assertParses("10:15:30Z", "time_no_millis");
assertParses("10:15:30+0100", "time_no_millis");
@ -1043,7 +1058,7 @@ public class DateFormattersTests extends ESTestCase {
assertParses("10:15:3Z", "time_no_millis");
assertParses("10:15:3+0100", "time_no_millis");
assertParses("10:15:3+01:00", "time_no_millis");
assertParseException("10:15:3", "time_no_millis");
assertParseException("10:15:3", "time_no_millis", 7);
assertParses("T10:15:30.1Z", "t_time");
assertParses("T10:15:30.123Z", "t_time");
@ -1061,8 +1076,8 @@ public class DateFormattersTests extends ESTestCase {
assertParses("T10:15:3.123Z", "t_time");
assertParses("T10:15:3.123+0100", "t_time");
assertParses("T10:15:3.123+01:00", "t_time");
assertParseException("T10:15:3.1", "t_time");
assertParseException("T10:15:3Z", "t_time");
assertParseException("T10:15:3.1", "t_time", 10);
assertParseException("T10:15:3Z", "t_time", 8);
assertParses("T10:15:30Z", "t_time_no_millis");
assertParses("T10:15:30+0100", "t_time_no_millis");
@ -1076,12 +1091,12 @@ public class DateFormattersTests extends ESTestCase {
assertParses("T10:15:3Z", "t_time_no_millis");
assertParses("T10:15:3+0100", "t_time_no_millis");
assertParses("T10:15:3+01:00", "t_time_no_millis");
assertParseException("T10:15:3", "t_time_no_millis");
assertParseException("T10:15:3", "t_time_no_millis", 8);
assertParses("2012-W48-6", "week_date");
assertParses("2012-W01-6", "week_date");
assertParses("2012-W1-6", "week_date");
assertParseException("2012-W1-8", "week_date");
assertParseException("2012-W1-8", "week_date", 0);
assertParses("2012-W48-6T10:15:30.1Z", "week_date_time");
assertParses("2012-W48-6T10:15:30.123Z", "week_date_time");
@ -1135,17 +1150,12 @@ public class DateFormattersTests extends ESTestCase {
}
public void testExceptionWhenCompositeParsingFails() {
assertParseException("2014-06-06T12:01:02.123", "yyyy-MM-dd'T'HH:mm:ss||yyyy-MM-dd'T'HH:mm:ss.SS");
}
public void testExceptionErrorIndex() {
Exception e = assertParseException("2024-01-01j", "iso8601||strict_date_optional_time");
assertThat(((DateTimeParseException) e.getCause()).getErrorIndex(), equalTo(10));
assertParseException("2014-06-06T12:01:02.123", "yyyy-MM-dd'T'HH:mm:ss||yyyy-MM-dd'T'HH:mm:ss.SS", 19);
}
public void testStrictParsing() {
assertParses("2018W313", "strict_basic_week_date");
assertParseException("18W313", "strict_basic_week_date");
assertParseException("18W313", "strict_basic_week_date", 0);
assertParses("2018W313T121212.1Z", "strict_basic_week_date_time");
assertParses("2018W313T121212.123Z", "strict_basic_week_date_time");
assertParses("2018W313T121212.123456789Z", "strict_basic_week_date_time");
@ -1153,52 +1163,52 @@ public class DateFormattersTests extends ESTestCase {
assertParses("2018W313T121212.123+0100", "strict_basic_week_date_time");
assertParses("2018W313T121212.1+01:00", "strict_basic_week_date_time");
assertParses("2018W313T121212.123+01:00", "strict_basic_week_date_time");
assertParseException("2018W313T12128.123Z", "strict_basic_week_date_time");
assertParseException("2018W313T12128.123456789Z", "strict_basic_week_date_time");
assertParseException("2018W313T81212.123Z", "strict_basic_week_date_time");
assertParseException("2018W313T12812.123Z", "strict_basic_week_date_time");
assertParseException("2018W313T12812.1Z", "strict_basic_week_date_time");
assertParseException("2018W313T12128.123Z", "strict_basic_week_date_time", 13);
assertParseException("2018W313T12128.123456789Z", "strict_basic_week_date_time", 13);
assertParseException("2018W313T81212.123Z", "strict_basic_week_date_time", 13);
assertParseException("2018W313T12812.123Z", "strict_basic_week_date_time", 13);
assertParseException("2018W313T12812.1Z", "strict_basic_week_date_time", 13);
assertParses("2018W313T121212Z", "strict_basic_week_date_time_no_millis");
assertParses("2018W313T121212+0100", "strict_basic_week_date_time_no_millis");
assertParses("2018W313T121212+01:00", "strict_basic_week_date_time_no_millis");
assertParseException("2018W313T12128Z", "strict_basic_week_date_time_no_millis");
assertParseException("2018W313T12128+0100", "strict_basic_week_date_time_no_millis");
assertParseException("2018W313T12128+01:00", "strict_basic_week_date_time_no_millis");
assertParseException("2018W313T81212Z", "strict_basic_week_date_time_no_millis");
assertParseException("2018W313T81212+0100", "strict_basic_week_date_time_no_millis");
assertParseException("2018W313T81212+01:00", "strict_basic_week_date_time_no_millis");
assertParseException("2018W313T12812Z", "strict_basic_week_date_time_no_millis");
assertParseException("2018W313T12812+0100", "strict_basic_week_date_time_no_millis");
assertParseException("2018W313T12812+01:00", "strict_basic_week_date_time_no_millis");
assertParseException("2018W313T12128Z", "strict_basic_week_date_time_no_millis", 13);
assertParseException("2018W313T12128+0100", "strict_basic_week_date_time_no_millis", 13);
assertParseException("2018W313T12128+01:00", "strict_basic_week_date_time_no_millis", 13);
assertParseException("2018W313T81212Z", "strict_basic_week_date_time_no_millis", 13);
assertParseException("2018W313T81212+0100", "strict_basic_week_date_time_no_millis", 13);
assertParseException("2018W313T81212+01:00", "strict_basic_week_date_time_no_millis", 13);
assertParseException("2018W313T12812Z", "strict_basic_week_date_time_no_millis", 13);
assertParseException("2018W313T12812+0100", "strict_basic_week_date_time_no_millis", 13);
assertParseException("2018W313T12812+01:00", "strict_basic_week_date_time_no_millis", 13);
assertParses("2018-12-31", "strict_date");
assertParseException("10000-12-31", "strict_date");
assertParseException("2018-8-31", "strict_date");
assertParseException("10000-12-31", "strict_date", 0);
assertParseException("2018-8-31", "strict_date", 5);
assertParses("2018-12-31T12", "strict_date_hour");
assertParseException("2018-12-31T8", "strict_date_hour");
assertParseException("2018-12-31T8", "strict_date_hour", 11);
assertParses("2018-12-31T12:12", "strict_date_hour_minute");
assertParseException("2018-12-31T8:3", "strict_date_hour_minute");
assertParseException("2018-12-31T8:3", "strict_date_hour_minute", 11);
assertParses("2018-12-31T12:12:12", "strict_date_hour_minute_second");
assertParseException("2018-12-31T12:12:1", "strict_date_hour_minute_second");
assertParseException("2018-12-31T12:12:1", "strict_date_hour_minute_second", 17);
assertParses("2018-12-31T12:12:12.1", "strict_date_hour_minute_second_fraction");
assertParses("2018-12-31T12:12:12.123", "strict_date_hour_minute_second_fraction");
assertParses("2018-12-31T12:12:12.123456789", "strict_date_hour_minute_second_fraction");
assertParses("2018-12-31T12:12:12.123", "strict_date_hour_minute_second_millis");
assertParses("2018-12-31T12:12:12.1", "strict_date_hour_minute_second_millis");
assertParses("2018-12-31T12:12:12.1", "strict_date_hour_minute_second_fraction");
assertParseException("2018-12-31T12:12:12", "strict_date_hour_minute_second_millis");
assertParseException("2018-12-31T12:12:12", "strict_date_hour_minute_second_fraction");
assertParseException("2018-12-31T12:12:12", "strict_date_hour_minute_second_millis", 19);
assertParseException("2018-12-31T12:12:12", "strict_date_hour_minute_second_fraction", 19);
assertParses("2018-12-31", "strict_date_optional_time");
assertParseException("2018-12-1", "strict_date_optional_time");
assertParseException("2018-1-31", "strict_date_optional_time");
assertParseException("10000-01-31", "strict_date_optional_time");
assertParseException("2018-12-1", "strict_date_optional_time", 7);
assertParseException("2018-1-31", "strict_date_optional_time", 4);
assertParseException("10000-01-31", "strict_date_optional_time", 4);
assertParses("2010-01-05T02:00", "strict_date_optional_time");
assertParses("2018-12-31T10:15:30", "strict_date_optional_time");
assertParses("2018-12-31T10:15:30Z", "strict_date_optional_time");
assertParses("2018-12-31T10:15:30+0100", "strict_date_optional_time");
assertParses("2018-12-31T10:15:30+01:00", "strict_date_optional_time");
assertParseException("2018-12-31T10:15:3", "strict_date_optional_time");
assertParseException("2018-12-31T10:5:30", "strict_date_optional_time");
assertParseException("2018-12-31T9:15:30", "strict_date_optional_time");
assertParseException("2018-12-31T10:15:3", "strict_date_optional_time", 16);
assertParseException("2018-12-31T10:5:30", "strict_date_optional_time", 13);
assertParseException("2018-12-31T9:15:30", "strict_date_optional_time", 11);
assertParses("2015-01-04T00:00Z", "strict_date_optional_time");
assertParses("2018-12-31T10:15:30.1Z", "strict_date_time");
assertParses("2018-12-31T10:15:30.123Z", "strict_date_time");
@ -1210,33 +1220,33 @@ public class DateFormattersTests extends ESTestCase {
assertParses("2018-12-31T10:15:30.11Z", "strict_date_time");
assertParses("2018-12-31T10:15:30.11+0100", "strict_date_time");
assertParses("2018-12-31T10:15:30.11+01:00", "strict_date_time");
assertParseException("2018-12-31T10:15:3.123Z", "strict_date_time");
assertParseException("2018-12-31T10:5:30.123Z", "strict_date_time");
assertParseException("2018-12-31T1:15:30.123Z", "strict_date_time");
assertParseException("2018-12-31T10:15:3.123Z", "strict_date_time", 17);
assertParseException("2018-12-31T10:5:30.123Z", "strict_date_time", 14);
assertParseException("2018-12-31T1:15:30.123Z", "strict_date_time", 11);
assertParses("2018-12-31T10:15:30Z", "strict_date_time_no_millis");
assertParses("2018-12-31T10:15:30+0100", "strict_date_time_no_millis");
assertParses("2018-12-31T10:15:30+01:00", "strict_date_time_no_millis");
assertParseException("2018-12-31T10:5:30Z", "strict_date_time_no_millis");
assertParseException("2018-12-31T10:15:3Z", "strict_date_time_no_millis");
assertParseException("2018-12-31T1:15:30Z", "strict_date_time_no_millis");
assertParseException("2018-12-31T10:5:30Z", "strict_date_time_no_millis", 14);
assertParseException("2018-12-31T10:15:3Z", "strict_date_time_no_millis", 17);
assertParseException("2018-12-31T1:15:30Z", "strict_date_time_no_millis", 11);
assertParses("12", "strict_hour");
assertParses("01", "strict_hour");
assertParseException("1", "strict_hour");
assertParseException("1", "strict_hour", 0);
assertParses("12:12", "strict_hour_minute");
assertParses("12:01", "strict_hour_minute");
assertParseException("12:1", "strict_hour_minute");
assertParseException("12:1", "strict_hour_minute", 3);
assertParses("12:12:12", "strict_hour_minute_second");
assertParses("12:12:01", "strict_hour_minute_second");
assertParseException("12:12:1", "strict_hour_minute_second");
assertParseException("12:12:1", "strict_hour_minute_second", 6);
assertParses("12:12:12.123", "strict_hour_minute_second_fraction");
assertParses("12:12:12.123456789", "strict_hour_minute_second_fraction");
assertParses("12:12:12.1", "strict_hour_minute_second_fraction");
assertParseException("12:12:12", "strict_hour_minute_second_fraction");
assertParseException("12:12:12", "strict_hour_minute_second_fraction", 8);
assertParses("12:12:12.123", "strict_hour_minute_second_millis");
assertParses("12:12:12.1", "strict_hour_minute_second_millis");
assertParseException("12:12:12", "strict_hour_minute_second_millis");
assertParseException("12:12:12", "strict_hour_minute_second_millis", 8);
assertParses("2018-128", "strict_ordinal_date");
assertParseException("2018-1", "strict_ordinal_date");
assertParseException("2018-1", "strict_ordinal_date", 5);
assertParses("2018-128T10:15:30.1Z", "strict_ordinal_date_time");
assertParses("2018-128T10:15:30.123Z", "strict_ordinal_date_time");
@ -1245,23 +1255,23 @@ public class DateFormattersTests extends ESTestCase {
assertParses("2018-128T10:15:30.123+0100", "strict_ordinal_date_time");
assertParses("2018-128T10:15:30.1+01:00", "strict_ordinal_date_time");
assertParses("2018-128T10:15:30.123+01:00", "strict_ordinal_date_time");
assertParseException("2018-1T10:15:30.123Z", "strict_ordinal_date_time");
assertParseException("2018-1T10:15:30.123Z", "strict_ordinal_date_time", 5);
assertParses("2018-128T10:15:30Z", "strict_ordinal_date_time_no_millis");
assertParses("2018-128T10:15:30+0100", "strict_ordinal_date_time_no_millis");
assertParses("2018-128T10:15:30+01:00", "strict_ordinal_date_time_no_millis");
assertParseException("2018-1T10:15:30Z", "strict_ordinal_date_time_no_millis");
assertParseException("2018-1T10:15:30Z", "strict_ordinal_date_time_no_millis", 5);
assertParses("10:15:30.1Z", "strict_time");
assertParses("10:15:30.123Z", "strict_time");
assertParses("10:15:30.123456789Z", "strict_time");
assertParses("10:15:30.123+0100", "strict_time");
assertParses("10:15:30.123+01:00", "strict_time");
assertParseException("1:15:30.123Z", "strict_time");
assertParseException("10:1:30.123Z", "strict_time");
assertParseException("10:15:3.123Z", "strict_time");
assertParseException("10:15:3.1", "strict_time");
assertParseException("10:15:3Z", "strict_time");
assertParseException("1:15:30.123Z", "strict_time", 0);
assertParseException("10:1:30.123Z", "strict_time", 3);
assertParseException("10:15:3.123Z", "strict_time", 6);
assertParseException("10:15:3.1", "strict_time", 6);
assertParseException("10:15:3Z", "strict_time", 6);
assertParses("10:15:30Z", "strict_time_no_millis");
assertParses("10:15:30+0100", "strict_time_no_millis");
@ -1269,10 +1279,10 @@ public class DateFormattersTests extends ESTestCase {
assertParses("01:15:30Z", "strict_time_no_millis");
assertParses("01:15:30+0100", "strict_time_no_millis");
assertParses("01:15:30+01:00", "strict_time_no_millis");
assertParseException("1:15:30Z", "strict_time_no_millis");
assertParseException("10:5:30Z", "strict_time_no_millis");
assertParseException("10:15:3Z", "strict_time_no_millis");
assertParseException("10:15:3", "strict_time_no_millis");
assertParseException("1:15:30Z", "strict_time_no_millis", 0);
assertParseException("10:5:30Z", "strict_time_no_millis", 3);
assertParseException("10:15:3Z", "strict_time_no_millis", 6);
assertParseException("10:15:3", "strict_time_no_millis", 6);
assertParses("T10:15:30.1Z", "strict_t_time");
assertParses("T10:15:30.123Z", "strict_t_time");
@ -1281,28 +1291,28 @@ public class DateFormattersTests extends ESTestCase {
assertParses("T10:15:30.123+0100", "strict_t_time");
assertParses("T10:15:30.1+01:00", "strict_t_time");
assertParses("T10:15:30.123+01:00", "strict_t_time");
assertParseException("T1:15:30.123Z", "strict_t_time");
assertParseException("T10:1:30.123Z", "strict_t_time");
assertParseException("T10:15:3.123Z", "strict_t_time");
assertParseException("T10:15:3.1", "strict_t_time");
assertParseException("T10:15:3Z", "strict_t_time");
assertParseException("T1:15:30.123Z", "strict_t_time", 1);
assertParseException("T10:1:30.123Z", "strict_t_time", 4);
assertParseException("T10:15:3.123Z", "strict_t_time", 7);
assertParseException("T10:15:3.1", "strict_t_time", 7);
assertParseException("T10:15:3Z", "strict_t_time", 7);
assertParses("T10:15:30Z", "strict_t_time_no_millis");
assertParses("T10:15:30+0100", "strict_t_time_no_millis");
assertParses("T10:15:30+01:00", "strict_t_time_no_millis");
assertParseException("T1:15:30Z", "strict_t_time_no_millis");
assertParseException("T10:1:30Z", "strict_t_time_no_millis");
assertParseException("T10:15:3Z", "strict_t_time_no_millis");
assertParseException("T10:15:3", "strict_t_time_no_millis");
assertParseException("T1:15:30Z", "strict_t_time_no_millis", 1);
assertParseException("T10:1:30Z", "strict_t_time_no_millis", 4);
assertParseException("T10:15:3Z", "strict_t_time_no_millis", 7);
assertParseException("T10:15:3", "strict_t_time_no_millis", 7);
assertParses("2012-W48-6", "strict_week_date");
assertParses("2012-W01-6", "strict_week_date");
assertParseException("2012-W1-6", "strict_week_date");
assertParseException("2012-W1-8", "strict_week_date");
assertParseException("2012-W1-6", "strict_week_date", 6);
assertParseException("2012-W1-8", "strict_week_date", 6);
assertParses("2012-W48-6", "strict_week_date");
assertParses("2012-W01-6", "strict_week_date");
assertParseException("2012-W1-6", "strict_week_date");
assertParseException("2012-W1-6", "strict_week_date", 6);
assertParseException("2012-W01-8", "strict_week_date");
assertParses("2012-W48-6T10:15:30.1Z", "strict_week_date_time");
@ -1312,38 +1322,38 @@ public class DateFormattersTests extends ESTestCase {
assertParses("2012-W48-6T10:15:30.123+0100", "strict_week_date_time");
assertParses("2012-W48-6T10:15:30.1+01:00", "strict_week_date_time");
assertParses("2012-W48-6T10:15:30.123+01:00", "strict_week_date_time");
assertParseException("2012-W1-6T10:15:30.123Z", "strict_week_date_time");
assertParseException("2012-W1-6T10:15:30.123Z", "strict_week_date_time", 6);
assertParses("2012-W48-6T10:15:30Z", "strict_week_date_time_no_millis");
assertParses("2012-W48-6T10:15:30+0100", "strict_week_date_time_no_millis");
assertParses("2012-W48-6T10:15:30+01:00", "strict_week_date_time_no_millis");
assertParseException("2012-W1-6T10:15:30Z", "strict_week_date_time_no_millis");
assertParseException("2012-W1-6T10:15:30Z", "strict_week_date_time_no_millis", 6);
assertParses("2012", "strict_year");
assertParseException("1", "strict_year");
assertParseException("1", "strict_year", 0);
assertParses("-2000", "strict_year");
assertParses("2012-12", "strict_year_month");
assertParseException("1-1", "strict_year_month");
assertParseException("1-1", "strict_year_month", 0);
assertParses("2012-12-31", "strict_year_month_day");
assertParseException("1-12-31", "strict_year_month_day");
assertParseException("2012-1-31", "strict_year_month_day");
assertParseException("2012-12-1", "strict_year_month_day");
assertParseException("1-12-31", "strict_year_month_day", 0);
assertParseException("2012-1-31", "strict_year_month_day", 4);
assertParseException("2012-12-1", "strict_year_month_day", 7);
assertParses("2018", "strict_weekyear");
assertParseException("1", "strict_weekyear");
assertParseException("1", "strict_weekyear", 0);
assertParses("2018", "strict_weekyear");
assertParses("2017", "strict_weekyear");
assertParseException("1", "strict_weekyear");
assertParseException("1", "strict_weekyear", 0);
assertParses("2018-W29", "strict_weekyear_week");
assertParses("2018-W01", "strict_weekyear_week");
assertParseException("2018-W1", "strict_weekyear_week");
assertParseException("2018-W1", "strict_weekyear_week", 6);
assertParses("2012-W31-5", "strict_weekyear_week_day");
assertParseException("2012-W1-1", "strict_weekyear_week_day");
assertParseException("2012-W1-1", "strict_weekyear_week_day", 6);
}
public void testDateFormatterWithLocale() {

View file

@ -13,6 +13,7 @@ import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.Processors;
import org.elasticsearch.node.Node;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.hamcrest.Matcher;
@ -22,6 +23,7 @@ import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
@ -628,7 +630,19 @@ public class EsExecutorsTests extends ESTestCase {
public void testParseExecutorName() throws InterruptedException {
final var executorName = randomAlphaOfLength(10);
final var threadFactory = EsExecutors.daemonThreadFactory(rarely() ? null : randomAlphaOfLength(10), executorName);
final String nodeName = rarely() ? null : randomIdentifier();
final ThreadFactory threadFactory;
if (nodeName == null) {
threadFactory = EsExecutors.daemonThreadFactory(Settings.EMPTY, executorName);
} else if (randomBoolean()) {
threadFactory = EsExecutors.daemonThreadFactory(
Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), nodeName).build(),
executorName
);
} else {
threadFactory = EsExecutors.daemonThreadFactory(nodeName, executorName);
}
final var thread = threadFactory.newThread(() -> {});
try {
assertThat(EsExecutors.executorName(thread.getName()), equalTo(executorName));

View file

@ -31,7 +31,7 @@ import org.elasticsearch.common.util.MockPageCacheRecycler;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.rest.ChunkedRestResponseBody;
import org.elasticsearch.rest.ChunkedRestResponseBodyPart;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestResponse;
@ -530,20 +530,20 @@ public class DefaultRestChannelTests extends ESTestCase {
{
// chunked response
final var isClosed = new AtomicBoolean();
channel.sendResponse(RestResponse.chunked(RestStatus.OK, new ChunkedRestResponseBody() {
channel.sendResponse(RestResponse.chunked(RestStatus.OK, new ChunkedRestResponseBodyPart() {
@Override
public boolean isDone() {
public boolean isPartComplete() {
return false;
}
@Override
public boolean isEndOfResponse() {
public boolean isLastPart() {
throw new AssertionError("should not check for end-of-response for HEAD request");
}
@Override
public void getContinuation(ActionListener<ChunkedRestResponseBody> listener) {
public void getNextPart(ActionListener<ChunkedRestResponseBodyPart> listener) {
throw new AssertionError("should not get any continuations for HEAD request");
}
@ -688,25 +688,25 @@ public class DefaultRestChannelTests extends ESTestCase {
HttpRequest httpRequest = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/") {
@Override
public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBody content) {
public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBodyPart firstBodyPart) {
try (var bso = new BytesStreamOutput()) {
writeContent(bso, content);
writeContent(bso, firstBodyPart);
return new TestHttpResponse(status, bso.bytes());
} catch (IOException e) {
return fail(e);
}
}
private static void writeContent(OutputStream bso, ChunkedRestResponseBody content) throws IOException {
while (content.isDone() == false) {
private static void writeContent(OutputStream bso, ChunkedRestResponseBodyPart content) throws IOException {
while (content.isPartComplete() == false) {
try (var bytes = content.encodeChunk(1 << 14, BytesRefRecycler.NON_RECYCLING_INSTANCE)) {
bytes.writeTo(bso);
}
}
if (content.isEndOfResponse()) {
if (content.isLastPart()) {
return;
}
writeContent(bso, PlainActionFuture.get(content::getContinuation));
writeContent(bso, PlainActionFuture.get(content::getNextPart));
}
};
@ -735,14 +735,14 @@ public class DefaultRestChannelTests extends ESTestCase {
)
);
final var parts = new ArrayList<ChunkedRestResponseBody>();
class TestBody implements ChunkedRestResponseBody {
final var parts = new ArrayList<ChunkedRestResponseBodyPart>();
class TestBodyPart implements ChunkedRestResponseBodyPart {
boolean isDone;
final BytesReference thisChunk;
final BytesReference remainingChunks;
final int remainingContinuations;
TestBody(BytesReference content, int remainingContinuations) {
TestBodyPart(BytesReference content, int remainingContinuations) {
if (remainingContinuations == 0) {
thisChunk = content;
remainingChunks = BytesArray.EMPTY;
@ -755,18 +755,18 @@ public class DefaultRestChannelTests extends ESTestCase {
}
@Override
public boolean isDone() {
public boolean isPartComplete() {
return isDone;
}
@Override
public boolean isEndOfResponse() {
public boolean isLastPart() {
return remainingContinuations == 0;
}
@Override
public void getContinuation(ActionListener<ChunkedRestResponseBody> listener) {
final var continuation = new TestBody(remainingChunks, remainingContinuations - 1);
public void getNextPart(ActionListener<ChunkedRestResponseBodyPart> listener) {
final var continuation = new TestBodyPart(remainingChunks, remainingContinuations - 1);
parts.add(continuation);
listener.onResponse(continuation);
}
@ -785,7 +785,7 @@ public class DefaultRestChannelTests extends ESTestCase {
}
final var isClosed = new AtomicBoolean();
final var firstPart = new TestBody(responseBody, between(0, 3));
final var firstPart = new TestBodyPart(responseBody, between(0, 3));
parts.add(firstPart);
assertEquals(
responseBody,
@ -797,8 +797,8 @@ public class DefaultRestChannelTests extends ESTestCase {
() -> channel.sendResponse(RestResponse.chunked(RestStatus.OK, firstPart, () -> {
assertTrue(isClosed.compareAndSet(false, true));
for (int i = 0; i < parts.size(); i++) {
assertTrue("isDone " + i, parts.get(i).isDone());
assertEquals("isEndOfResponse " + i, i == parts.size() - 1, parts.get(i).isEndOfResponse());
assertTrue("isPartComplete " + i, parts.get(i).isPartComplete());
assertEquals("isLastPart " + i, i == parts.size() - 1, parts.get(i).isLastPart());
}
}))
)

View file

@ -10,7 +10,7 @@ package org.elasticsearch.http;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.rest.ChunkedRestResponseBody;
import org.elasticsearch.rest.ChunkedRestResponseBodyPart;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestStatus;
@ -78,7 +78,7 @@ class TestHttpRequest implements HttpRequest {
}
@Override
public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBody content) {
public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBodyPart firstBodyPart) {
throw new UnsupportedOperationException("chunked responses not supported");
}

View file

@ -30,7 +30,7 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Map;
public class ChunkedRestResponseBodyTests extends ESTestCase {
public class ChunkedRestResponseBodyPartTests extends ESTestCase {
public void testEncodesChunkedXContentCorrectly() throws IOException {
final ChunkedToXContent chunkedToXContent = (ToXContent.Params outerParams) -> Iterators.forArray(
@ -50,7 +50,7 @@ public class ChunkedRestResponseBodyTests extends ESTestCase {
}
final var bytesDirect = BytesReference.bytes(builderDirect);
var chunkedResponse = ChunkedRestResponseBody.fromXContent(
var firstBodyPart = ChunkedRestResponseBodyPart.fromXContent(
chunkedToXContent,
ToXContent.EMPTY_PARAMS,
new FakeRestChannel(
@ -61,20 +61,25 @@ public class ChunkedRestResponseBodyTests extends ESTestCase {
);
final List<BytesReference> refsGenerated = new ArrayList<>();
while (chunkedResponse.isDone() == false) {
refsGenerated.add(chunkedResponse.encodeChunk(randomIntBetween(2, 10), BytesRefRecycler.NON_RECYCLING_INSTANCE));
while (firstBodyPart.isPartComplete() == false) {
refsGenerated.add(firstBodyPart.encodeChunk(randomIntBetween(2, 10), BytesRefRecycler.NON_RECYCLING_INSTANCE));
}
assertTrue(firstBodyPart.isLastPart());
assertEquals(bytesDirect, CompositeBytesReference.of(refsGenerated.toArray(new BytesReference[0])));
}
public void testFromTextChunks() throws IOException {
final var chunks = randomList(1000, () -> randomUnicodeOfLengthBetween(1, 100));
var body = ChunkedRestResponseBody.fromTextChunks("text/plain", Iterators.map(chunks.iterator(), s -> w -> w.write(s)));
var firstBodyPart = ChunkedRestResponseBodyPart.fromTextChunks(
"text/plain",
Iterators.map(chunks.iterator(), s -> w -> w.write(s))
);
final List<BytesReference> refsGenerated = new ArrayList<>();
while (body.isDone() == false) {
refsGenerated.add(body.encodeChunk(randomIntBetween(2, 10), BytesRefRecycler.NON_RECYCLING_INSTANCE));
while (firstBodyPart.isPartComplete() == false) {
refsGenerated.add(firstBodyPart.encodeChunk(randomIntBetween(2, 10), BytesRefRecycler.NON_RECYCLING_INSTANCE));
}
assertTrue(firstBodyPart.isLastPart());
final BytesReference chunkedBytes = CompositeBytesReference.of(refsGenerated.toArray(new BytesReference[0]));
try (var outputStream = new ByteArrayOutputStream(); var writer = new OutputStreamWriter(outputStream, StandardCharsets.UTF_8)) {

View file

@ -733,7 +733,7 @@ public class RestControllerTests extends ESTestCase {
}
@Override
public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBody content) {
public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBodyPart firstBodyPart) {
throw new AssertionError("should not be called");
}

View file

@ -97,7 +97,7 @@ public class RestResponseTests extends ESTestCase {
public void testEmptyChunkedBody() {
RestResponse response = RestResponse.chunked(
RestStatus.OK,
ChunkedRestResponseBody.fromTextChunks(RestResponse.TEXT_CONTENT_TYPE, Collections.emptyIterator()),
ChunkedRestResponseBodyPart.fromTextChunks(RestResponse.TEXT_CONTENT_TYPE, Collections.emptyIterator()),
null
);
assertFalse(response.isChunked());

View file

@ -432,14 +432,15 @@ public class RestTableTests extends ESTestCase {
};
final var bodyChunks = new ArrayList<String>();
final var chunkedRestResponseBody = response.chunkedContent();
final var firstBodyPart = response.chunkedContent();
while (chunkedRestResponseBody.isDone() == false) {
try (var chunk = chunkedRestResponseBody.encodeChunk(pageSize, recycler)) {
while (firstBodyPart.isPartComplete() == false) {
try (var chunk = firstBodyPart.encodeChunk(pageSize, recycler)) {
assertThat(chunk.length(), greaterThan(0));
bodyChunks.add(chunk.utf8ToString());
}
}
assertTrue(firstBodyPart.isLastPart());
assertEquals(0, openPages.get());
return bodyChunks;
}

View file

@ -13,6 +13,7 @@ import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.admin.indices.flush.FlushRequest;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.UnsafePlainActionFuture;
import org.elasticsearch.action.support.replication.TransportReplicationAction;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.MappingMetadata;
@ -869,7 +870,7 @@ public abstract class IndexShardTestCase extends ESTestCase {
routingTable
);
try {
PlainActionFuture<RecoveryResponse> future = new PlainActionFuture<>();
PlainActionFuture<RecoveryResponse> future = new UnsafePlainActionFuture<>(ThreadPool.Names.GENERIC);
recovery.recoverToTarget(future);
future.actionGet();
recoveryTarget.markAsDone();

View file

@ -28,8 +28,8 @@ public class RestResponseUtils {
return restResponse.content();
}
final var chunkedRestResponseBody = restResponse.chunkedContent();
assert chunkedRestResponseBody.isDone() == false;
final var firstResponseBodyPart = restResponse.chunkedContent();
assert firstResponseBodyPart.isPartComplete() == false;
final int pageSize;
try (var page = NON_RECYCLING_INSTANCE.obtain()) {
@ -37,12 +37,12 @@ public class RestResponseUtils {
}
try (var out = new BytesStreamOutput()) {
while (chunkedRestResponseBody.isDone() == false) {
try (var chunk = chunkedRestResponseBody.encodeChunk(pageSize, NON_RECYCLING_INSTANCE)) {
while (firstResponseBodyPart.isPartComplete() == false) {
try (var chunk = firstResponseBodyPart.encodeChunk(pageSize, NON_RECYCLING_INSTANCE)) {
chunk.writeTo(out);
}
}
assert chunkedRestResponseBody.isEndOfResponse() : "RestResponseUtils#getBodyContent does not support continuations (yet)";
assert firstResponseBodyPart.isLastPart() : "RestResponseUtils#getBodyContent does not support continuations (yet)";
out.flush();
return out.bytes();

View file

@ -374,6 +374,15 @@ public abstract class ESTestCase extends LuceneTestCase {
// We have to disable setting the number of available processors as tests in the same JVM randomize processors and will step on each
// other if we allow them to set the number of available processors as it's set-once in Netty.
System.setProperty("es.set.netty.runtime.available.processors", "false");
// sometimes use the java.time date formatters
// we can't use randomBoolean here, the random context isn't set properly
// so read it directly from the test seed in an unfortunately hacky way
String testSeed = System.getProperty("tests.seed", "0");
boolean firstBit = (Integer.parseInt(testSeed.substring(testSeed.length() - 1), 16) & 1) == 1;
if (firstBit) {
System.setProperty("es.datetime.java_time_parsers", "true");
}
}
protected final Logger logger = LogManager.getLogger(getClass());

View file

@ -16,7 +16,7 @@ import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.http.HttpRequest;
import org.elasticsearch.http.HttpResponse;
import org.elasticsearch.rest.ChunkedRestResponseBody;
import org.elasticsearch.rest.ChunkedRestResponseBodyPart;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.NamedXContentRegistry;
@ -129,7 +129,7 @@ public class FakeRestRequest extends RestRequest {
}
@Override
public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBody content) {
public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBodyPart firstBodyPart) {
return createResponse(status, BytesArray.EMPTY);
}

View file

@ -20,6 +20,7 @@ import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionListenerResponseHandler;
import org.elasticsearch.action.support.ChannelActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.UnsafePlainActionFuture;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodeUtils;
import org.elasticsearch.cluster.node.VersionInformation;
@ -960,7 +961,7 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
protected void doRun() throws Exception {
safeAwait(go);
for (int iter = 0; iter < 10; iter++) {
PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
PlainActionFuture<TestResponse> listener = new UnsafePlainActionFuture<>(ThreadPool.Names.GENERIC);
final String info = sender + "_B_" + iter;
serviceB.sendRequest(
nodeA,
@ -996,7 +997,7 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
protected void doRun() throws Exception {
go.await();
for (int iter = 0; iter < 10; iter++) {
PlainActionFuture<TestResponse> listener = new PlainActionFuture<>();
PlainActionFuture<TestResponse> listener = new UnsafePlainActionFuture<>(ThreadPool.Names.GENERIC);
final String info = sender + "_" + iter;
final DiscoveryNode node = nodeB; // capture now
try {
@ -3464,7 +3465,7 @@ public abstract class AbstractSimpleTransportTestCase extends ESTestCase {
* @param connectionProfile the connection profile to use when connecting to this node
*/
public static void connectToNode(TransportService service, DiscoveryNode node, ConnectionProfile connectionProfile) {
PlainActionFuture.get(fut -> service.connectToNode(node, connectionProfile, fut.map(x -> null)));
UnsafePlainActionFuture.get(fut -> service.connectToNode(node, connectionProfile, fut.map(x -> null)), ThreadPool.Names.GENERIC);
}
/**

View file

@ -14,6 +14,7 @@ import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.RefCountingListener;
import org.elasticsearch.action.support.UnsafePlainActionFuture;
import org.elasticsearch.blobcache.BlobCacheMetrics;
import org.elasticsearch.blobcache.BlobCacheUtils;
import org.elasticsearch.blobcache.common.ByteRange;
@ -36,6 +37,7 @@ import org.elasticsearch.env.Environment;
import org.elasticsearch.env.NodeEnvironment;
import org.elasticsearch.monitor.fs.FsProbe;
import org.elasticsearch.node.NodeRoleSettings;
import org.elasticsearch.repositories.blobstore.BlobStoreRepository;
import org.elasticsearch.threadpool.ThreadPool;
import java.io.IOException;
@ -1136,7 +1138,9 @@ public class SharedBlobCacheService<KeyType> implements Releasable {
int startRegion,
int endRegion
) throws InterruptedException, ExecutionException {
final PlainActionFuture<Void> readsComplete = new PlainActionFuture<>();
final PlainActionFuture<Void> readsComplete = new UnsafePlainActionFuture<>(
BlobStoreRepository.STATELESS_SHARD_PREWARMING_THREAD_NAME
);
final AtomicInteger bytesRead = new AtomicInteger();
try (var listeners = new RefCountingListener(1, readsComplete)) {
for (int region = startRegion; region <= endRegion; region++) {

View file

@ -26,6 +26,7 @@ import org.elasticsearch.action.admin.indices.stats.ShardStats;
import org.elasticsearch.action.support.ListenerTimeouts;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.ThreadedActionListener;
import org.elasticsearch.action.support.UnsafePlainActionFuture;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.RemoteClusterClient;
import org.elasticsearch.cluster.ClusterName;
@ -599,7 +600,11 @@ public class CcrRepository extends AbstractLifecycleComponent implements Reposit
Client followerClient,
Index followerIndex
) {
final PlainActionFuture<IndexMetadata> indexMetadataFuture = new PlainActionFuture<>();
// todo: this could manifest in production and seems we could make this async easily.
final PlainActionFuture<IndexMetadata> indexMetadataFuture = new UnsafePlainActionFuture<>(
Ccr.CCR_THREAD_POOL_NAME,
ThreadPool.Names.GENERIC
);
final long startTimeInNanos = System.nanoTime();
final Supplier<TimeValue> timeout = () -> {
final long elapsedInNanos = System.nanoTime() - startTimeInNanos;

View file

@ -15,6 +15,7 @@ import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.delete.DeleteRequest;
import org.elasticsearch.action.support.ActionTestUtils;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.UnsafePlainActionFuture;
import org.elasticsearch.action.support.replication.PostWriteRefresh;
import org.elasticsearch.action.support.replication.ReplicationResponse;
import org.elasticsearch.action.support.replication.TransportWriteAction;
@ -802,7 +803,7 @@ public class ShardFollowTaskReplicationTests extends ESIndexLevelReplicationTest
@Override
protected void performOnPrimary(IndexShard primary, BulkShardOperationsRequest request, ActionListener<PrimaryResult> listener) {
final PlainActionFuture<Releasable> permitFuture = new PlainActionFuture<>();
final PlainActionFuture<Releasable> permitFuture = new UnsafePlainActionFuture<>(ThreadPool.Names.GENERIC);
primary.acquirePrimaryOperationPermit(permitFuture, EsExecutors.DIRECT_EXECUTOR_SERVICE);
final TransportWriteAction.WritePrimaryResult<BulkShardOperationsRequest, BulkShardOperationsResponse> ccrResult;
final var threadpool = mock(ThreadPool.class);

View file

@ -12,7 +12,7 @@ import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.rest.ChunkedRestResponseBody;
import org.elasticsearch.rest.ChunkedRestResponseBodyPart;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestResponse;
@ -132,13 +132,13 @@ public final class EsqlResponseListener extends RestRefCountedChunkedToXContentL
if (mediaType instanceof TextFormat format) {
restResponse = RestResponse.chunked(
RestStatus.OK,
ChunkedRestResponseBody.fromTextChunks(format.contentType(restRequest), format.format(restRequest, esqlResponse)),
ChunkedRestResponseBodyPart.fromTextChunks(format.contentType(restRequest), format.format(restRequest, esqlResponse)),
releasable
);
} else {
restResponse = RestResponse.chunked(
RestStatus.OK,
ChunkedRestResponseBody.fromXContent(esqlResponse, channel.request(), channel),
ChunkedRestResponseBodyPart.fromXContent(esqlResponse, channel.request(), channel),
releasable
);
}

View file

@ -36,13 +36,22 @@ public class RestMoveToStepAction extends BaseRestHandler {
@Override
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
String index = restRequest.param("name");
TransportMoveToStepAction.Request request;
final var masterNodeTimeout = getMasterNodeTimeout(restRequest);
final var ackTimeout = getAckTimeout(restRequest);
final var index = restRequest.param("name");
final TransportMoveToStepAction.Request request;
try (XContentParser parser = restRequest.contentParser()) {
request = TransportMoveToStepAction.Request.parseRequest(index, parser);
request = TransportMoveToStepAction.Request.parseRequest(
(currentStepKey, nextStepKey) -> new TransportMoveToStepAction.Request(
masterNodeTimeout,
ackTimeout,
index,
currentStepKey,
nextStepKey
),
parser
);
}
request.ackTimeout(getAckTimeout(restRequest));
request.masterNodeTimeout(getMasterNodeTimeout(restRequest));
return channel -> client.execute(ILMActions.MOVE_TO_STEP, request, new RestToXContentListener<>(channel));
}
}

View file

@ -36,10 +36,8 @@ public class RestRetryAction extends BaseRestHandler {
@Override
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) {
String[] indices = Strings.splitStringByCommaToArray(restRequest.param("index"));
TransportRetryAction.Request request = new TransportRetryAction.Request(indices);
request.ackTimeout(getAckTimeout(restRequest));
request.masterNodeTimeout(getMasterNodeTimeout(restRequest));
final var indices = Strings.splitStringByCommaToArray(restRequest.param("index"));
final var request = new TransportRetryAction.Request(getMasterNodeTimeout(restRequest), getAckTimeout(restRequest), indices);
request.indices(indices);
request.indicesOptions(IndicesOptions.fromRequest(restRequest, IndicesOptions.strictExpandOpen()));
return channel -> client.execute(ILMActions.RETRY, request, new RestToXContentListener<>(channel));

View file

@ -32,6 +32,7 @@ import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
@ -188,15 +189,20 @@ public class TransportMoveToStepAction extends TransportMasterNodeAction<Transpo
}
public static class Request extends AcknowledgedRequest<Request> implements ToXContentObject {
public interface Factory {
Request create(Step.StepKey currentStepKey, PartialStepKey nextStepKey);
}
static final ParseField CURRENT_KEY_FIELD = new ParseField("current_step");
static final ParseField NEXT_KEY_FIELD = new ParseField("next_step");
private static final ConstructingObjectParser<Request, String> PARSER = new ConstructingObjectParser<>(
private static final ConstructingObjectParser<Request, Factory> PARSER = new ConstructingObjectParser<>(
"move_to_step_request",
false,
(a, index) -> {
(a, factory) -> {
Step.StepKey currentStepKey = (Step.StepKey) a[0];
PartialStepKey nextStepKey = (PartialStepKey) a[1];
return new Request(index, currentStepKey, nextStepKey);
return factory.create(currentStepKey, nextStepKey);
}
);
@ -207,12 +213,18 @@ public class TransportMoveToStepAction extends TransportMasterNodeAction<Transpo
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, name) -> PartialStepKey.parse(p), NEXT_KEY_FIELD);
}
private String index;
private Step.StepKey currentStepKey;
private PartialStepKey nextStepKey;
private final String index;
private final Step.StepKey currentStepKey;
private final PartialStepKey nextStepKey;
public Request(String index, Step.StepKey currentStepKey, PartialStepKey nextStepKey) {
super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT);
public Request(
TimeValue masterNodeTimeout,
TimeValue ackTimeout,
String index,
Step.StepKey currentStepKey,
PartialStepKey nextStepKey
) {
super(masterNodeTimeout, ackTimeout);
this.index = index;
this.currentStepKey = currentStepKey;
this.nextStepKey = nextStepKey;
@ -225,10 +237,6 @@ public class TransportMoveToStepAction extends TransportMasterNodeAction<Transpo
this.nextStepKey = new PartialStepKey(in);
}
public Request() {
super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT);
}
public String getIndex() {
return index;
}
@ -246,8 +254,8 @@ public class TransportMoveToStepAction extends TransportMasterNodeAction<Transpo
return null;
}
public static Request parseRequest(String name, XContentParser parser) {
return PARSER.apply(parser, name);
public static Request parseRequest(Factory factory, XContentParser parser) {
return PARSER.apply(parser, factory);
}
@Override

View file

@ -26,12 +26,12 @@ import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.metadata.LifecycleExecutionState;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
@ -114,11 +114,11 @@ public class TransportRetryAction extends TransportMasterNodeAction<TransportRet
}
public static class Request extends AcknowledgedRequest<Request> implements IndicesRequest.Replaceable {
private String[] indices = Strings.EMPTY_ARRAY;
private String[] indices;
private IndicesOptions indicesOptions = IndicesOptions.strictExpandOpen();
public Request(String... indices) {
super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT);
public Request(TimeValue masterNodeTimeout, TimeValue ackTimeout, String... indices) {
super(masterNodeTimeout, ackTimeout);
this.indices = indices;
}
@ -128,10 +128,6 @@ public class TransportRetryAction extends TransportMasterNodeAction<TransportRet
this.indicesOptions = IndicesOptions.readIndicesOptions(in);
}
public Request() {
super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT);
}
@Override
public Request indices(String... indices) {
this.indices = indices;

View file

@ -26,7 +26,13 @@ public class MoveToStepRequestTests extends AbstractXContentSerializingTestCase<
@Override
protected TransportMoveToStepAction.Request createTestInstance() {
return new TransportMoveToStepAction.Request(index, stepKeyTests.createTestInstance(), randomStepSpecification());
return new TransportMoveToStepAction.Request(
TEST_REQUEST_TIMEOUT,
TEST_REQUEST_TIMEOUT,
index,
stepKeyTests.createTestInstance(),
randomStepSpecification()
);
}
@Override
@ -36,7 +42,16 @@ public class MoveToStepRequestTests extends AbstractXContentSerializingTestCase<
@Override
protected TransportMoveToStepAction.Request doParseInstance(XContentParser parser) {
return TransportMoveToStepAction.Request.parseRequest(index, parser);
return TransportMoveToStepAction.Request.parseRequest(
(currentStepKey, nextStepKey) -> new TransportMoveToStepAction.Request(
TEST_REQUEST_TIMEOUT,
TEST_REQUEST_TIMEOUT,
index,
currentStepKey,
nextStepKey
),
parser
);
}
@Override
@ -52,7 +67,7 @@ public class MoveToStepRequestTests extends AbstractXContentSerializingTestCase<
default -> throw new AssertionError("Illegal randomisation branch");
}
return new TransportMoveToStepAction.Request(indexName, currentStepKey, nextStepKey);
return new TransportMoveToStepAction.Request(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT, indexName, currentStepKey, nextStepKey);
}
private static TransportMoveToStepAction.Request.PartialStepKey randomStepSpecification() {

View file

@ -8,6 +8,7 @@
package org.elasticsearch.xpack.ilm.action;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
@ -17,10 +18,11 @@ public class RetryRequestTests extends AbstractWireSerializingTestCase<Transport
@Override
protected TransportRetryAction.Request createTestInstance() {
TransportRetryAction.Request request = new TransportRetryAction.Request();
if (randomBoolean()) {
request.indices(generateRandomStringArray(20, 20, false));
}
final var request = new TransportRetryAction.Request(
TEST_REQUEST_TIMEOUT,
TEST_REQUEST_TIMEOUT,
randomBoolean() ? Strings.EMPTY_ARRAY : generateRandomStringArray(20, 20, false)
);
if (randomBoolean()) {
IndicesOptions indicesOptions = IndicesOptions.fromOptions(
randomBoolean(),
@ -66,8 +68,7 @@ public class RetryRequestTests extends AbstractWireSerializingTestCase<Transport
);
default -> throw new AssertionError("Illegal randomisation branch");
}
TransportRetryAction.Request newRequest = new TransportRetryAction.Request();
newRequest.indices(indices);
final var newRequest = new TransportRetryAction.Request(TEST_REQUEST_TIMEOUT, TEST_REQUEST_TIMEOUT, indices);
newRequest.indicesOptions(indicesOptions);
return newRequest;
}

View file

@ -44,6 +44,7 @@ import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInt
import org.elasticsearch.xpack.inference.services.elasticsearch.MultilingualE5SmallInternalServiceSettings;
import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings;
import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings;
import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserSecretSettings;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
@ -106,6 +107,7 @@ public class InferenceNamedWriteablesProvider {
addCohereNamedWriteables(namedWriteables);
addAzureOpenAiNamedWriteables(namedWriteables);
addAzureAiStudioNamedWriteables(namedWriteables);
addGoogleAiStudioNamedWritables(namedWriteables);
return namedWriteables;
}
@ -254,6 +256,16 @@ public class InferenceNamedWriteablesProvider {
);
}
private static void addGoogleAiStudioNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
GoogleAiStudioCompletionServiceSettings.NAME,
GoogleAiStudioCompletionServiceSettings::new
)
);
}
private static void addInternalElserNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, ElserInternalServiceSettings.NAME, ElserInternalServiceSettings::new)
@ -318,4 +330,5 @@ public class InferenceNamedWriteablesProvider {
new NamedWriteableRegistry.Entry(InferenceServiceResults.class, RankedDocsResults.NAME, RankedDocsResults::new)
);
}
}

View file

@ -68,6 +68,7 @@ import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
import org.elasticsearch.xpack.inference.services.elser.ElserInternalService;
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService;
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
@ -194,6 +195,7 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
context -> new CohereService(httpFactory.get(), serviceComponents.get()),
context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get()),
context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get()),
context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get()),
ElasticsearchInternalService::new
);
}

View file

@ -0,0 +1,34 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.action.googleaistudio;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel;
import java.util.Map;
import java.util.Objects;
public class GoogleAiStudioActionCreator implements GoogleAiStudioActionVisitor {
private final Sender sender;
private final ServiceComponents serviceComponents;
public GoogleAiStudioActionCreator(Sender sender, ServiceComponents serviceComponents) {
this.sender = Objects.requireNonNull(sender);
this.serviceComponents = Objects.requireNonNull(serviceComponents);
}
@Override
public ExecutableAction create(GoogleAiStudioCompletionModel model, Map<String, Object> taskSettings) {
// no overridden model as task settings are always empty for Google AI Studio completion model
return new GoogleAiStudioCompletionAction(sender, model, serviceComponents.threadPool());
}
}

View file

@ -0,0 +1,19 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.action.googleaistudio;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel;
import java.util.Map;
public interface GoogleAiStudioActionVisitor {
ExecutableAction create(GoogleAiStudioCompletionModel model, Map<String, Object> taskSettings);
}

View file

@ -0,0 +1,73 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.action.googleaistudio;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.GoogleAiStudioCompletionRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel;
import java.util.Objects;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;
public class GoogleAiStudioCompletionAction implements ExecutableAction {
private final String failedToSendRequestErrorMessage;
private final GoogleAiStudioCompletionRequestManager requestManager;
private final Sender sender;
public GoogleAiStudioCompletionAction(Sender sender, GoogleAiStudioCompletionModel model, ThreadPool threadPool) {
Objects.requireNonNull(threadPool);
Objects.requireNonNull(model);
this.sender = Objects.requireNonNull(sender);
this.requestManager = new GoogleAiStudioCompletionRequestManager(model, threadPool);
this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(model.uri(), "Google AI Studio completion");
}
@Override
public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
if (inferenceInputs instanceof DocumentsOnlyInput == false) {
listener.onFailure(new ElasticsearchStatusException("Invalid inference input type", RestStatus.INTERNAL_SERVER_ERROR));
return;
}
var docsOnlyInput = (DocumentsOnlyInput) inferenceInputs;
if (docsOnlyInput.getInputs().size() > 1) {
listener.onFailure(
new ElasticsearchStatusException("Google AI Studio completion only accepts 1 input", RestStatus.BAD_REQUEST)
);
return;
}
try {
ActionListener<InferenceServiceResults> wrappedListener = wrapFailuresInElasticsearchException(
failedToSendRequestErrorMessage,
listener
);
sender.send(requestManager, inferenceInputs, timeout, wrappedListener);
} catch (ElasticsearchException e) {
listener.onFailure(e);
} catch (Exception e) {
listener.onFailure(createInternalServerError(e, failedToSendRequestErrorMessage));
}
}
}

View file

@ -0,0 +1,75 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.googleaistudio;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.googleaistudio.GoogleAiStudioErrorResponseEntity;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;
public class GoogleAiStudioResponseHandler extends BaseResponseHandler {
static final String GOOGLE_AI_STUDIO_UNAVAILABLE = "The Google AI Studio service may be temporarily overloaded or down";
public GoogleAiStudioResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, GoogleAiStudioErrorResponseEntity::fromResponse);
}
@Override
public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
throws RetryException {
checkForFailureStatusCode(request, result);
checkForEmptyBody(throttlerManager, logger, request, result);
}
/**
* Validates the status code and throws a RetryException if not in the range [200, 300).
*
* The Google AI Studio error codes are documented <a href="https://ai.google.dev/gemini-api/docs/troubleshooting">here</a>.
* @param request The originating request
* @param result The http response and body
* @throws RetryException Throws if status code is {@code >= 300 or < 200 }
*/
void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException {
int statusCode = result.response().getStatusLine().getStatusCode();
if (statusCode >= 200 && statusCode < 300) {
return;
}
// handle error codes
if (statusCode == 500) {
throw new RetryException(true, buildError(SERVER_ERROR, request, result));
} else if (statusCode == 503) {
throw new RetryException(true, buildError(GOOGLE_AI_STUDIO_UNAVAILABLE, request, result));
} else if (statusCode > 500) {
throw new RetryException(false, buildError(SERVER_ERROR, request, result));
} else if (statusCode == 429) {
throw new RetryException(true, buildError(RATE_LIMIT, request, result));
} else if (statusCode == 404) {
throw new RetryException(false, buildError(resourceNotFoundError(request), request, result));
} else if (statusCode == 403) {
throw new RetryException(false, buildError(PERMISSION_DENIED, request, result));
} else if (statusCode >= 300 && statusCode < 400) {
throw new RetryException(false, buildError(REDIRECTION, request, result));
} else {
throw new RetryException(false, buildError(UNSUCCESSFUL, request, result));
}
}
private static String resourceNotFoundError(Request request) {
return format("Resource not found at [%s]", request.getURI());
}
}

View file

@ -23,6 +23,7 @@ public abstract class BaseResponseHandler implements ResponseHandler {
public static final String SERVER_ERROR = "Received a server error status code";
public static final String RATE_LIMIT = "Received a rate limit status code";
public static final String AUTHENTICATION = "Received an authentication error status code";
public static final String PERMISSION_DENIED = "Received a permission denied error status code";
public static final String REDIRECTION = "Unhandled redirection";
public static final String CONTENT_TOO_LARGE = "Received a content too large status code";
public static final String UNSUCCESSFUL = "Received an unsuccessful status code";

View file

@ -0,0 +1,56 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.http.sender;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.googleaistudio.GoogleAiStudioResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioCompletionRequest;
import org.elasticsearch.xpack.inference.external.response.googleaistudio.GoogleAiStudioCompletionResponseEntity;
import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel;
import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;
public class GoogleAiStudioCompletionRequestManager extends GoogleAiStudioRequestManager {
private static final Logger logger = LogManager.getLogger(GoogleAiStudioCompletionRequestManager.class);
private static final ResponseHandler HANDLER = createCompletionHandler();
private final GoogleAiStudioCompletionModel model;
private static ResponseHandler createCompletionHandler() {
return new GoogleAiStudioResponseHandler("google ai studio completion", GoogleAiStudioCompletionResponseEntity::fromResponse);
}
public GoogleAiStudioCompletionRequestManager(GoogleAiStudioCompletionModel model, ThreadPool threadPool) {
super(threadPool, model);
this.model = Objects.requireNonNull(model);
}
@Override
public Runnable create(
String query,
List<String> input,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
HttpClientContext context,
ActionListener<InferenceServiceResults> listener
) {
GoogleAiStudioCompletionRequest request = new GoogleAiStudioCompletionRequest(input, model);
return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
}
}

View file

@ -0,0 +1,27 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.http.sender;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioModel;
import java.util.Objects;
public abstract class GoogleAiStudioRequestManager extends BaseRequestManager {
GoogleAiStudioRequestManager(ThreadPool threadPool, GoogleAiStudioModel model) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
}
record RateLimitGrouping(int modelIdHash) {
public static RateLimitGrouping of(GoogleAiStudioModel model) {
Objects.requireNonNull(model);
return new RateLimitGrouping(model.rateLimitServiceSettings().modelId().hashCode());
}
}
}

View file

@ -0,0 +1,72 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.request.googleaistudio;
import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Objects;
public class GoogleAiStudioCompletionRequest implements GoogleAiStudioRequest {
private final List<String> input;
private final URI uri;
private final GoogleAiStudioCompletionModel model;
public GoogleAiStudioCompletionRequest(List<String> input, GoogleAiStudioCompletionModel model) {
this.input = input;
this.model = Objects.requireNonNull(model);
this.uri = model.uri();
}
@Override
public HttpRequest createHttpRequest() {
var httpPost = new HttpPost(uri);
var requestEntity = Strings.toString(new GoogleAiStudioCompletionRequestEntity(input));
ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
httpPost.setEntity(byteEntity);
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
GoogleAiStudioRequest.decorateWithApiKeyParameter(httpPost, model.getSecretSettings());
return new HttpRequest(httpPost, getInferenceEntityId());
}
@Override
public URI getURI() {
return this.uri;
}
@Override
public Request truncate() {
// No truncation for Google AI Studio completion
return this;
}
@Override
public boolean[] getTruncationInfo() {
// No truncation for Google AI Studio completion
return null;
}
@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
}
}

View file

@ -0,0 +1,79 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.request.googleaistudio;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
public record GoogleAiStudioCompletionRequestEntity(List<String> input) implements ToXContentObject {
private static final String CONTENTS_FIELD = "contents";
private static final String PARTS_FIELD = "parts";
private static final String TEXT_FIELD = "text";
private static final String GENERATION_CONFIG_FIELD = "generationConfig";
private static final String CANDIDATE_COUNT_FIELD = "candidateCount";
private static final String ROLE_FIELD = "role";
private static final String ROLE_USER = "user";
public GoogleAiStudioCompletionRequestEntity {
Objects.requireNonNull(input);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.startArray(CONTENTS_FIELD);
{
for (String content : input) {
builder.startObject();
{
builder.startArray(PARTS_FIELD);
builder.startObject();
{
builder.field(TEXT_FIELD, content);
}
builder.endObject();
builder.endArray();
}
builder.field(ROLE_FIELD, ROLE_USER);
builder.endObject();
}
}
builder.endArray();
builder.startObject(GENERATION_CONFIG_FIELD);
{
// default is already 1, but we want to guard ourselves against API changes so setting it explicitly
builder.field(CANDIDATE_COUNT_FIELD, 1);
}
builder.endObject();
builder.endObject();
return builder;
}
}

View file

@ -0,0 +1,38 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.request.googleaistudio;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.utils.URIBuilder;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioSecretSettings;
public interface GoogleAiStudioRequest extends Request {
String API_KEY_PARAMETER = "key";
static void decorateWithApiKeyParameter(HttpPost httpPost, GoogleAiStudioSecretSettings secretSettings) {
try {
var uri = httpPost.getURI();
var uriWithApiKey = new URIBuilder().setScheme(uri.getScheme())
.setHost(uri.getHost())
.setPort(uri.getPort())
.setPath(uri.getPath())
.addParameter(API_KEY_PARAMETER, secretSettings.apiKey().toString())
.build();
httpPost.setURI(uriWithApiKey);
} catch (Exception e) {
ValidationException validationException = new ValidationException(e);
validationException.addValidationError(e.getMessage());
throw validationException;
}
}
}

View file

@ -0,0 +1,22 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.request.googleaistudio;
public class GoogleAiStudioUtils {
public static final String HOST_SUFFIX = "generativelanguage.googleapis.com";
public static final String V1 = "v1";
public static final String MODELS = "models";
public static final String GENERATE_CONTENT_ACTION = "generateContent";
private GoogleAiStudioUtils() {}
}

View file

@ -0,0 +1,109 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.response.googleaistudio;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import java.io.IOException;
import java.util.List;
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField;
public class GoogleAiStudioCompletionResponseEntity {
private static final String FAILED_TO_FIND_FIELD_TEMPLATE =
"Failed to find required field [%s] in Google AI Studio completion response";
/**
* Parses the Google AI Studio completion response.
*
* For a request like:
*
* <pre>
* <code>
* {
* "contents": [
* {
* "parts": [{
* "text": "input"
* }]
* }
* ]
* }
* </code>
* </pre>
*
* The response would look like:
*
* <pre>
* <code>
* {
* "candidates": [
* {
* "content": {
* "parts": [
* {
* "text": "response"
* }
* ],
* "role": "model"
* },
* "finishReason": "STOP",
* "index": 0,
* "safetyRatings": [...]
* }
* ],
* "usageMetadata": { ... }
* }
* </code>
* </pre>
*
*/
public static ChatCompletionResults fromResponse(Request request, HttpResult response) throws IOException {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
moveToFirstToken(jsonParser);
XContentParser.Token token = jsonParser.currentToken();
ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser);
positionParserAtTokenAfterField(jsonParser, "candidates", FAILED_TO_FIND_FIELD_TEMPLATE);
jsonParser.nextToken();
ensureExpectedToken(XContentParser.Token.START_OBJECT, jsonParser.currentToken(), jsonParser);
positionParserAtTokenAfterField(jsonParser, "content", FAILED_TO_FIND_FIELD_TEMPLATE);
token = jsonParser.currentToken();
ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser);
positionParserAtTokenAfterField(jsonParser, "parts", FAILED_TO_FIND_FIELD_TEMPLATE);
jsonParser.nextToken();
ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser);
positionParserAtTokenAfterField(jsonParser, "text", FAILED_TO_FIND_FIELD_TEMPLATE);
XContentParser.Token contentToken = jsonParser.currentToken();
ensureExpectedToken(XContentParser.Token.VALUE_STRING, contentToken, jsonParser);
String content = jsonParser.text();
return new ChatCompletionResults(List.of(new ChatCompletionResults.Result(content)));
}
}
}

View file

@ -0,0 +1,78 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.response.googleaistudio;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorMessage;
import java.util.Map;
public class GoogleAiStudioErrorResponseEntity implements ErrorMessage {
private final String errorMessage;
private GoogleAiStudioErrorResponseEntity(String errorMessage) {
this.errorMessage = errorMessage;
}
@Override
public String getErrorMessage() {
return errorMessage;
}
/**
* An example error response for invalid auth would look like
* <code>
* {
* "error": {
* "code": 400,
* "message": "API key not valid. Please pass a valid API key.",
* "status": "INVALID_ARGUMENT",
* "details": [
* {
* "@type": "type.googleapis.com/google.rpc.ErrorInfo",
* "reason": "API_KEY_INVALID",
* "domain": "googleapis.com",
* "metadata": {
* "service": "generativelanguage.googleapis.com"
* }
* }
* ]
* }
* }
* </code>
* @param response The error response
* @return An error entity if the response is JSON with the above structure
* or null if the response does not contain the `error.message` field
*/
@SuppressWarnings("unchecked")
public static GoogleAiStudioErrorResponseEntity fromResponse(HttpResult response) {
try (
XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON)
.createParser(XContentParserConfiguration.EMPTY, response.body())
) {
var responseMap = jsonParser.map();
var error = (Map<String, Object>) responseMap.get("error");
if (error != null) {
var message = (String) error.get("message");
if (message != null) {
return new GoogleAiStudioErrorResponseEntity(message);
}
}
} catch (Exception e) {
// swallow the error
}
return null;
}
}

View file

@ -196,8 +196,8 @@ public class ServiceUtils {
);
}
public static String invalidUrlErrorMsg(String url, String settingName, String settingScope) {
return Strings.format("[%s] Invalid url [%s] received for field [%s]", settingScope, url, settingName);
public static String invalidUrlErrorMsg(String url, String settingName, String settingScope, String error) {
return Strings.format("[%s] Invalid url [%s] received for field [%s]. Error: %s", settingScope, url, settingName, error);
}
public static String mustBeNonEmptyString(String settingName, String scope) {
@ -231,7 +231,6 @@ public class ServiceUtils {
return Strings.format("[%s] does not allow the setting [%s]", scope, settingName);
}
// TODO improve URI validation logic
public static URI convertToUri(@Nullable String url, String settingName, String settingScope, ValidationException validationException) {
try {
if (url == null) {
@ -239,8 +238,8 @@ public class ServiceUtils {
}
return createUri(url);
} catch (IllegalArgumentException ignored) {
validationException.addValidationError(ServiceUtils.invalidUrlErrorMsg(url, settingName, settingScope));
} catch (IllegalArgumentException cause) {
validationException.addValidationError(ServiceUtils.invalidUrlErrorMsg(url, settingName, settingScope, cause.getMessage()));
return null;
}
}
@ -251,7 +250,7 @@ public class ServiceUtils {
try {
return new URI(url);
} catch (URISyntaxException e) {
throw new IllegalArgumentException(format("unable to parse url [%s]", url), e);
throw new IllegalArgumentException(format("unable to parse url [%s]. Reason: %s", url, e.getReason()), e);
}
}

View file

@ -66,7 +66,7 @@ public class CustomElandRerankTaskSettings implements TaskSettings {
}
/**
* Return either the request or orignal settings by preferring non-null fields
* Return either the request or original settings by preferring non-null fields
* from the request settings over the original settings.
*
* @param originalSettings the settings stored as part of the inference entity configuration

View file

@ -0,0 +1,39 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.googleaistudio;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionVisitor;
import java.util.Map;
import java.util.Objects;
public abstract class GoogleAiStudioModel extends Model {
private final GoogleAiStudioRateLimitServiceSettings rateLimitServiceSettings;
public GoogleAiStudioModel(
ModelConfigurations configurations,
ModelSecrets secrets,
GoogleAiStudioRateLimitServiceSettings rateLimitServiceSettings
) {
super(configurations, secrets);
this.rateLimitServiceSettings = Objects.requireNonNull(rateLimitServiceSettings);
}
public abstract ExecutableAction accept(GoogleAiStudioActionVisitor creator, Map<String, Object> taskSettings, InputType inputType);
public GoogleAiStudioRateLimitServiceSettings rateLimitServiceSettings() {
return rateLimitServiceSettings;
}
}

View file

@ -0,0 +1,18 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.googleaistudio;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
public interface GoogleAiStudioRateLimitServiceSettings {
String modelId();
RateLimitSettings rateLimitSettings();
}

View file

@ -0,0 +1,106 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.googleaistudio;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.SecretSettings;
import org.elasticsearch.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalSecureString;
public class GoogleAiStudioSecretSettings implements SecretSettings {
public static final String NAME = "google_ai_studio_secret_settings";
public static final String API_KEY = "api_key";
private final SecureString apiKey;
public static GoogleAiStudioSecretSettings fromMap(@Nullable Map<String, Object> map) {
if (map == null) {
return null;
}
ValidationException validationException = new ValidationException();
SecureString secureApiKey = extractOptionalSecureString(map, API_KEY, ModelSecrets.SECRET_SETTINGS, validationException);
if (secureApiKey == null) {
validationException.addValidationError(format("[secret_settings] must have [%s] set", API_KEY));
}
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
return new GoogleAiStudioSecretSettings(secureApiKey);
}
public GoogleAiStudioSecretSettings(SecureString apiKey) {
Objects.requireNonNull(apiKey);
this.apiKey = apiKey;
}
public GoogleAiStudioSecretSettings(StreamInput in) throws IOException {
this(in.readOptionalSecureString());
}
public SecureString apiKey() {
return apiKey;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (apiKey != null) {
builder.field(API_KEY, apiKey.toString());
}
builder.endObject();
return builder;
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_GOOGLE_AI_STUDIO_COMPLETION_ADDED;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalSecureString(apiKey);
}
@Override
public boolean equals(Object object) {
if (this == object) return true;
if (object == null || getClass() != object.getClass()) return false;
GoogleAiStudioSecretSettings that = (GoogleAiStudioSecretSettings) object;
return Objects.equals(apiKey, that.apiKey);
}
@Override
public int hashCode() {
return Objects.hash(apiKey);
}
}

View file

@ -0,0 +1,218 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.googleaistudio;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.ChunkingOptions;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
public class GoogleAiStudioService extends SenderService {
public static final String NAME = "googleaistudio";
public GoogleAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
}
@Override
public String name() {
return NAME;
}
@Override
public void parseRequestConfig(
String inferenceEntityId,
TaskType taskType,
Map<String, Object> config,
Set<String> platfromArchitectures,
ActionListener<Model> parsedModelListener
) {
try {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
GoogleAiStudioModel model = createModel(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
serviceSettingsMap,
TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
ConfigurationParseContext.REQUEST
);
throwIfNotEmptyMap(config, NAME);
throwIfNotEmptyMap(serviceSettingsMap, NAME);
throwIfNotEmptyMap(taskSettingsMap, NAME);
parsedModelListener.onResponse(model);
} catch (Exception e) {
parsedModelListener.onFailure(e);
}
}
private static GoogleAiStudioModel createModel(
String inferenceEntityId,
TaskType taskType,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
@Nullable Map<String, Object> secretSettings,
String failureMessage,
ConfigurationParseContext context
) {
return switch (taskType) {
case COMPLETION -> new GoogleAiStudioCompletionModel(
inferenceEntityId,
taskType,
NAME,
serviceSettings,
taskSettings,
secretSettings
);
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
};
}
@Override
public GoogleAiStudioModel parsePersistedConfigWithSecrets(
String inferenceEntityId,
TaskType taskType,
Map<String, Object> config,
Map<String, Object> secrets
) {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
Map<String, Object> secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);
return createModelFromPersistent(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
secretSettingsMap,
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
);
}
private static GoogleAiStudioModel createModelFromPersistent(
String inferenceEntityId,
TaskType taskType,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
Map<String, Object> secretSettings,
String failureMessage
) {
return createModel(
inferenceEntityId,
taskType,
serviceSettings,
taskSettings,
secretSettings,
failureMessage,
ConfigurationParseContext.PERSISTENT
);
}
@Override
public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
return createModelFromPersistent(
inferenceEntityId,
taskType,
serviceSettingsMap,
taskSettingsMap,
null,
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
);
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_GOOGLE_AI_STUDIO_COMPLETION_ADDED;
}
@Override
protected void doInfer(
Model model,
List<String> input,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
if (model instanceof GoogleAiStudioModel == false) {
listener.onFailure(createInvalidModelException(model));
return;
}
GoogleAiStudioModel googleAiStudioModel = (GoogleAiStudioModel) model;
var actionCreator = new GoogleAiStudioActionCreator(getSender(), getServiceComponents());
var action = googleAiStudioModel.accept(actionCreator, taskSettings, inputType);
action.execute(new DocumentsOnlyInput(input), timeout, listener);
}
@Override
protected void doInfer(
Model model,
String query,
List<String> input,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
throw new UnsupportedOperationException("Query input not supported for Google AI Studio");
}
@Override
protected void doChunkedInfer(
Model model,
String query,
List<String> input,
Map<String, Object> taskSettings,
InputType inputType,
ChunkingOptions chunkingOptions,
TimeValue timeout,
ActionListener<List<ChunkedInferenceServiceResults>> listener
) {
throw new UnsupportedOperationException("Chunked inference not supported yet for Google AI Studio");
}
}

View file

@ -0,0 +1,124 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.googleaistudio.completion;
import org.apache.http.client.utils.URIBuilder;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionVisitor;
import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioUtils;
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioModel;
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioSecretSettings;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Map;
import static org.elasticsearch.core.Strings.format;
public class GoogleAiStudioCompletionModel extends GoogleAiStudioModel {
private URI uri;
public GoogleAiStudioCompletionModel(
String inferenceEntityId,
TaskType taskType,
String service,
Map<String, Object> serviceSettings,
Map<String, Object> taskSettings,
Map<String, Object> secrets
) {
this(
inferenceEntityId,
taskType,
service,
GoogleAiStudioCompletionServiceSettings.fromMap(serviceSettings),
EmptyTaskSettings.INSTANCE,
GoogleAiStudioSecretSettings.fromMap(secrets)
);
}
// Should only be used directly for testing
GoogleAiStudioCompletionModel(
String inferenceEntityId,
TaskType taskType,
String service,
GoogleAiStudioCompletionServiceSettings serviceSettings,
TaskSettings taskSettings,
@Nullable GoogleAiStudioSecretSettings secrets
) {
super(
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings),
new ModelSecrets(secrets),
serviceSettings
);
try {
this.uri = buildUri(serviceSettings.modelId());
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
}
// Should only be used directly for testing
GoogleAiStudioCompletionModel(
String inferenceEntityId,
TaskType taskType,
String service,
String url,
GoogleAiStudioCompletionServiceSettings serviceSettings,
TaskSettings taskSettings,
@Nullable GoogleAiStudioSecretSettings secrets
) {
super(
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings),
new ModelSecrets(secrets),
serviceSettings
);
try {
this.uri = new URI(url);
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
}
public URI uri() {
return uri;
}
@Override
public GoogleAiStudioCompletionServiceSettings getServiceSettings() {
return (GoogleAiStudioCompletionServiceSettings) super.getServiceSettings();
}
@Override
public GoogleAiStudioSecretSettings getSecretSettings() {
return (GoogleAiStudioSecretSettings) super.getSecretSettings();
}
public static URI buildUri(String model) throws URISyntaxException {
return new URIBuilder().setScheme("https")
.setHost(GoogleAiStudioUtils.HOST_SUFFIX)
.setPathSegments(
GoogleAiStudioUtils.V1,
GoogleAiStudioUtils.MODELS,
format("%s:%s", model, GoogleAiStudioUtils.GENERATE_CONTENT_ACTION)
)
.build();
}
@Override
public ExecutableAction accept(GoogleAiStudioActionVisitor visitor, Map<String, Object> taskSettings, InputType inputType) {
return visitor.create(this, taskSettings);
}
}

View file

@ -0,0 +1,126 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.googleaistudio.completion;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioRateLimitServiceSettings;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
public class GoogleAiStudioCompletionServiceSettings extends FilteredXContentObject
implements
ServiceSettings,
GoogleAiStudioRateLimitServiceSettings {
public static final String NAME = "google_ai_studio_completion_service_settings";
/**
* Rate limits are defined at <a href="https://ai.google.dev/pricing">Google Gemini API Pricing</a>.
* For pay-as-you-go you've 360 requests per minute.
*/
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(360);
public static GoogleAiStudioCompletionServiceSettings fromMap(Map<String, Object> map) {
ValidationException validationException = new ValidationException();
String model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
return new GoogleAiStudioCompletionServiceSettings(model, rateLimitSettings);
}
private final String modelId;
private final RateLimitSettings rateLimitSettings;
public GoogleAiStudioCompletionServiceSettings(String modelId, @Nullable RateLimitSettings rateLimitSettings) {
this.modelId = modelId;
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
}
public GoogleAiStudioCompletionServiceSettings(StreamInput in) throws IOException {
modelId = in.readString();
rateLimitSettings = new RateLimitSettings(in);
}
@Override
public String modelId() {
return modelId;
}
@Override
public RateLimitSettings rateLimitSettings() {
return rateLimitSettings;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
toXContentFragmentOfExposedFields(builder, params);
rateLimitSettings.toXContent(builder, params);
builder.endObject();
return builder;
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_GOOGLE_AI_STUDIO_COMPLETION_ADDED;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
rateLimitSettings.writeTo(out);
}
@Override
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
builder.field(MODEL_ID, modelId);
return builder;
}
@Override
public boolean equals(Object object) {
if (this == object) return true;
if (object == null || getClass() != object.getClass()) return false;
GoogleAiStudioCompletionServiceSettings that = (GoogleAiStudioCompletionServiceSettings) object;
return Objects.equals(modelId, that.modelId) && Objects.equals(rateLimitSettings, that.rateLimitSettings);
}
@Override
public int hashCode() {
return Objects.hash(modelId, rateLimitSettings);
}
}

View file

@ -0,0 +1,83 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference;
import org.hamcrest.Description;
import org.hamcrest.Matcher;
import org.hamcrest.TypeSafeMatcher;
import java.util.regex.Pattern;
/**
* Utility class containing custom hamcrest {@link Matcher} implementations or other utility functionality related to hamcrest.
*/
public class MatchersUtils {
/**
* Custom matcher implementing a matcher operating on json strings ignoring whitespaces, which are not inside a key or a value.
*
* Example:
* {
* "key": "value"
* }
*
* will match
*
* {"key":"value"}
*
* as both json strings are equal ignoring the whitespace, which does not reside in a key or a value.
*
*/
protected static class IsEqualIgnoreWhitespaceInJsonString extends TypeSafeMatcher<String> {
protected static final Pattern WHITESPACE_IN_JSON_EXCEPT_KEYS_AND_VALUES_PATTERN = createPattern();
private static Pattern createPattern() {
String regex = "(?<=[:,\\[{])\\s+|\\s+(?=[\\]}:,])|^\\s+|\\s+$";
return Pattern.compile(regex);
}
private final String string;
IsEqualIgnoreWhitespaceInJsonString(String string) {
if (string == null) {
throw new IllegalArgumentException("Non-null value required");
}
this.string = string;
}
@Override
protected boolean matchesSafely(String item) {
java.util.regex.Matcher itemMatcher = WHITESPACE_IN_JSON_EXCEPT_KEYS_AND_VALUES_PATTERN.matcher(item);
java.util.regex.Matcher stringMatcher = WHITESPACE_IN_JSON_EXCEPT_KEYS_AND_VALUES_PATTERN.matcher(string);
String itemReplacedWhitespaces = itemMatcher.replaceAll("");
String stringReplacedWhitespaces = stringMatcher.replaceAll("");
return itemReplacedWhitespaces.equals(stringReplacedWhitespaces);
}
@Override
public void describeTo(Description description) {
java.util.regex.Matcher stringMatcher = WHITESPACE_IN_JSON_EXCEPT_KEYS_AND_VALUES_PATTERN.matcher(string);
String stringReplacedWhitespaces = stringMatcher.replaceAll("");
description.appendText("a string equal to (when all whitespaces are ignored expect in keys and values): ")
.appendValue(stringReplacedWhitespaces);
}
public static Matcher<String> equalToIgnoringWhitespaceInJsonString(String expectedString) {
return new IsEqualIgnoreWhitespaceInJsonString(expectedString);
}
}
public static Matcher<String> equalToIgnoringWhitespaceInJsonString(String expectedString) {
return IsEqualIgnoreWhitespaceInJsonString.equalToIgnoringWhitespaceInJsonString(expectedString);
}
}

View file

@ -0,0 +1,186 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference;
import org.elasticsearch.test.ESTestCase;
import org.hamcrest.Description;
import org.hamcrest.SelfDescribing;
import java.util.Objects;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import static org.hamcrest.Matchers.is;
public class MatchersUtilsTests extends ESTestCase {
public void testIsEqualIgnoreWhitespaceInJsonString_Pattern() {
var json = """
{
"key": "value"
}
""";
Pattern pattern = MatchersUtils.IsEqualIgnoreWhitespaceInJsonString.WHITESPACE_IN_JSON_EXCEPT_KEYS_AND_VALUES_PATTERN;
Matcher matcher = pattern.matcher(json);
String jsonWithRemovedWhitespaces = matcher.replaceAll("");
assertThat(jsonWithRemovedWhitespaces, is("""
{"key":"value"}"""));
}
public void testIsEqualIgnoreWhitespaceInJsonString_Pattern_DoesNotRemoveWhitespaceInKeysAndValues() {
var json = """
{
"key 1": "value 1"
}
""";
Pattern pattern = MatchersUtils.IsEqualIgnoreWhitespaceInJsonString.WHITESPACE_IN_JSON_EXCEPT_KEYS_AND_VALUES_PATTERN;
Matcher matcher = pattern.matcher(json);
String jsonWithRemovedWhitespaces = matcher.replaceAll("");
assertThat(jsonWithRemovedWhitespaces, is("""
{"key 1":"value 1"}"""));
}
public void testIsEqualIgnoreWhitespaceInJsonString_MatchesSafely_DoesMatch() {
var json = """
{
"key 1": "value 1",
"key 2: {
"key 3: "value 3"
},
"key 4": [
"value 4", "value 5"
]
}
""";
var jsonWithDifferentSpacing = """
{"key 1": "value 1",
"key 2: {
"key 3: "value 3"
},
"key 4": [
"value 4", "value 5"
]
}
""";
var typeSafeMatcher = new MatchersUtils.IsEqualIgnoreWhitespaceInJsonString(json);
boolean matches = typeSafeMatcher.matchesSafely(jsonWithDifferentSpacing);
assertTrue(matches);
}
public void testIsEqualIgnoreWhitespaceInJsonString_MatchesSafely_DoesNotMatch() {
var json = """
{
"key 1": "value 1",
"key 2: {
"key 3: "value 3"
},
"key 4": [
"value 4", "value 5"
]
}
""";
// one value missing in array
var jsonWithDifferentSpacing = """
{"key 1": "value 1",
"key 2: {
"key 3: "value 3"
},
"key 4": [
"value 4"
]
}
""";
var typeSafeMatcher = new MatchersUtils.IsEqualIgnoreWhitespaceInJsonString(json);
boolean matches = typeSafeMatcher.matchesSafely(jsonWithDifferentSpacing);
assertFalse(matches);
}
public void testIsEqualIgnoreWhitespaceInJsonString_DescribeTo() {
var jsonOne = """
{
"key": "value"
}
""";
var typeSafeMatcher = new MatchersUtils.IsEqualIgnoreWhitespaceInJsonString(jsonOne);
var description = new TestDescription("");
typeSafeMatcher.describeTo(description);
assertThat(description.toString(), is("""
a string equal to (when all whitespaces are ignored expect in keys and values): {"key":"value"}"""));
}
private static class TestDescription implements Description {
private String descriptionContent;
TestDescription(String descriptionContent) {
Objects.requireNonNull(descriptionContent);
this.descriptionContent = descriptionContent;
}
@Override
public Description appendText(String text) {
descriptionContent += text;
return this;
}
@Override
public Description appendDescriptionOf(SelfDescribing value) {
throw new UnsupportedOperationException();
}
@Override
public Description appendValue(Object value) {
descriptionContent += value;
return this;
}
@SafeVarargs
@Override
public final <T> Description appendValueList(String start, String separator, String end, T... values) {
throw new UnsupportedOperationException();
}
@Override
public <T> Description appendValueList(String start, String separator, String end, Iterable<T> values) {
throw new UnsupportedOperationException();
}
@Override
public Description appendList(String start, String separator, String end, Iterable<? extends SelfDescribing> values) {
throw new UnsupportedOperationException();
}
@Override
public String toString() {
return descriptionContent;
}
}
}

View file

@ -20,7 +20,6 @@ import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.common.TruncatorTests;
import org.elasticsearch.xpack.inference.external.action.openai.OpenAiChatCompletionActionTests;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@ -46,6 +45,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
import static org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioRequestFields.API_KEY_HEADER;
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
@ -163,10 +163,7 @@ public class AzureAiStudioActionAndCreatorTests extends ESTestCase {
var result = listener.actionGet(TIMEOUT);
assertThat(
result.asMap(),
is(OpenAiChatCompletionActionTests.buildExpectedChatCompletionResultMap(List.of("test input string")))
);
assertThat(result.asMap(), is(buildExpectationCompletion(List.of("test input string"))));
assertThat(webServer.requests(), hasSize(1));
MockRequest request = webServer.requests().get(0);

View file

@ -40,9 +40,9 @@ import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.cohere.CohereCompletionActionTests.buildExpectedChatCompletionResultMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.hamcrest.Matchers.equalTo;
@ -200,7 +200,7 @@ public class CohereActionCreatorTests extends ESTestCase {
var result = listener.actionGet(TIMEOUT);
assertThat(result.asMap(), is(buildExpectedChatCompletionResultMap(List.of("result"))));
assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result"))));
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), is(XContentType.JSON.mediaType()));
@ -260,7 +260,7 @@ public class CohereActionCreatorTests extends ESTestCase {
var result = listener.actionGet(TIMEOUT);
assertThat(result.asMap(), is(buildExpectedChatCompletionResultMap(List.of("result"))));
assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result"))));
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), is(XContentType.JSON.mediaType()));

View file

@ -23,7 +23,6 @@ import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
@ -44,6 +43,8 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
@ -119,7 +120,7 @@ public class CohereCompletionActionTests extends ESTestCase {
var result = listener.actionGet(TIMEOUT);
assertThat(result.asMap(), is(buildExpectedChatCompletionResultMap(List.of("result"))));
assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result"))));
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
@ -180,7 +181,7 @@ public class CohereCompletionActionTests extends ESTestCase {
var result = listener.actionGet(TIMEOUT);
assertThat(result.asMap(), is(buildExpectedChatCompletionResultMap(List.of("result"))));
assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result"))));
assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
@ -198,7 +199,7 @@ public class CohereCompletionActionTests extends ESTestCase {
public void testExecute_ThrowsURISyntaxException_ForInvalidUrl() throws IOException {
try (var sender = mock(Sender.class)) {
var thrownException = expectThrows(IllegalArgumentException.class, () -> createAction("a^b", "api key", "model", sender));
assertThat(thrownException.getMessage(), is("unable to parse url [a^b]"));
assertThat(thrownException.getMessage(), containsString("unable to parse url [a^b]"));
}
}
@ -338,13 +339,6 @@ public class CohereCompletionActionTests extends ESTestCase {
}
}
public static Map<String, Object> buildExpectedChatCompletionResultMap(List<String> results) {
return Map.of(
ChatCompletionResults.COMPLETION,
results.stream().map(result -> Map.of(ChatCompletionResults.Result.RESULT, result)).toList()
);
}
private CohereCompletionAction createAction(String url, String apiKey, @Nullable String modelName, Sender sender) {
var model = CohereCompletionModelTests.createModel(url, apiKey, modelName);

Some files were not shown because too many files have changed in this diff Show more