mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 09:28:55 -04:00
[ML] Retry on streaming errors (#123076)
We now always retry based on the provider's configured retry logic rather than the HTTP status code. Some providers (e.g. Cohere, Anthropic) will return 200 status codes with error bodies, others (e.g. OpenAI, Azure) will return non-200 status codes with non-streaming bodies. Notes: - Refactored from HttpResult to StreamingHttpResult, the byte body is now the streaming element while the http response lives outside the stream. - Refactored StreamingHttpResultPublisher so that it only pushes byte body into a queue. - Tests all now have to wait for the response to be fully consumed before closing the service, otherwise the close method will shut down the mock web server and apache will throw an error.
This commit is contained in:
parent
6a36f31f41
commit
dfe2adb592
16 changed files with 548 additions and 553 deletions
5
docs/changelog/123076.yaml
Normal file
5
docs/changelog/123076.yaml
Normal file
|
@ -0,0 +1,5 @@
|
|||
pr: 123076
|
||||
summary: Retry on streaming errors
|
||||
area: Machine Learning
|
||||
type: bug
|
||||
issues: []
|
|
@ -27,7 +27,6 @@ import java.io.Closeable;
|
|||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.CancellationException;
|
||||
import java.util.concurrent.Flow;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import static org.elasticsearch.core.Strings.format;
|
||||
|
@ -154,7 +153,7 @@ public class HttpClient implements Closeable {
|
|||
threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> listener.onFailure(exception));
|
||||
}
|
||||
|
||||
public void stream(HttpRequest request, HttpContext context, ActionListener<Flow.Publisher<HttpResult>> listener) throws IOException {
|
||||
public void stream(HttpRequest request, HttpContext context, ActionListener<StreamingHttpResult> listener) throws IOException {
|
||||
// The caller must call start() first before attempting to send a request
|
||||
assert status.get() == Status.STARTED : "call start() before attempting to send a request";
|
||||
|
||||
|
@ -162,7 +161,7 @@ public class HttpClient implements Closeable {
|
|||
|
||||
SocketAccess.doPrivileged(() -> client.execute(request.requestProducer(), streamingProcessor, context, new FutureCallback<>() {
|
||||
@Override
|
||||
public void completed(HttpResponse response) {
|
||||
public void completed(Void response) {
|
||||
streamingProcessor.close();
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.http;
|
||||
|
||||
import org.apache.http.HttpResponse;
|
||||
import org.elasticsearch.ExceptionsHelper;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.util.concurrent.Flow;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
public record StreamingHttpResult(HttpResponse response, Flow.Publisher<byte[]> body) {
|
||||
public boolean isSuccessfulResponse() {
|
||||
return RestStatus.isSuccessful(response.getStatusLine().getStatusCode());
|
||||
}
|
||||
|
||||
public Flow.Publisher<HttpResult> toHttpResult() {
|
||||
return subscriber -> body().subscribe(new Flow.Subscriber<>() {
|
||||
@Override
|
||||
public void onSubscribe(Flow.Subscription subscription) {
|
||||
subscriber.onSubscribe(subscription);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onNext(byte[] item) {
|
||||
subscriber.onNext(new HttpResult(response(), item));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable throwable) {
|
||||
subscriber.onError(throwable);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete() {
|
||||
subscriber.onComplete();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public void readFullResponse(ActionListener<HttpResult> fullResponse) {
|
||||
var stream = new ByteArrayOutputStream();
|
||||
AtomicReference<Flow.Subscription> upstream = new AtomicReference<>(null);
|
||||
body.subscribe(new Flow.Subscriber<>() {
|
||||
@Override
|
||||
public void onSubscribe(Flow.Subscription subscription) {
|
||||
upstream.set(subscription);
|
||||
upstream.get().request(1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onNext(byte[] item) {
|
||||
stream.writeBytes(item);
|
||||
upstream.get().request(1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable throwable) {
|
||||
ExceptionsHelper.maybeError(throwable).ifPresent(ExceptionsHelper::maybeDieOnAnotherThread);
|
||||
fullResponse.onFailure(new RuntimeException("Fatal while fully consuming stream", throwable));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete() {
|
||||
fullResponse.onResponse(new HttpResult(response, stream.toByteArray()));
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
|
@ -13,6 +13,7 @@ import org.apache.http.nio.IOControl;
|
|||
import org.apache.http.nio.protocol.HttpAsyncResponseConsumer;
|
||||
import org.apache.http.nio.util.SimpleInputBuffer;
|
||||
import org.apache.http.protocol.HttpContext;
|
||||
import org.elasticsearch.ExceptionsHelper;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
|
||||
|
@ -39,51 +40,31 @@ import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_P
|
|||
* so this publisher will send a single HttpResult. If the HttpResponse is healthy, Apache will send an HttpResponse with or without
|
||||
* the HttpEntity.</p>
|
||||
*/
|
||||
class StreamingHttpResultPublisher implements HttpAsyncResponseConsumer<HttpResponse>, Flow.Publisher<HttpResult> {
|
||||
private final HttpSettings settings;
|
||||
private final ActionListener<Flow.Publisher<HttpResult>> listener;
|
||||
class StreamingHttpResultPublisher implements HttpAsyncResponseConsumer<Void> {
|
||||
private final ActionListener<StreamingHttpResult> listener;
|
||||
private final AtomicBoolean listenerCalled = new AtomicBoolean(false);
|
||||
|
||||
// used to manage the HTTP response
|
||||
private volatile HttpResponse response;
|
||||
private volatile Exception ex;
|
||||
|
||||
// used to control the state of this publisher (Apache) and its interaction with its subscriber
|
||||
private final AtomicBoolean isDone = new AtomicBoolean(false);
|
||||
private final AtomicBoolean subscriptionCanceled = new AtomicBoolean(false);
|
||||
private volatile Flow.Subscriber<? super HttpResult> subscriber;
|
||||
|
||||
private final RequestBasedTaskRunner taskRunner;
|
||||
private final AtomicBoolean pendingRequest = new AtomicBoolean(false);
|
||||
private final Deque<Runnable> queue = new ConcurrentLinkedDeque<>();
|
||||
|
||||
// used to control the flow of data from the Apache client, if we're producing more bytes than we can consume then we'll pause
|
||||
private final SimpleInputBuffer inputBuffer = new SimpleInputBuffer(4096);
|
||||
private final AtomicLong bytesInQueue = new AtomicLong(0);
|
||||
private final Object ioLock = new Object();
|
||||
private volatile IOControl savedIoControl;
|
||||
private final DataPublisher publisher;
|
||||
private final ApacheClientBackpressure backpressure;
|
||||
|
||||
StreamingHttpResultPublisher(ThreadPool threadPool, HttpSettings settings, ActionListener<Flow.Publisher<HttpResult>> listener) {
|
||||
this.settings = Objects.requireNonNull(settings);
|
||||
private volatile Exception exception;
|
||||
|
||||
StreamingHttpResultPublisher(ThreadPool threadPool, HttpSettings settings, ActionListener<StreamingHttpResult> listener) {
|
||||
this.listener = ActionListener.notifyOnce(Objects.requireNonNull(listener));
|
||||
|
||||
this.taskRunner = new RequestBasedTaskRunner(new OffloadThread(), threadPool, UTILITY_THREAD_POOL_NAME);
|
||||
this.publisher = new DataPublisher(threadPool);
|
||||
this.backpressure = new ApacheClientBackpressure(Objects.requireNonNull(settings));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void responseReceived(HttpResponse httpResponse) {
|
||||
this.response = httpResponse;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void subscribe(Flow.Subscriber<? super HttpResult> subscriber) {
|
||||
if (this.subscriber != null) {
|
||||
subscriber.onError(new IllegalStateException("Only one subscriber is allowed for this Publisher."));
|
||||
return;
|
||||
if (listenerCalled.compareAndSet(false, true)) {
|
||||
listener.onResponse(new StreamingHttpResult(httpResponse, publisher));
|
||||
}
|
||||
|
||||
this.subscriber = subscriber;
|
||||
subscriber.onSubscribe(new HttpSubscription());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -100,49 +81,20 @@ class StreamingHttpResultPublisher implements HttpAsyncResponseConsumer<HttpResp
|
|||
if (consumed > 0) {
|
||||
var allBytes = new byte[consumed];
|
||||
inputBuffer.read(allBytes);
|
||||
queue.offer(() -> {
|
||||
subscriber.onNext(new HttpResult(response, allBytes));
|
||||
var currentBytesInQueue = bytesInQueue.updateAndGet(current -> Long.max(0, current - allBytes.length));
|
||||
if (savedIoControl != null) {
|
||||
var maxBytes = settings.getMaxResponseSize().getBytes() * 0.5;
|
||||
if (currentBytesInQueue <= maxBytes) {
|
||||
resumeProducer();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// always check if totalByteSize > the configured setting in case the settings change
|
||||
if (bytesInQueue.accumulateAndGet(allBytes.length, Long::sum) >= settings.getMaxResponseSize().getBytes()) {
|
||||
pauseProducer(ioControl);
|
||||
}
|
||||
|
||||
taskRunner.requestNextRun();
|
||||
|
||||
if (listenerCalled.compareAndSet(false, true)) {
|
||||
listener.onResponse(this);
|
||||
}
|
||||
backpressure.addBytesAndMaybePause(consumed, ioControl);
|
||||
publisher.onNext(allBytes);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
// if the provider closes the connection in the middle of the stream,
|
||||
// the contentDecoder will throw an exception trying to read the payload,
|
||||
// we should catch that and forward it downstream so we can properly handle it
|
||||
exception = e;
|
||||
publisher.onError(e);
|
||||
} finally {
|
||||
inputBuffer.reset();
|
||||
}
|
||||
}
|
||||
|
||||
private void pauseProducer(IOControl ioControl) {
|
||||
ioControl.suspendInput();
|
||||
synchronized (ioLock) {
|
||||
savedIoControl = ioControl;
|
||||
}
|
||||
}
|
||||
|
||||
private void resumeProducer() {
|
||||
synchronized (ioLock) {
|
||||
if (savedIoControl != null) {
|
||||
savedIoControl.requestInput();
|
||||
savedIoControl = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void responseCompleted(HttpContext httpContext) {}
|
||||
|
||||
|
@ -153,9 +105,8 @@ class StreamingHttpResultPublisher implements HttpAsyncResponseConsumer<HttpResp
|
|||
if (listenerCalled.compareAndSet(false, true)) {
|
||||
listener.onFailure(e);
|
||||
} else {
|
||||
ex = e;
|
||||
queue.offer(() -> subscriber.onError(e));
|
||||
taskRunner.requestNextRun();
|
||||
exception = e;
|
||||
publisher.onError(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -164,8 +115,7 @@ class StreamingHttpResultPublisher implements HttpAsyncResponseConsumer<HttpResp
|
|||
@Override
|
||||
public void close() {
|
||||
if (isDone.compareAndSet(false, true)) {
|
||||
queue.offer(() -> subscriber.onComplete());
|
||||
taskRunner.requestNextRun();
|
||||
publisher.onComplete();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -178,12 +128,12 @@ class StreamingHttpResultPublisher implements HttpAsyncResponseConsumer<HttpResp
|
|||
|
||||
@Override
|
||||
public Exception getException() {
|
||||
return ex;
|
||||
return exception;
|
||||
}
|
||||
|
||||
@Override
|
||||
public HttpResponse getResult() {
|
||||
return response;
|
||||
public Void getResult() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -191,44 +141,148 @@ class StreamingHttpResultPublisher implements HttpAsyncResponseConsumer<HttpResp
|
|||
return isDone.get();
|
||||
}
|
||||
|
||||
private class HttpSubscription implements Flow.Subscription {
|
||||
@Override
|
||||
public void request(long n) {
|
||||
if (subscriptionCanceled.get()) {
|
||||
/**
|
||||
* We only want to push payload data when the client is ready to receive it, so the client will use
|
||||
* {@link Flow.Subscription#request(long)} to request more data. We collect the payload bytes in a queue and process them on a
|
||||
* separate thread from both the Apache IO thread reading from the provider and the client's transport thread requesting more data.
|
||||
* Clients use {@link Flow.Subscription#cancel()} to exit early, and we'll forward that cancellation to the provider.
|
||||
*/
|
||||
private class DataPublisher implements Flow.Processor<byte[], byte[]> {
|
||||
private final RequestBasedTaskRunner taskRunner;
|
||||
private final Deque<byte[]> contentQueue = new ConcurrentLinkedDeque<>();
|
||||
private final AtomicLong pendingRequests = new AtomicLong(0);
|
||||
private volatile Exception pendingError = null;
|
||||
private volatile boolean completed = false;
|
||||
private volatile Flow.Subscriber<? super byte[]> downstream;
|
||||
|
||||
private DataPublisher(ThreadPool threadPool) {
|
||||
this.taskRunner = new RequestBasedTaskRunner(this::sendToSubscriber, threadPool, UTILITY_THREAD_POOL_NAME);
|
||||
}
|
||||
|
||||
private void sendToSubscriber() {
|
||||
if (downstream == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (n > 0) {
|
||||
pendingRequest.set(true);
|
||||
taskRunner.requestNextRun();
|
||||
} else {
|
||||
// per Subscription's spec, fail the subscriber and stop the processor
|
||||
cancel();
|
||||
subscriber.onError(new IllegalArgumentException("Subscriber requested a non-positive number " + n));
|
||||
if (pendingRequests.get() > 0 && pendingError != null) {
|
||||
pendingRequests.decrementAndGet();
|
||||
downstream.onError(pendingError);
|
||||
return;
|
||||
}
|
||||
byte[] nextBytes;
|
||||
while (pendingRequests.get() > 0 && (nextBytes = contentQueue.poll()) != null) {
|
||||
pendingRequests.decrementAndGet();
|
||||
backpressure.subtractBytesAndMaybeUnpause(nextBytes.length);
|
||||
downstream.onNext(nextBytes);
|
||||
}
|
||||
if (pendingRequests.get() > 0 && contentQueue.isEmpty() && completed) {
|
||||
pendingRequests.decrementAndGet();
|
||||
downstream.onComplete();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void cancel() {
|
||||
if (subscriptionCanceled.compareAndSet(false, true)) {
|
||||
taskRunner.cancel();
|
||||
public void subscribe(Flow.Subscriber<? super byte[]> subscriber) {
|
||||
if (this.downstream != null) {
|
||||
subscriber.onError(new IllegalStateException("Only one subscriber is allowed for this Publisher."));
|
||||
return;
|
||||
}
|
||||
|
||||
this.downstream = subscriber;
|
||||
downstream.onSubscribe(new Flow.Subscription() {
|
||||
@Override
|
||||
public void request(long n) {
|
||||
if (n > 0) {
|
||||
pendingRequests.addAndGet(n);
|
||||
taskRunner.requestNextRun();
|
||||
} else {
|
||||
// per Subscription's spec, fail the subscriber and stop the processor
|
||||
cancel();
|
||||
subscriber.onError(new IllegalArgumentException("Subscriber requested a non-positive number " + n));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void cancel() {
|
||||
if (subscriptionCanceled.compareAndSet(false, true)) {
|
||||
taskRunner.cancel();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onNext(byte[] item) {
|
||||
contentQueue.offer(item);
|
||||
taskRunner.requestNextRun();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable throwable) {
|
||||
if (throwable instanceof Exception e) {
|
||||
pendingError = e;
|
||||
} else {
|
||||
ExceptionsHelper.maybeError(throwable).ifPresent(ExceptionsHelper::maybeDieOnAnotherThread);
|
||||
pendingError = new RuntimeException("Unhandled error while streaming");
|
||||
}
|
||||
taskRunner.requestNextRun();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete() {
|
||||
completed = true;
|
||||
taskRunner.requestNextRun();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onSubscribe(Flow.Subscription subscription) {
|
||||
assert false : "Apache never calls this.";
|
||||
throw new UnsupportedOperationException("Apache never calls this.");
|
||||
}
|
||||
}
|
||||
|
||||
private class OffloadThread implements Runnable {
|
||||
@Override
|
||||
public void run() {
|
||||
if (subscriptionCanceled.get()) {
|
||||
return;
|
||||
}
|
||||
/**
|
||||
* We want to keep track of how much memory we are consuming while reading the payload from the external provider. Apache continuously
|
||||
* pushes payload data to us, whereas the client only requests the next set of bytes when they are ready, so we want to track how much
|
||||
* data we are holding in memory and potentially pause the Apache client if we have reached our limit.
|
||||
*/
|
||||
private static class ApacheClientBackpressure {
|
||||
private final HttpSettings settings;
|
||||
private final AtomicLong bytesInQueue = new AtomicLong(0);
|
||||
private final Object ioLock = new Object();
|
||||
private volatile IOControl savedIoControl;
|
||||
|
||||
if (queue.isEmpty() == false && pendingRequest.compareAndSet(true, false)) {
|
||||
var next = queue.poll();
|
||||
if (next != null) {
|
||||
next.run();
|
||||
} else {
|
||||
pendingRequest.set(true);
|
||||
private ApacheClientBackpressure(HttpSettings settings) {
|
||||
this.settings = settings;
|
||||
}
|
||||
|
||||
private void addBytesAndMaybePause(long count, IOControl ioControl) {
|
||||
if (bytesInQueue.addAndGet(count) >= settings.getMaxResponseSize().getBytes()) {
|
||||
pauseProducer(ioControl);
|
||||
}
|
||||
}
|
||||
|
||||
private void pauseProducer(IOControl ioControl) {
|
||||
ioControl.suspendInput();
|
||||
synchronized (ioLock) {
|
||||
savedIoControl = ioControl;
|
||||
}
|
||||
}
|
||||
|
||||
private void subtractBytesAndMaybeUnpause(long count) {
|
||||
var currentBytesInQueue = bytesInQueue.updateAndGet(current -> Long.max(0, current - count));
|
||||
if (savedIoControl != null) {
|
||||
var maxBytes = settings.getMaxResponseSize().getBytes() * 0.5;
|
||||
if (currentBytesInQueue <= maxBytes) {
|
||||
resumeProducer();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void resumeProducer() {
|
||||
synchronized (ioLock) {
|
||||
if (savedIoControl != null) {
|
||||
savedIoControl.requestInput();
|
||||
savedIoControl = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -116,9 +116,20 @@ public class RetryingHttpSender implements RequestSender {
|
|||
try {
|
||||
if (request.isStreaming() && responseHandler.canHandleStreamingResponses()) {
|
||||
httpClient.stream(request.createHttpRequest(), context, retryableListener.delegateFailure((l, r) -> {
|
||||
var streamingResponseHandler = new StreamingResponseHandler(throttlerManager, logger, request, responseHandler);
|
||||
r.subscribe(streamingResponseHandler);
|
||||
l.onResponse(responseHandler.parseResult(request, streamingResponseHandler));
|
||||
if (r.isSuccessfulResponse()) {
|
||||
l.onResponse(responseHandler.parseResult(request, r.toHttpResult()));
|
||||
} else {
|
||||
r.readFullResponse(l.delegateFailureAndWrap((ll, httpResult) -> {
|
||||
try {
|
||||
responseHandler.validateResponse(throttlerManager, logger, request, httpResult);
|
||||
InferenceServiceResults inferenceResults = responseHandler.parseResult(request, httpResult);
|
||||
ll.onResponse(inferenceResults);
|
||||
} catch (Exception e) {
|
||||
logException(logger, request, httpResult, responseHandler.getRequestType(), e);
|
||||
listener.onFailure(e); // skip retrying
|
||||
}
|
||||
}));
|
||||
}
|
||||
}));
|
||||
} else {
|
||||
httpClient.send(request.createHttpRequest(), context, retryableListener.delegateFailure((l, r) -> {
|
||||
|
|
|
@ -1,140 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.http.retry;
|
||||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
||||
|
||||
import java.util.concurrent.Flow;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
import static org.elasticsearch.core.Strings.format;
|
||||
|
||||
class StreamingResponseHandler implements Flow.Processor<HttpResult, HttpResult> {
|
||||
private static final Logger log = LogManager.getLogger(StreamingResponseHandler.class);
|
||||
private final ThrottlerManager throttlerManager;
|
||||
private final Logger throttlerLogger;
|
||||
private final Request request;
|
||||
private final ResponseHandler responseHandler;
|
||||
|
||||
private final AtomicBoolean upstreamIsClosed = new AtomicBoolean(false);
|
||||
private final AtomicBoolean processedFirstItem = new AtomicBoolean(false);
|
||||
|
||||
private volatile Flow.Subscription upstream;
|
||||
private volatile Flow.Subscriber<? super HttpResult> downstream;
|
||||
|
||||
StreamingResponseHandler(ThrottlerManager throttlerManager, Logger throttlerLogger, Request request, ResponseHandler responseHandler) {
|
||||
this.throttlerManager = throttlerManager;
|
||||
this.throttlerLogger = throttlerLogger;
|
||||
this.request = request;
|
||||
this.responseHandler = responseHandler;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void subscribe(Flow.Subscriber<? super HttpResult> subscriber) {
|
||||
if (downstream != null) {
|
||||
subscriber.onError(
|
||||
new IllegalStateException("Failed to initialize streaming response. Another subscriber is already subscribed.")
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
downstream = subscriber;
|
||||
subscriber.onSubscribe(forwardingSubscription());
|
||||
}
|
||||
|
||||
private Flow.Subscription forwardingSubscription() {
|
||||
return new Flow.Subscription() {
|
||||
@Override
|
||||
public void request(long n) {
|
||||
if (upstreamIsClosed.get()) {
|
||||
downstream.onComplete(); // shouldn't happen, but reinforce that we're no longer listening
|
||||
} else if (upstream != null) {
|
||||
upstream.request(n);
|
||||
} else {
|
||||
// this shouldn't happen, the expected call pattern is onNext -> subscribe after the listener is invoked
|
||||
var errorMessage = "Failed to initialize streaming response. onSubscribe must be called first to set the upstream";
|
||||
assert false : errorMessage;
|
||||
downstream.onError(new IllegalStateException(errorMessage));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void cancel() {
|
||||
if (upstreamIsClosed.compareAndSet(false, true) && upstream != null) {
|
||||
upstream.cancel();
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onSubscribe(Flow.Subscription subscription) {
|
||||
upstream = subscription;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onNext(HttpResult item) {
|
||||
if (processedFirstItem.compareAndSet(false, true)) {
|
||||
try {
|
||||
responseHandler.validateResponse(throttlerManager, throttlerLogger, request, item);
|
||||
} catch (Exception e) {
|
||||
logException(throttlerLogger, request, item, responseHandler.getRequestType(), e);
|
||||
upstream.cancel();
|
||||
onError(e);
|
||||
return;
|
||||
}
|
||||
}
|
||||
downstream.onNext(item);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable throwable) {
|
||||
if (upstreamIsClosed.compareAndSet(false, true)) {
|
||||
if (downstream != null) {
|
||||
downstream.onError(throwable);
|
||||
} else {
|
||||
log.warn(
|
||||
"Flow failed before the InferenceServiceResults were generated. The error should go to the listener directly.",
|
||||
throwable
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete() {
|
||||
if (upstreamIsClosed.compareAndSet(false, true)) {
|
||||
if (downstream != null) {
|
||||
downstream.onComplete();
|
||||
} else {
|
||||
log.debug("Flow completed before the InferenceServiceResults were generated. Shutting down this Processor.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void logException(Logger logger, Request request, HttpResult result, String requestType, Exception exception) {
|
||||
var causeException = ExceptionsHelper.unwrapCause(exception);
|
||||
|
||||
throttlerManager.warn(
|
||||
logger,
|
||||
format(
|
||||
"Failed to process the stream connection for request from inference entity id [%s] of type [%s] with status [%s] [%s]",
|
||||
request.getInferenceEntityId(),
|
||||
requestType,
|
||||
result.response().getStatusLine().getStatusCode(),
|
||||
result.response().getStatusLine().getReasonPhrase()
|
||||
),
|
||||
causeException
|
||||
);
|
||||
}
|
||||
}
|
|
@ -41,7 +41,6 @@ import java.net.URI;
|
|||
import java.net.URISyntaxException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.concurrent.CancellationException;
|
||||
import java.util.concurrent.Flow;
|
||||
import java.util.concurrent.Future;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
|
@ -174,7 +173,7 @@ public class HttpClientTests extends ESTestCase {
|
|||
try (var client = new HttpClient(emptyHttpSettings(), asyncClient, threadPool, mockThrottlerManager())) {
|
||||
client.start();
|
||||
|
||||
PlainActionFuture<Flow.Publisher<HttpResult>> listener = new PlainActionFuture<>();
|
||||
PlainActionFuture<StreamingHttpResult> listener = new PlainActionFuture<>();
|
||||
client.stream(httpPost, HttpClientContext.create(), listener);
|
||||
|
||||
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
|
||||
|
@ -196,7 +195,7 @@ public class HttpClientTests extends ESTestCase {
|
|||
try (var client = new HttpClient(emptyHttpSettings(), asyncClient, threadPool, mockThrottlerManager())) {
|
||||
client.start();
|
||||
|
||||
PlainActionFuture<Flow.Publisher<HttpResult>> listener = new PlainActionFuture<>();
|
||||
PlainActionFuture<StreamingHttpResult> listener = new PlainActionFuture<>();
|
||||
client.stream(httpPost, HttpClientContext.create(), listener);
|
||||
|
||||
var thrownException = expectThrows(CancellationException.class, () -> listener.actionGet(TIMEOUT));
|
||||
|
|
|
@ -14,8 +14,10 @@ import org.elasticsearch.action.ActionListener;
|
|||
import org.elasticsearch.action.support.ActionTestUtils;
|
||||
import org.elasticsearch.common.unit.ByteSizeValue;
|
||||
import org.elasticsearch.common.util.concurrent.EsExecutors;
|
||||
import org.elasticsearch.core.Tuple;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
|
||||
|
@ -53,7 +55,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
|
|||
private static final long maxBytes = message.length;
|
||||
private ThreadPool threadPool;
|
||||
private HttpSettings settings;
|
||||
private ActionListener<Flow.Publisher<HttpResult>> listener;
|
||||
private final AtomicReference<Tuple<StreamingHttpResult, Exception>> result = new AtomicReference<>(null);
|
||||
private StreamingHttpResultPublisher publisher;
|
||||
|
||||
@Before
|
||||
|
@ -61,12 +63,21 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
|
|||
super.setUp();
|
||||
threadPool = mock(ThreadPool.class);
|
||||
settings = mock(HttpSettings.class);
|
||||
listener = spy(ActionListener.noop());
|
||||
|
||||
when(threadPool.executor(UTILITY_THREAD_POOL_NAME)).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE);
|
||||
when(settings.getMaxResponseSize()).thenReturn(ByteSizeValue.ofBytes(maxBytes));
|
||||
|
||||
publisher = new StreamingHttpResultPublisher(threadPool, settings, listener);
|
||||
publisher = new StreamingHttpResultPublisher(threadPool, settings, listener());
|
||||
}
|
||||
|
||||
private ActionListener<StreamingHttpResult> listener() {
|
||||
return ActionListener.wrap(r -> result.set(Tuple.tuple(r, null)), e -> result.set(Tuple.tuple(null, e)));
|
||||
}
|
||||
|
||||
@After
|
||||
public void tearDown() throws Exception {
|
||||
super.tearDown();
|
||||
result.set(null);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -76,7 +87,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
|
|||
*/
|
||||
public void testFirstResponseCallsListener() throws IOException {
|
||||
var latch = new CountDownLatch(1);
|
||||
var listener = ActionTestUtils.<Flow.Publisher<HttpResult>>assertNoFailureListener(r -> latch.countDown());
|
||||
var listener = ActionTestUtils.<StreamingHttpResult>assertNoFailureListener(r -> latch.countDown());
|
||||
publisher = new StreamingHttpResultPublisher(threadPool, settings, listener);
|
||||
|
||||
publisher.responseReceived(mock(HttpResponse.class));
|
||||
|
@ -92,7 +103,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
|
|||
*/
|
||||
public void testNonEmptyFirstResponseCallsListener() throws IOException {
|
||||
var latch = new CountDownLatch(1);
|
||||
var listener = ActionTestUtils.<Flow.Publisher<HttpResult>>assertNoFailureListener(r -> latch.countDown());
|
||||
var listener = ActionTestUtils.<StreamingHttpResult>assertNoFailureListener(r -> latch.countDown());
|
||||
publisher = new StreamingHttpResultPublisher(threadPool, settings, listener);
|
||||
|
||||
when(settings.getMaxResponseSize()).thenReturn(ByteSizeValue.ofBytes(9000));
|
||||
|
@ -127,7 +138,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
|
|||
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
|
||||
|
||||
// subscribe
|
||||
publisher.subscribe(subscriber);
|
||||
testPublisher().subscribe(subscriber);
|
||||
assertThat("subscribe must call onSubscribe", subscriber.subscription, notNullValue());
|
||||
assertThat("onNext should only be called once we have requested data", subscriber.httpResult, nullValue());
|
||||
|
||||
|
@ -142,7 +153,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
|
|||
|
||||
// publisher sends data
|
||||
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
|
||||
assertThat("onNext was called with " + new String(message, StandardCharsets.UTF_8), subscriber.httpResult.body(), equalTo(message));
|
||||
assertThat("onNext was called with " + new String(message, StandardCharsets.UTF_8), subscriber.httpResult, equalTo(message));
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -157,7 +168,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
|
|||
publisher.close();
|
||||
|
||||
// subscriber requests data
|
||||
publisher.subscribe(subscriber);
|
||||
testPublisher().subscribe(subscriber);
|
||||
assertThat("subscribe must call onSubscribe", subscriber.subscription, notNullValue());
|
||||
subscriber.requestData();
|
||||
assertThat("onNext was called with the initial HttpResponse", subscriber.httpResult, notNullValue());
|
||||
|
@ -187,7 +198,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
|
|||
public void testResumeApache() throws IOException {
|
||||
var subscriber = new TestSubscriber();
|
||||
publisher.responseReceived(mock(HttpResponse.class));
|
||||
publisher.subscribe(subscriber);
|
||||
testPublisher().subscribe(subscriber);
|
||||
subscriber.requestData();
|
||||
subscriber.httpResult = null;
|
||||
|
||||
|
@ -212,7 +223,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
|
|||
var subscriber = new TestSubscriber();
|
||||
publisher.responseReceived(mock(HttpResponse.class));
|
||||
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
|
||||
publisher.subscribe(subscriber);
|
||||
testPublisher().subscribe(subscriber);
|
||||
subscriber.requestData();
|
||||
subscriber.httpResult = null;
|
||||
|
||||
|
@ -243,7 +254,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
|
|||
public void testErrorBeforeRequest() {
|
||||
var exception = new NullPointerException("test");
|
||||
publisher.failed(exception);
|
||||
verify(listener).onFailure(exception);
|
||||
assertThat(result.get().v2(), equalTo(exception));
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -361,21 +372,17 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
|
|||
* When cancel is called
|
||||
* Then we only send onComplete once
|
||||
*/
|
||||
public void testCancelIsIdempotent() throws IOException {
|
||||
Flow.Subscriber<HttpResult> subscriber = mock();
|
||||
|
||||
var subscription = ArgumentCaptor.forClass(Flow.Subscription.class);
|
||||
publisher.subscribe(subscriber);
|
||||
verify(subscriber).onSubscribe(subscription.capture());
|
||||
public void testCancelIsIdempotent() {
|
||||
Flow.Subscriber<byte[]> subscriber = mock();
|
||||
|
||||
publisher.responseReceived(mock());
|
||||
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
|
||||
subscription.getValue().request(1);
|
||||
|
||||
subscription.getValue().request(1);
|
||||
var subscription = ArgumentCaptor.forClass(Flow.Subscription.class);
|
||||
testPublisher().subscribe(subscriber);
|
||||
verify(subscriber).onSubscribe(subscription.capture());
|
||||
|
||||
subscription.getValue().request(2);
|
||||
publisher.cancel();
|
||||
verify(subscriber, times(1)).onComplete();
|
||||
subscription.getValue().request(1);
|
||||
publisher.cancel();
|
||||
verify(subscriber, times(1)).onComplete();
|
||||
}
|
||||
|
@ -384,21 +391,17 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
|
|||
* When close is called
|
||||
* Then we only send onComplete once
|
||||
*/
|
||||
public void testCloseIsIdempotent() throws IOException {
|
||||
Flow.Subscriber<HttpResult> subscriber = mock();
|
||||
|
||||
var subscription = ArgumentCaptor.forClass(Flow.Subscription.class);
|
||||
publisher.subscribe(subscriber);
|
||||
verify(subscriber).onSubscribe(subscription.capture());
|
||||
public void testCloseIsIdempotent() {
|
||||
Flow.Subscriber<byte[]> subscriber = mock();
|
||||
|
||||
publisher.responseReceived(mock());
|
||||
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
|
||||
subscription.getValue().request(1);
|
||||
|
||||
subscription.getValue().request(1);
|
||||
var subscription = ArgumentCaptor.forClass(Flow.Subscription.class);
|
||||
testPublisher().subscribe(subscriber);
|
||||
verify(subscriber).onSubscribe(subscription.capture());
|
||||
|
||||
subscription.getValue().request(2);
|
||||
publisher.close();
|
||||
verify(subscriber, times(1)).onComplete();
|
||||
subscription.getValue().request(1);
|
||||
publisher.close();
|
||||
verify(subscriber, times(1)).onComplete();
|
||||
}
|
||||
|
@ -409,20 +412,16 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
|
|||
*/
|
||||
public void testFailedIsIdempotent() throws IOException {
|
||||
var expectedException = new IllegalStateException("wow");
|
||||
Flow.Subscriber<HttpResult> subscriber = mock();
|
||||
|
||||
var subscription = ArgumentCaptor.forClass(Flow.Subscription.class);
|
||||
publisher.subscribe(subscriber);
|
||||
verify(subscriber).onSubscribe(subscription.capture());
|
||||
Flow.Subscriber<byte[]> subscriber = mock();
|
||||
|
||||
publisher.responseReceived(mock());
|
||||
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
|
||||
subscription.getValue().request(1);
|
||||
|
||||
subscription.getValue().request(1);
|
||||
var subscription = ArgumentCaptor.forClass(Flow.Subscription.class);
|
||||
testPublisher().subscribe(subscriber);
|
||||
verify(subscriber).onSubscribe(subscription.capture());
|
||||
|
||||
subscription.getValue().request(2);
|
||||
publisher.failed(expectedException);
|
||||
verify(subscriber, times(1)).onError(eq(expectedException));
|
||||
subscription.getValue().request(1);
|
||||
publisher.failed(expectedException);
|
||||
verify(subscriber, times(1)).onError(eq(expectedException));
|
||||
}
|
||||
|
@ -492,10 +491,11 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
|
|||
* Then that subscriber should receive an IllegalStateException
|
||||
*/
|
||||
public void testDoubleSubscribeFails() {
|
||||
publisher.subscribe(mock());
|
||||
publisher.responseReceived(mock());
|
||||
testPublisher().subscribe(mock());
|
||||
|
||||
var subscriber = new TestSubscriber();
|
||||
publisher.subscribe(subscriber);
|
||||
testPublisher().subscribe(subscriber);
|
||||
assertThat(subscriber.throwable, notNullValue());
|
||||
assertThat(subscriber.throwable, instanceOf(IllegalStateException.class));
|
||||
}
|
||||
|
@ -508,10 +508,10 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
|
|||
public void testReuseMlThread() throws ExecutionException, InterruptedException, TimeoutException {
|
||||
try {
|
||||
threadPool = spy(createThreadPool(inferenceUtilityPool()));
|
||||
publisher = new StreamingHttpResultPublisher(threadPool, settings, listener);
|
||||
publisher = new StreamingHttpResultPublisher(threadPool, settings, listener());
|
||||
var subscriber = new TestSubscriber();
|
||||
publisher.responseReceived(mock(HttpResponse.class));
|
||||
publisher.subscribe(subscriber);
|
||||
testPublisher().subscribe(subscriber);
|
||||
|
||||
CompletableFuture.runAsync(() -> {
|
||||
try {
|
||||
|
@ -523,7 +523,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
|
|||
}, threadPool.executor(UTILITY_THREAD_POOL_NAME)).get(5, TimeUnit.SECONDS);
|
||||
verify(threadPool, times(1)).executor(UTILITY_THREAD_POOL_NAME);
|
||||
assertThat("onNext was called with the initial HttpResponse", subscriber.httpResult, notNullValue());
|
||||
assertFalse("Expected HttpResult to have data", subscriber.httpResult.isBodyEmpty());
|
||||
assertNotEquals("Expected HttpResult to have data", 0, subscriber.httpResult.length);
|
||||
} finally {
|
||||
terminate(threadPool);
|
||||
}
|
||||
|
@ -548,11 +548,11 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
|
|||
return executorServiceSpy;
|
||||
}).when(threadPool).executor(UTILITY_THREAD_POOL_NAME);
|
||||
|
||||
publisher = new StreamingHttpResultPublisher(threadPool, settings, listener);
|
||||
publisher = new StreamingHttpResultPublisher(threadPool, settings, listener());
|
||||
publisher.responseReceived(mock(HttpResponse.class));
|
||||
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
|
||||
// create an infinitely running Subscriber
|
||||
var subscriber = new Flow.Subscriber<HttpResult>() {
|
||||
var subscriber = new Flow.Subscriber<byte[]>() {
|
||||
Flow.Subscription subscription;
|
||||
boolean completed = false;
|
||||
|
||||
|
@ -563,7 +563,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void onNext(HttpResult item) {
|
||||
public void onNext(byte[] item) {
|
||||
try {
|
||||
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
|
||||
} catch (IOException e) {
|
||||
|
@ -582,7 +582,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
|
|||
completed = true;
|
||||
}
|
||||
};
|
||||
publisher.subscribe(subscriber);
|
||||
testPublisher().subscribe(subscriber);
|
||||
|
||||
// verify the thread has started
|
||||
assertThat("Thread should have started on subscribe", futureHolder.get(), notNullValue());
|
||||
|
@ -604,7 +604,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
|
|||
// start with a message published
|
||||
publisher.responseReceived(mock(HttpResponse.class));
|
||||
TestSubscriber subscriber = new TestSubscriber() {
|
||||
public void onNext(HttpResult item) {
|
||||
public void onNext(byte[] item) {
|
||||
try {
|
||||
// publish a second message
|
||||
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
|
||||
|
@ -617,7 +617,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
|
|||
super.onNext(item);
|
||||
}
|
||||
};
|
||||
publisher.subscribe(subscriber);
|
||||
testPublisher().subscribe(subscriber);
|
||||
|
||||
verify(threadPool, times(0)).executor(UTILITY_THREAD_POOL_NAME);
|
||||
subscriber.requestData();
|
||||
|
@ -646,8 +646,9 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
|
|||
}
|
||||
|
||||
private TestSubscriber subscribe() {
|
||||
publisher.responseReceived(mock());
|
||||
var subscriber = new TestSubscriber();
|
||||
publisher.subscribe(subscriber);
|
||||
testPublisher().subscribe(subscriber);
|
||||
return subscriber;
|
||||
}
|
||||
|
||||
|
@ -655,12 +656,12 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
|
|||
publisher.responseReceived(mock(HttpResponse.class));
|
||||
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
|
||||
TestSubscriber subscriber = new TestSubscriber() {
|
||||
public void onNext(HttpResult item) {
|
||||
public void onNext(byte[] item) {
|
||||
runDuringOnNext.run();
|
||||
super.onNext(item);
|
||||
}
|
||||
};
|
||||
publisher.subscribe(subscriber);
|
||||
testPublisher().subscribe(subscriber);
|
||||
return subscriber;
|
||||
}
|
||||
|
||||
|
@ -668,19 +669,23 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
|
|||
publisher.responseReceived(mock(HttpResponse.class));
|
||||
publisher.consumeContent(contentDecoder(message), mock(IOControl.class));
|
||||
TestSubscriber subscriber = new TestSubscriber() {
|
||||
public void onNext(HttpResult item) {
|
||||
public void onNext(byte[] item) {
|
||||
runDuringOnNext.run();
|
||||
super.requestData();
|
||||
super.onNext(item);
|
||||
}
|
||||
};
|
||||
publisher.subscribe(subscriber);
|
||||
testPublisher().subscribe(subscriber);
|
||||
return subscriber;
|
||||
}
|
||||
|
||||
private static class TestSubscriber implements Flow.Subscriber<HttpResult> {
|
||||
private Flow.Publisher<byte[]> testPublisher() {
|
||||
return result.get().v1().body();
|
||||
}
|
||||
|
||||
private static class TestSubscriber implements Flow.Subscriber<byte[]> {
|
||||
private Flow.Subscription subscription;
|
||||
private HttpResult httpResult;
|
||||
private byte[] httpResult;
|
||||
private Throwable throwable;
|
||||
private boolean completed;
|
||||
|
||||
|
@ -690,7 +695,7 @@ public class StreamingHttpResultPublisherTests extends ESTestCase {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void onNext(HttpResult item) {
|
||||
public void onNext(byte[] item) {
|
||||
this.httpResult = item;
|
||||
}
|
||||
|
||||
|
|
|
@ -24,15 +24,18 @@ import org.elasticsearch.test.ESTestCase;
|
|||
import org.elasticsearch.threadpool.TestThreadPool;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpClient;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||
import org.elasticsearch.xpack.inference.external.http.StreamingHttpResult;
|
||||
import org.elasticsearch.xpack.inference.external.request.HttpRequestTests;
|
||||
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
||||
import org.junit.Before;
|
||||
import org.mockito.ArgumentMatchers;
|
||||
import org.mockito.stubbing.Answer;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.UnknownHostException;
|
||||
import java.util.concurrent.Flow;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.createDefaultRetrySettings;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
|
@ -42,7 +45,6 @@ import static org.mockito.ArgumentMatchers.any;
|
|||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.only;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||
|
@ -459,12 +461,12 @@ public class RetryingHttpSenderTests extends ESTestCase {
|
|||
verifyNoMoreInteractions(httpClient);
|
||||
}
|
||||
|
||||
public void testStream() throws IOException {
|
||||
public void testStreamSuccess() throws IOException {
|
||||
var httpClient = mock(HttpClient.class);
|
||||
Flow.Publisher<HttpResult> publisher = mock();
|
||||
StreamingHttpResult streamingHttpResult = new StreamingHttpResult(mockHttpResponse(), randomPublisher());
|
||||
doAnswer(ans -> {
|
||||
ActionListener<Flow.Publisher<HttpResult>> listener = ans.getArgument(2);
|
||||
listener.onResponse(publisher);
|
||||
ActionListener<StreamingHttpResult> listener = ans.getArgument(2);
|
||||
listener.onResponse(streamingHttpResult);
|
||||
return null;
|
||||
}).when(httpClient).stream(any(), any(), any());
|
||||
|
||||
|
@ -479,7 +481,28 @@ public class RetryingHttpSenderTests extends ESTestCase {
|
|||
|
||||
verify(httpClient, times(1)).stream(any(), any(), any());
|
||||
verifyNoMoreInteractions(httpClient);
|
||||
verify(publisher, only()).subscribe(any(StreamingResponseHandler.class));
|
||||
verify(responseHandler, times(1)).parseResult(any(), ArgumentMatchers.<Flow.Publisher<HttpResult>>any());
|
||||
}
|
||||
|
||||
private Flow.Publisher<byte[]> randomPublisher() {
|
||||
var calls = new AtomicInteger(randomIntBetween(1, 4));
|
||||
return subscriber -> {
|
||||
subscriber.onSubscribe(new Flow.Subscription() {
|
||||
@Override
|
||||
public void request(long n) {
|
||||
if (calls.getAndDecrement() > 0) {
|
||||
subscriber.onNext(randomByteArrayOfLength(3));
|
||||
} else {
|
||||
subscriber.onComplete();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void cancel() {
|
||||
|
||||
}
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
public void testStream_ResponseHandlerDoesNotHandleStreams() throws IOException {
|
||||
|
@ -549,6 +572,44 @@ public class RetryingHttpSenderTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testStream_DoesNotRetryIndefinitely() throws IOException {
|
||||
var threadPool = new TestThreadPool(getTestName());
|
||||
try {
|
||||
var httpClient = mock(HttpClient.class);
|
||||
doAnswer(ans -> {
|
||||
ActionListener<StreamingHttpResult> listener = ans.getArgument(2);
|
||||
listener.onFailure(new ConnectionClosedException("failed"));
|
||||
return null;
|
||||
}).when(httpClient).stream(any(), any(), any());
|
||||
|
||||
var handler = mock(ResponseHandler.class);
|
||||
when(handler.canHandleStreamingResponses()).thenReturn(true);
|
||||
|
||||
var retrier = new RetryingHttpSender(
|
||||
httpClient,
|
||||
mock(ThrottlerManager.class),
|
||||
createDefaultRetrySettings(),
|
||||
threadPool,
|
||||
EsExecutors.DIRECT_EXECUTOR_SERVICE
|
||||
);
|
||||
|
||||
var listener = new PlainActionFuture<InferenceServiceResults>();
|
||||
var request = mockRequest();
|
||||
when(request.isStreaming()).thenReturn(true);
|
||||
retrier.send(mock(Logger.class), request, () -> false, handler, listener);
|
||||
|
||||
// Assert that the retrying sender stopped after max retires even though the exception is retryable
|
||||
var thrownException = expectThrows(UncategorizedExecutionException.class, () -> listener.actionGet(TIMEOUT));
|
||||
assertThat(thrownException.getCause(), instanceOf(ConnectionClosedException.class));
|
||||
assertThat(thrownException.getMessage(), is("Failed execution"));
|
||||
assertThat(thrownException.getSuppressed().length, is(0));
|
||||
verify(httpClient, times(RetryingHttpSender.MAX_RETIES)).stream(any(), any(), any());
|
||||
verifyNoMoreInteractions(httpClient);
|
||||
} finally {
|
||||
terminate(threadPool);
|
||||
}
|
||||
}
|
||||
|
||||
public void testSend_DoesNotRetryIndefinitely_WithAlwaysRetryingResponseHandler() throws IOException {
|
||||
var threadPool = new TestThreadPool(getTestName());
|
||||
try {
|
||||
|
|
|
@ -1,127 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.external.http.retry;
|
||||
|
||||
import org.apache.http.HttpResponse;
|
||||
import org.apache.http.StatusLine;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.InjectMocks;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.MockitoAnnotations;
|
||||
|
||||
import java.util.concurrent.Flow;
|
||||
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.same;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class StreamingResponseHandlerTests extends ESTestCase {
|
||||
@Mock
|
||||
private HttpResponse response;
|
||||
@Mock
|
||||
private ThrottlerManager throttlerManager;
|
||||
@Mock
|
||||
private Logger logger;
|
||||
@Mock
|
||||
private Request request;
|
||||
@Mock
|
||||
private ResponseHandler responseHandler;
|
||||
@Mock
|
||||
private Flow.Subscriber<HttpResult> downstreamSubscriber;
|
||||
@InjectMocks
|
||||
private StreamingResponseHandler streamingResponseHandler;
|
||||
private AutoCloseable mocks;
|
||||
private HttpResult item;
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
super.setUp();
|
||||
mocks = MockitoAnnotations.openMocks(this);
|
||||
item = new HttpResult(response, new byte[0]);
|
||||
}
|
||||
|
||||
@After
|
||||
public void tearDown() throws Exception {
|
||||
super.tearDown();
|
||||
mocks.close();
|
||||
}
|
||||
|
||||
public void testResponseHandlerFailureIsForwardedToSubscriber() {
|
||||
var upstreamSubscription = upstreamSubscription();
|
||||
var expectedException = new RetryException(true, "ah");
|
||||
doThrow(expectedException).when(responseHandler).validateResponse(any(), any(), any(), any());
|
||||
|
||||
var statusLine = mock(StatusLine.class);
|
||||
when(statusLine.getStatusCode()).thenReturn(404);
|
||||
when(statusLine.getReasonPhrase()).thenReturn("not found");
|
||||
when(response.getStatusLine()).thenReturn(statusLine);
|
||||
|
||||
streamingResponseHandler.onNext(item);
|
||||
|
||||
verify(upstreamSubscription, times(1)).cancel();
|
||||
verify(downstreamSubscriber, times(1)).onError(expectedException);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private Flow.Subscription upstreamSubscription() {
|
||||
var upstreamSubscription = mock(Flow.Subscription.class);
|
||||
streamingResponseHandler.onSubscribe(upstreamSubscription);
|
||||
streamingResponseHandler.subscribe(downstreamSubscriber);
|
||||
return upstreamSubscription;
|
||||
}
|
||||
|
||||
public void testOnNextCallsDownstream() {
|
||||
upstreamSubscription();
|
||||
|
||||
streamingResponseHandler.onNext(item);
|
||||
|
||||
verify(downstreamSubscriber, times(1)).onNext(item);
|
||||
}
|
||||
|
||||
public void testCompleteForwardsComplete() {
|
||||
upstreamSubscription();
|
||||
|
||||
streamingResponseHandler.onComplete();
|
||||
|
||||
verify(downstreamSubscriber, times(1)).onSubscribe(any());
|
||||
verify(downstreamSubscriber, times(1)).onComplete();
|
||||
}
|
||||
|
||||
public void testErrorForwardsError() {
|
||||
var expectedError = new RetryException(false, "ah");
|
||||
upstreamSubscription();
|
||||
|
||||
streamingResponseHandler.onError(expectedError);
|
||||
|
||||
verify(downstreamSubscriber, times(1)).onSubscribe(any());
|
||||
verify(downstreamSubscriber, times(1)).onError(same(expectedError));
|
||||
}
|
||||
|
||||
public void testSubscriptionForwardsRequest() {
|
||||
var upstreamSubscription = upstreamSubscription();
|
||||
|
||||
var downstream = ArgumentCaptor.forClass(Flow.Subscription.class);
|
||||
verify(downstreamSubscriber, times(1)).onSubscribe(downstream.capture());
|
||||
var downstreamSubscription = downstream.getValue();
|
||||
|
||||
var requestCount = randomIntBetween(2, 200);
|
||||
downstreamSubscription.request(requestCount);
|
||||
verify(upstreamSubscription, times(1)).request(requestCount);
|
||||
}
|
||||
}
|
|
@ -554,13 +554,11 @@ public class AnthropicServiceTests extends ESTestCase {
|
|||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
|
||||
var result = streamChatCompletion();
|
||||
|
||||
InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent("""
|
||||
streamChatCompletion().hasNoErrors().hasEvent("""
|
||||
{"completion":[{"delta":"Hello"},{"delta":", World"}]}""");
|
||||
}
|
||||
|
||||
private InferenceServiceResults streamChatCompletion() throws IOException {
|
||||
private InferenceEventsAssertion streamChatCompletion() throws Exception {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
try (var service = new AnthropicService(senderFactory, createWithEmptySettings(threadPool))) {
|
||||
var model = AnthropicChatCompletionModelTests.createChatCompletionModel(
|
||||
|
@ -581,7 +579,7 @@ public class AnthropicServiceTests extends ESTestCase {
|
|||
listener
|
||||
);
|
||||
|
||||
return listener.actionGet(TIMEOUT);
|
||||
return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -592,11 +590,7 @@ public class AnthropicServiceTests extends ESTestCase {
|
|||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
|
||||
var result = streamChatCompletion();
|
||||
|
||||
InferenceEventsAssertion.assertThat(result)
|
||||
.hasFinishedStream()
|
||||
.hasNoEvents()
|
||||
streamChatCompletion().hasNoEvents()
|
||||
.hasErrorWithStatusCode(RestStatus.REQUEST_ENTITY_TOO_LARGE.getStatus())
|
||||
.hasErrorContaining("blah");
|
||||
}
|
||||
|
|
|
@ -61,7 +61,6 @@ import org.junit.After;
|
|||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.URISyntaxException;
|
||||
import java.util.EnumSet;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
|
@ -1345,13 +1344,11 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
|
||||
var result = streamChatCompletion();
|
||||
|
||||
InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent("""
|
||||
streamChatCompletion().hasNoErrors().hasEvent("""
|
||||
{"completion":[{"delta":"hello, world"}]}""");
|
||||
}
|
||||
|
||||
private InferenceServiceResults streamChatCompletion() throws IOException, URISyntaxException {
|
||||
private InferenceEventsAssertion streamChatCompletion() throws Exception {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
|
||||
var model = AzureAiStudioChatCompletionModelTests.createModel(
|
||||
|
@ -1373,7 +1370,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|||
listener
|
||||
);
|
||||
|
||||
return listener.actionGet(TIMEOUT);
|
||||
return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1389,13 +1386,14 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|||
}""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson));
|
||||
|
||||
var result = streamChatCompletion();
|
||||
|
||||
InferenceEventsAssertion.assertThat(result)
|
||||
.hasFinishedStream()
|
||||
.hasNoEvents()
|
||||
.hasErrorWithStatusCode(401)
|
||||
.hasErrorContaining("You didn't provide an API key...");
|
||||
var e = assertThrows(ElasticsearchStatusException.class, this::streamChatCompletion);
|
||||
assertThat(
|
||||
e.getMessage(),
|
||||
equalTo(
|
||||
"Received an authentication error status code for request from inference entity id [id] status [401]. "
|
||||
+ "Error message: [You didn't provide an API key...]"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@SuppressWarnings("checkstyle:LineLength")
|
||||
|
|
|
@ -1410,13 +1410,11 @@ public class AzureOpenAiServiceTests extends ESTestCase {
|
|||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
|
||||
var result = streamChatCompletion();
|
||||
|
||||
InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent("""
|
||||
streamChatCompletion().hasNoErrors().hasEvent("""
|
||||
{"completion":[{"delta":"hello, world"}]}""");
|
||||
}
|
||||
|
||||
private InferenceServiceResults streamChatCompletion() throws IOException, URISyntaxException {
|
||||
private InferenceEventsAssertion streamChatCompletion() throws Exception {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) {
|
||||
var model = AzureOpenAiCompletionModelTests.createCompletionModel(
|
||||
|
@ -1441,7 +1439,7 @@ public class AzureOpenAiServiceTests extends ESTestCase {
|
|||
listener
|
||||
);
|
||||
|
||||
return listener.actionGet(TIMEOUT);
|
||||
return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1457,13 +1455,14 @@ public class AzureOpenAiServiceTests extends ESTestCase {
|
|||
}""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson));
|
||||
|
||||
var result = streamChatCompletion();
|
||||
|
||||
InferenceEventsAssertion.assertThat(result)
|
||||
.hasFinishedStream()
|
||||
.hasNoEvents()
|
||||
.hasErrorWithStatusCode(401)
|
||||
.hasErrorContaining("You didn't provide an API key...");
|
||||
var e = assertThrows(ElasticsearchStatusException.class, this::streamChatCompletion);
|
||||
assertThat(
|
||||
e.getMessage(),
|
||||
equalTo(
|
||||
"Received an authentication error status code for request from inference entity id [id] status [401]."
|
||||
+ " Error message: [You didn't provide an API key...]"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@SuppressWarnings("checkstyle:LineLength")
|
||||
|
|
|
@ -1611,13 +1611,11 @@ public class CohereServiceTests extends ESTestCase {
|
|||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
|
||||
var result = streamChatCompletion();
|
||||
|
||||
InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent("""
|
||||
streamChatCompletion().hasNoErrors().hasEvent("""
|
||||
{"completion":[{"delta":"hello"},{"delta":"there"}]}""");
|
||||
}
|
||||
|
||||
private InferenceServiceResults streamChatCompletion() throws IOException {
|
||||
private InferenceEventsAssertion streamChatCompletion() throws Exception {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) {
|
||||
var model = CohereCompletionModelTests.createModel(getUrl(webServer), "secret", "model");
|
||||
|
@ -1633,7 +1631,7 @@ public class CohereServiceTests extends ESTestCase {
|
|||
listener
|
||||
);
|
||||
|
||||
return listener.actionGet(TIMEOUT);
|
||||
return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1643,13 +1641,7 @@ public class CohereServiceTests extends ESTestCase {
|
|||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
|
||||
var result = streamChatCompletion();
|
||||
|
||||
InferenceEventsAssertion.assertThat(result)
|
||||
.hasFinishedStream()
|
||||
.hasNoEvents()
|
||||
.hasErrorWithStatusCode(500)
|
||||
.hasErrorContaining("how dare you");
|
||||
streamChatCompletion().hasNoEvents().hasErrorWithStatusCode(500).hasErrorContaining("how dare you");
|
||||
}
|
||||
|
||||
@SuppressWarnings("checkstyle:LineLength")
|
||||
|
|
|
@ -1012,18 +1012,18 @@ public class ElasticInferenceServiceTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testUnifiedCompletionError() throws Exception {
|
||||
testUnifiedStreamError(404, """
|
||||
public void testUnifiedCompletionError() {
|
||||
var e = assertThrows(UnifiedChatCompletionException.class, () -> testUnifiedStream(404, """
|
||||
{
|
||||
"error": "The model `rainbow-sprinkles` does not exist or you do not have access to it."
|
||||
}""", """
|
||||
{\
|
||||
"error":{\
|
||||
"code":"not_found",\
|
||||
"message":"Received an unsuccessful status code for request from inference entity id [id] status \
|
||||
[404]. Error message: [The model `rainbow-sprinkles` does not exist or you do not have access to it.]",\
|
||||
"type":"error"\
|
||||
}}""");
|
||||
}"""));
|
||||
assertThat(
|
||||
e.getMessage(),
|
||||
equalTo(
|
||||
"Received an unsuccessful status code for request from inference entity id [id] status "
|
||||
+ "[404]. Error message: [The model `rainbow-sprinkles` does not exist or you do not have access to it.]"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
public void testUnifiedCompletionErrorMidStream() throws Exception {
|
||||
|
@ -1054,6 +1054,25 @@ public class ElasticInferenceServiceTests extends ESTestCase {
|
|||
}
|
||||
|
||||
private void testUnifiedStreamError(int responseCode, String responseJson, String expectedJson) throws Exception {
|
||||
testUnifiedStream(responseCode, responseJson).hasNoEvents().hasErrorMatching(e -> {
|
||||
e = unwrapCause(e);
|
||||
assertThat(e, isA(UnifiedChatCompletionException.class));
|
||||
try (var builder = XContentFactory.jsonBuilder()) {
|
||||
((UnifiedChatCompletionException) e).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
|
||||
try {
|
||||
xContent.toXContent(builder, EMPTY_PARAMS);
|
||||
} catch (IOException ex) {
|
||||
throw new RuntimeException(ex);
|
||||
}
|
||||
});
|
||||
var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
|
||||
|
||||
assertThat(json, is(expectedJson));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private InferenceEventsAssertion testUnifiedStream(int responseCode, String responseJson) throws Exception {
|
||||
var eisGatewayUrl = getUrl(webServer);
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
try (var service = createService(senderFactory, eisGatewayUrl)) {
|
||||
|
@ -1077,24 +1096,7 @@ public class ElasticInferenceServiceTests extends ESTestCase {
|
|||
listener
|
||||
);
|
||||
|
||||
var result = listener.actionGet(TIMEOUT);
|
||||
|
||||
InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoEvents().hasErrorMatching(e -> {
|
||||
e = unwrapCause(e);
|
||||
assertThat(e, isA(UnifiedChatCompletionException.class));
|
||||
try (var builder = XContentFactory.jsonBuilder()) {
|
||||
((UnifiedChatCompletionException) e).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
|
||||
try {
|
||||
xContent.toXContent(builder, EMPTY_PARAMS);
|
||||
} catch (IOException ex) {
|
||||
throw new RuntimeException(ex);
|
||||
}
|
||||
});
|
||||
var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
|
||||
|
||||
assertThat(json, is(expectedJson));
|
||||
}
|
||||
});
|
||||
return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ import org.apache.http.HttpHeaders;
|
|||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.support.ActionTestUtils;
|
||||
import org.elasticsearch.action.support.PlainActionFuture;
|
||||
import org.elasticsearch.common.bytes.BytesArray;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
|
@ -28,6 +29,7 @@ import org.elasticsearch.inference.Model;
|
|||
import org.elasticsearch.inference.SimilarityMeasure;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.test.http.MockResponse;
|
||||
import org.elasticsearch.test.http.MockWebServer;
|
||||
|
@ -63,6 +65,7 @@ import java.util.EnumSet;
|
|||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.ExceptionsHelper.unwrapCause;
|
||||
|
@ -1079,13 +1082,60 @@ public class OpenAiServiceTests extends ESTestCase {
|
|||
}
|
||||
}""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson));
|
||||
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) {
|
||||
var model = OpenAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "org", "secret", "model", "user");
|
||||
var latch = new CountDownLatch(1);
|
||||
service.unifiedCompletionInfer(
|
||||
model,
|
||||
UnifiedCompletionRequest.of(
|
||||
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null))
|
||||
),
|
||||
InferenceAction.Request.DEFAULT_TIMEOUT,
|
||||
ActionListener.runAfter(ActionTestUtils.assertNoSuccessListener(e -> {
|
||||
try (var builder = XContentFactory.jsonBuilder()) {
|
||||
var t = unwrapCause(e);
|
||||
assertThat(t, isA(UnifiedChatCompletionException.class));
|
||||
((UnifiedChatCompletionException) t).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> {
|
||||
try {
|
||||
xContent.toXContent(builder, EMPTY_PARAMS);
|
||||
} catch (IOException ex) {
|
||||
throw new RuntimeException(ex);
|
||||
}
|
||||
});
|
||||
var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType());
|
||||
|
||||
assertThat(json, is("""
|
||||
{\
|
||||
"error":{\
|
||||
"code":"model_not_found",\
|
||||
"message":"Received an unsuccessful status code for request from inference entity id [id] status \
|
||||
[404]. Error message: [The model `gpt-4awero` does not exist or you do not have access to it.]",\
|
||||
"type":"invalid_request_error"\
|
||||
}}"""));
|
||||
} catch (IOException ex) {
|
||||
throw new RuntimeException(ex);
|
||||
}
|
||||
}), latch::countDown)
|
||||
);
|
||||
assertTrue(latch.await(30, TimeUnit.SECONDS));
|
||||
}
|
||||
}
|
||||
|
||||
public void testMidStreamUnifiedCompletionError() throws Exception {
|
||||
String responseJson = """
|
||||
event: error
|
||||
data: { "error": { "message": "Timed out waiting for more data", "type": "timeout" } }
|
||||
|
||||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
testStreamError("""
|
||||
{\
|
||||
"error":{\
|
||||
"code":"model_not_found",\
|
||||
"message":"Received an unsuccessful status code for request from inference entity id [id] status \
|
||||
[404]. Error message: [The model `gpt-4awero` does not exist or you do not have access to it.]",\
|
||||
"type":"invalid_request_error"\
|
||||
"message":"Received an error response for request from inference entity id [id]. Error message: \
|
||||
[Timed out waiting for more data]",\
|
||||
"type":"timeout"\
|
||||
}}""");
|
||||
}
|
||||
|
||||
|
@ -1124,22 +1174,6 @@ public class OpenAiServiceTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testMidStreamUnifiedCompletionError() throws Exception {
|
||||
String responseJson = """
|
||||
event: error
|
||||
data: { "error": { "message": "Timed out waiting for more data", "type": "timeout" } }
|
||||
|
||||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
testStreamError("""
|
||||
{\
|
||||
"error":{\
|
||||
"message":"Received an error response for request from inference entity id [id]. Error message: \
|
||||
[Timed out waiting for more data]",\
|
||||
"type":"timeout"\
|
||||
}}""");
|
||||
}
|
||||
|
||||
public void testUnifiedCompletionMalformedError() throws Exception {
|
||||
String responseJson = """
|
||||
data: { invalid json }
|
||||
|
@ -1179,13 +1213,11 @@ public class OpenAiServiceTests extends ESTestCase {
|
|||
""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
|
||||
var result = streamCompletion();
|
||||
|
||||
InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent("""
|
||||
streamCompletion().hasNoErrors().hasEvent("""
|
||||
{"completion":[{"delta":"hello, world"}]}""");
|
||||
}
|
||||
|
||||
private InferenceServiceResults streamCompletion() throws IOException {
|
||||
private InferenceEventsAssertion streamCompletion() throws Exception {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) {
|
||||
var model = OpenAiChatCompletionModelTests.createCompletionModel(getUrl(webServer), "org", "secret", "model", "user");
|
||||
|
@ -1201,7 +1233,7 @@ public class OpenAiServiceTests extends ESTestCase {
|
|||
listener
|
||||
);
|
||||
|
||||
return listener.actionGet(TIMEOUT);
|
||||
return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1217,13 +1249,48 @@ public class OpenAiServiceTests extends ESTestCase {
|
|||
}""";
|
||||
webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson));
|
||||
|
||||
var result = streamCompletion();
|
||||
var e = assertThrows(ElasticsearchStatusException.class, this::streamCompletion);
|
||||
assertThat(e.status(), equalTo(RestStatus.UNAUTHORIZED));
|
||||
assertThat(
|
||||
e.getMessage(),
|
||||
equalTo(
|
||||
"Received an authentication error status code for request from inference entity id [id] status [401]. "
|
||||
+ "Error message: [You didn't provide an API key...]"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
InferenceEventsAssertion.assertThat(result)
|
||||
.hasFinishedStream()
|
||||
.hasNoEvents()
|
||||
.hasErrorWithStatusCode(401)
|
||||
.hasErrorContaining("You didn't provide an API key...");
|
||||
public void testInfer_StreamRequestRetry() throws Exception {
|
||||
webServer.enqueue(new MockResponse().setResponseCode(503).setBody("""
|
||||
{
|
||||
"error": {
|
||||
"message": "server busy",
|
||||
"type": "server_busy"
|
||||
}
|
||||
}"""));
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody("""
|
||||
data: {\
|
||||
"id":"12345",\
|
||||
"object":"chat.completion.chunk",\
|
||||
"created":123456789,\
|
||||
"model":"gpt-4o-mini",\
|
||||
"system_fingerprint": "123456789",\
|
||||
"choices":[\
|
||||
{\
|
||||
"index":0,\
|
||||
"delta":{\
|
||||
"content":"hello, world"\
|
||||
},\
|
||||
"logprobs":null,\
|
||||
"finish_reason":null\
|
||||
}\
|
||||
]\
|
||||
}
|
||||
|
||||
"""));
|
||||
|
||||
streamCompletion().hasNoErrors().hasEvent("""
|
||||
{"completion":[{"delta":"hello, world"}]}""");
|
||||
}
|
||||
|
||||
public void testSupportsStreaming() throws IOException {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue