[ML] Retry on streaming errors (#123076)

We now always retry based on the provider's configured retry logic
rather than the HTTP status code. Some providers (e.g. Cohere,
Anthropic) will return 200 status codes with error bodies, others (e.g.
OpenAI, Azure) will return non-200 status codes with non-streaming
bodies.

Notes:
- Refactored from HttpResult to StreamingHttpResult, the byte body is
  now the streaming element while the http response lives outside the
  stream.
- Refactored StreamingHttpResultPublisher so that it only pushes byte
  body into a queue.
- Tests all now have to wait for the response to be fully consumed
  before closing the service, otherwise the close method will shut down
  the mock web server and apache will throw an error.
This commit is contained in:
Pat Whelan 2025-03-04 09:33:00 -05:00 committed by GitHub
parent 6a36f31f41
commit dfe2adb592
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 548 additions and 553 deletions

View file

@ -0,0 +1,5 @@
pr: 123076
summary: Retry on streaming errors
area: Machine Learning
type: bug
issues: []

View file

@ -27,7 +27,6 @@ import java.io.Closeable;
import java.io.IOException;
import java.util.Objects;
import java.util.concurrent.CancellationException;
import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicReference;
import static org.elasticsearch.core.Strings.format;
@ -154,7 +153,7 @@ public class HttpClient implements Closeable {
threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> listener.onFailure(exception));
}
public void stream(HttpRequest request, HttpContext context, ActionListener<Flow.Publisher<HttpResult>> listener) throws IOException {
public void stream(HttpRequest request, HttpContext context, ActionListener<StreamingHttpResult> listener) throws IOException {
// The caller must call start() first before attempting to send a request
assert status.get() == Status.STARTED : "call start() before attempting to send a request";
@ -162,7 +161,7 @@ public class HttpClient implements Closeable {
SocketAccess.doPrivileged(() -> client.execute(request.requestProducer(), streamingProcessor, context, new FutureCallback<>() {
@Override
public void completed(HttpResponse response) {
public void completed(Void response) {
streamingProcessor.close();
}

View file

@ -0,0 +1,76 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.http;
import org.apache.http.HttpResponse;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.rest.RestStatus;
import java.io.ByteArrayOutputStream;
import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicReference;
public record StreamingHttpResult(HttpResponse response, Flow.Publisher<byte[]> body) {
public boolean isSuccessfulResponse() {
return RestStatus.isSuccessful(response.getStatusLine().getStatusCode());
}
public Flow.Publisher<HttpResult> toHttpResult() {
return subscriber -> body().subscribe(new Flow.Subscriber<>() {
@Override
public void onSubscribe(Flow.Subscription subscription) {
subscriber.onSubscribe(subscription);
}
@Override
public void onNext(byte[] item) {
subscriber.onNext(new HttpResult(response(), item));
}
@Override
public void onError(Throwable throwable) {
subscriber.onError(throwable);
}
@Override
public void onComplete() {
subscriber.onComplete();
}
});
}
public void readFullResponse(ActionListener<HttpResult> fullResponse) {
var stream = new ByteArrayOutputStream();
AtomicReference<Flow.Subscription> upstream = new AtomicReference<>(null);
body.subscribe(new Flow.Subscriber<>() {
@Override
public void onSubscribe(Flow.Subscription subscription) {
upstream.set(subscription);
upstream.get().request(1);
}
@Override
public void onNext(byte[] item) {
stream.writeBytes(item);
upstream.get().request(1);
}
@Override
public void onError(Throwable throwable) {
ExceptionsHelper.maybeError(throwable).ifPresent(ExceptionsHelper::maybeDieOnAnotherThread);
fullResponse.onFailure(new RuntimeException("Fatal while fully consuming stream", throwable));
}
@Override
public void onComplete() {
fullResponse.onResponse(new HttpResult(response, stream.toByteArray()));
}
});
}
}

View file

@ -13,6 +13,7 @@ import org.apache.http.nio.IOControl;
import org.apache.http.nio.protocol.HttpAsyncResponseConsumer;
import org.apache.http.nio.util.SimpleInputBuffer;
import org.apache.http.protocol.HttpContext;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.threadpool.ThreadPool;
@ -39,51 +40,31 @@ import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_P
* so this publisher will send a single HttpResult. If the HttpResponse is healthy, Apache will send an HttpResponse with or without
* the HttpEntity.</p>
*/
class StreamingHttpResultPublisher implements HttpAsyncResponseConsumer<HttpResponse>, Flow.Publisher<HttpResult> {
private final HttpSettings settings;
private final ActionListener<Flow.Publisher<HttpResult>> listener;
class StreamingHttpResultPublisher implements HttpAsyncResponseConsumer<Void> {
private final ActionListener<StreamingHttpResult> listener;
private final AtomicBoolean listenerCalled = new AtomicBoolean(false);
// used to manage the HTTP response
private volatile HttpResponse response;
private volatile Exception ex;
// used to control the state of this publisher (Apache) and its interaction with its subscriber
private final AtomicBoolean isDone = new AtomicBoolean(false);
private final AtomicBoolean subscriptionCanceled = new AtomicBoolean(false);
private volatile Flow.Subscriber<? super HttpResult> subscriber;
private final RequestBasedTaskRunner taskRunner;
private final AtomicBoolean pendingRequest = new AtomicBoolean(false);
private final Deque<Runnable> queue = new ConcurrentLinkedDeque<>();
// used to control the flow of data from the Apache client, if we're producing more bytes than we can consume then we'll pause
private final SimpleInputBuffer inputBuffer = new SimpleInputBuffer(4096);
private final AtomicLong bytesInQueue = new AtomicLong(0);
private final Object ioLock = new Object();
private volatile IOControl savedIoControl;
private final DataPublisher publisher;
private final ApacheClientBackpressure backpressure;
StreamingHttpResultPublisher(ThreadPool threadPool, HttpSettings settings, ActionListener<Flow.Publisher<HttpResult>> listener) {
this.settings = Objects.requireNonNull(settings);
private volatile Exception exception;
StreamingHttpResultPublisher(ThreadPool threadPool, HttpSettings settings, ActionListener<StreamingHttpResult> listener) {
this.listener = ActionListener.notifyOnce(Objects.requireNonNull(listener));
this.taskRunner = new RequestBasedTaskRunner(new OffloadThread(), threadPool, UTILITY_THREAD_POOL_NAME);
this.publisher = new DataPublisher(threadPool);
this.backpressure = new ApacheClientBackpressure(Objects.requireNonNull(settings));
}
@Override
public void responseReceived(HttpResponse httpResponse) {
this.response = httpResponse;
if (listenerCalled.compareAndSet(false, true)) {
listener.onResponse(new StreamingHttpResult(httpResponse, publisher));
}
@Override
public void subscribe(Flow.Subscriber<? super HttpResult> subscriber) {
if (this.subscriber != null) {
subscriber.onError(new IllegalStateException("Only one subscriber is allowed for this Publisher."));
return;
}
this.subscriber = subscriber;
subscriber.onSubscribe(new HttpSubscription());
}
@Override
@ -100,49 +81,20 @@ class StreamingHttpResultPublisher implements HttpAsyncResponseConsumer<HttpResp
if (consumed > 0) {
var allBytes = new byte[consumed];
inputBuffer.read(allBytes);
queue.offer(() -> {
subscriber.onNext(new HttpResult(response, allBytes));
var currentBytesInQueue = bytesInQueue.updateAndGet(current -> Long.max(0, current - allBytes.length));
if (savedIoControl != null) {
var maxBytes = settings.getMaxResponseSize().getBytes() * 0.5;
if (currentBytesInQueue <= maxBytes) {
resumeProducer();
}
}
});
// always check if totalByteSize > the configured setting in case the settings change
if (bytesInQueue.accumulateAndGet(allBytes.length, Long::sum) >= settings.getMaxResponseSize().getBytes()) {
pauseProducer(ioControl);
}
taskRunner.requestNextRun();
if (listenerCalled.compareAndSet(false, true)) {
listener.onResponse(this);
}
backpressure.addBytesAndMaybePause(consumed, ioControl);
publisher.onNext(allBytes);
}
} catch (Exception e) {
// if the provider closes the connection in the middle of the stream,
// the contentDecoder will throw an exception trying to read the payload,
// we should catch that and forward it downstream so we can properly handle it
exception = e;
publisher.onError(e);
} finally {
inputBuffer.reset();
}
}
private void pauseProducer(IOControl ioControl) {
ioControl.suspendInput();
synchronized (ioLock) {
savedIoControl = ioControl;
}
}
private void resumeProducer() {
synchronized (ioLock) {
if (savedIoControl != null) {
savedIoControl.requestInput();
savedIoControl = null;
}
}
}
@Override
public void responseCompleted(HttpContext httpContext) {}
@ -153,9 +105,8 @@ class StreamingHttpResultPublisher implements HttpAsyncResponseConsumer<HttpResp
if (listenerCalled.compareAndSet(false, true)) {
listener.onFailure(e);
} else {
ex = e;
queue.offer(() -> subscriber.onError(e));
taskRunner.requestNextRun();
exception = e;
publisher.onError(e);
}
}
}
@ -164,8 +115,7 @@ class StreamingHttpResultPublisher implements HttpAsyncResponseConsumer<HttpResp
@Override
public void close() {
if (isDone.compareAndSet(false, true)) {
queue.offer(() -> subscriber.onComplete());
taskRunner.requestNextRun();
publisher.onComplete();
}
}
@ -178,12 +128,12 @@ class StreamingHttpResultPublisher implements HttpAsyncResponseConsumer<HttpResp
@Override
public Exception getException() {
return ex;
return exception;
}
@Override
public HttpResponse getResult() {
return response;
public Void getResult() {
return null;
}
@Override
@ -191,15 +141,58 @@ class StreamingHttpResultPublisher implements HttpAsyncResponseConsumer<HttpResp
return isDone.get();
}
private class HttpSubscription implements Flow.Subscription {
/**
* We only want to push payload data when the client is ready to receive it, so the client will use
* {@link Flow.Subscription#request(long)} to request more data. We collect the payload bytes in a queue and process them on a
* separate thread from both the Apache IO thread reading from the provider and the client's transport thread requesting more data.
* Clients use {@link Flow.Subscription#cancel()} to exit early, and we'll forward that cancellation to the provider.
*/
private class DataPublisher implements Flow.Processor<byte[], byte[]> {
private final RequestBasedTaskRunner taskRunner;
private final Deque<byte[]> contentQueue = new ConcurrentLinkedDeque<>();
private final AtomicLong pendingRequests = new AtomicLong(0);
private volatile Exception pendingError = null;
private volatile boolean completed = false;
private volatile Flow.Subscriber<? super byte[]> downstream;
private DataPublisher(ThreadPool threadPool) {
this.taskRunner = new RequestBasedTaskRunner(this::sendToSubscriber, threadPool, UTILITY_THREAD_POOL_NAME);
}
private void sendToSubscriber() {
if (downstream == null) {
return;
}
if (pendingRequests.get() > 0 && pendingError != null) {
pendingRequests.decrementAndGet();
downstream.onError(pendingError);
return;
}
byte[] nextBytes;
while (pendingRequests.get() > 0 && (nextBytes = contentQueue.poll()) != null) {
pendingRequests.decrementAndGet();
backpressure.subtractBytesAndMaybeUnpause(nextBytes.length);
downstream.onNext(nextBytes);
}
if (pendingRequests.get() > 0 && contentQueue.isEmpty() && completed) {
pendingRequests.decrementAndGet();
downstream.onComplete();
}
}
@Override
public void request(long n) {
if (subscriptionCanceled.get()) {
public void subscribe(Flow.Subscriber<? super byte[]> subscriber) {
if (this.downstream != null) {
subscriber.onError(new IllegalStateException("Only one subscriber is allowed for this Publisher."));
return;
}
this.downstream = subscriber;
downstream.onSubscribe(new Flow.Subscription() {
@Override
public void request(long n) {
if (n > 0) {
pendingRequest.set(true);
pendingRequests.addAndGet(n);
taskRunner.requestNextRun();
} else {
// per Subscription's spec, fail the subscriber and stop the processor
@ -214,21 +207,82 @@ class StreamingHttpResultPublisher implements HttpAsyncResponseConsumer<HttpResp
taskRunner.cancel();
}
}
});
}
private class OffloadThread implements Runnable {
@Override
public void run() {
if (subscriptionCanceled.get()) {
return;
public void onNext(byte[] item) {
contentQueue.offer(item);
taskRunner.requestNextRun();
}
if (queue.isEmpty() == false && pendingRequest.compareAndSet(true, false)) {
var next = queue.poll();
if (next != null) {
next.run();
@Override
public void onError(Throwable throwable) {
if (throwable instanceof Exception e) {
pendingError = e;
} else {
pendingRequest.set(true);
ExceptionsHelper.maybeError(throwable).ifPresent(ExceptionsHelper::maybeDieOnAnotherThread);
pendingError = new RuntimeException("Unhandled error while streaming");
}
taskRunner.requestNextRun();
}
@Override
public void onComplete() {
completed = true;
taskRunner.requestNextRun();
}
@Override
public void onSubscribe(Flow.Subscription subscription) {
assert false : "Apache never calls this.";
throw new UnsupportedOperationException("Apache never calls this.");
}
}
/**
* We want to keep track of how much memory we are consuming while reading the payload from the external provider. Apache continuously
* pushes payload data to us, whereas the client only requests the next set of bytes when they are ready, so we want to track how much
* data we are holding in memory and potentially pause the Apache client if we have reached our limit.
*/
private static class ApacheClientBackpressure {
private final HttpSettings settings;
private final AtomicLong bytesInQueue = new AtomicLong(0);
private final Object ioLock = new Object();
private volatile IOControl savedIoControl;
private ApacheClientBackpressure(HttpSettings settings) {
this.settings = settings;
}
private void addBytesAndMaybePause(long count, IOControl ioControl) {
if (bytesInQueue.addAndGet(count) >= settings.getMaxResponseSize().getBytes()) {
pauseProducer(ioControl);
}
}
private void pauseProducer(IOControl ioControl) {
ioControl.suspendInput();
synchronized (ioLock) {
savedIoControl = ioControl;
}
}
private void subtractBytesAndMaybeUnpause(long count) {
var currentBytesInQueue = bytesInQueue.updateAndGet(current -> Long.max(0, current - count));
if (savedIoControl != null) {
var maxBytes = settings.getMaxResponseSize().getBytes() * 0.5;
if (currentBytesInQueue <= maxBytes) {
resumeProducer();
}
}
}
private void resumeProducer() {
synchronized (ioLock) {
if (savedIoControl != null) {
savedIoControl.requestInput();
savedIoControl = null;
}
}
}

View file

@ -116,9 +116,20 @@ public class RetryingHttpSender implements RequestSender {
try {
if (request.isStreaming() && responseHandler.canHandleStreamingResponses()) {
httpClient.stream(request.createHttpRequest(), context, retryableListener.delegateFailure((l, r) -> {
var streamingResponseHandler = new StreamingResponseHandler(throttlerManager, logger, request, responseHandler);
r.subscribe(streamingResponseHandler);
l.onResponse(responseHandler.parseResult(request, streamingResponseHandler));
if (r.isSuccessfulResponse()) {
l.onResponse(responseHandler.parseResult(request, r.toHttpResult()));
} else {
r.readFullResponse(l.delegateFailureAndWrap((ll, httpResult) -> {
try {
responseHandler.validateResponse(throttlerManager, logger, request, httpResult);
InferenceServiceResults inferenceResults = responseHandler.parseResult(request, httpResult);
ll.onResponse(inferenceResults);
} catch (Exception e) {
logException(logger, request, httpResult, responseHandler.getRequestType(), e);
listener.onFailure(e); // skip retrying
}
}));
}
}));
} else {
httpClient.send(request.createHttpRequest(), context, retryableListener.delegateFailure((l, r) -> {

View file

@ -1,140 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.http.retry;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicBoolean;
import static org.elasticsearch.core.Strings.format;
class StreamingResponseHandler implements Flow.Processor<HttpResult, HttpResult> {
private static final Logger log = LogManager.getLogger(StreamingResponseHandler.class);
private final ThrottlerManager throttlerManager;
private final Logger throttlerLogger;
private final Request request;
private final ResponseHandler responseHandler;
private final AtomicBoolean upstreamIsClosed = new AtomicBoolean(false);
private final AtomicBoolean processedFirstItem = new AtomicBoolean(false);
private volatile Flow.Subscription upstream;
private volatile Flow.Subscriber<? super HttpResult> downstream;
StreamingResponseHandler(ThrottlerManager throttlerManager, Logger throttlerLogger, Request request, ResponseHandler responseHandler) {
this.throttlerManager = throttlerManager;
this.throttlerLogger = throttlerLogger;
this.request = request;
this.responseHandler = responseHandler;
}
@Override
public void subscribe(Flow.Subscriber<? super HttpResult> subscriber) {
if (downstream != null) {
subscriber.onError(
new IllegalStateException("Failed to initialize streaming response. Another subscriber is already subscribed.")
);
return;
}
downstream = subscriber;
subscriber.onSubscribe(forwardingSubscription());
}
private Flow.Subscription forwardingSubscription() {
return new Flow.Subscription() {
@Override
public void request(long n) {
if (upstreamIsClosed.get()) {
downstream.onComplete(); // shouldn't happen, but reinforce that we're no longer listening
} else if (upstream != null) {
upstream.request(n);
} else {
// this shouldn't happen, the expected call pattern is onNext -> subscribe after the listener is invoked
var errorMessage = "Failed to initialize streaming response. onSubscribe must be called first to set the upstream";
assert false : errorMessage;
downstream.onError(new IllegalStateException(errorMessage));
}
}
@Override
public void cancel() {
if (upstreamIsClosed.compareAndSet(false, true) && upstream != null) {
upstream.cancel();
}
}
};
}
@Override
public void onSubscribe(Flow.Subscription subscription) {
upstream = subscription;
}
@Override
public void onNext(HttpResult item) {
if (processedFirstItem.compareAndSet(false, true)) {
try {
responseHandler.validateResponse(throttlerManager, throttlerLogger, request, item);
} catch (Exception e) {
logException(throttlerLogger, request, item, responseHandler.getRequestType(), e);
upstream.cancel();
onError(e);
return;
}
}
downstream.onNext(item);
}
@Override
public void onError(Throwable throwable) {
if (upstreamIsClosed.compareAndSet(false, true)) {
if (downstream != null) {
downstream.onError(throwable);
} else {
log.warn(
"Flow failed before the InferenceServiceResults were generated. The error should go to the listener directly.",
throwable
);
}
}
}
@Override
public void onComplete() {
if (upstreamIsClosed.compareAndSet(false, true)) {
if (downstream != null) {
downstream.onComplete();
} else {
log.debug("Flow completed before the InferenceServiceResults were generated. Shutting down this Processor.");
}
}
}
private void logException(Logger logger, Request request, HttpResult result, String requestType, Exception exception) {
var causeException = ExceptionsHelper.unwrapCause(exception);
throttlerManager.warn(
logger,
format(
"Failed to process the stream connection for request from inference entity id [%s] of type [%s] with status [%s] [%s]",
request.getInferenceEntityId(),
requestType,
result.response().getStatusLine().getStatusCode(),
result.response().getStatusLine().getReasonPhrase()
),
causeException
);
}
}

View file

@ -41,7 +41,6 @@ import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.CancellationException;
import java.util.concurrent.Flow;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
@ -174,7 +173,7 @@ public class HttpClientTests extends ESTestCase {
try (var client = new HttpClient(emptyHttpSettings(), asyncClient, threadPool, mockThrottlerManager())) {
client.start();
PlainActionFuture<Flow.Publisher<HttpResult>> listener = new PlainActionFuture<>();
PlainActionFuture<StreamingHttpResult> listener = new PlainActionFuture<>();
client.stream(httpPost, HttpClientContext.create(), listener);
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
@ -196,7 +195,7 @@ public class HttpClientTests extends ESTestCase {
try (var client = new HttpClient(emptyHttpSettings(), asyncClient, threadPool, mockThrottlerManager())) {
client.start();
PlainActionFuture<Flow.Publisher<HttpResult>> listener = new PlainActionFuture<>();
PlainActionFuture<StreamingHttpResult> listener = new PlainActionFuture<>();
client.stream(httpPost, HttpClientContext.create(), listener);
var thrownException = expectThrows(CancellationException.class, () -> listener.actionGet(TIMEOUT));

View file

@ -14,8 +14,10 @@ import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionTestUtils;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.junit.After;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
@ -53,7 +55,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
private static final long maxBytes = message.length;
private ThreadPool threadPool;
private HttpSettings settings;
private ActionListener<Flow.Publisher<HttpResult>> listener;
private final AtomicReference<Tuple<StreamingHttpResult, Exception>> result = new AtomicReference<>(null);
private StreamingHttpResultPublisher publisher;
@Before
@ -61,12 +63,21 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
super.setUp();
threadPool = mock(ThreadPool.class);
settings = mock(HttpSettings.class);
listener = spy(ActionListener.noop());
when(threadPool.executor(UTILITY_THREAD_POOL_NAME)).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE);
when(settings.getMaxResponseSize()).thenReturn(ByteSizeValue.ofBytes(maxBytes));
publisher = new StreamingHttpResultPublisher(threadPool, settings, listener);
publisher = new StreamingHttpResultPublisher(threadPool, settings, listener());
}
private ActionListener<StreamingHttpResult> listener() {
return ActionListener.wrap(r -> result.set(Tuple.tuple(r, null)), e -> result.set(Tuple.tuple(null, e)));
}
@After
public void tearDown() throws Exception {
super.tearDown();
result.set(null);
}
/**
@ -76,7 +87,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
*/
public void testFirstResponseCallsListener() throws IOException {
var latch = new CountDownLatch(1);
var listener = ActionTestUtils.<Flow.Publisher<HttpResult>>assertNoFailureListener(r -> latch.countDown());
var listener = ActionTestUtils.<StreamingHttpResult>assertNoFailureListener(r -> latch.countDown());
publisher = new StreamingHttpResultPublisher(threadPool, settings, listener);
publisher.responseReceived(mock(HttpResponse.class));
@ -92,7 +103,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
*/
public void testNonEmptyFirstResponseCallsListener() throws IOException {
var latch = new CountDownLatch(1);
var listener = ActionTestUtils.<Flow.Publisher<HttpResult>>assertNoFailureListener(r -> latch.countDown());
var listener = ActionTestUtils.<StreamingHttpResult>assertNoFailureListener(r -> latch.countDown());
publisher = new StreamingHttpResultPublisher(threadPool, settings, listener);
when(settings.getMaxResponseSize()).thenReturn(ByteSizeValue.ofBytes(9000));
@ -127,7 +138,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
// subscribe
publisher.subscribe(subscriber);
testPublisher().subscribe(subscriber);
assertThat("subscribe must call onSubscribe", subscriber.subscription, notNullValue());
assertThat("onNext should only be called once we have requested data", subscriber.httpResult, nullValue());
@ -142,7 +153,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
// publisher sends data
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
assertThat("onNext was called with " + new String(message, StandardCharsets.UTF_8), subscriber.httpResult.body(), equalTo(message));
assertThat("onNext was called with " + new String(message, StandardCharsets.UTF_8), subscriber.httpResult, equalTo(message));
}
/**
@ -157,7 +168,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
publisher.close();
// subscriber requests data
publisher.subscribe(subscriber);
testPublisher().subscribe(subscriber);
assertThat("subscribe must call onSubscribe", subscriber.subscription, notNullValue());
subscriber.requestData();
assertThat("onNext was called with the initial HttpResponse", subscriber.httpResult, notNullValue());
@ -187,7 +198,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
public void testResumeApache() throws IOException {
var subscriber = new TestSubscriber();
publisher.responseReceived(mock(HttpResponse.class));
publisher.subscribe(subscriber);
testPublisher().subscribe(subscriber);
subscriber.requestData();
subscriber.httpResult = null;
@ -212,7 +223,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
var subscriber = new TestSubscriber();
publisher.responseReceived(mock(HttpResponse.class));
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
publisher.subscribe(subscriber);
testPublisher().subscribe(subscriber);
subscriber.requestData();
subscriber.httpResult = null;
@ -243,7 +254,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
public void testErrorBeforeRequest() {
var exception = new NullPointerException("test");
publisher.failed(exception);
verify(listener).onFailure(exception);
assertThat(result.get().v2(), equalTo(exception));
}
/**
@ -361,21 +372,17 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
* When cancel is called
* Then we only send onComplete once
*/
public void testCancelIsIdempotent() throws IOException {
Flow.Subscriber<HttpResult> subscriber = mock();
var subscription = ArgumentCaptor.forClass(Flow.Subscription.class);
publisher.subscribe(subscriber);
verify(subscriber).onSubscribe(subscription.capture());
public void testCancelIsIdempotent() {
Flow.Subscriber<byte[]> subscriber = mock();
publisher.responseReceived(mock());
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
subscription.getValue().request(1);
subscription.getValue().request(1);
var subscription = ArgumentCaptor.forClass(Flow.Subscription.class);
testPublisher().subscribe(subscriber);
verify(subscriber).onSubscribe(subscription.capture());
subscription.getValue().request(2);
publisher.cancel();
verify(subscriber, times(1)).onComplete();
subscription.getValue().request(1);
publisher.cancel();
verify(subscriber, times(1)).onComplete();
}
@ -384,21 +391,17 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
* When close is called
* Then we only send onComplete once
*/
public void testCloseIsIdempotent() throws IOException {
Flow.Subscriber<HttpResult> subscriber = mock();
var subscription = ArgumentCaptor.forClass(Flow.Subscription.class);
publisher.subscribe(subscriber);
verify(subscriber).onSubscribe(subscription.capture());
public void testCloseIsIdempotent() {
Flow.Subscriber<byte[]> subscriber = mock();
publisher.responseReceived(mock());
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
subscription.getValue().request(1);
subscription.getValue().request(1);
var subscription = ArgumentCaptor.forClass(Flow.Subscription.class);
testPublisher().subscribe(subscriber);
verify(subscriber).onSubscribe(subscription.capture());
subscription.getValue().request(2);
publisher.close();
verify(subscriber, times(1)).onComplete();
subscription.getValue().request(1);
publisher.close();
verify(subscriber, times(1)).onComplete();
}
@ -409,20 +412,16 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
*/
public void testFailedIsIdempotent() throws IOException {
var expectedException = new IllegalStateException("wow");
Flow.Subscriber<HttpResult> subscriber = mock();
var subscription = ArgumentCaptor.forClass(Flow.Subscription.class);
publisher.subscribe(subscriber);
verify(subscriber).onSubscribe(subscription.capture());
Flow.Subscriber<byte[]> subscriber = mock();
publisher.responseReceived(mock());
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
subscription.getValue().request(1);
subscription.getValue().request(1);
var subscription = ArgumentCaptor.forClass(Flow.Subscription.class);
testPublisher().subscribe(subscriber);
verify(subscriber).onSubscribe(subscription.capture());
subscription.getValue().request(2);
publisher.failed(expectedException);
verify(subscriber, times(1)).onError(eq(expectedException));
subscription.getValue().request(1);
publisher.failed(expectedException);
verify(subscriber, times(1)).onError(eq(expectedException));
}
@ -492,10 +491,11 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
* Then that subscriber should receive an IllegalStateException
*/
public void testDoubleSubscribeFails() {
publisher.subscribe(mock());
publisher.responseReceived(mock());
testPublisher().subscribe(mock());
var subscriber = new TestSubscriber();
publisher.subscribe(subscriber);
testPublisher().subscribe(subscriber);
assertThat(subscriber.throwable, notNullValue());
assertThat(subscriber.throwable, instanceOf(IllegalStateException.class));
}
@ -508,10 +508,10 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
public void testReuseMlThread() throws ExecutionException, InterruptedException, TimeoutException {
try {
threadPool = spy(createThreadPool(inferenceUtilityPool()));
publisher = new StreamingHttpResultPublisher(threadPool, settings, listener);
publisher = new StreamingHttpResultPublisher(threadPool, settings, listener());
var subscriber = new TestSubscriber();
publisher.responseReceived(mock(HttpResponse.class));
publisher.subscribe(subscriber);
testPublisher().subscribe(subscriber);
CompletableFuture.runAsync(() -> {
try {
@ -523,7 +523,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
}, threadPool.executor(UTILITY_THREAD_POOL_NAME)).get(5, TimeUnit.SECONDS);
verify(threadPool, times(1)).executor(UTILITY_THREAD_POOL_NAME);
assertThat("onNext was called with the initial HttpResponse", subscriber.httpResult, notNullValue());
assertFalse("Expected HttpResult to have data", subscriber.httpResult.isBodyEmpty());
assertNotEquals("Expected HttpResult to have data", 0, subscriber.httpResult.length);
} finally {
terminate(threadPool);
}
@ -548,11 +548,11 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
return executorServiceSpy;
}).when(threadPool).executor(UTILITY_THREAD_POOL_NAME);
publisher = new StreamingHttpResultPublisher(threadPool, settings, listener);
publisher = new StreamingHttpResultPublisher(threadPool, settings, listener());
publisher.responseReceived(mock(HttpResponse.class));
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
// create an infinitely running Subscriber
var subscriber = new Flow.Subscriber<HttpResult>() {
var subscriber = new Flow.Subscriber<byte[]>() {
Flow.Subscription subscription;
boolean completed = false;
@ -563,7 +563,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
}
@Override
public void onNext(HttpResult item) {
public void onNext(byte[] item) {
try {
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
} catch (IOException e) {
@ -582,7 +582,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
completed = true;
}
};
publisher.subscribe(subscriber);
testPublisher().subscribe(subscriber);
// verify the thread has started
assertThat("Thread should have started on subscribe", futureHolder.get(), notNullValue());
@ -604,7 +604,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
// start with a message published
publisher.responseReceived(mock(HttpResponse.class));
TestSubscriber subscriber = new TestSubscriber() {
public void onNext(HttpResult item) {
public void onNext(byte[] item) {
try {
// publish a second message
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
@ -617,7 +617,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
super.onNext(item);
}
};
publisher.subscribe(subscriber);
testPublisher().subscribe(subscriber);
verify(threadPool, times(0)).executor(UTILITY_THREAD_POOL_NAME);
subscriber.requestData();
@ -646,8 +646,9 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
}
private TestSubscriber subscribe() {
publisher.responseReceived(mock());
var subscriber = new TestSubscriber();
publisher.subscribe(subscriber);
testPublisher().subscribe(subscriber);
return subscriber;
}
@ -655,12 +656,12 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
publisher.responseReceived(mock(HttpResponse.class));
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
TestSubscriber subscriber = new TestSubscriber() {
public void onNext(HttpResult item) {
public void onNext(byte[] item) {
runDuringOnNext.run();
super.onNext(item);
}
};
publisher.subscribe(subscriber);
testPublisher().subscribe(subscriber);
return subscriber;
}
@ -668,19 +669,23 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
publisher.responseReceived(mock(HttpResponse.class));
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
TestSubscriber subscriber = new TestSubscriber() {
public void onNext(HttpResult item) {
public void onNext(byte[] item) {
runDuringOnNext.run();
super.requestData();
super.onNext(item);
}
};
publisher.subscribe(subscriber);
testPublisher().subscribe(subscriber);
return subscriber;
}
private static class TestSubscriber implements Flow.Subscriber<HttpResult> {
private Flow.Publisher<byte[]> testPublisher() {
return result.get().v1().body();
}
private static class TestSubscriber implements Flow.Subscriber<byte[]> {
private Flow.Subscription subscription;
private HttpResult httpResult;
private byte[] httpResult;
private Throwable throwable;
private boolean completed;
@ -690,7 +695,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
}
@Override
public void onNext(HttpResult item) {
public void onNext(byte[] item) {
this.httpResult = item;
}

View file

@ -24,15 +24,18 @@ import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.xpack.inference.external.http.HttpClient;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.StreamingHttpResult;
import org.elasticsearch.xpack.inference.external.request.HttpRequestTests;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.junit.Before;
import org.mockito.ArgumentMatchers;
import org.mockito.stubbing.Answer;
import java.io.IOException;
import java.net.UnknownHostException;
import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicInteger;
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.createDefaultRetrySettings;
import static org.hamcrest.Matchers.instanceOf;
@ -42,7 +45,6 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.only;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
@ -459,12 +461,12 @@ public class RetryingHttpSenderTests extends ESTestCase {
verifyNoMoreInteractions(httpClient);
}
public void testStream() throws IOException {
public void testStreamSuccess() throws IOException {
var httpClient = mock(HttpClient.class);
Flow.Publisher<HttpResult> publisher = mock();
StreamingHttpResult streamingHttpResult = new StreamingHttpResult(mockHttpResponse(), randomPublisher());
doAnswer(ans -> {
ActionListener<Flow.Publisher<HttpResult>> listener = ans.getArgument(2);
listener.onResponse(publisher);
ActionListener<StreamingHttpResult> listener = ans.getArgument(2);
listener.onResponse(streamingHttpResult);
return null;
}).when(httpClient).stream(any(), any(), any());
@ -479,7 +481,28 @@ public class RetryingHttpSenderTests extends ESTestCase {
verify(httpClient, times(1)).stream(any(), any(), any());
verifyNoMoreInteractions(httpClient);
verify(publisher, only()).subscribe(any(StreamingResponseHandler.class));
verify(responseHandler, times(1)).parseResult(any(), ArgumentMatchers.<Flow.Publisher<HttpResult>>any());
}
private Flow.Publisher<byte[]> randomPublisher() {
var calls = new AtomicInteger(randomIntBetween(1, 4));
return subscriber -> {
subscriber.onSubscribe(new Flow.Subscription() {
@Override
public void request(long n) {
if (calls.getAndDecrement() > 0) {
subscriber.onNext(randomByteArrayOfLength(3));
} else {
subscriber.onComplete();
}
}
@Override
public void cancel() {
}
});
};
}
public void testStream_ResponseHandlerDoesNotHandleStreams() throws IOException {
@ -549,6 +572,44 @@ public class RetryingHttpSenderTests extends ESTestCase {
}
}
public void testStream_DoesNotRetryIndefinitely() throws IOException {
var threadPool = new TestThreadPool(getTestName());
try {
var httpClient = mock(HttpClient.class);
doAnswer(ans -> {
ActionListener<StreamingHttpResult> listener = ans.getArgument(2);
listener.onFailure(new ConnectionClosedException("failed"));
return null;
}).when(httpClient).stream(any(), any(), any());
var handler = mock(ResponseHandler.class);
when(handler.canHandleStreamingResponses()).thenReturn(true);
var retrier = new RetryingHttpSender(
httpClient,
mock(ThrottlerManager.class),
createDefaultRetrySettings(),
threadPool,
EsExecutors.DIRECT_EXECUTOR_SERVICE
);
var listener = new PlainActionFuture<InferenceServiceResults>();
var request = mockRequest();
when(request.isStreaming()).thenReturn(true);
retrier.send(mock(Logger.class), request, () -> false, handler, listener);
// Assert that the retrying sender stopped after max retires even though the exception is retryable
var thrownException = expectThrows(UncategorizedExecutionException.class, () -> listener.actionGet(TIMEOUT));
assertThat(thrownException.getCause(), instanceOf(ConnectionClosedException.class));
assertThat(thrownException.getMessage(), is("Failed execution"));
assertThat(thrownException.getSuppressed().length, is(0));
verify(httpClient, times(RetryingHttpSender.MAX_RETIES)).stream(any(), any(), any());
verifyNoMoreInteractions(httpClient);
} finally {
terminate(threadPool);
}
}
public void testSend_DoesNotRetryIndefinitely_WithAlwaysRetryingResponseHandler() throws IOException {
var threadPool = new TestThreadPool(getTestName());
try {

View file

@ -1,127 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.external.http.retry;
import org.apache.http.HttpResponse;
import org.apache.http.StatusLine;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.junit.After;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import java.util.concurrent.Flow;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class StreamingResponseHandlerTests extends ESTestCase {
@Mock
private HttpResponse response;
@Mock
private ThrottlerManager throttlerManager;
@Mock
private Logger logger;
@Mock
private Request request;
@Mock
private ResponseHandler responseHandler;
@Mock
private Flow.Subscriber<HttpResult> downstreamSubscriber;
@InjectMocks
private StreamingResponseHandler streamingResponseHandler;
private AutoCloseable mocks;
private HttpResult item;
@Before
public void setUp() throws Exception {
super.setUp();
mocks = MockitoAnnotations.openMocks(this);
item = new HttpResult(response, new byte[0]);
}
@After
public void tearDown() throws Exception {
super.tearDown();
mocks.close();
}
public void testResponseHandlerFailureIsForwardedToSubscriber() {
var upstreamSubscription = upstreamSubscription();
var expectedException = new RetryException(true, "ah");
doThrow(expectedException).when(responseHandler).validateResponse(any(), any(), any(), any());
var statusLine = mock(StatusLine.class);
when(statusLine.getStatusCode()).thenReturn(404);
when(statusLine.getReasonPhrase()).thenReturn("not found");
when(response.getStatusLine()).thenReturn(statusLine);
streamingResponseHandler.onNext(item);
verify(upstreamSubscription, times(1)).cancel();
verify(downstreamSubscriber, times(1)).onError(expectedException);
}
@SuppressWarnings("unchecked")
private Flow.Subscription upstreamSubscription() {
var upstreamSubscription = mock(Flow.Subscription.class);
streamingResponseHandler.onSubscribe(upstreamSubscription);
streamingResponseHandler.subscribe(downstreamSubscriber);
return upstreamSubscription;
}
public void testOnNextCallsDownstream() {
upstreamSubscription();
streamingResponseHandler.onNext(item);
verify(downstreamSubscriber, times(1)).onNext(item);
}
public void testCompleteForwardsComplete() {
upstreamSubscription();
streamingResponseHandler.onComplete();
verify(downstreamSubscriber, times(1)).onSubscribe(any());
verify(downstreamSubscriber, times(1)).onComplete();
}
public void testErrorForwardsError() {
var expectedError = new RetryException(false, "ah");
upstreamSubscription();
streamingResponseHandler.onError(expectedError);
verify(downstreamSubscriber, times(1)).onSubscribe(any());
verify(downstreamSubscriber, times(1)).onError(same(expectedError));
}
public void testSubscriptionForwardsRequest() {
var upstreamSubscription = upstreamSubscription();
var downstream = ArgumentCaptor.forClass(Flow.Subscription.class);
verify(downstreamSubscriber, times(1)).onSubscribe(downstream.capture());
var downstreamSubscription = downstream.getValue();
var requestCount = randomIntBetween(2, 200);
downstreamSubscription.request(requestCount);
verify(upstreamSubscription, times(1)).request(requestCount);
}
}

View file

@ -554,13 +554,11 @@ public class AnthropicServiceTests extends ESTestCase {
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
var result = streamChatCompletion();
InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent("""
streamChatCompletion().hasNoErrors().hasEvent("""
{"completion":[{"delta":"Hello"},{"delta":", World"}]}""");
}
private InferenceServiceResults streamChatCompletion() throws IOException {
private InferenceEventsAssertion streamChatCompletion() throws Exception {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = new AnthropicService(senderFactory, createWithEmptySettings(threadPool))) {
var model = AnthropicChatCompletionModelTests.createChatCompletionModel(
@ -581,7 +579,7 @@ public class AnthropicServiceTests extends ESTestCase {
listener
);
return listener.actionGet(TIMEOUT);
return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream();
}
}
@ -592,11 +590,7 @@ public class AnthropicServiceTests extends ESTestCase {
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
var result = streamChatCompletion();
InferenceEventsAssertion.assertThat(result)
.hasFinishedStream()
.hasNoEvents()
streamChatCompletion().hasNoEvents()
.hasErrorWithStatusCode(RestStatus.REQUEST_ENTITY_TOO_LARGE.getStatus())
.hasErrorContaining("blah");
}

View file

@ -61,7 +61,6 @@ import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.net.URISyntaxException;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
@ -1345,13 +1344,11 @@ public class AzureAiStudioServiceTests extends ESTestCase {
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
var result = streamChatCompletion();
InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent("""
streamChatCompletion().hasNoErrors().hasEvent("""
{"completion":[{"delta":"hello, world"}]}""");
}
private InferenceServiceResults streamChatCompletion() throws IOException, URISyntaxException {
private InferenceEventsAssertion streamChatCompletion() throws Exception {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
var model = AzureAiStudioChatCompletionModelTests.createModel(
@ -1373,7 +1370,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
listener
);
return listener.actionGet(TIMEOUT);
return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream();
}
}
@ -1389,13 +1386,14 @@ public class AzureAiStudioServiceTests extends ESTestCase {
}""";
webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson));
var result = streamChatCompletion();
InferenceEventsAssertion.assertThat(result)
.hasFinishedStream()
.hasNoEvents()
.hasErrorWithStatusCode(401)
.hasErrorContaining("You didn't provide an API key...");
var e = assertThrows(ElasticsearchStatusException.class, this::streamChatCompletion);
assertThat(
e.getMessage(),
equalTo(
"Received an authentication error status code for request from inference entity id [id] status [401]. "
+ "Error message: [You didn't provide an API key...]"
)
);
}
@SuppressWarnings("checkstyle:LineLength")

View file

@ -1410,13 +1410,11 @@ public class AzureOpenAiServiceTests extends ESTestCase {
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
var result = streamChatCompletion();
InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent("""
streamChatCompletion().hasNoErrors().hasEvent("""
{"completion":[{"delta":"hello, world"}]}""");
}
private InferenceServiceResults streamChatCompletion() throws IOException, URISyntaxException {
private InferenceEventsAssertion streamChatCompletion() throws Exception {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) {
var model = AzureOpenAiCompletionModelTests.createCompletionModel(
@ -1441,7 +1439,7 @@ public class AzureOpenAiServiceTests extends ESTestCase {
listener
);
return listener.actionGet(TIMEOUT);
return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream();
}
}
@ -1457,13 +1455,14 @@ public class AzureOpenAiServiceTests extends ESTestCase {
}""";
webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson));
var result = streamChatCompletion();
InferenceEventsAssertion.assertThat(result)
.hasFinishedStream()
.hasNoEvents()
.hasErrorWithStatusCode(401)
.hasErrorContaining("You didn't provide an API key...");
var e = assertThrows(ElasticsearchStatusException.class, this::streamChatCompletion);
assertThat(
e.getMessage(),
equalTo(
"Received an authentication error status code for request from inference entity id [id] status [401]."
+ " Error message: [You didn't provide an API key...]"
)
);
}
@SuppressWarnings("checkstyle:LineLength")

View file

@ -1611,13 +1611,11 @@ public class CohereServiceTests extends ESTestCase {
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
var result = streamChatCompletion();
InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent("""
streamChatCompletion().hasNoErrors().hasEvent("""
{"completion":[{"delta":"hello"},{"delta":"there"}]}""");
}
private InferenceServiceResults streamChatCompletion() throws IOException {
private InferenceEventsAssertion streamChatCompletion() throws Exception {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) {
var model = CohereCompletionModelTests.createModel(getUrl(webServer), "secret", "model");
@ -1633,7 +1631,7 @@ public class CohereServiceTests extends ESTestCase {
listener
);
return listener.actionGet(TIMEOUT);
return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream();
}
}
@ -1643,13 +1641,7 @@ public class CohereServiceTests extends ESTestCase {
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
var result = streamChatCompletion();
InferenceEventsAssertion.assertThat(result)
.hasFinishedStream()
.hasNoEvents()
.hasErrorWithStatusCode(500)
.hasErrorContaining("how dare you");
streamChatCompletion().hasNoEvents().hasErrorWithStatusCode(500).hasErrorContaining("how dare you");
}
@SuppressWarnings("checkstyle:LineLength")

View file

@ -1012,18 +1012,18 @@ public class ElasticInferenceServiceTests extends ESTestCase {
}
}
public void testUnifiedCompletionError() throws Exception {
testUnifiedStreamError(404, """
public void testUnifiedCompletionError() {
var e = assertThrows(UnifiedChatCompletionException.class, () -> testUnifiedStream(404, """
{
"error": "The model `rainbow-sprinkles` does not exist or you do not have access to it."
}""", """
{\
"error":{\
"code":"not_found",\
"message":"Received an unsuccessful status code for request from inference entity id [id] status \
[404]. Error message: [The model `rainbow-sprinkles` does not exist or you do not have access to it.]",\
"type":"error"\
}}""");
}"""));
assertThat(
e.getMessage(),
equalTo(
"Received an unsuccessful status code for request from inference entity id [id] status "
+ "[404]. Error message: [The model `rainbow-sprinkles` does not exist or you do not have access to it.]"
)
);
}
public void testUnifiedCompletionErrorMidStream() throws Exception {
@ -1054,6 +1054,25 @@ public class ElasticInferenceServiceTests extends ESTestCase {
}
private void testUnifiedStreamError(int responseCode, String responseJson, String expectedJson) throws Exception {
testUnifiedStream(responseCode, responseJson).hasNoEvents().hasErrorMatching(e -> {
e = unwrapCause(e);
assertThat(e, isA(UnifiedChatCompletionException.class));
try (var builder = XContentFactory.jsonBuilder()) {
((UnifiedChatCompletionException) e).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
try {
xContent.toXContent(builder, EMPTY_PARAMS);
} catch (IOException ex) {
throw new RuntimeException(ex);
}
});
var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
assertThat(json, is(expectedJson));
}
});
}
private InferenceEventsAssertion testUnifiedStream(int responseCode, String responseJson) throws Exception {
var eisGatewayUrl = getUrl(webServer);
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = createService(senderFactory, eisGatewayUrl)) {
@ -1077,24 +1096,7 @@ public class ElasticInferenceServiceTests extends ESTestCase {
listener
);
var result = listener.actionGet(TIMEOUT);
InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoEvents().hasErrorMatching(e -> {
e = unwrapCause(e);
assertThat(e, isA(UnifiedChatCompletionException.class));
try (var builder = XContentFactory.jsonBuilder()) {
((UnifiedChatCompletionException) e).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
try {
xContent.toXContent(builder, EMPTY_PARAMS);
} catch (IOException ex) {
throw new RuntimeException(ex);
}
});
var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
assertThat(json, is(expectedJson));
}
});
return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream();
}
}

View file

@ -13,6 +13,7 @@ import org.apache.http.HttpHeaders;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionTestUtils;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
@ -28,6 +29,7 @@ import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
@ -63,6 +65,7 @@ import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.ExceptionsHelper.unwrapCause;
@ -1079,13 +1082,60 @@ public class OpenAiServiceTests extends ESTestCase {
}
}""";
webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson));
testStreamError("""
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) {
var model = OpenAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "org", "secret", "model", "user");
var latch = new CountDownLatch(1);
service.unifiedCompletionInfer(
model,
UnifiedCompletionRequest.of(
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
),
InferenceAction.Request.DEFAULT_TIMEOUT,
ActionListener.runAfter(ActionTestUtils.assertNoSuccessListener(e -> {
try (var builder = XContentFactory.jsonBuilder()) {
var t = unwrapCause(e);
assertThat(t, isA(UnifiedChatCompletionException.class));
((UnifiedChatCompletionException) t).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
try {
xContent.toXContent(builder, EMPTY_PARAMS);
} catch (IOException ex) {
throw new RuntimeException(ex);
}
});
var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
assertThat(json, is("""
{\
"error":{\
"code":"model_not_found",\
"message":"Received an unsuccessful status code for request from inference entity id [id] status \
[404]. Error message: [The model `gpt-4awero` does not exist or you do not have access to it.]",\
"type":"invalid_request_error"\
}}"""));
} catch (IOException ex) {
throw new RuntimeException(ex);
}
}), latch::countDown)
);
assertTrue(latch.await(30, TimeUnit.SECONDS));
}
}
public void testMidStreamUnifiedCompletionError() throws Exception {
String responseJson = """
event: error
data: { "error": { "message": "Timed out waiting for more data", "type": "timeout" } }
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
testStreamError("""
{\
"error":{\
"message":"Received an error response for request from inference entity id [id]. Error message: \
[Timed out waiting for more data]",\
"type":"timeout"\
}}""");
}
@ -1124,22 +1174,6 @@ public class OpenAiServiceTests extends ESTestCase {
}
}
public void testMidStreamUnifiedCompletionError() throws Exception {
String responseJson = """
event: error
data: { "error": { "message": "Timed out waiting for more data", "type": "timeout" } }
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
testStreamError("""
{\
"error":{\
"message":"Received an error response for request from inference entity id [id]. Error message: \
[Timed out waiting for more data]",\
"type":"timeout"\
}}""");
}
public void testUnifiedCompletionMalformedError() throws Exception {
String responseJson = """
data: { invalid json }
@ -1179,13 +1213,11 @@ public class OpenAiServiceTests extends ESTestCase {
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
var result = streamCompletion();
InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent("""
streamCompletion().hasNoErrors().hasEvent("""
{"completion":[{"delta":"hello, world"}]}""");
}
private InferenceServiceResults streamCompletion() throws IOException {
private InferenceEventsAssertion streamCompletion() throws Exception {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) {
var model = OpenAiChatCompletionModelTests.createCompletionModel(getUrl(webServer), "org", "secret", "model", "user");
@ -1201,7 +1233,7 @@ public class OpenAiServiceTests extends ESTestCase {
listener
);
return listener.actionGet(TIMEOUT);
return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream();
}
}
@ -1217,13 +1249,48 @@ public class OpenAiServiceTests extends ESTestCase {
}""";
webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson));
var result = streamCompletion();
var e = assertThrows(ElasticsearchStatusException.class, this::streamCompletion);
assertThat(e.status(), equalTo(RestStatus.UNAUTHORIZED));
assertThat(
e.getMessage(),
equalTo(
"Received an authentication error status code for request from inference entity id [id] status [401]. "
+ "Error message: [You didn't provide an API key...]"
)
);
}
InferenceEventsAssertion.assertThat(result)
.hasFinishedStream()
.hasNoEvents()
.hasErrorWithStatusCode(401)
.hasErrorContaining("You didn't provide an API key...");
public void testInfer_StreamRequestRetry() throws Exception {
webServer.enqueue(new MockResponse().setResponseCode(503).setBody("""
{
"error": {
"message": "server busy",
"type": "server_busy"
}
}"""));
webServer.enqueue(new MockResponse().setResponseCode(200).setBody("""
data: {\
"id":"12345",\
"object":"chat.completion.chunk",\
"created":123456789,\
"model":"gpt-4o-mini",\
"system_fingerprint": "123456789",\
"choices":[\
{\
"index":0,\
"delta":{\
"content":"hello, world"\
},\
"logprobs":null,\
"finish_reason":null\
}\
]\
}
"""));
streamCompletion().hasNoErrors().hasEvent("""
{"completion":[{"delta":"hello, world"}]}""");
}
public void testSupportsStreaming() throws IOException {