Ensure partial bulks released if channel closes (#112724)

Currently, the entire close pipeline is not hooked up in case of a
channel close while a request is being buffered or executed. This commit
resolves the issue by adding a connection to a stream closure.
This commit is contained in:
Tim Brooks 2024-09-11 19:20:15 -06:00
parent 2dbbd7dd45
commit 58e3a39392
7 changed files with 122 additions and 32 deletions

View file

@ -111,9 +111,11 @@ public class Netty4IncrementalRequestHandlingIT extends ESNetty4IntegTestCase {
assertTrue(recvChunk.isLast);
assertEquals(0, recvChunk.chunk.length());
recvChunk.chunk.close();
assertFalse(handler.streamClosed);
// send response to process following request
handler.sendResponse(new RestResponse(RestStatus.OK, ""));
assertBusy(() -> assertTrue(handler.streamClosed));
}
assertBusy(() -> assertEquals("should receive all server responses", totalRequests, ctx.clientRespQueue.size()));
}
@ -146,14 +148,16 @@ public class Netty4IncrementalRequestHandlingIT extends ESNetty4IntegTestCase {
}
}
assertFalse(handler.streamClosed);
assertEquals("sent and received payloads are not the same", sendData, recvData);
handler.sendResponse(new RestResponse(RestStatus.OK, ""));
assertBusy(() -> assertTrue(handler.streamClosed));
}
assertBusy(() -> assertEquals("should receive all server responses", totalRequests, ctx.clientRespQueue.size()));
}
}
// ensures that all received chunks are released when connection closed
// ensures that all received chunks are released when connection closed and handler notified
public void testClientConnectionCloseMidStream() throws Exception {
try (var ctx = setupClientCtx()) {
var opaqueId = opaqueId(0);
@ -168,10 +172,14 @@ public class Netty4IncrementalRequestHandlingIT extends ESNetty4IntegTestCase {
// enable auto-read to receive channel close event
handler.stream.channel().config().setAutoRead(true);
assertFalse(handler.streamClosed);
// terminate connection and wait resources are released
ctx.clientChannel.close();
assertBusy(() -> assertNull(handler.stream.buf()));
assertBusy(() -> {
assertNull(handler.stream.buf());
assertTrue(handler.streamClosed);
});
}
}
@ -187,10 +195,14 @@ public class Netty4IncrementalRequestHandlingIT extends ESNetty4IntegTestCase {
// await stream handler is ready and request full content
var handler = ctx.awaitRestChannelAccepted(opaqueId);
assertBusy(() -> assertNotNull(handler.stream.buf()));
assertFalse(handler.streamClosed);
// terminate connection on server and wait resources are released
handler.channel.request().getHttpChannel().close();
assertBusy(() -> assertNull(handler.stream.buf()));
assertBusy(() -> {
assertNull(handler.stream.buf());
assertTrue(handler.streamClosed);
});
}
}
@ -471,6 +483,7 @@ public class Netty4IncrementalRequestHandlingIT extends ESNetty4IntegTestCase {
final Netty4HttpRequestBodyStream stream;
RestChannel channel;
boolean recvLast = false;
volatile boolean streamClosed = false;
ServerRequestHandler(String opaqueId, Netty4HttpRequestBodyStream stream) {
this.opaqueId = opaqueId;
@ -488,6 +501,11 @@ public class Netty4IncrementalRequestHandlingIT extends ESNetty4IntegTestCase {
channelAccepted.onResponse(null);
}
@Override
public void streamClose() {
streamClosed = true;
}
void sendResponse(RestResponse response) {
channel.sendResponse(response);
}

View file

@ -133,6 +133,9 @@ public class Netty4HttpRequestBodyStream implements HttpBody.Stream {
private void doClose() {
closing = true;
if (handler != null) {
handler.close();
}
if (buf != null) {
buf.release();
buf = null;

View file

@ -90,6 +90,29 @@ public class IncrementalBulkIT extends ESIntegTestCase {
assertFalse(refCounted.hasReferences());
}
public void testBufferedResourcesReleasedOnClose() {
String index = "test";
createIndex(index);
String nodeName = internalCluster().getRandomNodeName();
IncrementalBulkService incrementalBulkService = internalCluster().getInstance(IncrementalBulkService.class, nodeName);
IndexingPressure indexingPressure = internalCluster().getInstance(IndexingPressure.class, nodeName);
IncrementalBulkService.Handler handler = incrementalBulkService.newBulkRequest();
IndexRequest indexRequest = indexRequest(index);
AbstractRefCounted refCounted = AbstractRefCounted.of(() -> {});
handler.addItems(List.of(indexRequest), refCounted::decRef, () -> {});
assertTrue(refCounted.hasReferences());
assertThat(indexingPressure.stats().getCurrentCoordinatingBytes(), greaterThan(0L));
handler.close();
assertFalse(refCounted.hasReferences());
assertThat(indexingPressure.stats().getCurrentCoordinatingBytes(), equalTo(0L));
}
public void testIndexingPressureRejection() {
String index = "test";
createIndex(index);
@ -303,14 +326,20 @@ public class IncrementalBulkIT extends ESIntegTestCase {
String secondShardNode = findShard(resolveIndex(index), 1);
IndexingPressure primaryPressure = internalCluster().getInstance(IndexingPressure.class, node);
long memoryLimit = primaryPressure.stats().getMemoryLimit();
long primaryRejections = primaryPressure.stats().getPrimaryRejections();
try (Releasable releasable = primaryPressure.markPrimaryOperationStarted(10, memoryLimit, false)) {
while (nextRequested.get()) {
nextRequested.set(false);
refCounted.incRef();
handler.addItems(List.of(indexRequest(index)), refCounted::decRef, () -> nextRequested.set(true));
while (primaryPressure.stats().getPrimaryRejections() == primaryRejections) {
while (nextRequested.get()) {
nextRequested.set(false);
refCounted.incRef();
List<DocWriteRequest<?>> requests = new ArrayList<>();
for (int i = 0; i < 20; ++i) {
requests.add(indexRequest(index));
}
handler.addItems(requests, refCounted::decRef, () -> nextRequested.set(true));
}
assertBusy(() -> assertTrue(nextRequested.get()));
}
assertBusy(() -> assertTrue(nextRequested.get()));
}
while (nextRequested.get()) {

View file

@ -102,6 +102,7 @@ public class IncrementalBulkService {
private final ArrayList<Releasable> releasables = new ArrayList<>(4);
private final ArrayList<BulkResponse> responses = new ArrayList<>(2);
private boolean closed = false;
private boolean globalFailure = false;
private boolean incrementalRequestSubmitted = false;
private ThreadContext.StoredContext requestContext;
@ -127,6 +128,7 @@ public class IncrementalBulkService {
}
public void addItems(List<DocWriteRequest<?>> items, Releasable releasable, Runnable nextItems) {
assert closed == false;
if (bulkActionLevelFailure != null) {
shortCircuitDueToTopLevelFailure(items, releasable);
nextItems.run();
@ -138,12 +140,13 @@ public class IncrementalBulkService {
incrementalRequestSubmitted = true;
try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
requestContext.restore();
final ArrayList<Releasable> toRelease = new ArrayList<>(releasables);
releasables.clear();
client.bulk(bulkRequest, ActionListener.runAfter(new ActionListener<>() {
@Override
public void onResponse(BulkResponse bulkResponse) {
responses.add(bulkResponse);
releaseCurrentReferences();
handleBulkSuccess(bulkResponse);
createNewBulkRequest(
new BulkRequest.IncrementalState(bulkResponse.getIncrementalState().shardLevelFailures(), true)
);
@ -155,6 +158,7 @@ public class IncrementalBulkService {
}
}, () -> {
requestContext = threadContext.newStoredContext();
toRelease.forEach(Releasable::close);
nextItems.run();
}));
}
@ -180,14 +184,15 @@ public class IncrementalBulkService {
if (internalAddItems(items, releasable)) {
try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
requestContext.restore();
client.bulk(bulkRequest, new ActionListener<>() {
final ArrayList<Releasable> toRelease = new ArrayList<>(releasables);
releasables.clear();
client.bulk(bulkRequest, ActionListener.runBefore(new ActionListener<>() {
private final boolean isFirstRequest = incrementalRequestSubmitted == false;
@Override
public void onResponse(BulkResponse bulkResponse) {
responses.add(bulkResponse);
releaseCurrentReferences();
handleBulkSuccess(bulkResponse);
listener.onResponse(combineResponses());
}
@ -196,7 +201,7 @@ public class IncrementalBulkService {
handleBulkFailure(isFirstRequest, e);
errorResponse(listener);
}
});
}, () -> toRelease.forEach(Releasable::close)));
}
} else {
errorResponse(listener);
@ -204,6 +209,13 @@ public class IncrementalBulkService {
}
}
@Override
public void close() {
closed = true;
releasables.forEach(Releasable::close);
releasables.clear();
}
private void shortCircuitDueToTopLevelFailure(List<DocWriteRequest<?>> items, Releasable releasable) {
assert releasables.isEmpty();
assert bulkRequest == null;
@ -221,12 +233,17 @@ public class IncrementalBulkService {
}
}
private void handleBulkSuccess(BulkResponse bulkResponse) {
responses.add(bulkResponse);
bulkRequest = null;
}
private void handleBulkFailure(boolean isFirstRequest, Exception e) {
assert bulkActionLevelFailure == null;
globalFailure = isFirstRequest;
bulkActionLevelFailure = e;
addItemLevelFailures(bulkRequest.requests());
releaseCurrentReferences();
bulkRequest = null;
}
private void addItemLevelFailures(List<DocWriteRequest<?>> items) {
@ -254,6 +271,8 @@ public class IncrementalBulkService {
return true;
} catch (EsRejectedExecutionException e) {
handleBulkFailure(incrementalRequestSubmitted == false, e);
releasables.forEach(Releasable::close);
releasables.clear();
return false;
}
}
@ -298,10 +317,5 @@ public class IncrementalBulkService {
return new BulkResponse(bulkItemResponses, tookInMillis, ingestTookInMillis);
}
@Override
public void close() {
// TODO: Implement
}
}
}

View file

@ -100,8 +100,11 @@ public sealed interface HttpBody extends Releasable permits HttpBody.Full, HttpB
}
@FunctionalInterface
interface ChunkHandler {
interface ChunkHandler extends Releasable {
void onNext(ReleasableBytesReference chunk, boolean isLast);
@Override
default void close() {}
}
record ByteRefHttpBody(BytesReference bytes) implements Full {}

View file

@ -20,6 +20,7 @@ import org.elasticsearch.core.RefCounted;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.RestApiVersion;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.http.HttpBody;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.rest.action.admin.cluster.RestNodesUsageAction;
@ -127,7 +128,17 @@ public abstract class BaseRestHandler implements RestHandler {
if (request.isStreamedContent()) {
assert action instanceof RequestBodyChunkConsumer;
var chunkConsumer = (RequestBodyChunkConsumer) action;
request.contentStream().setHandler((chunk, isLast) -> chunkConsumer.handleChunk(channel, chunk, isLast));
request.contentStream().setHandler(new HttpBody.ChunkHandler() {
@Override
public void onNext(ReleasableBytesReference chunk, boolean isLast) {
chunkConsumer.handleChunk(channel, chunk, isLast);
}
@Override
public void close() {
chunkConsumer.streamClose();
}
});
}
usageCount.increment();
@ -189,6 +200,13 @@ public abstract class BaseRestHandler implements RestHandler {
public interface RequestBodyChunkConsumer extends RestChannelConsumer {
void handleChunk(RestChannel channel, ReleasableBytesReference chunk, boolean isLast);
/**
* Called when the stream closes. This could happen prior to the completion of the request if the underlying channel was closed.
* Implementors should do their best to clean up resources and early terminate request processing if it is triggered before a
* response is generated.
*/
default void streamClose() {}
}
/**

View file

@ -32,6 +32,7 @@ import org.elasticsearch.rest.ServerlessScope;
import org.elasticsearch.rest.action.RestRefCountedChunkedToXContentListener;
import org.elasticsearch.rest.action.RestToXContentListener;
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
import org.elasticsearch.transport.Transports;
import java.io.IOException;
import java.util.ArrayDeque;
@ -148,7 +149,7 @@ public class RestBulkAction extends BaseRestHandler {
private IncrementalBulkService.Handler handler;
private volatile RestChannel restChannel;
private boolean isException;
private boolean shortCircuited;
private final ArrayDeque<ReleasableBytesReference> unParsedChunks = new ArrayDeque<>(4);
private final ArrayList<DocWriteRequest<?>> items = new ArrayList<>(4);
@ -178,7 +179,7 @@ public class RestBulkAction extends BaseRestHandler {
public void handleChunk(RestChannel channel, ReleasableBytesReference chunk, boolean isLast) {
assert handler != null;
assert channel == restChannel;
if (isException) {
if (shortCircuited) {
chunk.close();
return;
}
@ -215,12 +216,8 @@ public class RestBulkAction extends BaseRestHandler {
);
} catch (Exception e) {
// TODO: This needs to be better
Releasables.close(handler);
Releasables.close(unParsedChunks);
unParsedChunks.clear();
shortCircuit();
new RestToXContentListener<>(channel).onFailure(e);
isException = true;
return;
}
@ -242,8 +239,16 @@ public class RestBulkAction extends BaseRestHandler {
}
@Override
public void close() {
RequestBodyChunkConsumer.super.close();
public void streamClose() {
assert Transports.assertTransportThread();
shortCircuit();
}
private void shortCircuit() {
shortCircuited = true;
Releasables.close(handler);
Releasables.close(unParsedChunks);
unParsedChunks.clear();
}
private ArrayList<Releasable> accountParsing(int bytesConsumed) {