From 91724a6f0e345da8b26b9fb85e79976caad9126c Mon Sep 17 00:00:00 2001 From: David Turner Date: Wed, 21 Dec 2022 13:47:44 +0000 Subject: [PATCH] Reject connection attempts while closing (#92465) Today if there is a constant stream of connection attempts then it's possible for the `ClusterConnectionManager` to wait forever in `close()` for `connectingRefCounter` to be fully released. With this commit we reject connection attempts while closing, avoiding this starvation situation. --- docs/changelog/92465.yaml | 5 + .../action/support/PlainActionFuture.java | 8 ++ .../transport/ClusterConnectionManager.java | 8 +- .../ClusterConnectionManagerTests.java | 128 ++++++++++++++++++ 4 files changed, 147 insertions(+), 2 deletions(-) create mode 100644 docs/changelog/92465.yaml diff --git a/docs/changelog/92465.yaml b/docs/changelog/92465.yaml new file mode 100644 index 000000000000..5c02ddff1e17 --- /dev/null +++ b/docs/changelog/92465.yaml @@ -0,0 +1,5 @@ +pr: 92465 +summary: Reject connection attempts while closing +area: Network +type: bug +issues: [] diff --git a/server/src/main/java/org/elasticsearch/action/support/PlainActionFuture.java b/server/src/main/java/org/elasticsearch/action/support/PlainActionFuture.java index 3d5c2c027393..20483b8ff440 100644 --- a/server/src/main/java/org/elasticsearch/action/support/PlainActionFuture.java +++ b/server/src/main/java/org/elasticsearch/action/support/PlainActionFuture.java @@ -10,6 +10,8 @@ package org.elasticsearch.action.support; import org.elasticsearch.core.CheckedConsumer; +import java.util.concurrent.TimeUnit; + public class PlainActionFuture extends AdapterActionFuture { public static PlainActionFuture newFuture() { @@ -22,6 +24,12 @@ public class PlainActionFuture extends AdapterActionFuture { return fut.actionGet(); } + public static T get(CheckedConsumer, E> e, long timeout, TimeUnit unit) throws E { + PlainActionFuture fut = newFuture(); + e.accept(fut); + return fut.actionGet(timeout, unit); + } + @Override protected T convert(T listenerResponse) { return listenerResponse; diff --git a/server/src/main/java/org/elasticsearch/transport/ClusterConnectionManager.java b/server/src/main/java/org/elasticsearch/transport/ClusterConnectionManager.java index 95a1dddd94da..4eef70a53ef0 100644 --- a/server/src/main/java/org/elasticsearch/transport/ClusterConnectionManager.java +++ b/server/src/main/java/org/elasticsearch/transport/ClusterConnectionManager.java @@ -123,8 +123,8 @@ public class ClusterConnectionManager implements ConnectionManager { return; } - if (connectingRefCounter.tryIncRef() == false) { - listener.onFailure(new IllegalStateException("connection manager is closed")); + if (acquireConnectingRef() == false) { + listener.onFailure(new ConnectTransportException(node, "connection manager is closed")); return; } @@ -378,4 +378,8 @@ public class ClusterConnectionManager implements ConnectionManager { return defaultProfile; } + private boolean acquireConnectingRef() { + return closing.get() == false && connectingRefCounter.tryIncRef(); + } + } diff --git a/server/src/test/java/org/elasticsearch/transport/ClusterConnectionManagerTests.java b/server/src/test/java/org/elasticsearch/transport/ClusterConnectionManagerTests.java index 41ff27deca61..bf639a9776ef 100644 --- a/server/src/test/java/org/elasticsearch/transport/ClusterConnectionManagerTests.java +++ b/server/src/test/java/org/elasticsearch/transport/ClusterConnectionManagerTests.java @@ -19,6 +19,9 @@ import org.elasticsearch.cluster.node.DiscoveryNodeRole; import org.elasticsearch.common.logging.Loggers; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; +import org.elasticsearch.common.util.concurrent.ConcurrentCollections; +import org.elasticsearch.common.util.concurrent.RunOnce; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; @@ -34,6 +37,7 @@ import java.net.InetAddress; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Queue; import java.util.Set; import java.util.concurrent.BrokenBarrierException; import java.util.concurrent.ConcurrentHashMap; @@ -41,12 +45,15 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.CyclicBarrier; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; import static org.elasticsearch.test.ActionListenerUtils.anyActionListener; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.notNullValue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -344,6 +351,127 @@ public class ClusterConnectionManagerTests extends ESTestCase { } } + public void testConcurrentConnectsDuringClose() throws Exception { + + // This test ensures that closing the connection manager doesn't block forever, even if there's a constant stream of attempts to + // open connections. Note that closing the connection manager _does_ block while there are in-flight connection attempts, and in + // practice each attempt will (eventually) finish, so we're just trying to test that constant open attempts do not cause starvation. + // + // It works by spawning connection-open attempts in several concurrent loops, putting a Runnable to complete each attempt into a + // queue, and then consuming and completing the enqueued runnables in a separate thread. The consuming thread is throttled via a + // Semaphore, from which the main thread steals a permit which ensures that there's always at least one pending connection while the + // close is ongoing even though no connection attempt blocks forever. + + final Semaphore pendingConnectionPermits = new Semaphore(0); + final Queue pendingConnections = ConcurrentCollections.newQueue(); + + // transport#openConnection enqueues a Runnable to complete the connection attempt + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + final ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + final DiscoveryNode targetNode = (DiscoveryNode) invocationOnMock.getArguments()[0]; + pendingConnections.add(() -> listener.onResponse(new TestConnect(targetNode))); + pendingConnectionPermits.release(); + return null; + }).when(transport).openConnection(any(), eq(connectionProfile), anyActionListener()); + + final ConnectionManager.ConnectionValidator validator = (c, p, l) -> l.onResponse(null); + + // Once we start to see connections being rejected, we give back the stolen permit so that the last connection can complete + final Runnable onConnectException = new RunOnce(pendingConnectionPermits::release); + + // Create a few threads which open connections in a loop. Must be at least 2 so that there's always more connections incoming. + final int connectionLoops = between(2, 4); + final CountDownLatch connectionLoopCountDown = new CountDownLatch(connectionLoops); + final AtomicBoolean expectConnectionFailures = new AtomicBoolean(); // unexpected failures would make this test pass vacuously + + class ConnectionLoop extends AbstractRunnable { + + @Override + public void onFailure(Exception e) { + assert false : e; + } + + @Override + protected void doRun() throws Exception { + final DiscoveryNode discoveryNode = new DiscoveryNode( + "", + new TransportAddress(InetAddress.getLoopbackAddress(), 0), + Version.CURRENT + ); + final ActionListener listener = new ActionListener() { + @Override + public void onResponse(Releasable releasable) { + releasable.close(); + threadPool.generic().execute(ConnectionLoop.this); + } + + @Override + public void onFailure(Exception e) { + assertTrue(expectConnectionFailures.get()); + assertThat(e, instanceOf(ConnectTransportException.class)); + assertThat(e.getMessage(), containsString("connection manager is closed")); + onConnectException.run(); + connectionLoopCountDown.countDown(); + } + }; + + connectionManager.connectToNode(discoveryNode, connectionProfile, validator, listener); + } + } + + for (int i = 0; i < connectionLoops; i++) { + threadPool.generic().execute(new ConnectionLoop()); + } + + // Create a separate thread to complete pending connection attempts, throttled by the pendingConnectionPermits semaphore + final Thread completionThread = new Thread(() -> { + while (true) { + try { + assertTrue(pendingConnectionPermits.tryAcquire(10, TimeUnit.SECONDS)); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + // There could still be items in the queue when we are interrupted, so drain the queue before exiting: + while (pendingConnectionPermits.tryAcquire()) { + // noinspection ConstantConditions + pendingConnections.poll().run(); + } + return; + } + // noinspection ConstantConditions + pendingConnections.poll().run(); + } + }); + completionThread.start(); + + // Steal a permit so that the consumer lags behind the producers ... + assertTrue(pendingConnectionPermits.tryAcquire(10, TimeUnit.SECONDS)); + // ... and then send a connection attempt through the system to ensure that the lagging has started + Releasables.closeExpectNoException( + PlainActionFuture.get( + fut -> connectionManager.connectToNode( + new DiscoveryNode("", new TransportAddress(InetAddress.getLoopbackAddress(), 0), Version.CURRENT), + connectionProfile, + validator, + fut + ), + 30, + TimeUnit.SECONDS + ) + ); + + // Now close the connection manager + expectConnectionFailures.set(true); + connectionManager.close(); + // Success! The close call returned + + // Clean up and check everything completed properly + assertTrue(connectionLoopCountDown.await(10, TimeUnit.SECONDS)); + completionThread.interrupt(); + completionThread.join(); + assertTrue(pendingConnections.isEmpty()); + } + public void testConcurrentConnectsAndDisconnects() throws Exception { final DiscoveryNode node = new DiscoveryNode("", new TransportAddress(InetAddress.getLoopbackAddress(), 0), Version.CURRENT); doAnswer(invocationOnMock -> {