Remove reference counting from InboundMessage and make it Releasable (#126138)

There is no actual need to reference-count InboundMessage instances. Their lifecycle is completely linear and we can simplify it away. This saves a little work directly but more importantly, it enables more eager releasing of the underlying buffers in a follow-up.
---------

Co-authored-by: David Turner <david.turner@elastic.co>
This commit is contained in:
Armin Braun 2025-04-17 15:10:08 +02:00 committed by GitHub
parent a72883e8e3
commit 149ff93789
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 85 additions and 42 deletions

View file

@ -120,7 +120,7 @@ public class InboundAggregator implements Releasable {
checkBreaker(aggregated.getHeader(), aggregated.getContentLength(), breakerControl);
}
if (isShortCircuited()) {
aggregated.decRef();
aggregated.close();
success = true;
return new InboundMessage(aggregated.getHeader(), aggregationException);
} else {
@ -131,7 +131,7 @@ public class InboundAggregator implements Releasable {
} finally {
resetCurrentAggregation();
if (success == false) {
aggregated.decRef();
aggregated.close();
}
}
}

View file

@ -22,7 +22,6 @@ import org.elasticsearch.common.network.HandlingTimeTracker;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.threadpool.ThreadPool;
@ -87,21 +86,31 @@ public class InboundHandler {
this.slowLogThresholdMs = slowLogThreshold.getMillis();
}
/**
* @param message the transport message received, guaranteed to be closed by this method if it returns without exception.
* Callers must ensure that {@code message} is closed if this method throws an exception but must not release
* the message themselves otherwise
*/
void inboundMessage(TcpChannel channel, InboundMessage message) throws Exception {
final long startTime = threadPool.rawRelativeTimeInMillis();
channel.getChannelStats().markAccessed(startTime);
TransportLogger.logInboundMessage(channel, message);
if (message.isPing()) {
keepAlive.receiveKeepAlive(channel);
keepAlive.receiveKeepAlive(channel); // pings hold no resources, no need to close
} else {
messageReceived(channel, message, startTime);
messageReceived(channel, /* autocloses absent exception */ message, startTime);
}
}
// Empty stream constant to avoid instantiating a new stream for empty messages.
private static final StreamInput EMPTY_STREAM_INPUT = new ByteBufferStreamInput(ByteBuffer.wrap(BytesRef.EMPTY_BYTES));
/**
* @param message the transport message received, guaranteed to be closed by this method if it returns without exception.
* Callers must ensure that {@code message} is closed if this method throws an exception but must not release
* the message themselves otherwise
*/
private void messageReceived(TcpChannel channel, InboundMessage message, long startTime) throws IOException {
final InetSocketAddress remoteAddress = channel.getRemoteAddress();
final Header header = message.getHeader();
@ -115,14 +124,16 @@ public class InboundHandler {
threadContext.setHeaders(header.getHeaders());
threadContext.putTransient("_remote_address", remoteAddress);
if (header.isRequest()) {
handleRequest(channel, message);
handleRequest(channel, /* autocloses absent exception */ message);
} else {
// Responses do not support short circuiting currently
assert message.isShortCircuit() == false;
responseHandler = findResponseHandler(header);
// ignore if its null, the service logs it
if (responseHandler != null) {
executeResponseHandler(message, responseHandler, remoteAddress);
executeResponseHandler( /* autocloses absent exception */ message, responseHandler, remoteAddress);
} else {
message.close();
}
}
} finally {
@ -135,6 +146,11 @@ public class InboundHandler {
}
}
/**
* @param message the transport message received, guaranteed to be closed by this method if it returns without exception.
* Callers must ensure that {@code message} is closed if this method throws an exception but must not release
* the message themselves otherwise
*/
private void executeResponseHandler(
InboundMessage message,
TransportResponseHandler<?> responseHandler,
@ -145,13 +161,13 @@ public class InboundHandler {
final StreamInput streamInput = namedWriteableStream(message.openOrGetStreamInput());
assert assertRemoteVersion(streamInput, header.getVersion());
if (header.isError()) {
handlerResponseError(streamInput, message, responseHandler);
handlerResponseError(streamInput, /* autocloses */ message, responseHandler);
} else {
handleResponse(remoteAddress, streamInput, responseHandler, message);
handleResponse(remoteAddress, streamInput, responseHandler, /* autocloses */ message);
}
} else {
assert header.isError() == false;
handleResponse(remoteAddress, EMPTY_STREAM_INPUT, responseHandler, message);
handleResponse(remoteAddress, EMPTY_STREAM_INPUT, responseHandler, /* autocloses */ message);
}
}
@ -220,10 +236,15 @@ public class InboundHandler {
}
}
/**
* @param message the transport message received, guaranteed to be closed by this method if it returns without exception.
* Callers must ensure that {@code message} is closed if this method throws an exception but must not release
* the message themselves otherwise
*/
private <T extends TransportRequest> void handleRequest(TcpChannel channel, InboundMessage message) throws IOException {
final Header header = message.getHeader();
if (header.isHandshake()) {
handleHandshakeRequest(channel, message);
handleHandshakeRequest(channel, /* autocloses */ message);
return;
}
@ -243,7 +264,7 @@ public class InboundHandler {
Releasables.assertOnce(message.takeBreakerReleaseControl())
);
try {
try (message) {
messageListener.onRequestReceived(requestId, action);
if (reg != null) {
reg.addRequestStats(header.getNetworkMessageSize() + TcpHeader.BYTES_REQUIRED_FOR_MESSAGE_SIZE);
@ -331,6 +352,9 @@ public class InboundHandler {
}
}
/**
* @param message guaranteed to get closed by this method
*/
private void handleHandshakeRequest(TcpChannel channel, InboundMessage message) throws IOException {
var header = message.getHeader();
assert header.actionName.equals(TransportHandshaker.HANDSHAKE_ACTION_NAME);
@ -351,7 +375,7 @@ public class InboundHandler {
true,
Releasables.assertOnce(message.takeBreakerReleaseControl())
);
try {
try (message) {
handshaker.handleHandshake(transportChannel, requestId, stream);
} catch (Exception e) {
logger.warn(
@ -371,29 +395,30 @@ public class InboundHandler {
}
}
/**
* @param message guaranteed to get closed by this method
*/
private <T extends TransportResponse> void handleResponse(
InetSocketAddress remoteAddress,
final StreamInput stream,
final TransportResponseHandler<T> handler,
final InboundMessage inboundMessage
final InboundMessage message
) {
final var executor = handler.executor();
if (executor == EsExecutors.DIRECT_EXECUTOR_SERVICE) {
// no need to provide a buffer release here, we never escape the buffer when handling directly
doHandleResponse(handler, remoteAddress, stream, inboundMessage.getHeader(), () -> {});
doHandleResponse(handler, remoteAddress, stream, /* autocloses */ message);
} else {
inboundMessage.mustIncRef();
// release buffer once we deserialize the message, but have a fail-safe in #onAfter below in case that didn't work out
final Releasable releaseBuffer = Releasables.releaseOnce(inboundMessage::decRef);
executor.execute(new ForkingResponseHandlerRunnable(handler, null) {
@Override
protected void doRun() {
doHandleResponse(handler, remoteAddress, stream, inboundMessage.getHeader(), releaseBuffer);
doHandleResponse(handler, remoteAddress, stream, /* autocloses */ message);
}
@Override
public void onAfter() {
Releasables.closeExpectNoException(releaseBuffer);
message.close();
}
});
}
@ -404,20 +429,19 @@ public class InboundHandler {
* @param handler response handler
* @param remoteAddress remote address that the message was sent from
* @param stream bytes stream for reading the message
* @param header message header
* @param releaseResponseBuffer releasable that will be released once the message has been read from the {@code stream}
* @param inboundMessage inbound message, guaranteed to get closed by this method
* @param <T> response message type
*/
private <T extends TransportResponse> void doHandleResponse(
TransportResponseHandler<T> handler,
InetSocketAddress remoteAddress,
final StreamInput stream,
final Header header,
Releasable releaseResponseBuffer
InboundMessage inboundMessage
) {
final T response;
try (releaseResponseBuffer) {
try (inboundMessage) {
response = handler.read(stream);
verifyResponseReadFully(inboundMessage.getHeader(), handler, stream);
} catch (Exception e) {
final TransportException serializationException = new TransportSerializationException(
"Failed to deserialize response from handler [" + handler + "]",
@ -429,7 +453,6 @@ public class InboundHandler {
return;
}
try {
verifyResponseReadFully(header, handler, stream);
handler.handleResponse(response);
} catch (Exception e) {
doHandleException(handler, new ResponseHandlerFailureTransportException(e));
@ -438,9 +461,12 @@ public class InboundHandler {
}
}
/**
* @param message guaranteed to get closed by this method
*/
private void handlerResponseError(StreamInput stream, InboundMessage message, final TransportResponseHandler<?> handler) {
Exception error;
try {
try (message) {
error = stream.readException();
verifyResponseReadFully(message.getHeader(), handler, stream);
} catch (Exception e) {

View file

@ -12,14 +12,15 @@ package org.elasticsearch.transport;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.core.AbstractRefCounted;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.Releasable;
import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle;
import java.util.Objects;
public class InboundMessage extends AbstractRefCounted {
public class InboundMessage implements Releasable {
private final Header header;
private final ReleasableBytesReference content;
@ -28,6 +29,19 @@ public class InboundMessage extends AbstractRefCounted {
private Releasable breakerRelease;
private StreamInput streamInput;
@SuppressWarnings("unused") // updated via CLOSED (and _only_ via CLOSED)
private boolean closed;
private static final VarHandle CLOSED;
static {
try {
CLOSED = MethodHandles.lookup().findVarHandle(InboundMessage.class, "closed", boolean.class);
} catch (Exception e) {
throw new ExceptionInInitializerError(e);
}
}
public InboundMessage(Header header, ReleasableBytesReference content, Releasable breakerRelease) {
this.header = header;
this.content = content;
@ -84,7 +98,7 @@ public class InboundMessage extends AbstractRefCounted {
public StreamInput openOrGetStreamInput() throws IOException {
assert isPing == false && content != null;
assert hasReferences();
assert (boolean) CLOSED.getAcquire(this) == false;
if (streamInput == null) {
streamInput = content.streamInput();
streamInput.setTransportVersion(header.getVersion());
@ -98,7 +112,10 @@ public class InboundMessage extends AbstractRefCounted {
}
@Override
protected void closeInternal() {
public void close() {
if (CLOSED.compareAndSet(this, false, true) == false) {
return;
}
try {
IOUtils.close(streamInput, content, breakerRelease);
} catch (Exception e) {

View file

@ -112,13 +112,8 @@ public class InboundPipeline implements Releasable {
messageHandler.accept(channel, PING_MESSAGE);
} else if (fragment == InboundDecoder.END_CONTENT) {
assert aggregator.isAggregating();
InboundMessage aggregated = aggregator.finishAggregation();
try {
statsTracker.markMessageReceived();
messageHandler.accept(channel, aggregated);
} finally {
aggregated.decRef();
}
statsTracker.markMessageReceived();
messageHandler.accept(channel, /* autocloses */ aggregator.finishAggregation());
} else {
assert aggregator.isAggregating();
assert fragment instanceof ReleasableBytesReference;

View file

@ -813,9 +813,14 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
*/
public void inboundMessage(TcpChannel channel, InboundMessage message) {
try {
inboundHandler.inboundMessage(channel, message);
inboundHandler.inboundMessage(channel, /* autocloses absent exception */ message);
message = null;
} catch (Exception e) {
onException(channel, e);
} finally {
if (message != null) {
message.close();
}
}
}

View file

@ -95,7 +95,7 @@ public class InboundAggregatorTests extends ESTestCase {
for (ReleasableBytesReference reference : references) {
assertTrue(reference.hasReferences());
}
aggregated.decRef();
aggregated.close();
for (ReleasableBytesReference reference : references) {
assertFalse(reference.hasReferences());
}

View file

@ -50,7 +50,7 @@ public class InboundPipelineTests extends ESTestCase {
final List<Tuple<MessageData, Exception>> actual = new ArrayList<>();
final List<ReleasableBytesReference> toRelease = new ArrayList<>();
final BiConsumer<TcpChannel, InboundMessage> messageHandler = (c, m) -> {
try {
try (m) {
final Header header = m.getHeader();
final MessageData actualData;
final TransportVersion version = header.getVersion();
@ -204,7 +204,7 @@ public class InboundPipelineTests extends ESTestCase {
}
public void testDecodeExceptionIsPropagated() throws IOException {
BiConsumer<TcpChannel, InboundMessage> messageHandler = (c, m) -> {};
BiConsumer<TcpChannel, InboundMessage> messageHandler = (c, m) -> m.close();
final StatsTracker statsTracker = new StatsTracker();
final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime());
final InboundDecoder decoder = new InboundDecoder(recycler);
@ -245,7 +245,7 @@ public class InboundPipelineTests extends ESTestCase {
}
public void testEnsureBodyIsNotPrematurelyReleased() throws IOException {
BiConsumer<TcpChannel, InboundMessage> messageHandler = (c, m) -> {};
BiConsumer<TcpChannel, InboundMessage> messageHandler = (c, m) -> m.close();
final StatsTracker statsTracker = new StatsTracker();
final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime());
final InboundDecoder decoder = new InboundDecoder(recycler);