From dfe2adb5922c58b04a97f676749ffb45a627937c Mon Sep 17 00:00:00 2001
From: Pat Whelan
Date: Tue, 4 Mar 2025 09:33:00 -0500
Subject: [PATCH] [ML] Retry on streaming errors (#123076)
We now always retry based on the provider's configured retry logic
rather than the HTTP status code. Some providers (e.g. Cohere,
Anthropic) will return 200 status codes with error bodies, others (e.g.
OpenAI, Azure) will return non-200 status codes with non-streaming
bodies.
Notes:
- Refactored from HttpResult to StreamingHttpResult, the byte body is
now the streaming element while the http response lives outside the
stream.
- Refactored StreamingHttpResultPublisher so that it only pushes byte
body into a queue.
- Tests all now have to wait for the response to be fully consumed
before closing the service, otherwise the close method will shut down
the mock web server and apache will throw an error.
---
docs/changelog/123076.yaml | 5 +
.../inference/external/http/HttpClient.java | 5 +-
.../external/http/StreamingHttpResult.java | 76 +++++
.../http/StreamingHttpResultPublisher.java | 260 +++++++++++-------
.../http/retry/RetryingHttpSender.java | 17 +-
.../http/retry/StreamingResponseHandler.java | 140 ----------
.../external/http/HttpClientTests.java | 5 +-
.../StreamingHttpResultPublisherTests.java | 129 ++++-----
.../http/retry/RetryingHttpSenderTests.java | 73 ++++-
.../retry/StreamingResponseHandlerTests.java | 127 ---------
.../anthropic/AnthropicServiceTests.java | 14 +-
.../AzureAiStudioServiceTests.java | 24 +-
.../azureopenai/AzureOpenAiServiceTests.java | 23 +-
.../services/cohere/CohereServiceTests.java | 16 +-
.../elastic/ElasticInferenceServiceTests.java | 58 ++--
.../services/openai/OpenAiServiceTests.java | 129 ++++++---
16 files changed, 548 insertions(+), 553 deletions(-)
create mode 100644 docs/changelog/123076.yaml
create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResult.java
delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/StreamingResponseHandler.java
delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/StreamingResponseHandlerTests.java
diff --git a/docs/changelog/123076.yaml b/docs/changelog/123076.yaml
new file mode 100644
index 000000000000..270c202f3bbd
--- /dev/null
+++ b/docs/changelog/123076.yaml
@@ -0,0 +1,5 @@
+pr: 123076
+summary: Retry on streaming errors
+area: Machine Learning
+type: bug
+issues: []
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClient.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClient.java
index f890b6fce13d..7936e6779c8d 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClient.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpClient.java
@@ -27,7 +27,6 @@ import java.io.Closeable;
import java.io.IOException;
import java.util.Objects;
import java.util.concurrent.CancellationException;
-import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicReference;
import static org.elasticsearch.core.Strings.format;
@@ -154,7 +153,7 @@ public class HttpClient implements Closeable {
threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> listener.onFailure(exception));
}
- public void stream(HttpRequest request, HttpContext context, ActionListener> listener) throws IOException {
+ public void stream(HttpRequest request, HttpContext context, ActionListener listener) throws IOException {
// The caller must call start() first before attempting to send a request
assert status.get() == Status.STARTED : "call start() before attempting to send a request";
@@ -162,7 +161,7 @@ public class HttpClient implements Closeable {
SocketAccess.doPrivileged(() -> client.execute(request.requestProducer(), streamingProcessor, context, new FutureCallback<>() {
@Override
- public void completed(HttpResponse response) {
+ public void completed(Void response) {
streamingProcessor.close();
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResult.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResult.java
new file mode 100644
index 000000000000..1786ee98fcd8
--- /dev/null
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResult.java
@@ -0,0 +1,76 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.http;
+
+import org.apache.http.HttpResponse;
+import org.elasticsearch.ExceptionsHelper;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.rest.RestStatus;
+
+import java.io.ByteArrayOutputStream;
+import java.util.concurrent.Flow;
+import java.util.concurrent.atomic.AtomicReference;
+
+public record StreamingHttpResult(HttpResponse response, Flow.Publisher body) {
+ public boolean isSuccessfulResponse() {
+ return RestStatus.isSuccessful(response.getStatusLine().getStatusCode());
+ }
+
+ public Flow.Publisher toHttpResult() {
+ return subscriber -> body().subscribe(new Flow.Subscriber<>() {
+ @Override
+ public void onSubscribe(Flow.Subscription subscription) {
+ subscriber.onSubscribe(subscription);
+ }
+
+ @Override
+ public void onNext(byte[] item) {
+ subscriber.onNext(new HttpResult(response(), item));
+ }
+
+ @Override
+ public void onError(Throwable throwable) {
+ subscriber.onError(throwable);
+ }
+
+ @Override
+ public void onComplete() {
+ subscriber.onComplete();
+ }
+ });
+ }
+
+ public void readFullResponse(ActionListener fullResponse) {
+ var stream = new ByteArrayOutputStream();
+ AtomicReference upstream = new AtomicReference<>(null);
+ body.subscribe(new Flow.Subscriber<>() {
+ @Override
+ public void onSubscribe(Flow.Subscription subscription) {
+ upstream.set(subscription);
+ upstream.get().request(1);
+ }
+
+ @Override
+ public void onNext(byte[] item) {
+ stream.writeBytes(item);
+ upstream.get().request(1);
+ }
+
+ @Override
+ public void onError(Throwable throwable) {
+ ExceptionsHelper.maybeError(throwable).ifPresent(ExceptionsHelper::maybeDieOnAnotherThread);
+ fullResponse.onFailure(new RuntimeException("Fatal while fully consuming stream", throwable));
+ }
+
+ @Override
+ public void onComplete() {
+ fullResponse.onResponse(new HttpResult(response, stream.toByteArray()));
+ }
+ });
+ }
+}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResultPublisher.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResultPublisher.java
index 0b2268a448c8..62ac1ac8a56b 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResultPublisher.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResultPublisher.java
@@ -13,6 +13,7 @@ import org.apache.http.nio.IOControl;
import org.apache.http.nio.protocol.HttpAsyncResponseConsumer;
import org.apache.http.nio.util.SimpleInputBuffer;
import org.apache.http.protocol.HttpContext;
+import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.threadpool.ThreadPool;
@@ -39,51 +40,31 @@ import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_P
* so this publisher will send a single HttpResult. If the HttpResponse is healthy, Apache will send an HttpResponse with or without
* the HttpEntity.
*/
-class StreamingHttpResultPublisher implements HttpAsyncResponseConsumer, Flow.Publisher {
- private final HttpSettings settings;
- private final ActionListener> listener;
+class StreamingHttpResultPublisher implements HttpAsyncResponseConsumer {
+ private final ActionListener listener;
private final AtomicBoolean listenerCalled = new AtomicBoolean(false);
- // used to manage the HTTP response
- private volatile HttpResponse response;
- private volatile Exception ex;
-
- // used to control the state of this publisher (Apache) and its interaction with its subscriber
private final AtomicBoolean isDone = new AtomicBoolean(false);
private final AtomicBoolean subscriptionCanceled = new AtomicBoolean(false);
- private volatile Flow.Subscriber super HttpResult> subscriber;
- private final RequestBasedTaskRunner taskRunner;
- private final AtomicBoolean pendingRequest = new AtomicBoolean(false);
- private final Deque queue = new ConcurrentLinkedDeque<>();
-
- // used to control the flow of data from the Apache client, if we're producing more bytes than we can consume then we'll pause
private final SimpleInputBuffer inputBuffer = new SimpleInputBuffer(4096);
- private final AtomicLong bytesInQueue = new AtomicLong(0);
- private final Object ioLock = new Object();
- private volatile IOControl savedIoControl;
+ private final DataPublisher publisher;
+ private final ApacheClientBackpressure backpressure;
- StreamingHttpResultPublisher(ThreadPool threadPool, HttpSettings settings, ActionListener> listener) {
- this.settings = Objects.requireNonNull(settings);
+ private volatile Exception exception;
+
+ StreamingHttpResultPublisher(ThreadPool threadPool, HttpSettings settings, ActionListener listener) {
this.listener = ActionListener.notifyOnce(Objects.requireNonNull(listener));
- this.taskRunner = new RequestBasedTaskRunner(new OffloadThread(), threadPool, UTILITY_THREAD_POOL_NAME);
+ this.publisher = new DataPublisher(threadPool);
+ this.backpressure = new ApacheClientBackpressure(Objects.requireNonNull(settings));
}
@Override
public void responseReceived(HttpResponse httpResponse) {
- this.response = httpResponse;
- }
-
- @Override
- public void subscribe(Flow.Subscriber super HttpResult> subscriber) {
- if (this.subscriber != null) {
- subscriber.onError(new IllegalStateException("Only one subscriber is allowed for this Publisher."));
- return;
+ if (listenerCalled.compareAndSet(false, true)) {
+ listener.onResponse(new StreamingHttpResult(httpResponse, publisher));
}
-
- this.subscriber = subscriber;
- subscriber.onSubscribe(new HttpSubscription());
}
@Override
@@ -100,49 +81,20 @@ class StreamingHttpResultPublisher implements HttpAsyncResponseConsumer 0) {
var allBytes = new byte[consumed];
inputBuffer.read(allBytes);
- queue.offer(() -> {
- subscriber.onNext(new HttpResult(response, allBytes));
- var currentBytesInQueue = bytesInQueue.updateAndGet(current -> Long.max(0, current - allBytes.length));
- if (savedIoControl != null) {
- var maxBytes = settings.getMaxResponseSize().getBytes() * 0.5;
- if (currentBytesInQueue <= maxBytes) {
- resumeProducer();
- }
- }
- });
-
- // always check if totalByteSize > the configured setting in case the settings change
- if (bytesInQueue.accumulateAndGet(allBytes.length, Long::sum) >= settings.getMaxResponseSize().getBytes()) {
- pauseProducer(ioControl);
- }
-
- taskRunner.requestNextRun();
-
- if (listenerCalled.compareAndSet(false, true)) {
- listener.onResponse(this);
- }
+ backpressure.addBytesAndMaybePause(consumed, ioControl);
+ publisher.onNext(allBytes);
}
+ } catch (Exception e) {
+ // if the provider closes the connection in the middle of the stream,
+ // the contentDecoder will throw an exception trying to read the payload,
+ // we should catch that and forward it downstream so we can properly handle it
+ exception = e;
+ publisher.onError(e);
} finally {
inputBuffer.reset();
}
}
- private void pauseProducer(IOControl ioControl) {
- ioControl.suspendInput();
- synchronized (ioLock) {
- savedIoControl = ioControl;
- }
- }
-
- private void resumeProducer() {
- synchronized (ioLock) {
- if (savedIoControl != null) {
- savedIoControl.requestInput();
- savedIoControl = null;
- }
- }
- }
-
@Override
public void responseCompleted(HttpContext httpContext) {}
@@ -153,9 +105,8 @@ class StreamingHttpResultPublisher implements HttpAsyncResponseConsumer subscriber.onError(e));
- taskRunner.requestNextRun();
+ exception = e;
+ publisher.onError(e);
}
}
}
@@ -164,8 +115,7 @@ class StreamingHttpResultPublisher implements HttpAsyncResponseConsumer subscriber.onComplete());
- taskRunner.requestNextRun();
+ publisher.onComplete();
}
}
@@ -178,12 +128,12 @@ class StreamingHttpResultPublisher implements HttpAsyncResponseConsumer {
+ private final RequestBasedTaskRunner taskRunner;
+ private final Deque contentQueue = new ConcurrentLinkedDeque<>();
+ private final AtomicLong pendingRequests = new AtomicLong(0);
+ private volatile Exception pendingError = null;
+ private volatile boolean completed = false;
+ private volatile Flow.Subscriber super byte[]> downstream;
+
+ private DataPublisher(ThreadPool threadPool) {
+ this.taskRunner = new RequestBasedTaskRunner(this::sendToSubscriber, threadPool, UTILITY_THREAD_POOL_NAME);
+ }
+
+ private void sendToSubscriber() {
+ if (downstream == null) {
return;
}
-
- if (n > 0) {
- pendingRequest.set(true);
- taskRunner.requestNextRun();
- } else {
- // per Subscription's spec, fail the subscriber and stop the processor
- cancel();
- subscriber.onError(new IllegalArgumentException("Subscriber requested a non-positive number " + n));
+ if (pendingRequests.get() > 0 && pendingError != null) {
+ pendingRequests.decrementAndGet();
+ downstream.onError(pendingError);
+ return;
+ }
+ byte[] nextBytes;
+ while (pendingRequests.get() > 0 && (nextBytes = contentQueue.poll()) != null) {
+ pendingRequests.decrementAndGet();
+ backpressure.subtractBytesAndMaybeUnpause(nextBytes.length);
+ downstream.onNext(nextBytes);
+ }
+ if (pendingRequests.get() > 0 && contentQueue.isEmpty() && completed) {
+ pendingRequests.decrementAndGet();
+ downstream.onComplete();
}
}
@Override
- public void cancel() {
- if (subscriptionCanceled.compareAndSet(false, true)) {
- taskRunner.cancel();
+ public void subscribe(Flow.Subscriber super byte[]> subscriber) {
+ if (this.downstream != null) {
+ subscriber.onError(new IllegalStateException("Only one subscriber is allowed for this Publisher."));
+ return;
}
+
+ this.downstream = subscriber;
+ downstream.onSubscribe(new Flow.Subscription() {
+ @Override
+ public void request(long n) {
+ if (n > 0) {
+ pendingRequests.addAndGet(n);
+ taskRunner.requestNextRun();
+ } else {
+ // per Subscription's spec, fail the subscriber and stop the processor
+ cancel();
+ subscriber.onError(new IllegalArgumentException("Subscriber requested a non-positive number " + n));
+ }
+ }
+
+ @Override
+ public void cancel() {
+ if (subscriptionCanceled.compareAndSet(false, true)) {
+ taskRunner.cancel();
+ }
+ }
+ });
+ }
+
+ @Override
+ public void onNext(byte[] item) {
+ contentQueue.offer(item);
+ taskRunner.requestNextRun();
+ }
+
+ @Override
+ public void onError(Throwable throwable) {
+ if (throwable instanceof Exception e) {
+ pendingError = e;
+ } else {
+ ExceptionsHelper.maybeError(throwable).ifPresent(ExceptionsHelper::maybeDieOnAnotherThread);
+ pendingError = new RuntimeException("Unhandled error while streaming");
+ }
+ taskRunner.requestNextRun();
+ }
+
+ @Override
+ public void onComplete() {
+ completed = true;
+ taskRunner.requestNextRun();
+ }
+
+ @Override
+ public void onSubscribe(Flow.Subscription subscription) {
+ assert false : "Apache never calls this.";
+ throw new UnsupportedOperationException("Apache never calls this.");
}
}
- private class OffloadThread implements Runnable {
- @Override
- public void run() {
- if (subscriptionCanceled.get()) {
- return;
- }
+ /**
+ * We want to keep track of how much memory we are consuming while reading the payload from the external provider. Apache continuously
+ * pushes payload data to us, whereas the client only requests the next set of bytes when they are ready, so we want to track how much
+ * data we are holding in memory and potentially pause the Apache client if we have reached our limit.
+ */
+ private static class ApacheClientBackpressure {
+ private final HttpSettings settings;
+ private final AtomicLong bytesInQueue = new AtomicLong(0);
+ private final Object ioLock = new Object();
+ private volatile IOControl savedIoControl;
- if (queue.isEmpty() == false && pendingRequest.compareAndSet(true, false)) {
- var next = queue.poll();
- if (next != null) {
- next.run();
- } else {
- pendingRequest.set(true);
+ private ApacheClientBackpressure(HttpSettings settings) {
+ this.settings = settings;
+ }
+
+ private void addBytesAndMaybePause(long count, IOControl ioControl) {
+ if (bytesInQueue.addAndGet(count) >= settings.getMaxResponseSize().getBytes()) {
+ pauseProducer(ioControl);
+ }
+ }
+
+ private void pauseProducer(IOControl ioControl) {
+ ioControl.suspendInput();
+ synchronized (ioLock) {
+ savedIoControl = ioControl;
+ }
+ }
+
+ private void subtractBytesAndMaybeUnpause(long count) {
+ var currentBytesInQueue = bytesInQueue.updateAndGet(current -> Long.max(0, current - count));
+ if (savedIoControl != null) {
+ var maxBytes = settings.getMaxResponseSize().getBytes() * 0.5;
+ if (currentBytesInQueue <= maxBytes) {
+ resumeProducer();
+ }
+ }
+ }
+
+ private void resumeProducer() {
+ synchronized (ioLock) {
+ if (savedIoControl != null) {
+ savedIoControl.requestInput();
+ savedIoControl = null;
}
}
}
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java
index 1c303f6e965c..b71887ce6018 100644
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java
+++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java
@@ -116,9 +116,20 @@ public class RetryingHttpSender implements RequestSender {
try {
if (request.isStreaming() && responseHandler.canHandleStreamingResponses()) {
httpClient.stream(request.createHttpRequest(), context, retryableListener.delegateFailure((l, r) -> {
- var streamingResponseHandler = new StreamingResponseHandler(throttlerManager, logger, request, responseHandler);
- r.subscribe(streamingResponseHandler);
- l.onResponse(responseHandler.parseResult(request, streamingResponseHandler));
+ if (r.isSuccessfulResponse()) {
+ l.onResponse(responseHandler.parseResult(request, r.toHttpResult()));
+ } else {
+ r.readFullResponse(l.delegateFailureAndWrap((ll, httpResult) -> {
+ try {
+ responseHandler.validateResponse(throttlerManager, logger, request, httpResult);
+ InferenceServiceResults inferenceResults = responseHandler.parseResult(request, httpResult);
+ ll.onResponse(inferenceResults);
+ } catch (Exception e) {
+ logException(logger, request, httpResult, responseHandler.getRequestType(), e);
+ listener.onFailure(e); // skip retrying
+ }
+ }));
+ }
}));
} else {
httpClient.send(request.createHttpRequest(), context, retryableListener.delegateFailure((l, r) -> {
diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/StreamingResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/StreamingResponseHandler.java
deleted file mode 100644
index 44e04ae28751..000000000000
--- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/StreamingResponseHandler.java
+++ /dev/null
@@ -1,140 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.inference.external.http.retry;
-
-import org.apache.logging.log4j.LogManager;
-import org.apache.logging.log4j.Logger;
-import org.elasticsearch.ExceptionsHelper;
-import org.elasticsearch.xpack.inference.external.http.HttpResult;
-import org.elasticsearch.xpack.inference.external.request.Request;
-import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
-
-import java.util.concurrent.Flow;
-import java.util.concurrent.atomic.AtomicBoolean;
-
-import static org.elasticsearch.core.Strings.format;
-
-class StreamingResponseHandler implements Flow.Processor {
- private static final Logger log = LogManager.getLogger(StreamingResponseHandler.class);
- private final ThrottlerManager throttlerManager;
- private final Logger throttlerLogger;
- private final Request request;
- private final ResponseHandler responseHandler;
-
- private final AtomicBoolean upstreamIsClosed = new AtomicBoolean(false);
- private final AtomicBoolean processedFirstItem = new AtomicBoolean(false);
-
- private volatile Flow.Subscription upstream;
- private volatile Flow.Subscriber super HttpResult> downstream;
-
- StreamingResponseHandler(ThrottlerManager throttlerManager, Logger throttlerLogger, Request request, ResponseHandler responseHandler) {
- this.throttlerManager = throttlerManager;
- this.throttlerLogger = throttlerLogger;
- this.request = request;
- this.responseHandler = responseHandler;
- }
-
- @Override
- public void subscribe(Flow.Subscriber super HttpResult> subscriber) {
- if (downstream != null) {
- subscriber.onError(
- new IllegalStateException("Failed to initialize streaming response. Another subscriber is already subscribed.")
- );
- return;
- }
-
- downstream = subscriber;
- subscriber.onSubscribe(forwardingSubscription());
- }
-
- private Flow.Subscription forwardingSubscription() {
- return new Flow.Subscription() {
- @Override
- public void request(long n) {
- if (upstreamIsClosed.get()) {
- downstream.onComplete(); // shouldn't happen, but reinforce that we're no longer listening
- } else if (upstream != null) {
- upstream.request(n);
- } else {
- // this shouldn't happen, the expected call pattern is onNext -> subscribe after the listener is invoked
- var errorMessage = "Failed to initialize streaming response. onSubscribe must be called first to set the upstream";
- assert false : errorMessage;
- downstream.onError(new IllegalStateException(errorMessage));
- }
- }
-
- @Override
- public void cancel() {
- if (upstreamIsClosed.compareAndSet(false, true) && upstream != null) {
- upstream.cancel();
- }
- }
- };
- }
-
- @Override
- public void onSubscribe(Flow.Subscription subscription) {
- upstream = subscription;
- }
-
- @Override
- public void onNext(HttpResult item) {
- if (processedFirstItem.compareAndSet(false, true)) {
- try {
- responseHandler.validateResponse(throttlerManager, throttlerLogger, request, item);
- } catch (Exception e) {
- logException(throttlerLogger, request, item, responseHandler.getRequestType(), e);
- upstream.cancel();
- onError(e);
- return;
- }
- }
- downstream.onNext(item);
- }
-
- @Override
- public void onError(Throwable throwable) {
- if (upstreamIsClosed.compareAndSet(false, true)) {
- if (downstream != null) {
- downstream.onError(throwable);
- } else {
- log.warn(
- "Flow failed before the InferenceServiceResults were generated. The error should go to the listener directly.",
- throwable
- );
- }
- }
- }
-
- @Override
- public void onComplete() {
- if (upstreamIsClosed.compareAndSet(false, true)) {
- if (downstream != null) {
- downstream.onComplete();
- } else {
- log.debug("Flow completed before the InferenceServiceResults were generated. Shutting down this Processor.");
- }
- }
- }
-
- private void logException(Logger logger, Request request, HttpResult result, String requestType, Exception exception) {
- var causeException = ExceptionsHelper.unwrapCause(exception);
-
- throttlerManager.warn(
- logger,
- format(
- "Failed to process the stream connection for request from inference entity id [%s] of type [%s] with status [%s] [%s]",
- request.getInferenceEntityId(),
- requestType,
- result.response().getStatusLine().getStatusCode(),
- result.response().getStatusLine().getReasonPhrase()
- ),
- causeException
- );
- }
-}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/HttpClientTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/HttpClientTests.java
index aa27bf0d2fc8..fc293b5f5566 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/HttpClientTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/HttpClientTests.java
@@ -41,7 +41,6 @@ import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.CancellationException;
-import java.util.concurrent.Flow;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
@@ -174,7 +173,7 @@ public class HttpClientTests extends ESTestCase {
try (var client = new HttpClient(emptyHttpSettings(), asyncClient, threadPool, mockThrottlerManager())) {
client.start();
- PlainActionFuture> listener = new PlainActionFuture<>();
+ PlainActionFuture listener = new PlainActionFuture<>();
client.stream(httpPost, HttpClientContext.create(), listener);
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
@@ -196,7 +195,7 @@ public class HttpClientTests extends ESTestCase {
try (var client = new HttpClient(emptyHttpSettings(), asyncClient, threadPool, mockThrottlerManager())) {
client.start();
- PlainActionFuture> listener = new PlainActionFuture<>();
+ PlainActionFuture listener = new PlainActionFuture<>();
client.stream(httpPost, HttpClientContext.create(), listener);
var thrownException = expectThrows(CancellationException.class, () -> listener.actionGet(TIMEOUT));
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResultPublisherTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResultPublisherTests.java
index a400b67b3761..672dd05abc91 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResultPublisherTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/StreamingHttpResultPublisherTests.java
@@ -14,8 +14,10 @@ import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionTestUtils;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.concurrent.EsExecutors;
+import org.elasticsearch.core.Tuple;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
+import org.junit.After;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
@@ -53,7 +55,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
private static final long maxBytes = message.length;
private ThreadPool threadPool;
private HttpSettings settings;
- private ActionListener> listener;
+ private final AtomicReference> result = new AtomicReference<>(null);
private StreamingHttpResultPublisher publisher;
@Before
@@ -61,12 +63,21 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
super.setUp();
threadPool = mock(ThreadPool.class);
settings = mock(HttpSettings.class);
- listener = spy(ActionListener.noop());
when(threadPool.executor(UTILITY_THREAD_POOL_NAME)).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE);
when(settings.getMaxResponseSize()).thenReturn(ByteSizeValue.ofBytes(maxBytes));
- publisher = new StreamingHttpResultPublisher(threadPool, settings, listener);
+ publisher = new StreamingHttpResultPublisher(threadPool, settings, listener());
+ }
+
+ private ActionListener listener() {
+ return ActionListener.wrap(r -> result.set(Tuple.tuple(r, null)), e -> result.set(Tuple.tuple(null, e)));
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ super.tearDown();
+ result.set(null);
}
/**
@@ -76,7 +87,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
*/
public void testFirstResponseCallsListener() throws IOException {
var latch = new CountDownLatch(1);
- var listener = ActionTestUtils.>assertNoFailureListener(r -> latch.countDown());
+ var listener = ActionTestUtils.assertNoFailureListener(r -> latch.countDown());
publisher = new StreamingHttpResultPublisher(threadPool, settings, listener);
publisher.responseReceived(mock(HttpResponse.class));
@@ -92,7 +103,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
*/
public void testNonEmptyFirstResponseCallsListener() throws IOException {
var latch = new CountDownLatch(1);
- var listener = ActionTestUtils.>assertNoFailureListener(r -> latch.countDown());
+ var listener = ActionTestUtils.assertNoFailureListener(r -> latch.countDown());
publisher = new StreamingHttpResultPublisher(threadPool, settings, listener);
when(settings.getMaxResponseSize()).thenReturn(ByteSizeValue.ofBytes(9000));
@@ -127,7 +138,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
// subscribe
- publisher.subscribe(subscriber);
+ testPublisher().subscribe(subscriber);
assertThat("subscribe must call onSubscribe", subscriber.subscription, notNullValue());
assertThat("onNext should only be called once we have requested data", subscriber.httpResult, nullValue());
@@ -142,7 +153,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
// publisher sends data
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
- assertThat("onNext was called with " + new String(message, StandardCharsets.UTF_8), subscriber.httpResult.body(), equalTo(message));
+ assertThat("onNext was called with " + new String(message, StandardCharsets.UTF_8), subscriber.httpResult, equalTo(message));
}
/**
@@ -157,7 +168,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
publisher.close();
// subscriber requests data
- publisher.subscribe(subscriber);
+ testPublisher().subscribe(subscriber);
assertThat("subscribe must call onSubscribe", subscriber.subscription, notNullValue());
subscriber.requestData();
assertThat("onNext was called with the initial HttpResponse", subscriber.httpResult, notNullValue());
@@ -187,7 +198,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
public void testResumeApache() throws IOException {
var subscriber = new TestSubscriber();
publisher.responseReceived(mock(HttpResponse.class));
- publisher.subscribe(subscriber);
+ testPublisher().subscribe(subscriber);
subscriber.requestData();
subscriber.httpResult = null;
@@ -212,7 +223,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
var subscriber = new TestSubscriber();
publisher.responseReceived(mock(HttpResponse.class));
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
- publisher.subscribe(subscriber);
+ testPublisher().subscribe(subscriber);
subscriber.requestData();
subscriber.httpResult = null;
@@ -243,7 +254,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
public void testErrorBeforeRequest() {
var exception = new NullPointerException("test");
publisher.failed(exception);
- verify(listener).onFailure(exception);
+ assertThat(result.get().v2(), equalTo(exception));
}
/**
@@ -361,21 +372,17 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
* When cancel is called
* Then we only send onComplete once
*/
- public void testCancelIsIdempotent() throws IOException {
- Flow.Subscriber subscriber = mock();
-
- var subscription = ArgumentCaptor.forClass(Flow.Subscription.class);
- publisher.subscribe(subscriber);
- verify(subscriber).onSubscribe(subscription.capture());
+ public void testCancelIsIdempotent() {
+ Flow.Subscriber subscriber = mock();
publisher.responseReceived(mock());
- publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
- subscription.getValue().request(1);
- subscription.getValue().request(1);
+ var subscription = ArgumentCaptor.forClass(Flow.Subscription.class);
+ testPublisher().subscribe(subscriber);
+ verify(subscriber).onSubscribe(subscription.capture());
+
+ subscription.getValue().request(2);
publisher.cancel();
- verify(subscriber, times(1)).onComplete();
- subscription.getValue().request(1);
publisher.cancel();
verify(subscriber, times(1)).onComplete();
}
@@ -384,21 +391,17 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
* When close is called
* Then we only send onComplete once
*/
- public void testCloseIsIdempotent() throws IOException {
- Flow.Subscriber subscriber = mock();
-
- var subscription = ArgumentCaptor.forClass(Flow.Subscription.class);
- publisher.subscribe(subscriber);
- verify(subscriber).onSubscribe(subscription.capture());
+ public void testCloseIsIdempotent() {
+ Flow.Subscriber subscriber = mock();
publisher.responseReceived(mock());
- publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
- subscription.getValue().request(1);
- subscription.getValue().request(1);
+ var subscription = ArgumentCaptor.forClass(Flow.Subscription.class);
+ testPublisher().subscribe(subscriber);
+ verify(subscriber).onSubscribe(subscription.capture());
+
+ subscription.getValue().request(2);
publisher.close();
- verify(subscriber, times(1)).onComplete();
- subscription.getValue().request(1);
publisher.close();
verify(subscriber, times(1)).onComplete();
}
@@ -409,20 +412,16 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
*/
public void testFailedIsIdempotent() throws IOException {
var expectedException = new IllegalStateException("wow");
- Flow.Subscriber subscriber = mock();
-
- var subscription = ArgumentCaptor.forClass(Flow.Subscription.class);
- publisher.subscribe(subscriber);
- verify(subscriber).onSubscribe(subscription.capture());
+ Flow.Subscriber subscriber = mock();
publisher.responseReceived(mock());
- publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
- subscription.getValue().request(1);
- subscription.getValue().request(1);
+ var subscription = ArgumentCaptor.forClass(Flow.Subscription.class);
+ testPublisher().subscribe(subscriber);
+ verify(subscriber).onSubscribe(subscription.capture());
+
+ subscription.getValue().request(2);
publisher.failed(expectedException);
- verify(subscriber, times(1)).onError(eq(expectedException));
- subscription.getValue().request(1);
publisher.failed(expectedException);
verify(subscriber, times(1)).onError(eq(expectedException));
}
@@ -492,10 +491,11 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
* Then that subscriber should receive an IllegalStateException
*/
public void testDoubleSubscribeFails() {
- publisher.subscribe(mock());
+ publisher.responseReceived(mock());
+ testPublisher().subscribe(mock());
var subscriber = new TestSubscriber();
- publisher.subscribe(subscriber);
+ testPublisher().subscribe(subscriber);
assertThat(subscriber.throwable, notNullValue());
assertThat(subscriber.throwable, instanceOf(IllegalStateException.class));
}
@@ -508,10 +508,10 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
public void testReuseMlThread() throws ExecutionException, InterruptedException, TimeoutException {
try {
threadPool = spy(createThreadPool(inferenceUtilityPool()));
- publisher = new StreamingHttpResultPublisher(threadPool, settings, listener);
+ publisher = new StreamingHttpResultPublisher(threadPool, settings, listener());
var subscriber = new TestSubscriber();
publisher.responseReceived(mock(HttpResponse.class));
- publisher.subscribe(subscriber);
+ testPublisher().subscribe(subscriber);
CompletableFuture.runAsync(() -> {
try {
@@ -523,7 +523,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
}, threadPool.executor(UTILITY_THREAD_POOL_NAME)).get(5, TimeUnit.SECONDS);
verify(threadPool, times(1)).executor(UTILITY_THREAD_POOL_NAME);
assertThat("onNext was called with the initial HttpResponse", subscriber.httpResult, notNullValue());
- assertFalse("Expected HttpResult to have data", subscriber.httpResult.isBodyEmpty());
+ assertNotEquals("Expected HttpResult to have data", 0, subscriber.httpResult.length);
} finally {
terminate(threadPool);
}
@@ -548,11 +548,11 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
return executorServiceSpy;
}).when(threadPool).executor(UTILITY_THREAD_POOL_NAME);
- publisher = new StreamingHttpResultPublisher(threadPool, settings, listener);
+ publisher = new StreamingHttpResultPublisher(threadPool, settings, listener());
publisher.responseReceived(mock(HttpResponse.class));
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
// create an infinitely running Subscriber
- var subscriber = new Flow.Subscriber() {
+ var subscriber = new Flow.Subscriber() {
Flow.Subscription subscription;
boolean completed = false;
@@ -563,7 +563,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
}
@Override
- public void onNext(HttpResult item) {
+ public void onNext(byte[] item) {
try {
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
} catch (IOException e) {
@@ -582,7 +582,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
completed = true;
}
};
- publisher.subscribe(subscriber);
+ testPublisher().subscribe(subscriber);
// verify the thread has started
assertThat("Thread should have started on subscribe", futureHolder.get(), notNullValue());
@@ -604,7 +604,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
// start with a message published
publisher.responseReceived(mock(HttpResponse.class));
TestSubscriber subscriber = new TestSubscriber() {
- public void onNext(HttpResult item) {
+ public void onNext(byte[] item) {
try {
// publish a second message
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
@@ -617,7 +617,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
super.onNext(item);
}
};
- publisher.subscribe(subscriber);
+ testPublisher().subscribe(subscriber);
verify(threadPool, times(0)).executor(UTILITY_THREAD_POOL_NAME);
subscriber.requestData();
@@ -646,8 +646,9 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
}
private TestSubscriber subscribe() {
+ publisher.responseReceived(mock());
var subscriber = new TestSubscriber();
- publisher.subscribe(subscriber);
+ testPublisher().subscribe(subscriber);
return subscriber;
}
@@ -655,12 +656,12 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
publisher.responseReceived(mock(HttpResponse.class));
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
TestSubscriber subscriber = new TestSubscriber() {
- public void onNext(HttpResult item) {
+ public void onNext(byte[] item) {
runDuringOnNext.run();
super.onNext(item);
}
};
- publisher.subscribe(subscriber);
+ testPublisher().subscribe(subscriber);
return subscriber;
}
@@ -668,19 +669,23 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
publisher.responseReceived(mock(HttpResponse.class));
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
TestSubscriber subscriber = new TestSubscriber() {
- public void onNext(HttpResult item) {
+ public void onNext(byte[] item) {
runDuringOnNext.run();
super.requestData();
super.onNext(item);
}
};
- publisher.subscribe(subscriber);
+ testPublisher().subscribe(subscriber);
return subscriber;
}
- private static class TestSubscriber implements Flow.Subscriber {
+ private Flow.Publisher testPublisher() {
+ return result.get().v1().body();
+ }
+
+ private static class TestSubscriber implements Flow.Subscriber {
private Flow.Subscription subscription;
- private HttpResult httpResult;
+ private byte[] httpResult;
private Throwable throwable;
private boolean completed;
@@ -690,7 +695,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
}
@Override
- public void onNext(HttpResult item) {
+ public void onNext(byte[] item) {
this.httpResult = item;
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSenderTests.java
index 7e5b8e680836..96401f014081 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSenderTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSenderTests.java
@@ -24,15 +24,18 @@ import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.xpack.inference.external.http.HttpClient;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.http.StreamingHttpResult;
import org.elasticsearch.xpack.inference.external.request.HttpRequestTests;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.junit.Before;
+import org.mockito.ArgumentMatchers;
import org.mockito.stubbing.Answer;
import java.io.IOException;
import java.net.UnknownHostException;
import java.util.concurrent.Flow;
+import java.util.concurrent.atomic.AtomicInteger;
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.createDefaultRetrySettings;
import static org.hamcrest.Matchers.instanceOf;
@@ -42,7 +45,6 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.only;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
@@ -459,12 +461,12 @@ public class RetryingHttpSenderTests extends ESTestCase {
verifyNoMoreInteractions(httpClient);
}
- public void testStream() throws IOException {
+ public void testStreamSuccess() throws IOException {
var httpClient = mock(HttpClient.class);
- Flow.Publisher publisher = mock();
+ StreamingHttpResult streamingHttpResult = new StreamingHttpResult(mockHttpResponse(), randomPublisher());
doAnswer(ans -> {
- ActionListener> listener = ans.getArgument(2);
- listener.onResponse(publisher);
+ ActionListener listener = ans.getArgument(2);
+ listener.onResponse(streamingHttpResult);
return null;
}).when(httpClient).stream(any(), any(), any());
@@ -479,7 +481,28 @@ public class RetryingHttpSenderTests extends ESTestCase {
verify(httpClient, times(1)).stream(any(), any(), any());
verifyNoMoreInteractions(httpClient);
- verify(publisher, only()).subscribe(any(StreamingResponseHandler.class));
+ verify(responseHandler, times(1)).parseResult(any(), ArgumentMatchers.>any());
+ }
+
+ private Flow.Publisher randomPublisher() {
+ var calls = new AtomicInteger(randomIntBetween(1, 4));
+ return subscriber -> {
+ subscriber.onSubscribe(new Flow.Subscription() {
+ @Override
+ public void request(long n) {
+ if (calls.getAndDecrement() > 0) {
+ subscriber.onNext(randomByteArrayOfLength(3));
+ } else {
+ subscriber.onComplete();
+ }
+ }
+
+ @Override
+ public void cancel() {
+
+ }
+ });
+ };
}
public void testStream_ResponseHandlerDoesNotHandleStreams() throws IOException {
@@ -549,6 +572,44 @@ public class RetryingHttpSenderTests extends ESTestCase {
}
}
+ public void testStream_DoesNotRetryIndefinitely() throws IOException {
+ var threadPool = new TestThreadPool(getTestName());
+ try {
+ var httpClient = mock(HttpClient.class);
+ doAnswer(ans -> {
+ ActionListener listener = ans.getArgument(2);
+ listener.onFailure(new ConnectionClosedException("failed"));
+ return null;
+ }).when(httpClient).stream(any(), any(), any());
+
+ var handler = mock(ResponseHandler.class);
+ when(handler.canHandleStreamingResponses()).thenReturn(true);
+
+ var retrier = new RetryingHttpSender(
+ httpClient,
+ mock(ThrottlerManager.class),
+ createDefaultRetrySettings(),
+ threadPool,
+ EsExecutors.DIRECT_EXECUTOR_SERVICE
+ );
+
+ var listener = new PlainActionFuture();
+ var request = mockRequest();
+ when(request.isStreaming()).thenReturn(true);
+ retrier.send(mock(Logger.class), request, () -> false, handler, listener);
+
+ // Assert that the retrying sender stopped after max retires even though the exception is retryable
+ var thrownException = expectThrows(UncategorizedExecutionException.class, () -> listener.actionGet(TIMEOUT));
+ assertThat(thrownException.getCause(), instanceOf(ConnectionClosedException.class));
+ assertThat(thrownException.getMessage(), is("Failed execution"));
+ assertThat(thrownException.getSuppressed().length, is(0));
+ verify(httpClient, times(RetryingHttpSender.MAX_RETIES)).stream(any(), any(), any());
+ verifyNoMoreInteractions(httpClient);
+ } finally {
+ terminate(threadPool);
+ }
+ }
+
public void testSend_DoesNotRetryIndefinitely_WithAlwaysRetryingResponseHandler() throws IOException {
var threadPool = new TestThreadPool(getTestName());
try {
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/StreamingResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/StreamingResponseHandlerTests.java
deleted file mode 100644
index 6894c9a715f2..000000000000
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/StreamingResponseHandlerTests.java
+++ /dev/null
@@ -1,127 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.inference.external.http.retry;
-
-import org.apache.http.HttpResponse;
-import org.apache.http.StatusLine;
-import org.apache.logging.log4j.Logger;
-import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.xpack.inference.external.http.HttpResult;
-import org.elasticsearch.xpack.inference.external.request.Request;
-import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
-import org.junit.After;
-import org.junit.Before;
-import org.mockito.ArgumentCaptor;
-import org.mockito.InjectMocks;
-import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
-
-import java.util.concurrent.Flow;
-
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.same;
-import static org.mockito.Mockito.doThrow;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
-
-public class StreamingResponseHandlerTests extends ESTestCase {
- @Mock
- private HttpResponse response;
- @Mock
- private ThrottlerManager throttlerManager;
- @Mock
- private Logger logger;
- @Mock
- private Request request;
- @Mock
- private ResponseHandler responseHandler;
- @Mock
- private Flow.Subscriber downstreamSubscriber;
- @InjectMocks
- private StreamingResponseHandler streamingResponseHandler;
- private AutoCloseable mocks;
- private HttpResult item;
-
- @Before
- public void setUp() throws Exception {
- super.setUp();
- mocks = MockitoAnnotations.openMocks(this);
- item = new HttpResult(response, new byte[0]);
- }
-
- @After
- public void tearDown() throws Exception {
- super.tearDown();
- mocks.close();
- }
-
- public void testResponseHandlerFailureIsForwardedToSubscriber() {
- var upstreamSubscription = upstreamSubscription();
- var expectedException = new RetryException(true, "ah");
- doThrow(expectedException).when(responseHandler).validateResponse(any(), any(), any(), any());
-
- var statusLine = mock(StatusLine.class);
- when(statusLine.getStatusCode()).thenReturn(404);
- when(statusLine.getReasonPhrase()).thenReturn("not found");
- when(response.getStatusLine()).thenReturn(statusLine);
-
- streamingResponseHandler.onNext(item);
-
- verify(upstreamSubscription, times(1)).cancel();
- verify(downstreamSubscriber, times(1)).onError(expectedException);
- }
-
- @SuppressWarnings("unchecked")
- private Flow.Subscription upstreamSubscription() {
- var upstreamSubscription = mock(Flow.Subscription.class);
- streamingResponseHandler.onSubscribe(upstreamSubscription);
- streamingResponseHandler.subscribe(downstreamSubscriber);
- return upstreamSubscription;
- }
-
- public void testOnNextCallsDownstream() {
- upstreamSubscription();
-
- streamingResponseHandler.onNext(item);
-
- verify(downstreamSubscriber, times(1)).onNext(item);
- }
-
- public void testCompleteForwardsComplete() {
- upstreamSubscription();
-
- streamingResponseHandler.onComplete();
-
- verify(downstreamSubscriber, times(1)).onSubscribe(any());
- verify(downstreamSubscriber, times(1)).onComplete();
- }
-
- public void testErrorForwardsError() {
- var expectedError = new RetryException(false, "ah");
- upstreamSubscription();
-
- streamingResponseHandler.onError(expectedError);
-
- verify(downstreamSubscriber, times(1)).onSubscribe(any());
- verify(downstreamSubscriber, times(1)).onError(same(expectedError));
- }
-
- public void testSubscriptionForwardsRequest() {
- var upstreamSubscription = upstreamSubscription();
-
- var downstream = ArgumentCaptor.forClass(Flow.Subscription.class);
- verify(downstreamSubscriber, times(1)).onSubscribe(downstream.capture());
- var downstreamSubscription = downstream.getValue();
-
- var requestCount = randomIntBetween(2, 200);
- downstreamSubscription.request(requestCount);
- verify(upstreamSubscription, times(1)).request(requestCount);
- }
-}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java
index f48cf3b9f485..405aba35e8d3 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java
@@ -554,13 +554,11 @@ public class AnthropicServiceTests extends ESTestCase {
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
- var result = streamChatCompletion();
-
- InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent("""
+ streamChatCompletion().hasNoErrors().hasEvent("""
{"completion":[{"delta":"Hello"},{"delta":", World"}]}""");
}
- private InferenceServiceResults streamChatCompletion() throws IOException {
+ private InferenceEventsAssertion streamChatCompletion() throws Exception {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = new AnthropicService(senderFactory, createWithEmptySettings(threadPool))) {
var model = AnthropicChatCompletionModelTests.createChatCompletionModel(
@@ -581,7 +579,7 @@ public class AnthropicServiceTests extends ESTestCase {
listener
);
- return listener.actionGet(TIMEOUT);
+ return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream();
}
}
@@ -592,11 +590,7 @@ public class AnthropicServiceTests extends ESTestCase {
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
- var result = streamChatCompletion();
-
- InferenceEventsAssertion.assertThat(result)
- .hasFinishedStream()
- .hasNoEvents()
+ streamChatCompletion().hasNoEvents()
.hasErrorWithStatusCode(RestStatus.REQUEST_ENTITY_TOO_LARGE.getStatus())
.hasErrorContaining("blah");
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java
index 045789a92bf3..cdd8494c9b34 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java
@@ -61,7 +61,6 @@ import org.junit.After;
import org.junit.Before;
import java.io.IOException;
-import java.net.URISyntaxException;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
@@ -1345,13 +1344,11 @@ public class AzureAiStudioServiceTests extends ESTestCase {
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
- var result = streamChatCompletion();
-
- InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent("""
+ streamChatCompletion().hasNoErrors().hasEvent("""
{"completion":[{"delta":"hello, world"}]}""");
}
- private InferenceServiceResults streamChatCompletion() throws IOException, URISyntaxException {
+ private InferenceEventsAssertion streamChatCompletion() throws Exception {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
var model = AzureAiStudioChatCompletionModelTests.createModel(
@@ -1373,7 +1370,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
listener
);
- return listener.actionGet(TIMEOUT);
+ return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream();
}
}
@@ -1389,13 +1386,14 @@ public class AzureAiStudioServiceTests extends ESTestCase {
}""";
webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson));
- var result = streamChatCompletion();
-
- InferenceEventsAssertion.assertThat(result)
- .hasFinishedStream()
- .hasNoEvents()
- .hasErrorWithStatusCode(401)
- .hasErrorContaining("You didn't provide an API key...");
+ var e = assertThrows(ElasticsearchStatusException.class, this::streamChatCompletion);
+ assertThat(
+ e.getMessage(),
+ equalTo(
+ "Received an authentication error status code for request from inference entity id [id] status [401]. "
+ + "Error message: [You didn't provide an API key...]"
+ )
+ );
}
@SuppressWarnings("checkstyle:LineLength")
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java
index e58a7049ef87..7ee595cddf08 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java
@@ -1410,13 +1410,11 @@ public class AzureOpenAiServiceTests extends ESTestCase {
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
- var result = streamChatCompletion();
-
- InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent("""
+ streamChatCompletion().hasNoErrors().hasEvent("""
{"completion":[{"delta":"hello, world"}]}""");
}
- private InferenceServiceResults streamChatCompletion() throws IOException, URISyntaxException {
+ private InferenceEventsAssertion streamChatCompletion() throws Exception {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) {
var model = AzureOpenAiCompletionModelTests.createCompletionModel(
@@ -1441,7 +1439,7 @@ public class AzureOpenAiServiceTests extends ESTestCase {
listener
);
- return listener.actionGet(TIMEOUT);
+ return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream();
}
}
@@ -1457,13 +1455,14 @@ public class AzureOpenAiServiceTests extends ESTestCase {
}""";
webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson));
- var result = streamChatCompletion();
-
- InferenceEventsAssertion.assertThat(result)
- .hasFinishedStream()
- .hasNoEvents()
- .hasErrorWithStatusCode(401)
- .hasErrorContaining("You didn't provide an API key...");
+ var e = assertThrows(ElasticsearchStatusException.class, this::streamChatCompletion);
+ assertThat(
+ e.getMessage(),
+ equalTo(
+ "Received an authentication error status code for request from inference entity id [id] status [401]."
+ + " Error message: [You didn't provide an API key...]"
+ )
+ );
}
@SuppressWarnings("checkstyle:LineLength")
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
index bd1dbc201f52..7d959b9bff0a 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
@@ -1611,13 +1611,11 @@ public class CohereServiceTests extends ESTestCase {
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
- var result = streamChatCompletion();
-
- InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent("""
+ streamChatCompletion().hasNoErrors().hasEvent("""
{"completion":[{"delta":"hello"},{"delta":"there"}]}""");
}
- private InferenceServiceResults streamChatCompletion() throws IOException {
+ private InferenceEventsAssertion streamChatCompletion() throws Exception {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) {
var model = CohereCompletionModelTests.createModel(getUrl(webServer), "secret", "model");
@@ -1633,7 +1631,7 @@ public class CohereServiceTests extends ESTestCase {
listener
);
- return listener.actionGet(TIMEOUT);
+ return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream();
}
}
@@ -1643,13 +1641,7 @@ public class CohereServiceTests extends ESTestCase {
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
- var result = streamChatCompletion();
-
- InferenceEventsAssertion.assertThat(result)
- .hasFinishedStream()
- .hasNoEvents()
- .hasErrorWithStatusCode(500)
- .hasErrorContaining("how dare you");
+ streamChatCompletion().hasNoEvents().hasErrorWithStatusCode(500).hasErrorContaining("how dare you");
}
@SuppressWarnings("checkstyle:LineLength")
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java
index b2ff028750e2..2ecd39b3991b 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java
@@ -1012,18 +1012,18 @@ public class ElasticInferenceServiceTests extends ESTestCase {
}
}
- public void testUnifiedCompletionError() throws Exception {
- testUnifiedStreamError(404, """
+ public void testUnifiedCompletionError() {
+ var e = assertThrows(UnifiedChatCompletionException.class, () -> testUnifiedStream(404, """
{
"error": "The model `rainbow-sprinkles` does not exist or you do not have access to it."
- }""", """
- {\
- "error":{\
- "code":"not_found",\
- "message":"Received an unsuccessful status code for request from inference entity id [id] status \
- [404]. Error message: [The model `rainbow-sprinkles` does not exist or you do not have access to it.]",\
- "type":"error"\
- }}""");
+ }"""));
+ assertThat(
+ e.getMessage(),
+ equalTo(
+ "Received an unsuccessful status code for request from inference entity id [id] status "
+ + "[404]. Error message: [The model `rainbow-sprinkles` does not exist or you do not have access to it.]"
+ )
+ );
}
public void testUnifiedCompletionErrorMidStream() throws Exception {
@@ -1054,6 +1054,25 @@ public class ElasticInferenceServiceTests extends ESTestCase {
}
private void testUnifiedStreamError(int responseCode, String responseJson, String expectedJson) throws Exception {
+ testUnifiedStream(responseCode, responseJson).hasNoEvents().hasErrorMatching(e -> {
+ e = unwrapCause(e);
+ assertThat(e, isA(UnifiedChatCompletionException.class));
+ try (var builder = XContentFactory.jsonBuilder()) {
+ ((UnifiedChatCompletionException) e).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
+ try {
+ xContent.toXContent(builder, EMPTY_PARAMS);
+ } catch (IOException ex) {
+ throw new RuntimeException(ex);
+ }
+ });
+ var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
+
+ assertThat(json, is(expectedJson));
+ }
+ });
+ }
+
+ private InferenceEventsAssertion testUnifiedStream(int responseCode, String responseJson) throws Exception {
var eisGatewayUrl = getUrl(webServer);
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = createService(senderFactory, eisGatewayUrl)) {
@@ -1077,24 +1096,7 @@ public class ElasticInferenceServiceTests extends ESTestCase {
listener
);
- var result = listener.actionGet(TIMEOUT);
-
- InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoEvents().hasErrorMatching(e -> {
- e = unwrapCause(e);
- assertThat(e, isA(UnifiedChatCompletionException.class));
- try (var builder = XContentFactory.jsonBuilder()) {
- ((UnifiedChatCompletionException) e).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
- try {
- xContent.toXContent(builder, EMPTY_PARAMS);
- } catch (IOException ex) {
- throw new RuntimeException(ex);
- }
- });
- var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
-
- assertThat(json, is(expectedJson));
- }
- });
+ return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream();
}
}
diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
index 13782a538f1f..d608f4a33ff5 100644
--- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
+++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
@@ -13,6 +13,7 @@ import org.apache.http.HttpHeaders;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.ActionTestUtils;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
@@ -28,6 +29,7 @@ import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
+import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
@@ -63,6 +65,7 @@ import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.ExceptionsHelper.unwrapCause;
@@ -1079,13 +1082,60 @@ public class OpenAiServiceTests extends ESTestCase {
}
}""";
webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson));
+
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+ try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) {
+ var model = OpenAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "org", "secret", "model", "user");
+ var latch = new CountDownLatch(1);
+ service.unifiedCompletionInfer(
+ model,
+ UnifiedCompletionRequest.of(
+ List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
+ ),
+ InferenceAction.Request.DEFAULT_TIMEOUT,
+ ActionListener.runAfter(ActionTestUtils.assertNoSuccessListener(e -> {
+ try (var builder = XContentFactory.jsonBuilder()) {
+ var t = unwrapCause(e);
+ assertThat(t, isA(UnifiedChatCompletionException.class));
+ ((UnifiedChatCompletionException) t).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
+ try {
+ xContent.toXContent(builder, EMPTY_PARAMS);
+ } catch (IOException ex) {
+ throw new RuntimeException(ex);
+ }
+ });
+ var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
+
+ assertThat(json, is("""
+ {\
+ "error":{\
+ "code":"model_not_found",\
+ "message":"Received an unsuccessful status code for request from inference entity id [id] status \
+ [404]. Error message: [The model `gpt-4awero` does not exist or you do not have access to it.]",\
+ "type":"invalid_request_error"\
+ }}"""));
+ } catch (IOException ex) {
+ throw new RuntimeException(ex);
+ }
+ }), latch::countDown)
+ );
+ assertTrue(latch.await(30, TimeUnit.SECONDS));
+ }
+ }
+
+ public void testMidStreamUnifiedCompletionError() throws Exception {
+ String responseJson = """
+ event: error
+ data: { "error": { "message": "Timed out waiting for more data", "type": "timeout" } }
+
+ """;
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
testStreamError("""
{\
"error":{\
- "code":"model_not_found",\
- "message":"Received an unsuccessful status code for request from inference entity id [id] status \
- [404]. Error message: [The model `gpt-4awero` does not exist or you do not have access to it.]",\
- "type":"invalid_request_error"\
+ "message":"Received an error response for request from inference entity id [id]. Error message: \
+ [Timed out waiting for more data]",\
+ "type":"timeout"\
}}""");
}
@@ -1124,22 +1174,6 @@ public class OpenAiServiceTests extends ESTestCase {
}
}
- public void testMidStreamUnifiedCompletionError() throws Exception {
- String responseJson = """
- event: error
- data: { "error": { "message": "Timed out waiting for more data", "type": "timeout" } }
-
- """;
- webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
- testStreamError("""
- {\
- "error":{\
- "message":"Received an error response for request from inference entity id [id]. Error message: \
- [Timed out waiting for more data]",\
- "type":"timeout"\
- }}""");
- }
-
public void testUnifiedCompletionMalformedError() throws Exception {
String responseJson = """
data: { invalid json }
@@ -1179,13 +1213,11 @@ public class OpenAiServiceTests extends ESTestCase {
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
- var result = streamCompletion();
-
- InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent("""
+ streamCompletion().hasNoErrors().hasEvent("""
{"completion":[{"delta":"hello, world"}]}""");
}
- private InferenceServiceResults streamCompletion() throws IOException {
+ private InferenceEventsAssertion streamCompletion() throws Exception {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) {
var model = OpenAiChatCompletionModelTests.createCompletionModel(getUrl(webServer), "org", "secret", "model", "user");
@@ -1201,7 +1233,7 @@ public class OpenAiServiceTests extends ESTestCase {
listener
);
- return listener.actionGet(TIMEOUT);
+ return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream();
}
}
@@ -1217,13 +1249,48 @@ public class OpenAiServiceTests extends ESTestCase {
}""";
webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson));
- var result = streamCompletion();
+ var e = assertThrows(ElasticsearchStatusException.class, this::streamCompletion);
+ assertThat(e.status(), equalTo(RestStatus.UNAUTHORIZED));
+ assertThat(
+ e.getMessage(),
+ equalTo(
+ "Received an authentication error status code for request from inference entity id [id] status [401]. "
+ + "Error message: [You didn't provide an API key...]"
+ )
+ );
+ }
- InferenceEventsAssertion.assertThat(result)
- .hasFinishedStream()
- .hasNoEvents()
- .hasErrorWithStatusCode(401)
- .hasErrorContaining("You didn't provide an API key...");
+ public void testInfer_StreamRequestRetry() throws Exception {
+ webServer.enqueue(new MockResponse().setResponseCode(503).setBody("""
+ {
+ "error": {
+ "message": "server busy",
+ "type": "server_busy"
+ }
+ }"""));
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody("""
+ data: {\
+ "id":"12345",\
+ "object":"chat.completion.chunk",\
+ "created":123456789,\
+ "model":"gpt-4o-mini",\
+ "system_fingerprint": "123456789",\
+ "choices":[\
+ {\
+ "index":0,\
+ "delta":{\
+ "content":"hello, world"\
+ },\
+ "logprobs":null,\
+ "finish_reason":null\
+ }\
+ ]\
+ }
+
+ """));
+
+ streamCompletion().hasNoErrors().hasEvent("""
+ {"completion":[{"delta":"hello, world"}]}""");
}
public void testSupportsStreaming() throws IOException {