diff --git a/docs/changelog/126686.yaml b/docs/changelog/126686.yaml new file mode 100644 index 000000000000..802ec538e5c1 --- /dev/null +++ b/docs/changelog/126686.yaml @@ -0,0 +1,6 @@ +pr: 126686 +summary: Fix race condition in `RestCancellableNodeClient` +area: Task Management +type: bug +issues: + - 88201 diff --git a/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/IndicesSegmentsRestCancellationIT.java b/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/IndicesSegmentsRestCancellationIT.java index 92fde6d7765c..a90b04d54649 100644 --- a/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/IndicesSegmentsRestCancellationIT.java +++ b/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/IndicesSegmentsRestCancellationIT.java @@ -12,23 +12,12 @@ package org.elasticsearch.http; import org.apache.http.client.methods.HttpGet; import org.elasticsearch.action.admin.indices.segments.IndicesSegmentsAction; import org.elasticsearch.client.Request; -import org.elasticsearch.test.junit.annotations.TestIssueLogging; public class IndicesSegmentsRestCancellationIT extends BlockedSearcherRestCancellationTestCase { - @TestIssueLogging( - issueUrl = "https://github.com/elastic/elasticsearch/issues/88201", - value = "org.elasticsearch.http.BlockedSearcherRestCancellationTestCase:DEBUG" - + ",org.elasticsearch.transport.TransportService:TRACE" - ) public void testIndicesSegmentsRestCancellation() throws Exception { runTest(new Request(HttpGet.METHOD_NAME, "/_segments"), IndicesSegmentsAction.NAME); } - @TestIssueLogging( - issueUrl = "https://github.com/elastic/elasticsearch/issues/88201", - value = "org.elasticsearch.http.BlockedSearcherRestCancellationTestCase:DEBUG" - + ",org.elasticsearch.transport.TransportService:TRACE" - ) public void testCatSegmentsRestCancellation() throws Exception { runTest(new Request(HttpGet.METHOD_NAME, "/_cat/segments"), IndicesSegmentsAction.NAME); } diff --git a/server/src/main/java/org/elasticsearch/rest/action/RestCancellableNodeClient.java b/server/src/main/java/org/elasticsearch/rest/action/RestCancellableNodeClient.java index 33b3ef35671e..e4e8378e4355 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/RestCancellableNodeClient.java +++ b/server/src/main/java/org/elasticsearch/rest/action/RestCancellableNodeClient.java @@ -18,14 +18,14 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.client.internal.FilterClient; import org.elasticsearch.client.internal.OriginSettingClient; import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.core.Nullable; import org.elasticsearch.http.HttpChannel; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskId; -import java.util.ArrayList; +import java.util.Collection; import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; @@ -112,12 +112,14 @@ public class RestCancellableNodeClient extends FilterClient { private class CloseListener implements ActionListener { private final AtomicReference channel = new AtomicReference<>(); - private final Set tasks = new HashSet<>(); + + @Nullable // if already drained + private Set tasks = new HashSet<>(); CloseListener() {} synchronized int getNumTasks() { - return tasks.size(); + return tasks == null ? 0 : tasks.size(); } void maybeRegisterChannel(HttpChannel httpChannel) { @@ -130,16 +132,23 @@ public class RestCancellableNodeClient extends FilterClient { } } - synchronized void registerTask(TaskHolder taskHolder, TaskId taskId) { - taskHolder.taskId = taskId; - if (taskHolder.completed == false) { - this.tasks.add(taskId); + void registerTask(TaskHolder taskHolder, TaskId taskId) { + synchronized (this) { + taskHolder.taskId = taskId; + if (tasks != null) { + if (taskHolder.completed == false) { + tasks.add(taskId); + } + return; + } } + // else tasks == null so the channel is already closed + cancelTask(taskId); } synchronized void unregisterTask(TaskHolder taskHolder) { - if (taskHolder.taskId != null) { - this.tasks.remove(taskHolder.taskId); + if (taskHolder.taskId != null && tasks != null) { + tasks.remove(taskHolder.taskId); } taskHolder.completed = true; } @@ -149,18 +158,20 @@ public class RestCancellableNodeClient extends FilterClient { final HttpChannel httpChannel = channel.get(); assert httpChannel != null : "channel not registered"; // when the channel gets closed it won't be reused: we can remove it from the map and forget about it. - CloseListener closeListener = httpChannels.remove(httpChannel); - assert closeListener != null : "channel not found in the map of tracked channels"; - final List toCancel; - synchronized (this) { - toCancel = new ArrayList<>(tasks); - tasks.clear(); - } - for (TaskId taskId : toCancel) { + final CloseListener closeListener = httpChannels.remove(httpChannel); + assert closeListener != null : "channel not found in the map of tracked channels: " + httpChannel; + assert closeListener == CloseListener.this : "channel had a different CloseListener registered: " + httpChannel; + for (final var taskId : drainTasks()) { cancelTask(taskId); } } + private synchronized Collection drainTasks() { + final var drained = tasks; + tasks = null; + return drained; + } + @Override public void onFailure(Exception e) { onResponse(null); diff --git a/server/src/test/java/org/elasticsearch/rest/action/RestCancellableNodeClientTests.java b/server/src/test/java/org/elasticsearch/rest/action/RestCancellableNodeClientTests.java index 74c6fceddf71..c58621d03ce8 100644 --- a/server/src/test/java/org/elasticsearch/rest/action/RestCancellableNodeClientTests.java +++ b/server/src/test/java/org/elasticsearch/rest/action/RestCancellableNodeClientTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.action.search.TransportSearchAction; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.http.HttpChannel; import org.elasticsearch.http.HttpResponse; import org.elasticsearch.tasks.Task; @@ -44,6 +45,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.LongSupplier; public class RestCancellableNodeClientTests extends ESTestCase { @@ -148,8 +150,42 @@ public class RestCancellableNodeClientTests extends ESTestCase { assertEquals(totalSearches, testClient.cancelledTasks.size()); } + public void testConcurrentExecuteAndClose() throws Exception { + final var testClient = new TestClient(Settings.EMPTY, threadPool, true); + int initialHttpChannels = RestCancellableNodeClient.getNumChannels(); + int numTasks = randomIntBetween(1, 30); + TestHttpChannel channel = new TestHttpChannel(); + final var startLatch = new CountDownLatch(1); + final var doneLatch = new CountDownLatch(numTasks + 1); + final var expectedTasks = Sets.newHashSetWithExpectedSize(numTasks); + for (int j = 0; j < numTasks; j++) { + RestCancellableNodeClient client = new RestCancellableNodeClient(testClient, channel); + threadPool.generic().execute(() -> { + client.execute(TransportSearchAction.TYPE, new SearchRequest(), ActionListener.running(ESTestCase::fail)); + startLatch.countDown(); + doneLatch.countDown(); + }); + expectedTasks.add(new TaskId(testClient.getLocalNodeId(), j)); + } + threadPool.generic().execute(() -> { + try { + safeAwait(startLatch); + channel.awaitClose(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new AssertionError(e); + } finally { + doneLatch.countDown(); + } + }); + safeAwait(doneLatch); + assertEquals(initialHttpChannels, RestCancellableNodeClient.getNumChannels()); + assertEquals(expectedTasks, testClient.cancelledTasks); + } + private static class TestClient extends NodeClient { - private final AtomicLong counter = new AtomicLong(0); + private final LongSupplier searchTaskIdGenerator = new AtomicLong(0)::getAndIncrement; + private final LongSupplier cancelTaskIdGenerator = new AtomicLong(1000)::getAndIncrement; private final Set cancelledTasks = new CopyOnWriteArraySet<>(); private final AtomicInteger searchRequests = new AtomicInteger(0); private final boolean timeout; @@ -167,9 +203,17 @@ public class RestCancellableNodeClientTests extends ESTestCase { ) { switch (action.name()) { case TransportCancelTasksAction.NAME -> { - CancelTasksRequest cancelTasksRequest = (CancelTasksRequest) request; - assertTrue("tried to cancel the same task more than once", cancelledTasks.add(cancelTasksRequest.getTargetTaskId())); - Task task = request.createTask(counter.getAndIncrement(), "cancel_task", action.name(), null, Collections.emptyMap()); + assertTrue( + "tried to cancel the same task more than once", + cancelledTasks.add(asInstanceOf(CancelTasksRequest.class, request).getTargetTaskId()) + ); + Task task = request.createTask( + cancelTaskIdGenerator.getAsLong(), + "cancel_task", + action.name(), + null, + Collections.emptyMap() + ); if (randomBoolean()) { listener.onResponse(null); } else { @@ -180,7 +224,13 @@ public class RestCancellableNodeClientTests extends ESTestCase { } case TransportSearchAction.NAME -> { searchRequests.incrementAndGet(); - Task searchTask = request.createTask(counter.getAndIncrement(), "search", action.name(), null, Collections.emptyMap()); + Task searchTask = request.createTask( + searchTaskIdGenerator.getAsLong(), + "search", + action.name(), + null, + Collections.emptyMap() + ); if (timeout == false) { if (rarely()) { // make sure that search is sometimes also called from the same thread before the task is returned @@ -191,7 +241,7 @@ public class RestCancellableNodeClientTests extends ESTestCase { } return searchTask; } - default -> throw new UnsupportedOperationException(); + default -> throw new AssertionError("unexpected action " + action.name()); } } @@ -222,10 +272,7 @@ public class RestCancellableNodeClientTests extends ESTestCase { @Override public void close() { - if (open.compareAndSet(true, false) == false) { - assert false : "HttpChannel is already closed"; - return; // nothing to do - } + assertTrue("HttpChannel is already closed", open.compareAndSet(true, false)); ActionListener listener = closeListener.get(); if (listener != null) { boolean failure = randomBoolean(); @@ -241,6 +288,7 @@ public class RestCancellableNodeClientTests extends ESTestCase { } private void awaitClose() throws InterruptedException { + assertNotNull("must set closeListener before calling awaitClose", closeListener.get()); close(); closeLatch.await(); } @@ -257,7 +305,7 @@ public class RestCancellableNodeClientTests extends ESTestCase { listener.onResponse(null); } else { if (closeListener.compareAndSet(null, listener) == false) { - throw new IllegalStateException("close listener already set, only one is allowed!"); + throw new AssertionError("close listener already set, only one is allowed!"); } } }