Move HTTP content aggregation from Netty into RestController (#129302)

This commit is contained in:
Mikhail Berezovskiy 2025-06-19 09:05:17 -07:00 committed by GitHub
parent 083326e658
commit eeca493860
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
35 changed files with 687 additions and 286 deletions

View file

@ -0,0 +1,6 @@
pr: 129302
summary: Move HTTP content aggregation from Netty into `RestController`
area: Network
type: enhancement
issues:
- 120746

View file

@ -920,6 +920,11 @@ public class Netty4IncrementalRequestHandlingIT extends ESNetty4IntegTestCase {
return ROUTE;
}
@Override
public boolean supportsContentStream() {
return true;
}
@Override
public List<Route> routes() {
return List.of(new Route(RestRequest.Method.POST, ROUTE));

View file

@ -47,6 +47,8 @@ class MissingReadDetector extends ChannelDuplexHandler {
if (pendingRead == false) {
long now = timer.absoluteTimeInMillis();
if (now >= lastRead + interval) {
// if you encounter this warning during test make sure you consume content of RestRequest if it's a stream
// or use AggregatingDispatcher that will consume stream fully and produce RestRequest with full content.
logger.warn("chan-id={} haven't read from channel for [{}ms]", ctx.channel().id(), (now - lastRead));
}
}

View file

@ -1,59 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.http.netty4;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpObject;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpRequestDecoder;
import io.netty.handler.codec.http.LastHttpContent;
import org.elasticsearch.http.HttpPreRequest;
import org.elasticsearch.http.netty4.internal.HttpHeadersAuthenticatorUtils;
import java.util.function.Predicate;
/**
* A wrapper around {@link HttpObjectAggregator}. Provides optional content aggregation based on
* predicate. {@link HttpObjectAggregator} also handles Expect: 100-continue and oversized content.
* Provides content size handling for non-aggregated requests too.
*/
public class Netty4HttpAggregator extends HttpObjectAggregator {
private static final Predicate<HttpPreRequest> IGNORE_TEST = (req) -> req.uri().startsWith("/_test/request-stream") == false;
private final Predicate<HttpPreRequest> decider;
private final Netty4HttpContentSizeHandler streamContentSizeHandler;
private boolean aggregating = true;
public Netty4HttpAggregator(int maxContentLength, Predicate<HttpPreRequest> decider, HttpRequestDecoder decoder) {
super(maxContentLength);
this.decider = decider;
this.streamContentSizeHandler = new Netty4HttpContentSizeHandler(decoder, maxContentLength);
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
assert msg instanceof HttpObject;
if (msg instanceof HttpRequest request) {
var preReq = HttpHeadersAuthenticatorUtils.asHttpPreRequest(request);
aggregating = (decider.test(preReq) && IGNORE_TEST.test(preReq)) || request.decoderResult().isFailure();
}
if (aggregating || msg instanceof FullHttpRequest) {
super.channelRead(ctx, msg);
if (msg instanceof LastHttpContent == false) {
ctx.read(); // HttpObjectAggregator is tricky with auto-read off, it might not call read again, calling on its behalf
}
} else {
streamContentSizeHandler.channelRead(ctx, msg);
}
}
}

View file

@ -19,7 +19,6 @@ import io.netty.handler.codec.compression.JdkZlibEncoder;
import io.netty.handler.codec.http.DefaultHttpContent;
import io.netty.handler.codec.http.DefaultHttpResponse;
import io.netty.handler.codec.http.DefaultLastHttpContent;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpContent;
import io.netty.handler.codec.http.HttpObject;
import io.netty.handler.codec.http.HttpRequest;
@ -131,18 +130,13 @@ public class Netty4HttpPipeliningHandler extends ChannelDuplexHandler {
} else {
nonError = (Exception) cause;
}
netty4HttpRequest = new Netty4HttpRequest(readSequence++, (FullHttpRequest) request, nonError);
netty4HttpRequest = new Netty4HttpRequest(readSequence++, request, nonError);
} else {
assert currentRequestStream == null : "current stream must be null for new request";
if (request instanceof FullHttpRequest fullHttpRequest) {
netty4HttpRequest = new Netty4HttpRequest(readSequence++, fullHttpRequest);
currentRequestStream = null;
} else {
var contentStream = new Netty4HttpRequestBodyStream(ctx, serverTransport.getThreadPool().getThreadContext());
currentRequestStream = contentStream;
netty4HttpRequest = new Netty4HttpRequest(readSequence++, request, contentStream);
shouldRead = false;
}
var contentStream = new Netty4HttpRequestBodyStream(ctx, serverTransport.getThreadPool().getThreadContext());
currentRequestStream = contentStream;
netty4HttpRequest = new Netty4HttpRequest(readSequence++, request, contentStream);
shouldRead = false;
}
handlePipelinedRequest(ctx, netty4HttpRequest);
} else {
@ -150,11 +144,11 @@ public class Netty4HttpPipeliningHandler extends ChannelDuplexHandler {
assert currentRequestStream != null : "current stream must exists before handling http content";
shouldRead = false;
currentRequestStream.handleNettyContent((HttpContent) msg);
if (msg instanceof LastHttpContent) {
currentRequestStream = null;
}
}
} finally {
if (msg instanceof LastHttpContent) {
currentRequestStream = null;
}
if (shouldRead) {
ctx.channel().eventLoop().execute(ctx::read);
}
@ -167,7 +161,7 @@ public class Netty4HttpPipeliningHandler extends ChannelDuplexHandler {
final Netty4HttpChannel channel = ctx.channel().attr(Netty4HttpServerTransport.HTTP_CHANNEL_KEY).get();
boolean success = false;
assert Transports.assertDefaultThreadContext(serverTransport.getThreadPool().getThreadContext());
assert Transports.assertTransportThread();
assert ctx.channel().eventLoop().inEventLoop();
try {
serverTransport.incomingRequest(pipelinedRequest, channel);
success = true;

View file

@ -9,13 +9,11 @@
package org.elasticsearch.http.netty4;
import io.netty.buffer.Unpooled;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.EmptyHttpHeaders;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpUtil;
import io.netty.handler.codec.http.QueryStringDecoder;
import io.netty.handler.codec.http.cookie.Cookie;
import io.netty.handler.codec.http.cookie.ServerCookieDecoder;
@ -28,7 +26,6 @@ import org.elasticsearch.http.HttpResponse;
import org.elasticsearch.rest.ChunkedRestResponseBodyPart;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.transport.netty4.Netty4Utils;
import java.util.AbstractMap;
import java.util.Collection;
@ -41,71 +38,57 @@ import java.util.stream.Collectors;
public class Netty4HttpRequest implements HttpRequest {
private final FullHttpRequest request;
private final HttpBody content;
private final int sequence;
private final io.netty.handler.codec.http.HttpRequest nettyRequest;
private boolean hasContent;
private HttpBody content;
private final Map<String, List<String>> headers;
private final AtomicBoolean released;
private final Exception inboundException;
private final boolean pooled;
private final int sequence;
private final QueryStringDecoder queryStringDecoder;
Netty4HttpRequest(int sequence, io.netty.handler.codec.http.HttpRequest request, Netty4HttpRequestBodyStream contentStream) {
this(
sequence,
new DefaultFullHttpRequest(
request.protocolVersion(),
request.method(),
request.uri(),
Unpooled.EMPTY_BUFFER,
request.headers(),
EmptyHttpHeaders.INSTANCE
),
new AtomicBoolean(false),
true,
contentStream,
null
);
public Netty4HttpRequest(int sequence, io.netty.handler.codec.http.HttpRequest nettyRequest, Exception exception) {
this(sequence, nettyRequest, HttpBody.empty(), new AtomicBoolean(false), exception);
}
Netty4HttpRequest(int sequence, FullHttpRequest request) {
this(sequence, request, new AtomicBoolean(false), true, Netty4Utils.fullHttpBodyFrom(request.content()));
}
Netty4HttpRequest(int sequence, FullHttpRequest request, Exception inboundException) {
this(sequence, request, new AtomicBoolean(false), true, Netty4Utils.fullHttpBodyFrom(request.content()), inboundException);
}
private Netty4HttpRequest(int sequence, FullHttpRequest request, AtomicBoolean released, boolean pooled, HttpBody content) {
this(sequence, request, released, pooled, content, null);
public Netty4HttpRequest(int sequence, io.netty.handler.codec.http.HttpRequest nettyRequest, HttpBody content) {
this(sequence, nettyRequest, content, new AtomicBoolean(false), null);
}
private Netty4HttpRequest(
int sequence,
FullHttpRequest request,
AtomicBoolean released,
boolean pooled,
io.netty.handler.codec.http.HttpRequest nettyRequest,
HttpBody content,
AtomicBoolean released,
Exception inboundException
) {
this.sequence = sequence;
this.request = request;
this.headers = getHttpHeadersAsMap(request.headers());
this.nettyRequest = nettyRequest;
this.hasContent = hasContentHeader(nettyRequest);
this.content = content;
this.pooled = pooled;
this.headers = getHttpHeadersAsMap(nettyRequest.headers());
this.released = released;
this.inboundException = inboundException;
this.queryStringDecoder = new QueryStringDecoder(request.uri());
this.queryStringDecoder = new QueryStringDecoder(nettyRequest.uri());
}
private static boolean hasContentHeader(io.netty.handler.codec.http.HttpRequest nettyRequest) {
return HttpUtil.isTransferEncodingChunked(nettyRequest) || HttpUtil.getContentLength(nettyRequest, 0L) > 0;
}
@Override
public boolean hasContent() {
return hasContent;
}
@Override
public RestRequest.Method method() {
return translateRequestMethod(request.method());
return translateRequestMethod(nettyRequest.method());
}
@Override
public String uri() {
return request.uri();
return nettyRequest.uri();
}
@Override
@ -119,10 +102,17 @@ public class Netty4HttpRequest implements HttpRequest {
return content;
}
@Override
public void setBody(HttpBody body) {
assert this.content.isStream() : "only stream content can be replaced";
assert body.isFull() : "only full content can replace stream";
this.content = body;
this.hasContent = body.isEmpty() == false;
}
@Override
public void release() {
if (pooled && released.compareAndSet(false, true)) {
request.release();
if (released.compareAndSet(false, true)) {
content.close();
}
}
@ -134,7 +124,7 @@ public class Netty4HttpRequest implements HttpRequest {
@Override
public List<String> strictCookies() {
String cookieString = request.headers().get(HttpHeaderNames.COOKIE);
String cookieString = nettyRequest.headers().get(HttpHeaderNames.COOKIE);
if (cookieString != null) {
Set<Cookie> cookies = ServerCookieDecoder.STRICT.decode(cookieString);
if (cookies.isEmpty() == false) {
@ -146,40 +136,36 @@ public class Netty4HttpRequest implements HttpRequest {
@Override
public HttpVersion protocolVersion() {
if (request.protocolVersion().equals(io.netty.handler.codec.http.HttpVersion.HTTP_1_0)) {
return HttpRequest.HttpVersion.HTTP_1_0;
} else if (request.protocolVersion().equals(io.netty.handler.codec.http.HttpVersion.HTTP_1_1)) {
return HttpRequest.HttpVersion.HTTP_1_1;
if (nettyRequest.protocolVersion().equals(io.netty.handler.codec.http.HttpVersion.HTTP_1_0)) {
return HttpVersion.HTTP_1_0;
} else if (nettyRequest.protocolVersion().equals(io.netty.handler.codec.http.HttpVersion.HTTP_1_1)) {
return HttpVersion.HTTP_1_1;
} else {
throw new IllegalArgumentException("Unexpected http protocol version: " + request.protocolVersion());
throw new IllegalArgumentException("Unexpected http protocol version: " + nettyRequest.protocolVersion());
}
}
@Override
public HttpRequest removeHeader(String header) {
HttpHeaders copiedHeadersWithout = request.headers().copy();
HttpHeaders copiedHeadersWithout = nettyRequest.headers().copy();
copiedHeadersWithout.remove(header);
HttpHeaders copiedTrailingHeadersWithout = request.trailingHeaders().copy();
copiedTrailingHeadersWithout.remove(header);
FullHttpRequest requestWithoutHeader = new DefaultFullHttpRequest(
request.protocolVersion(),
request.method(),
request.uri(),
request.content(),
copiedHeadersWithout,
copiedTrailingHeadersWithout
var requestWithoutHeader = new DefaultHttpRequest(
nettyRequest.protocolVersion(),
nettyRequest.method(),
nettyRequest.uri(),
copiedHeadersWithout
);
return new Netty4HttpRequest(sequence, requestWithoutHeader, released, pooled, content);
return new Netty4HttpRequest(sequence, requestWithoutHeader, content, released, null);
}
@Override
public Netty4FullHttpResponse createResponse(RestStatus status, BytesReference contentRef) {
return new Netty4FullHttpResponse(sequence, request.protocolVersion(), status, contentRef);
return new Netty4FullHttpResponse(sequence, nettyRequest.protocolVersion(), status, contentRef);
}
@Override
public HttpResponse createResponse(RestStatus status, ChunkedRestResponseBodyPart firstBodyPart) {
return new Netty4ChunkedHttpResponse(sequence, request.protocolVersion(), status, firstBodyPart);
return new Netty4ChunkedHttpResponse(sequence, nettyRequest.protocolVersion(), status, firstBodyPart);
}
@Override
@ -188,7 +174,7 @@ public class Netty4HttpRequest implements HttpRequest {
}
public io.netty.handler.codec.http.HttpRequest getNettyRequest() {
return request;
return nettyRequest;
}
public static RestRequest.Method translateRequestMethod(HttpMethod httpMethod) {

View file

@ -73,11 +73,11 @@ public class Netty4HttpRequestBodyStream implements HttpBody.Stream {
public void handleNettyContent(HttpContent httpContent) {
assert ctx.channel().eventLoop().inEventLoop() : Thread.currentThread().getName();
assert readLastChunk == false;
if (closing) {
httpContent.release();
read();
} else {
assert readLastChunk == false;
try (var ignored = threadContext.restoreExistingContext(requestContext)) {
var isLast = httpContent instanceof LastHttpContent;
var buf = Netty4Utils.toReleasableBytesReference(httpContent.content());
@ -105,17 +105,19 @@ public class Netty4HttpRequestBodyStream implements HttpBody.Stream {
private void doClose() {
assert ctx.channel().eventLoop().inEventLoop() : Thread.currentThread().getName();
closing = true;
try (var ignored = threadContext.restoreExistingContext(requestContext)) {
for (var tracer : tracingHandlers) {
Releasables.closeExpectNoException(tracer);
if (closing == false) {
closing = true;
try (var ignored = threadContext.restoreExistingContext(requestContext)) {
for (var tracer : tracingHandlers) {
Releasables.closeExpectNoException(tracer);
}
if (handler != null) {
handler.close();
}
}
if (handler != null) {
handler.close();
if (readLastChunk == false) {
read();
}
}
if (readLastChunk == false) {
read();
}
}
}

View file

@ -24,7 +24,6 @@ import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.http.HttpContentCompressor;
import io.netty.handler.codec.http.HttpContentDecompressor;
import io.netty.handler.codec.http.HttpMessage;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequestDecoder;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseEncoder;
@ -39,7 +38,6 @@ import io.netty.util.ResourceLeakDetector;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.bulk.IncrementalBulkService;
import org.elasticsearch.common.network.CloseableChannel;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.network.ThreadWatchdog;
@ -101,7 +99,6 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport {
private final TLSConfig tlsConfig;
private final AcceptChannelHandler.AcceptPredicate acceptChannelPredicate;
private final HttpValidator httpValidator;
private final IncrementalBulkService.Enabled enabled;
private final ThreadWatchdog threadWatchdog;
private final int readTimeoutMillis;
@ -140,7 +137,6 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport {
this.acceptChannelPredicate = acceptChannelPredicate;
this.httpValidator = httpValidator;
this.threadWatchdog = networkService.getThreadWatchdog();
this.enabled = new IncrementalBulkService.Enabled(clusterSettings);
this.pipeliningMaxEvents = SETTING_PIPELINING_MAX_EVENTS.get(settings);
@ -286,7 +282,7 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport {
}
public ChannelHandler configureServerChannelHandler() {
return new HttpChannelHandler(this, handlingSettings, tlsConfig, acceptChannelPredicate, httpValidator, enabled);
return new HttpChannelHandler(this, handlingSettings, tlsConfig, acceptChannelPredicate, httpValidator);
}
static final AttributeKey<Netty4HttpChannel> HTTP_CHANNEL_KEY = AttributeKey.newInstance("es-http-channel");
@ -299,22 +295,19 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport {
private final TLSConfig tlsConfig;
private final BiPredicate<String, InetSocketAddress> acceptChannelPredicate;
private final HttpValidator httpValidator;
private final IncrementalBulkService.Enabled enabled;
protected HttpChannelHandler(
final Netty4HttpServerTransport transport,
final HttpHandlingSettings handlingSettings,
final TLSConfig tlsConfig,
@Nullable final BiPredicate<String, InetSocketAddress> acceptChannelPredicate,
@Nullable final HttpValidator httpValidator,
IncrementalBulkService.Enabled enabled
@Nullable final HttpValidator httpValidator
) {
this.transport = transport;
this.handlingSettings = handlingSettings;
this.tlsConfig = tlsConfig;
this.acceptChannelPredicate = acceptChannelPredicate;
this.httpValidator = httpValidator;
this.enabled = enabled;
}
@Override
@ -389,15 +382,7 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport {
)
);
}
// combines the HTTP message pieces into a single full HTTP request (with headers and body)
final HttpObjectAggregator aggregator = new Netty4HttpAggregator(
handlingSettings.maxContentLength(),
httpPreRequest -> enabled.get() == false
|| ((httpPreRequest.rawPath().endsWith("/_bulk") == false)
|| httpPreRequest.rawPath().startsWith("/_xpack/monitoring/_bulk")),
decoder
);
aggregator.setMaxCumulationBufferComponents(transport.maxCompositeBufferComponents);
ch.pipeline()
.addLast("decoder_compress", new HttpContentDecompressor()) // this handles request body decompression
.addLast("encoder", new HttpResponseEncoder() {
@ -412,7 +397,8 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport {
return super.isContentAlwaysEmpty(msg);
}
})
.addLast("aggregator", aggregator);
.addLast(new Netty4HttpContentSizeHandler(decoder, handlingSettings.maxContentLength()));
if (handlingSettings.compression()) {
ch.pipeline().addLast("encoder_compress", new HttpContentCompressor(handlingSettings.compressionLevel()) {
@Override

View file

@ -16,11 +16,9 @@ import io.netty.handler.codec.DecoderResult;
import io.netty.handler.codec.http.DefaultHttpContent;
import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.DefaultLastHttpContent;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpContent;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpRequestDecoder;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.LastHttpContent;
@ -178,24 +176,6 @@ public class Netty4HttpHeaderValidatorTests extends ESTestCase {
asInstanceOf(LastHttpContent.class, channel.readInbound()).release();
}
public void testWithAggregator() {
channel.pipeline().addLast(new Netty4HttpAggregator(8192, (req) -> true, new HttpRequestDecoder()));
channel.writeInbound(newHttpRequest());
channel.writeInbound(newHttpContent());
channel.writeInbound(newLastHttpContent());
channel.read();
assertNull("should ignore read while validating", channel.readInbound());
var validationRequest = validatorRequestQueue.poll();
assertNotNull(validationRequest);
validationRequest.listener.onResponse(null);
channel.runPendingTasks();
asInstanceOf(FullHttpRequest.class, channel.readInbound()).release();
}
public void testBufferPipelinedRequestsWhenValidating() {
final var expectedChunks = new ArrayDeque<HttpContent>();
expectedChunks.addLast(newHttpContent());

View file

@ -19,6 +19,7 @@ import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.DecoderResult;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultHttpContent;
import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.DefaultHttpResponse;
import io.netty.handler.codec.http.DefaultLastHttpContent;
import io.netty.handler.codec.http.FullHttpResponse;
@ -35,15 +36,24 @@ import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.bytes.ZeroBytesReference;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.network.ThreadWatchdog;
import org.elasticsearch.common.network.ThreadWatchdogHelper;
import org.elasticsearch.common.recycler.Recycler;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.http.AggregatingDispatcher;
import org.elasticsearch.http.HttpResponse;
import org.elasticsearch.rest.ChunkedRestResponseBodyPart;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.telemetry.tracing.Tracer;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.netty4.Netty4Utils;
import org.elasticsearch.transport.netty4.NettyAllocator;
import org.elasticsearch.transport.netty4.SharedGroupFactory;
import org.elasticsearch.transport.netty4.TLSConfig;
import org.junit.After;
import java.nio.channels.ClosedChannelException;
@ -70,7 +80,6 @@ import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.sameInstance;
import static org.hamcrest.core.Is.is;
import static org.mockito.Mockito.mock;
public class Netty4HttpPipeliningHandlerTests extends ESTestCase {
@ -79,11 +88,14 @@ public class Netty4HttpPipeliningHandlerTests extends ESTestCase {
private final Map<String, CountDownLatch> waitingRequests = new ConcurrentHashMap<>();
private final Map<String, CountDownLatch> finishingRequests = new ConcurrentHashMap<>();
private final ThreadPool threadPool = new TestThreadPool("pipelining test");
@After
public void tearDown() throws Exception {
waitingRequests.keySet().forEach(this::finishRequest);
terminateExecutorService(handlerService);
terminateExecutorService(eventLoopService);
threadPool.shutdownNow();
super.tearDown();
}
@ -126,12 +138,31 @@ public class Netty4HttpPipeliningHandlerTests extends ESTestCase {
}
private EmbeddedChannel makeEmbeddedChannelWithSimulatedWork(int numberOfRequests) {
return new EmbeddedChannel(new Netty4HttpPipeliningHandler(numberOfRequests, null, new ThreadWatchdog.ActivityTracker()) {
@Override
protected void handlePipelinedRequest(ChannelHandlerContext ctx, Netty4HttpRequest pipelinedRequest) {
ctx.fireChannelRead(pipelinedRequest);
}
}, new WorkEmulatorHandler());
return new EmbeddedChannel(
new Netty4HttpPipeliningHandler(numberOfRequests, httpServerTransport(), new ThreadWatchdog.ActivityTracker()) {
@Override
protected void handlePipelinedRequest(ChannelHandlerContext ctx, Netty4HttpRequest pipelinedRequest) {
ctx.fireChannelRead(pipelinedRequest);
}
},
new WorkEmulatorHandler()
);
}
private Netty4HttpServerTransport httpServerTransport() {
return new Netty4HttpServerTransport(
Settings.EMPTY,
new NetworkService(List.of()),
threadPool,
xContentRegistry(),
new AggregatingDispatcher(),
ClusterSettings.createBuiltInClusterSettings(),
new SharedGroupFactory(Settings.EMPTY),
Tracer.NOOP,
TLSConfig.noTLS(),
null,
null
);
}
public void testThatPipeliningWorksWhenSlowRequestsInDifferentOrder() throws InterruptedException {
@ -193,7 +224,7 @@ public class Netty4HttpPipeliningHandlerTests extends ESTestCase {
public void testPipeliningRequestsAreReleased() {
final int numberOfRequests = 10;
final EmbeddedChannel embeddedChannel = new EmbeddedChannel(
new Netty4HttpPipeliningHandler(numberOfRequests + 1, null, new ThreadWatchdog.ActivityTracker())
new Netty4HttpPipeliningHandler(numberOfRequests + 1, httpServerTransport(), new ThreadWatchdog.ActivityTracker())
);
for (int i = 0; i < numberOfRequests; i++) {
@ -485,7 +516,7 @@ public class Netty4HttpPipeliningHandlerTests extends ESTestCase {
final var watchdog = new ThreadWatchdog();
final var activityTracker = watchdog.getActivityTrackerForCurrentThread();
final var requestHandled = new AtomicBoolean();
final var handler = new Netty4HttpPipeliningHandler(Integer.MAX_VALUE, mock(Netty4HttpServerTransport.class), activityTracker) {
final var handler = new Netty4HttpPipeliningHandler(Integer.MAX_VALUE, httpServerTransport(), activityTracker) {
@Override
protected void handlePipelinedRequest(ChannelHandlerContext ctx, Netty4HttpRequest pipelinedRequest) {
// thread is not idle while handling the request
@ -526,11 +557,7 @@ public class Netty4HttpPipeliningHandlerTests extends ESTestCase {
}
private Netty4HttpPipeliningHandler getTestHttpHandler() {
return new Netty4HttpPipeliningHandler(
Integer.MAX_VALUE,
mock(Netty4HttpServerTransport.class),
new ThreadWatchdog.ActivityTracker()
) {
return new Netty4HttpPipeliningHandler(Integer.MAX_VALUE, httpServerTransport(), new ThreadWatchdog.ActivityTracker()) {
@Override
protected void handlePipelinedRequest(ChannelHandlerContext ctx, Netty4HttpRequest pipelinedRequest) {
ctx.fireChannelRead(pipelinedRequest);
@ -591,8 +618,8 @@ public class Netty4HttpPipeliningHandlerTests extends ESTestCase {
assertThat(data, is(expectedContent));
}
private DefaultFullHttpRequest createHttpRequest(String uri) {
return new DefaultFullHttpRequest(HTTP_1_1, HttpMethod.GET, uri);
private Object[] createHttpRequest(String uri) {
return new Object[] { new DefaultHttpRequest(HTTP_1_1, HttpMethod.GET, uri), LastHttpContent.EMPTY_LAST_CONTENT };
}
private class WorkEmulatorHandler extends SimpleChannelInboundHandler<Netty4HttpRequest> {

View file

@ -0,0 +1,77 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.http.netty4;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpUtil;
import io.netty.handler.codec.http.HttpVersion;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.http.HttpBody;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.rest.FakeHttpBodyStream;
public class Netty4HttpRequestTests extends ESTestCase {
public void testEmptyFullContent() {
final var request = new Netty4HttpRequest(0, new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"), HttpBody.empty());
assertFalse(request.hasContent());
}
public void testEmptyStreamContent() {
final var request = new Netty4HttpRequest(
0,
new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"),
new FakeHttpBodyStream()
);
assertFalse(request.hasContent());
}
public void testNonEmptyFullContent() {
final var len = between(1, 1024);
final var request = new Netty4HttpRequest(
0,
new DefaultHttpRequest(
HttpVersion.HTTP_1_1,
HttpMethod.GET,
"/",
new DefaultHttpHeaders().add(HttpHeaderNames.CONTENT_LENGTH, len)
),
HttpBody.fromBytesReference(new BytesArray(new byte[len]))
);
assertTrue(request.hasContent());
}
public void testNonEmptyStreamContent() {
final var len = between(1, 1024);
final var nettyRequestWithLen = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
HttpUtil.setContentLength(nettyRequestWithLen, len);
final var requestWithLen = new Netty4HttpRequest(0, nettyRequestWithLen, new FakeHttpBodyStream());
assertTrue(requestWithLen.hasContent());
final var nettyChunkedRequest = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/", new DefaultHttpHeaders());
HttpUtil.setTransferEncodingChunked(nettyChunkedRequest, true);
final var chunkedRequest = new Netty4HttpRequest(0, nettyChunkedRequest, new FakeHttpBodyStream());
assertTrue(chunkedRequest.hasContent());
}
public void testReplaceContent() {
final var len = between(1, 1024);
final var nettyRequestWithLen = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
HttpUtil.setContentLength(nettyRequestWithLen, len);
final var streamRequest = new Netty4HttpRequest(0, nettyRequestWithLen, new FakeHttpBodyStream());
streamRequest.setBody(HttpBody.fromBytesReference(randomBytesReference(len)));
assertTrue(streamRequest.hasContent());
}
}

View file

@ -20,11 +20,11 @@ import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.http.AggregatingDispatcher;
import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.http.HttpRequest;
import org.elasticsearch.http.HttpResponse;
import org.elasticsearch.http.HttpServerTransport;
import org.elasticsearch.http.NullDispatcher;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.telemetry.tracing.Tracer;
import org.elasticsearch.test.ESTestCase;
@ -103,7 +103,7 @@ public class Netty4HttpServerPipeliningTests extends ESTestCase {
Netty4HttpServerPipeliningTests.this.networkService,
Netty4HttpServerPipeliningTests.this.threadPool,
xContentRegistry(),
new NullDispatcher(),
new AggregatingDispatcher(),
new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS),
new SharedGroupFactory(settings),
Tracer.NOOP,

View file

@ -47,7 +47,6 @@ import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.ElasticsearchWrapperException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.bulk.IncrementalBulkService;
import org.elasticsearch.action.support.ActionTestUtils;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.SubscribableListener;
@ -68,12 +67,12 @@ import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.http.AbstractHttpServerTransportTestCase;
import org.elasticsearch.http.AggregatingDispatcher;
import org.elasticsearch.http.BindHttpException;
import org.elasticsearch.http.CorsHandler;
import org.elasticsearch.http.HttpHeadersValidationException;
import org.elasticsearch.http.HttpServerTransport;
import org.elasticsearch.http.HttpTransportSettings;
import org.elasticsearch.http.NullDispatcher;
import org.elasticsearch.http.netty4.internal.HttpHeadersAuthenticatorUtils;
import org.elasticsearch.http.netty4.internal.HttpValidator;
import org.elasticsearch.rest.ChunkedRestResponseBodyPart;
@ -193,9 +192,9 @@ public class Netty4HttpServerTransportTests extends AbstractHttpServerTransportT
final int contentLength,
final HttpResponseStatus expectedStatus
) throws InterruptedException {
final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() {
final HttpServerTransport.Dispatcher dispatcher = new AggregatingDispatcher() {
@Override
public void dispatchRequest(RestRequest request, RestChannel channel, ThreadContext threadContext) {
public void dispatchAggregatedRequest(RestRequest request, RestChannel channel, ThreadContext threadContext) {
channel.sendResponse(new RestResponse(OK, RestResponse.TEXT_CONTENT_TYPE, new BytesArray("done")));
}
@ -263,7 +262,7 @@ public class Netty4HttpServerTransportTests extends AbstractHttpServerTransportT
networkService,
threadPool,
xContentRegistry(),
new NullDispatcher(),
new AggregatingDispatcher(),
clusterSettings,
new SharedGroupFactory(Settings.EMPTY),
Tracer.NOOP,
@ -284,7 +283,7 @@ public class Netty4HttpServerTransportTests extends AbstractHttpServerTransportT
networkService,
threadPool,
xContentRegistry(),
new NullDispatcher(),
new AggregatingDispatcher(),
clusterSettings,
new SharedGroupFactory(settings),
Tracer.NOOP,
@ -425,8 +424,7 @@ public class Netty4HttpServerTransportTests extends AbstractHttpServerTransportT
handlingSettings,
TLSConfig.noTLS(),
null,
randomFrom((httpPreRequest, channel, listener) -> listener.onResponse(null), null),
new IncrementalBulkService.Enabled(clusterSettings)
randomFrom((httpPreRequest, channel, listener) -> listener.onResponse(null), null)
) {
@Override
protected void initChannel(Channel ch) throws Exception {
@ -852,9 +850,9 @@ public class Netty4HttpServerTransportTests extends AbstractHttpServerTransportT
final Settings settings = createBuilderWithPort().put(HttpTransportSettings.SETTING_HTTP_MAX_CONTENT_LENGTH.getKey(), "1mb")
.build();
final String requestString = randomAlphaOfLength(2 * 1024 * 1024); // request size is twice the limit
final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() {
final HttpServerTransport.Dispatcher dispatcher = new AggregatingDispatcher() {
@Override
public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) {
public void dispatchAggregatedRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) {
throw new AssertionError("Request dispatched but shouldn't");
}
@ -863,20 +861,11 @@ public class Netty4HttpServerTransportTests extends AbstractHttpServerTransportT
throw new AssertionError("Request dispatched but shouldn't");
}
};
final HttpValidator httpValidator = (httpRequest, channel, validationListener) -> {
// assert that the validator sees the request unaltered
assertThat(httpRequest.uri(), is(uri));
if (randomBoolean()) {
validationListener.onResponse(null);
} else {
validationListener.onFailure(new ElasticsearchException("Boom"));
}
};
try (
Netty4HttpServerTransport transport = getTestNetty4HttpServerTransport(
settings,
dispatcher,
httpValidator,
(r, c, l) -> l.onResponse(null),
(restRequest, threadContext) -> {
throw new AssertionError("Request dispatched but shouldn't");
}
@ -1060,9 +1049,9 @@ public class Netty4HttpServerTransportTests extends AbstractHttpServerTransportT
final SubscribableListener<Void> transportClosedFuture = new SubscribableListener<>();
final CountDownLatch handlingRequestLatch = new CountDownLatch(1);
final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() {
final HttpServerTransport.Dispatcher dispatcher = new AggregatingDispatcher() {
@Override
public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) {
public void dispatchAggregatedRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) {
assertEquals(request.uri(), url);
final var response = RestResponse.chunked(
OK,

View file

@ -931,7 +931,7 @@ public class ActionModule extends AbstractModule {
registerHandler.accept(new RestCountAction());
registerHandler.accept(new RestTermVectorsAction());
registerHandler.accept(new RestMultiTermVectorsAction());
registerHandler.accept(new RestBulkAction(settings, bulkService));
registerHandler.accept(new RestBulkAction(settings, clusterSettings, bulkService));
registerHandler.accept(new RestUpdateAction());
registerHandler.accept(new RestSearchAction(restController.getSearchUsageHolder(), clusterSupportsFeature));

View file

@ -31,6 +31,13 @@ public sealed interface HttpBody extends Releasable permits HttpBody.Full, HttpB
return this instanceof Full;
}
default boolean isEmpty() {
if (isFull()) {
return asFull().bytes().length() == 0;
}
return false;
}
default boolean isStream() {
return this instanceof Stream;
}
@ -113,5 +120,10 @@ public sealed interface HttpBody extends Releasable permits HttpBody.Full, HttpB
default void close() {}
}
record ByteRefHttpBody(ReleasableBytesReference bytes) implements Full {}
record ByteRefHttpBody(ReleasableBytesReference bytes) implements Full {
@Override
public void close() {
bytes.close();
}
}
}

View file

@ -64,4 +64,5 @@ public interface HttpPreRequest {
}
return null;
}
}

View file

@ -30,12 +30,16 @@ public interface HttpRequest extends HttpPreRequest {
HttpBody body();
void setBody(HttpBody body);
List<String> strictCookies();
HttpVersion protocolVersion();
HttpRequest removeHeader(String header);
boolean hasContent();
/**
* Create an http response from this request and the supplied status and content.
*/

View file

@ -0,0 +1,94 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.rest;
import org.elasticsearch.common.bytes.CompositeBytesReference;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.http.HttpBody;
import java.util.ArrayList;
import java.util.function.Consumer;
public class RestContentAggregator {
private static void replaceBody(RestRequest restRequest, ReleasableBytesReference aggregate) {
restRequest.getHttpRequest().setBody(new HttpBody.ByteRefHttpBody(aggregate));
}
/**
* Aggregates content of the RestRequest and notifies consumer with updated, in-place, RestRequest.
* If content is already aggregated then passes through same request.
*/
public static void aggregate(RestRequest restRequest, Consumer<RestRequest> resultConsumer) {
final var httpRequest = restRequest.getHttpRequest();
switch (httpRequest.body()) {
case HttpBody.Full full -> resultConsumer.accept(restRequest);
case HttpBody.Stream stream -> {
final var aggregationHandler = new AggregationChunkHandler(restRequest, resultConsumer);
stream.setHandler(aggregationHandler);
stream.next();
}
}
}
private static class AggregationChunkHandler implements HttpBody.ChunkHandler {
final RestRequest restRequest;
final Consumer<RestRequest> resultConsumer;
final HttpBody.Stream stream;
boolean closing;
ArrayList<ReleasableBytesReference> chunks;
private AggregationChunkHandler(RestRequest restRequest, Consumer<RestRequest> resultConsumer) {
this.restRequest = restRequest;
this.resultConsumer = resultConsumer;
this.stream = restRequest.getHttpRequest().body().asStream();
}
@Override
public void onNext(ReleasableBytesReference chunk, boolean isLast) {
if (closing) {
chunk.close();
return;
}
if (isLast == false) {
if (chunks == null) {
chunks = new ArrayList<>(); // allocate array only when there is more than one chunk
}
chunks.add(chunk);
stream.next();
} else {
if (chunks == null) {
replaceBody(restRequest, chunk);
} else {
chunks.add(chunk);
var comp = CompositeBytesReference.of(chunks.toArray(new ReleasableBytesReference[0]));
var relComp = new ReleasableBytesReference(comp, Releasables.wrap(chunks));
replaceBody(restRequest, relComp);
}
chunks = null;
closing = true;
resultConsumer.accept(restRequest);
}
}
@Override
public void close() {
if (closing == false) {
closing = true;
if (chunks != null) {
Releasables.close(chunks);
chunks = null;
}
}
}
}
}

View file

@ -389,6 +389,26 @@ public class RestController implements HttpServerTransport.Dispatcher {
return Collections.unmodifiableSortedMap(allStats);
}
private void maybeAggregateAndDispatchRequest(
RestRequest restRequest,
RestChannel restChannel,
RestHandler handler,
MethodHandlers methodHandlers,
ThreadContext threadContext
) throws Exception {
if (handler.supportsContentStream()) {
dispatchRequest(restRequest, restChannel, handler, methodHandlers, threadContext);
} else {
RestContentAggregator.aggregate(restRequest, (aggregatedRequest) -> {
try {
dispatchRequest(aggregatedRequest, restChannel, handler, methodHandlers, threadContext);
} catch (Exception e) {
throw new ElasticsearchException(e);
}
});
}
}
private void dispatchRequest(
RestRequest request,
RestChannel channel,
@ -424,8 +444,6 @@ public class RestController implements HttpServerTransport.Dispatcher {
return;
}
}
// TODO: estimate streamed content size for circuit breaker,
// something like http_max_chunk_size * avg_compression_ratio(for compressed content)
final int contentLength = request.isFullContent() ? request.contentLength() : 0;
try {
if (handler.canTripCircuitBreaker()) {
@ -623,7 +641,7 @@ public class RestController implements HttpServerTransport.Dispatcher {
} else {
startTrace(threadContext, channel, handlers.getPath());
var decoratedChannel = new MeteringRestChannelDecorator(channel, requestsCounter, handler.getConcreteRestHandler());
dispatchRequest(request, decoratedChannel, handler, handlers, threadContext);
maybeAggregateAndDispatchRequest(request, decoratedChannel, handler, handlers, threadContext);
return;
}
}

View file

@ -40,6 +40,10 @@ public interface RestHandler {
return true;
}
default boolean supportsContentStream() {
return false;
}
/**
* Indicates if the RestHandler supports bulk content. A bulk request contains multiple objects
* delineated by {@link XContent#bulkSeparator()}. If a handler returns true this will affect

View file

@ -291,7 +291,7 @@ public class RestRequest implements ToXContent.Params, Traceable {
}
public boolean hasContent() {
return isStreamedContent() || contentLength() > 0;
return httpRequest.hasContent();
}
public int contentLength() {
@ -325,6 +325,7 @@ public class RestRequest implements ToXContent.Params, Traceable {
}
public HttpBody.Stream contentStream() {
this.contentConsumed = true;
return httpRequest.body().asStream();
}

View file

@ -21,6 +21,7 @@ import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.bytes.CompositeBytesReference;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
@ -63,12 +64,14 @@ public class RestBulkAction extends BaseRestHandler {
private final boolean allowExplicitIndex;
private final IncrementalBulkService bulkHandler;
private final IncrementalBulkService.Enabled incrementalEnabled;
private final Set<String> capabilities;
public RestBulkAction(Settings settings, IncrementalBulkService bulkHandler) {
public RestBulkAction(Settings settings, ClusterSettings clusterSettings, IncrementalBulkService bulkHandler) {
this.allowExplicitIndex = MULTI_ALLOW_EXPLICIT_INDEX.get(settings);
this.bulkHandler = bulkHandler;
this.capabilities = Set.of(FAILURE_STORE_STATUS_CAPABILITY);
this.incrementalEnabled = new IncrementalBulkService.Enabled(clusterSettings);
}
@Override
@ -86,6 +89,11 @@ public class RestBulkAction extends BaseRestHandler {
return "bulk_action";
}
@Override
public boolean supportsContentStream() {
return incrementalEnabled.get();
}
@Override
public RestChannelConsumer prepareRequest(final RestRequest request, final NodeClient client) throws IOException {
if (request.isStreamedContent() == false) {

View file

@ -17,11 +17,11 @@ import org.elasticsearch.common.transport.BoundTransportAddress;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.http.AggregatingDispatcher;
import org.elasticsearch.http.HttpInfo;
import org.elasticsearch.http.HttpPreRequest;
import org.elasticsearch.http.HttpServerTransport;
import org.elasticsearch.http.HttpStats;
import org.elasticsearch.http.NullDispatcher;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.plugins.NetworkPlugin;
import org.elasticsearch.telemetry.tracing.Tracer;
@ -299,7 +299,7 @@ public class NetworkModuleTests extends ESTestCase {
null,
xContentRegistry(),
null,
new NullDispatcher(),
new AggregatingDispatcher(),
(preRequest, threadContext) -> {},
new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS),
Tracer.NOOP

View file

@ -959,13 +959,12 @@ public class AbstractHttpServerTransportTests extends ESTestCase {
public void testStopWorksWithNoOpenRequests() {
var grace = SHORT_GRACE_PERIOD_MS;
try (var noWait = LogExpectation.unexpectedTimeout(grace); var transport = new TestHttpServerTransport(gracePeriod(grace))) {
final TestHttpRequest httpRequest = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/") {
@Override
public Map<String, List<String>> getHeaders() {
// close connection before shutting down
return Map.of(CONNECTION, List.of(CLOSE));
}
};
final TestHttpRequest httpRequest = new TestHttpRequest(
HttpRequest.HttpVersion.HTTP_1_1,
RestRequest.Method.GET,
"/",
Map.of(CONNECTION, List.of(CLOSE))
);
TestHttpChannel httpChannel = new TestHttpChannel();
transport.serverAcceptedChannel(httpChannel);
transport.incomingRequest(httpRequest, httpChannel);

View file

@ -20,23 +20,48 @@ import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
class TestHttpRequest implements HttpRequest {
public class TestHttpRequest implements HttpRequest {
private final Supplier<HttpVersion> version;
private final RestRequest.Method method;
private final String uri;
private final HashMap<String, List<String>> headers = new HashMap<>();
private final Map<String, List<String>> headers;
private final HttpBody body;
TestHttpRequest(Supplier<HttpVersion> versionSupplier, RestRequest.Method method, String uri) {
public TestHttpRequest(
Supplier<HttpVersion> versionSupplier,
RestRequest.Method method,
String uri,
Map<String, List<String>> headers,
HttpBody body
) {
this.version = versionSupplier;
this.method = method;
this.uri = uri;
this.headers = headers;
this.body = body;
}
TestHttpRequest(HttpVersion version, RestRequest.Method method, String uri) {
public TestHttpRequest(RestRequest.Method method, String uri, Map<String, List<String>> headers, HttpBody body) {
this(() -> HttpVersion.HTTP_1_1, method, uri, headers, body);
}
public TestHttpRequest(RestRequest.Method method, String uri, Map<String, List<String>> headers, BytesReference body) {
this(() -> HttpVersion.HTTP_1_1, method, uri, headers, HttpBody.fromBytesReference(body));
}
public TestHttpRequest(Supplier<HttpVersion> versionSupplier, RestRequest.Method method, String uri) {
this(versionSupplier, method, uri, new HashMap<>(), HttpBody.empty());
}
public TestHttpRequest(HttpVersion version, RestRequest.Method method, String uri) {
this(() -> version, method, uri);
}
public TestHttpRequest(HttpVersion version, RestRequest.Method method, String uri, Map<String, List<String>> headers) {
this(() -> version, method, uri, headers, HttpBody.empty());
}
@Override
public RestRequest.Method method() {
return method;
@ -49,7 +74,12 @@ class TestHttpRequest implements HttpRequest {
@Override
public HttpBody body() {
return HttpBody.empty();
return body;
}
@Override
public void setBody(HttpBody body) {
throw new IllegalStateException("not allowed");
}
@Override
@ -72,6 +102,11 @@ class TestHttpRequest implements HttpRequest {
throw new UnsupportedOperationException("Do not support removing header on test request.");
}
@Override
public boolean hasContent() {
return body.isEmpty() == false;
}
@Override
public HttpResponse createResponse(RestStatus status, BytesReference content) {
return new TestHttpResponse(status, content);

View file

@ -0,0 +1,116 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.rest;
import org.elasticsearch.common.bytes.CompositeBytesReference;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.http.HttpBody;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.rest.FakeHttpBodyStream;
import org.elasticsearch.test.rest.FakeRestRequest.FakeHttpChannel;
import org.elasticsearch.test.rest.FakeRestRequest.FakeHttpRequest;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import static java.util.stream.IntStream.range;
import static org.elasticsearch.rest.RestContentAggregator.aggregate;
public class RestContentAggregatorTests extends ESTestCase {
RestRequest newRestRequest(int size) {
return RestRequest.request(
parserConfig(),
new FakeHttpRequest(
RestRequest.Method.POST,
"/",
Map.of("Content-Length", List.of(Integer.toString(size))),
HttpBody.fromBytesReference(randomBytesReference(size))
),
new FakeHttpChannel(null)
);
}
public void testFullBodyPassThrough() {
var fullRequest = newRestRequest(between(1, 1024));
var aggRef = new AtomicReference<RestRequest>();
aggregate(fullRequest, aggRef::set);
var aggRequest = aggRef.get();
assertSame(fullRequest, aggRequest);
assertSame(fullRequest.content(), aggRequest.content());
}
public void testZeroLengthStream() {
var stream = new FakeHttpBodyStream();
var request = newRestRequest(0);
request.getHttpRequest().setBody(stream);
var aggregatedRef = new AtomicReference<RestRequest>();
aggregate(request, aggregatedRef::set);
stream.sendNext(ReleasableBytesReference.empty(), true);
assertEquals(0, aggregatedRef.get().contentLength());
}
public void testAggregateRandomSize() {
var chunkSize = between(1, 1024);
var nChunks = between(1, 1000);
var stream = new FakeHttpBodyStream();
var streamChunks = range(0, nChunks).mapToObj(i -> randomReleasableBytesReference(chunkSize)).toList();
var request = newRestRequest(chunkSize * nChunks);
request.getHttpRequest().setBody(stream);
var aggregatedRef = new AtomicReference<RestRequest>();
aggregate(request, aggregatedRef::set);
for (var i = 0; i < nChunks - 1; i++) {
assertTrue(stream.isRequested());
stream.sendNext(streamChunks.get(i), false);
}
assertTrue(stream.isRequested());
stream.sendNext(streamChunks.getLast(), true);
var aggregated = aggregatedRef.get();
var expectedBytes = CompositeBytesReference.of(streamChunks.toArray(new ReleasableBytesReference[0]));
assertEquals(expectedBytes, aggregated.content());
aggregated.content().close();
}
public void testReleaseChunksOnClose() {
var chunkSize = between(1, 1024);
var nChunks = between(1, 100);
var stream = new FakeHttpBodyStream();
var request = newRestRequest(chunkSize * nChunks * 2);
request.getHttpRequest().setBody(stream);
AtomicReference<RestRequest> aggregatedRef = new AtomicReference<>();
aggregate(request, aggregatedRef::set);
// buffered chunks, must be released after close()
var chunksBeforeClose = range(0, nChunks).mapToObj(i -> randomReleasableBytesReference(chunkSize)).toList();
for (var chunk : chunksBeforeClose) {
assertTrue(stream.isRequested());
stream.sendNext(chunk, false);
}
stream.close();
assertFalse(chunksBeforeClose.stream().anyMatch(ReleasableBytesReference::hasReferences));
// non-buffered, must be released on arrival
var chunksAfterClose = range(0, nChunks).mapToObj(i -> randomReleasableBytesReference(chunkSize)).toList();
for (var chunk : chunksAfterClose) {
assertTrue(stream.isRequested());
stream.sendNext(chunk, false);
}
assertFalse(chunksAfterClose.stream().anyMatch(ReleasableBytesReference::hasReferences));
assertNull(aggregatedRef.get());
}
}

View file

@ -879,6 +879,11 @@ public class RestControllerTests extends ESTestCase {
return HttpBody.empty();
}
@Override
public void setBody(HttpBody body) {
throw new IllegalStateException("not allowed");
}
@Override
public Map<String, List<String>> getHeaders() {
Map<String, List<String>> headers = new HashMap<>();
@ -903,6 +908,11 @@ public class RestControllerTests extends ESTestCase {
return this;
}
@Override
public boolean hasContent() {
return hasContent;
}
@Override
public HttpResponse createResponse(RestStatus status, BytesReference content) {
return null;

View file

@ -14,9 +14,8 @@ import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.http.HttpBody;
import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.http.HttpRequest;
import org.elasticsearch.http.TestHttpRequest;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.rest.FakeRestRequest;
import org.elasticsearch.xcontent.NamedXContentRegistry;
@ -41,7 +40,6 @@ import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class RestRequestTests extends ESTestCase {
@ -86,11 +84,11 @@ public class RestRequestTests extends ESTestCase {
}
private <T extends Exception> void runConsumesContentTest(final CheckedConsumer<RestRequest, T> consumer, final boolean expected) {
final HttpRequest httpRequest = mock(HttpRequest.class);
when(httpRequest.uri()).thenReturn("");
when(httpRequest.body()).thenReturn(HttpBody.fromBytesReference(new BytesArray(new byte[1])));
when(httpRequest.getHeaders()).thenReturn(
Collections.singletonMap("Content-Type", Collections.singletonList(randomFrom("application/json", "application/x-ndjson")))
final var httpRequest = new TestHttpRequest(
RestRequest.Method.GET,
"/",
Map.of("Content-Type", List.of(randomFrom("application/json", "application/x-ndjson"))),
new BytesArray(new byte[1])
);
final RestRequest request = RestRequest.request(XContentParserConfiguration.EMPTY, httpRequest, mock(HttpChannel.class));
assertFalse(request.isContentConsumed());

View file

@ -20,6 +20,7 @@ import org.elasticsearch.action.update.UpdateRequest;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.http.HttpBody;
@ -66,6 +67,7 @@ public class RestBulkActionTests extends ESTestCase {
params.put("pipeline", "timestamps");
new RestBulkAction(
settings(IndexVersion.current()).build(),
ClusterSettings.createBuiltInClusterSettings(),
new IncrementalBulkService(mock(Client.class), mock(IndexingPressure.class))
).handleRequest(
new FakeRestRequest.Builder(xContentRegistry()).withPath("my_index/_bulk").withParams(params).withContent(new BytesArray("""
@ -101,6 +103,7 @@ public class RestBulkActionTests extends ESTestCase {
{
new RestBulkAction(
settings(IndexVersion.current()).build(),
ClusterSettings.createBuiltInClusterSettings(),
new IncrementalBulkService(mock(Client.class), mock(IndexingPressure.class))
).handleRequest(
new FakeRestRequest.Builder(xContentRegistry()).withPath("my_index/_bulk")
@ -125,6 +128,7 @@ public class RestBulkActionTests extends ESTestCase {
bulkCalled.set(false);
new RestBulkAction(
settings(IndexVersion.current()).build(),
ClusterSettings.createBuiltInClusterSettings(),
new IncrementalBulkService(mock(Client.class), mock(IndexingPressure.class))
).handleRequest(
new FakeRestRequest.Builder(xContentRegistry()).withPath("my_index/_bulk")
@ -148,6 +152,7 @@ public class RestBulkActionTests extends ESTestCase {
bulkCalled.set(false);
new RestBulkAction(
settings(IndexVersion.current()).build(),
ClusterSettings.createBuiltInClusterSettings(),
new IncrementalBulkService(mock(Client.class), mock(IndexingPressure.class))
).handleRequest(
new FakeRestRequest.Builder(xContentRegistry()).withPath("my_index/_bulk")
@ -172,6 +177,7 @@ public class RestBulkActionTests extends ESTestCase {
bulkCalled.set(false);
new RestBulkAction(
settings(IndexVersion.current()).build(),
ClusterSettings.createBuiltInClusterSettings(),
new IncrementalBulkService(mock(Client.class), mock(IndexingPressure.class))
).handleRequest(
new FakeRestRequest.Builder(xContentRegistry()).withPath("my_index/_bulk")

View file

@ -11,13 +11,18 @@ package org.elasticsearch.http;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestContentAggregator;
import org.elasticsearch.rest.RestRequest;
public class NullDispatcher implements HttpServerTransport.Dispatcher {
public class AggregatingDispatcher implements HttpServerTransport.Dispatcher {
public void dispatchAggregatedRequest(RestRequest restRequest, RestChannel restChannel, ThreadContext threadContext) {
assert restRequest.isStreamedContent();
}
@Override
public void dispatchRequest(RestRequest request, RestChannel channel, ThreadContext threadContext) {
public final void dispatchRequest(RestRequest request, RestChannel channel, ThreadContext threadContext) {
RestContentAggregator.aggregate(request, (r) -> dispatchAggregatedRequest(r, channel, threadContext));
}
@Override

View file

@ -70,6 +70,7 @@ import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.bytes.CompositeBytesReference;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
@ -98,6 +99,7 @@ import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.AbstractRefCounted;
import org.elasticsearch.core.Booleans;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.core.CheckedRunnable;
@ -998,6 +1000,13 @@ public abstract class ESTestCase extends LuceneTestCase {
return CompositeBytesReference.of(slices.toArray(BytesReference[]::new));
}
public ReleasableBytesReference randomReleasableBytesReference(int length) {
return new ReleasableBytesReference(randomBytesReference(length), LeakTracker.wrap(new AbstractRefCounted() {
@Override
protected void closeInternal() {}
}));
}
public static short randomShort() {
return (short) random().nextInt();
}

View file

@ -0,0 +1,78 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.test.rest;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.http.HttpBody;
import java.util.ArrayList;
import java.util.List;
public class FakeHttpBodyStream implements HttpBody.Stream {
private final List<ChunkHandler> tracingHandlers = new ArrayList<>();
private ChunkHandler handler;
private boolean requested;
private boolean closed;
public boolean isClosed() {
return closed;
}
public boolean isRequested() {
return requested;
}
@Override
public ChunkHandler handler() {
return handler;
}
@Override
public void addTracingHandler(ChunkHandler chunkHandler) {
tracingHandlers.add(chunkHandler);
}
@Override
public void setHandler(ChunkHandler chunkHandler) {
this.handler = chunkHandler;
}
@Override
public void next() {
if (closed) {
return;
}
requested = true;
}
public void sendNext(ReleasableBytesReference chunk, boolean isLast) {
if (requested) {
for (var h : tracingHandlers) {
h.onNext(chunk, isLast);
}
handler.onNext(chunk, isLast);
} else {
throw new IllegalStateException("chunk is not requested");
}
}
@Override
public void close() {
if (closed == false) {
closed = true;
for (var h : tracingHandlers) {
h.close();
}
if (handler != null) {
handler.close();
}
}
}
}

View file

@ -55,24 +55,22 @@ public class FakeRestRequest extends RestRequest {
private final Method method;
private final String uri;
private final HttpBody content;
private final Map<String, List<String>> headers;
private HttpBody body;
private final Exception inboundException;
public FakeHttpRequest(Method method, String uri, BytesReference content, Map<String, List<String>> headers) {
this(method, uri, content == null ? HttpBody.empty() : HttpBody.fromBytesReference(content), headers, null);
public FakeHttpRequest(Method method, String uri, BytesReference body, Map<String, List<String>> headers) {
this(method, uri, body == null ? HttpBody.empty() : HttpBody.fromBytesReference(body), headers, null);
}
private FakeHttpRequest(
Method method,
String uri,
HttpBody content,
Map<String, List<String>> headers,
Exception inboundException
) {
public FakeHttpRequest(Method method, String uri, Map<String, List<String>> headers, HttpBody body) {
this(method, uri, body, headers, null);
}
private FakeHttpRequest(Method method, String uri, HttpBody body, Map<String, List<String>> headers, Exception inboundException) {
this.method = method;
this.uri = uri;
this.content = content;
this.body = body;
this.headers = headers;
this.inboundException = inboundException;
}
@ -89,7 +87,12 @@ public class FakeRestRequest extends RestRequest {
@Override
public HttpBody body() {
return content;
return body;
}
@Override
public void setBody(HttpBody body) {
this.body = body;
}
@Override
@ -111,7 +114,12 @@ public class FakeRestRequest extends RestRequest {
public HttpRequest removeHeader(String header) {
final var filteredHeaders = new HashMap<>(headers);
filteredHeaders.remove(header);
return new FakeHttpRequest(method, uri, content, filteredHeaders, inboundException);
return new FakeHttpRequest(method, uri, body, filteredHeaders, inboundException);
}
@Override
public boolean hasContent() {
return body.isEmpty() == false;
}
@Override

View file

@ -32,7 +32,7 @@ import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.env.TestEnvironment;
import org.elasticsearch.http.AbstractHttpServerTransportTestCase;
import org.elasticsearch.http.HttpServerTransport;
import org.elasticsearch.http.AggregatingDispatcher;
import org.elasticsearch.http.netty4.Netty4HttpServerTransport;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestRequest;
@ -255,13 +255,13 @@ public class SecurityNetty4HttpServerTransportCloseNotifyTests extends AbstractH
}
}
private static class QueuedDispatcher implements HttpServerTransport.Dispatcher {
private static class QueuedDispatcher extends AggregatingDispatcher {
BlockingQueue<ReqCtx> reqQueue = new LinkedBlockingDeque<>();
BlockingDeque<ErrCtx> errQueue = new LinkedBlockingDeque<>();
@Override
public void dispatchRequest(RestRequest request, RestChannel channel, ThreadContext threadContext) {
reqQueue.add(new ReqCtx(request, channel, threadContext));
public void dispatchAggregatedRequest(RestRequest restRequest, RestChannel restChannel, ThreadContext threadContext) {
reqQueue.add(new ReqCtx(restRequest, restChannel, threadContext));
}
@Override

View file

@ -31,11 +31,11 @@ import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.env.Environment;
import org.elasticsearch.env.TestEnvironment;
import org.elasticsearch.http.AbstractHttpServerTransportTestCase;
import org.elasticsearch.http.AggregatingDispatcher;
import org.elasticsearch.http.HttpHeadersValidationException;
import org.elasticsearch.http.HttpRequest;
import org.elasticsearch.http.HttpServerTransport;
import org.elasticsearch.http.HttpTransportSettings;
import org.elasticsearch.http.NullDispatcher;
import org.elasticsearch.http.netty4.Netty4FullHttpResponse;
import org.elasticsearch.http.netty4.Netty4HttpServerTransport;
import org.elasticsearch.http.netty4.internal.HttpHeadersAuthenticatorUtils;
@ -111,7 +111,7 @@ public class SecurityNetty4HttpServerTransportTests extends AbstractHttpServerTr
new NetworkService(Collections.emptyList()),
mock(ThreadPool.class),
xContentRegistry(),
new NullDispatcher(),
new AggregatingDispatcher(),
randomClusterSettings(),
new SharedGroupFactory(settings),
Tracer.NOOP,
@ -138,7 +138,7 @@ public class SecurityNetty4HttpServerTransportTests extends AbstractHttpServerTr
new NetworkService(Collections.emptyList()),
mock(ThreadPool.class),
xContentRegistry(),
new NullDispatcher(),
new AggregatingDispatcher(),
randomClusterSettings(),
new SharedGroupFactory(settings),
Tracer.NOOP,
@ -165,7 +165,7 @@ public class SecurityNetty4HttpServerTransportTests extends AbstractHttpServerTr
new NetworkService(Collections.emptyList()),
mock(ThreadPool.class),
xContentRegistry(),
new NullDispatcher(),
new AggregatingDispatcher(),
randomClusterSettings(),
new SharedGroupFactory(settings),
Tracer.NOOP,
@ -192,7 +192,7 @@ public class SecurityNetty4HttpServerTransportTests extends AbstractHttpServerTr
new NetworkService(Collections.emptyList()),
mock(ThreadPool.class),
xContentRegistry(),
new NullDispatcher(),
new AggregatingDispatcher(),
randomClusterSettings(),
new SharedGroupFactory(settings),
Tracer.NOOP,
@ -214,7 +214,7 @@ public class SecurityNetty4HttpServerTransportTests extends AbstractHttpServerTr
new NetworkService(Collections.emptyList()),
mock(ThreadPool.class),
xContentRegistry(),
new NullDispatcher(),
new AggregatingDispatcher(),
randomClusterSettings(),
new SharedGroupFactory(settings),
Tracer.NOOP,
@ -237,7 +237,7 @@ public class SecurityNetty4HttpServerTransportTests extends AbstractHttpServerTr
new NetworkService(Collections.emptyList()),
mock(ThreadPool.class),
xContentRegistry(),
new NullDispatcher(),
new AggregatingDispatcher(),
randomClusterSettings(),
new SharedGroupFactory(settings),
Tracer.NOOP,
@ -269,7 +269,7 @@ public class SecurityNetty4HttpServerTransportTests extends AbstractHttpServerTr
new NetworkService(Collections.emptyList()),
mock(ThreadPool.class),
xContentRegistry(),
new NullDispatcher(),
new AggregatingDispatcher(),
randomClusterSettings(),
new SharedGroupFactory(settings),
Tracer.NOOP,