handle 100-continue and oversized streaming request (#112179)

This commit is contained in:
Mikhail Berezovskiy 2024-08-26 13:10:52 -07:00 committed by Tim Brooks
parent 478baf1459
commit 1b77421cf8
2 changed files with 167 additions and 24 deletions

View file

@ -11,6 +11,7 @@ package org.elasticsearch.http.netty4;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufInputStream;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
@ -25,11 +26,16 @@ import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.DefaultLastHttpContent;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpChunkedInput;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpContent;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpUtil;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.handler.stream.ChunkedStream;
import io.netty.handler.stream.ChunkedWriteHandler;
import org.elasticsearch.ESNetty4IntegTestCase;
import org.elasticsearch.action.support.SubscribableListener;
@ -42,9 +48,13 @@ import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.IndexScopedSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.settings.SettingsFilter;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.CollectionUtils;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.http.HttpHandlingSettings;
import org.elasticsearch.http.HttpServerTransport;
import org.elasticsearch.http.HttpTransportSettings;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.BaseRestHandler;
@ -62,9 +72,7 @@ import java.util.Collection;
import java.util.List;
import java.util.concurrent.BlockingDeque;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;
import java.util.function.Predicate;
@ -79,6 +87,13 @@ import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;
@ESIntegTestCase.ClusterScope(numDataNodes = 1)
public class Netty4IncrementalRequestHandlingIT extends ESNetty4IntegTestCase {
@Override
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
Settings.Builder builder = Settings.builder().put(super.nodeSettings(nodeOrdinal, otherSettings));
builder.put(HttpTransportSettings.SETTING_HTTP_MAX_CONTENT_LENGTH.getKey(), new ByteSizeValue(50, ByteSizeUnit.MB));
return builder.build();
}
// ensure empty http content has single 0 size chunk
public void testEmptyContent() throws Exception {
try (var ctx = setupClientCtx()) {
@ -112,7 +127,7 @@ public class Netty4IncrementalRequestHandlingIT extends ESNetty4IntegTestCase {
var opaqueId = opaqueId(reqNo);
// this dataset will be compared with one on server side
var dataSize = randomIntBetween(1024, 10 * 1024 * 1024);
var dataSize = randomIntBetween(1024, maxContentLength());
var sendData = Unpooled.wrappedBuffer(randomByteArrayOfLength(dataSize));
sendData.retain();
ctx.clientChannel.writeAndFlush(fullHttpRequest(opaqueId, sendData));
@ -213,12 +228,98 @@ public class Netty4IncrementalRequestHandlingIT extends ESNetty4IntegTestCase {
bufSize >= minBufSize && bufSize <= maxBufSize
);
});
handler.consumeBytes(MBytes(10));
handler.readBytes(MBytes(10));
}
assertTrue(handler.stream.hasLast());
}
}
// ensures that server reply 100-continue on acceptable request size
public void test100Continue() throws Exception {
try (var ctx = setupClientCtx()) {
for (int reqNo = 0; reqNo < randomIntBetween(2, 10); reqNo++) {
var id = opaqueId(reqNo);
var acceptableContentLength = randomIntBetween(0, maxContentLength());
// send request header and await 100-continue
var req = httpRequest(id, acceptableContentLength);
HttpUtil.set100ContinueExpected(req, true);
ctx.clientChannel.writeAndFlush(req);
var resp = (FullHttpResponse) safePoll(ctx.clientRespQueue);
assertEquals(HttpResponseStatus.CONTINUE, resp.status());
resp.release();
// send content
var content = randomContent(acceptableContentLength, true);
ctx.clientChannel.writeAndFlush(content);
// consume content and reply 200
var handler = ctx.awaitRestChannelAccepted(id);
var consumed = handler.readAllBytes();
assertEquals(acceptableContentLength, consumed);
handler.sendResponse(new RestResponse(RestStatus.OK, ""));
resp = (FullHttpResponse) safePoll(ctx.clientRespQueue);
assertEquals(HttpResponseStatus.OK, resp.status());
resp.release();
}
}
}
// ensures that server reply 413-too-large on oversized request with expect-100-continue
public void test413TooLargeOnExpect100Continue() throws Exception {
try (var ctx = setupClientCtx()) {
for (int reqNo = 0; reqNo < randomIntBetween(2, 10); reqNo++) {
var id = opaqueId(reqNo);
var oversized = maxContentLength() + 1;
// send request header and await 413 too large
var req = httpRequest(id, oversized);
HttpUtil.set100ContinueExpected(req, true);
ctx.clientChannel.writeAndFlush(req);
var resp = (FullHttpResponse) safePoll(ctx.clientRespQueue);
assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, resp.status());
resp.release();
// terminate request
ctx.clientChannel.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT);
}
}
}
// ensures that oversized chunked encoded request has no limits at http layer
// rest handler is responsible for oversized requests
public void testOversizedChunkedEncodingNoLimits() throws Exception {
try (var ctx = setupClientCtx()) {
for (var reqNo = 0; reqNo < randomIntBetween(2, 10); reqNo++) {
var id = opaqueId(reqNo);
var contentSize = maxContentLength() + 1;
var content = randomByteArrayOfLength(contentSize);
var is = new ByteBufInputStream(Unpooled.wrappedBuffer(content));
var chunkedIs = new ChunkedStream(is);
var httpChunkedIs = new HttpChunkedInput(chunkedIs, LastHttpContent.EMPTY_LAST_CONTENT);
var req = httpRequest(id, 0);
HttpUtil.setTransferEncodingChunked(req, true);
ctx.clientChannel.pipeline().addLast(new ChunkedWriteHandler());
ctx.clientChannel.writeAndFlush(req);
ctx.clientChannel.writeAndFlush(httpChunkedIs);
var handler = ctx.awaitRestChannelAccepted(id);
var consumed = handler.readAllBytes();
assertEquals(contentSize, consumed);
handler.sendResponse(new RestResponse(RestStatus.OK, ""));
var resp = (FullHttpResponse) safePoll(ctx.clientRespQueue);
assertEquals(HttpResponseStatus.OK, resp.status());
resp.release();
}
}
}
private int maxContentLength() {
return HttpHandlingSettings.fromSettings(internalCluster().getInstance(Settings.class)).maxContentLength();
}
private String opaqueId(int reqNo) {
return getTestName() + "-" + reqNo;
}
@ -369,24 +470,25 @@ public class Netty4IncrementalRequestHandlingIT extends ESNetty4IntegTestCase {
channel.sendResponse(response);
}
void consumeBytes(int bytes) {
if (recvLast) {
return;
}
while (bytes > 0) {
stream.next();
var recvChunk = safePoll(recvChunks);
bytes -= recvChunk.chunk.length();
recvChunk.chunk.close();
if (recvChunk.isLast) {
recvLast = true;
break;
int readBytes(int bytes) {
var consumed = 0;
if (recvLast == false) {
while (consumed < bytes) {
stream.next();
var recvChunk = safePoll(recvChunks);
consumed += recvChunk.chunk.length();
recvChunk.chunk.close();
if (recvChunk.isLast) {
recvLast = true;
break;
}
}
}
return consumed;
}
Future<?> onChannelThread(Callable<?> task) {
return this.stream.channel().eventLoop().submit(task);
int readAllBytes() {
return readBytes(Integer.MAX_VALUE);
}
record Chunk(ReleasableBytesReference chunk, boolean isLast) {}

View file

@ -9,19 +9,32 @@
package org.elasticsearch.http.netty4;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpContent;
import io.netty.handler.codec.http.HttpObject;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpUtil;
import org.elasticsearch.http.HttpPreRequest;
import org.elasticsearch.http.netty4.internal.HttpHeadersAuthenticatorUtils;
import java.util.function.Predicate;
/**
* A wrapper around {@link HttpObjectAggregator}. Provides optional content aggregation based on
* predicate. {@link HttpObjectAggregator} also handles Expect: 100-continue and oversized content.
* Unfortunately, Netty does not provide handlers for oversized messages beyond HttpObjectAggregator.
*/
public class Netty4HttpAggregator extends HttpObjectAggregator {
private static final Predicate<HttpPreRequest> IGNORE_TEST = (req) -> req.uri().startsWith("/_test/request-stream") == false;
private final Predicate<HttpPreRequest> decider;
private boolean shouldAggregate;
private boolean aggregating = true;
private boolean ignoreContentAfterContinueResponse = false;
public Netty4HttpAggregator(int maxContentLength) {
this(maxContentLength, IGNORE_TEST);
@ -33,15 +46,43 @@ public class Netty4HttpAggregator extends HttpObjectAggregator {
}
@Override
public boolean acceptInboundMessage(Object msg) throws Exception {
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
assert msg instanceof HttpObject;
if (msg instanceof HttpRequest request) {
var preReq = HttpHeadersAuthenticatorUtils.asHttpPreRequest(request);
shouldAggregate = decider.test(preReq);
aggregating = decider.test(preReq);
}
if (shouldAggregate) {
return super.acceptInboundMessage(msg);
if (aggregating || msg instanceof FullHttpRequest) {
super.channelRead(ctx, msg);
} else {
return false;
handle(ctx, (HttpObject) msg);
}
}
private void handle(ChannelHandlerContext ctx, HttpObject msg) {
if (msg instanceof HttpRequest request) {
var continueResponse = newContinueResponse(request, maxContentLength(), ctx.pipeline());
if (continueResponse != null) {
// there are 3 responses expected: 100, 413, 417
// on 100 we pass request further and reply to client to continue
// on 413/417 we ignore following content
ctx.writeAndFlush(continueResponse);
var resp = (FullHttpResponse) continueResponse;
if (resp.status() != HttpResponseStatus.CONTINUE) {
ignoreContentAfterContinueResponse = true;
return;
}
HttpUtil.set100ContinueExpected(request, false);
}
ignoreContentAfterContinueResponse = false;
ctx.fireChannelRead(msg);
} else {
var httpContent = (HttpContent) msg;
if (ignoreContentAfterContinueResponse) {
httpContent.release();
} else {
ctx.fireChannelRead(msg);
}
}
}
}