mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-04-19 04:45:07 -04:00
[ML] Refactor inference request executor to leverage scheduled execution (#126858)
* Using threadpool schedule and fixing tests * Update docs/changelog/126858.yaml * Clean up * change log
This commit is contained in:
parent
e42c118ec6
commit
7a0f63c1a0
4 changed files with 57 additions and 74 deletions
6
docs/changelog/126858.yaml
Normal file
6
docs/changelog/126858.yaml
Normal file
|
@ -0,0 +1,6 @@
|
|||
pr: 126858
|
||||
summary: Leverage threadpool schedule for inference api to avoid long running thread
|
||||
area: Machine Learning
|
||||
type: bug
|
||||
issues:
|
||||
- 126853
|
|
@ -57,15 +57,6 @@ import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_P
|
|||
*/
|
||||
public class RequestExecutorService implements RequestExecutor {
|
||||
|
||||
/**
|
||||
* Provides dependency injection mainly for testing
|
||||
*/
|
||||
interface Sleeper {
|
||||
void sleep(TimeValue sleepTime) throws InterruptedException;
|
||||
}
|
||||
|
||||
// default for tests
|
||||
static final Sleeper DEFAULT_SLEEPER = sleepTime -> sleepTime.timeUnit().sleep(sleepTime.duration());
|
||||
// default for tests
|
||||
static final AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> DEFAULT_QUEUE_CREATOR =
|
||||
new AdjustableCapacityBlockingQueue.QueueCreator<>() {
|
||||
|
@ -118,7 +109,6 @@ public class RequestExecutorService implements RequestExecutor {
|
|||
private final Clock clock;
|
||||
private final AtomicBoolean shutdown = new AtomicBoolean(false);
|
||||
private final AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> queueCreator;
|
||||
private final Sleeper sleeper;
|
||||
private final RateLimiterCreator rateLimiterCreator;
|
||||
private final AtomicReference<Scheduler.Cancellable> cancellableCleanupTask = new AtomicReference<>();
|
||||
private final AtomicBoolean started = new AtomicBoolean(false);
|
||||
|
@ -129,16 +119,7 @@ public class RequestExecutorService implements RequestExecutor {
|
|||
RequestExecutorServiceSettings settings,
|
||||
RequestSender requestSender
|
||||
) {
|
||||
this(
|
||||
threadPool,
|
||||
DEFAULT_QUEUE_CREATOR,
|
||||
startupLatch,
|
||||
settings,
|
||||
requestSender,
|
||||
Clock.systemUTC(),
|
||||
DEFAULT_SLEEPER,
|
||||
DEFAULT_RATE_LIMIT_CREATOR
|
||||
);
|
||||
this(threadPool, DEFAULT_QUEUE_CREATOR, startupLatch, settings, requestSender, Clock.systemUTC(), DEFAULT_RATE_LIMIT_CREATOR);
|
||||
}
|
||||
|
||||
public RequestExecutorService(
|
||||
|
@ -148,7 +129,6 @@ public class RequestExecutorService implements RequestExecutor {
|
|||
RequestExecutorServiceSettings settings,
|
||||
RequestSender requestSender,
|
||||
Clock clock,
|
||||
Sleeper sleeper,
|
||||
RateLimiterCreator rateLimiterCreator
|
||||
) {
|
||||
this.threadPool = Objects.requireNonNull(threadPool);
|
||||
|
@ -157,7 +137,6 @@ public class RequestExecutorService implements RequestExecutor {
|
|||
this.requestSender = Objects.requireNonNull(requestSender);
|
||||
this.settings = Objects.requireNonNull(settings);
|
||||
this.clock = Objects.requireNonNull(clock);
|
||||
this.sleeper = Objects.requireNonNull(sleeper);
|
||||
this.rateLimiterCreator = Objects.requireNonNull(rateLimiterCreator);
|
||||
}
|
||||
|
||||
|
@ -213,15 +192,10 @@ public class RequestExecutorService implements RequestExecutor {
|
|||
startCleanupTask();
|
||||
signalStartInitiated();
|
||||
|
||||
while (isShutdown() == false) {
|
||||
handleTasks();
|
||||
}
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
} finally {
|
||||
shutdown();
|
||||
notifyRequestsOfShutdown();
|
||||
terminationLatch.countDown();
|
||||
handleTasks();
|
||||
} catch (Exception e) {
|
||||
logger.warn("Failed to start request executor", e);
|
||||
cleanup();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -256,13 +230,44 @@ public class RequestExecutorService implements RequestExecutor {
|
|||
}
|
||||
}
|
||||
|
||||
private void handleTasks() throws InterruptedException {
|
||||
var timeToWait = settings.getTaskPollFrequency();
|
||||
for (var endpoint : rateLimitGroupings.values()) {
|
||||
timeToWait = TimeValue.min(endpoint.executeEnqueuedTask(), timeToWait);
|
||||
private void scheduleNextHandleTasks(TimeValue timeToWait) {
|
||||
if (shutdown.get()) {
|
||||
logger.debug("Shutdown requested while scheduling next handle task call, cleaning up");
|
||||
cleanup();
|
||||
return;
|
||||
}
|
||||
|
||||
sleeper.sleep(timeToWait);
|
||||
threadPool.schedule(this::handleTasks, timeToWait, threadPool.executor(UTILITY_THREAD_POOL_NAME));
|
||||
}
|
||||
|
||||
private void cleanup() {
|
||||
try {
|
||||
shutdown();
|
||||
notifyRequestsOfShutdown();
|
||||
terminationLatch.countDown();
|
||||
} catch (Exception e) {
|
||||
logger.warn("Encountered an error while cleaning up", e);
|
||||
}
|
||||
}
|
||||
|
||||
private void handleTasks() {
|
||||
try {
|
||||
if (shutdown.get()) {
|
||||
logger.debug("Shutdown requested while handling tasks, cleaning up");
|
||||
cleanup();
|
||||
return;
|
||||
}
|
||||
|
||||
var timeToWait = settings.getTaskPollFrequency();
|
||||
for (var endpoint : rateLimitGroupings.values()) {
|
||||
timeToWait = TimeValue.min(endpoint.executeEnqueuedTask(), timeToWait);
|
||||
}
|
||||
|
||||
scheduleNextHandleTasks(timeToWait);
|
||||
} catch (Exception e) {
|
||||
logger.warn("Encountered an error while handling tasks", e);
|
||||
cleanup();
|
||||
}
|
||||
}
|
||||
|
||||
private void notifyRequestsOfShutdown() {
|
||||
|
|
|
@ -50,6 +50,7 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
|||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
||||
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER;
|
||||
import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata;
|
||||
import static org.elasticsearch.xpack.inference.services.openai.OpenAiUtils.ORGANIZATION_HEADER;
|
||||
|
@ -90,7 +91,7 @@ public class HttpRequestSenderTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception {
|
||||
var senderFactory = createSenderFactory(clientManager, threadRef);
|
||||
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
|
||||
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
|
|
@ -51,7 +51,6 @@ import static org.hamcrest.Matchers.is;
|
|||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyInt;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
@ -206,7 +205,7 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
|||
assertFalse(thrownException.isExecutorShutdown());
|
||||
}
|
||||
|
||||
public void testTaskThrowsError_CallsOnFailure() {
|
||||
public void testTaskThrowsError_CallsOnFailure() throws InterruptedException {
|
||||
var requestSender = mock(RetryingHttpSender.class);
|
||||
|
||||
var service = createRequestExecutorService(null, requestSender);
|
||||
|
@ -229,6 +228,8 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
|||
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
|
||||
assertThat(thrownException.getMessage(), is(format("Failed to send request from inference entity id [%s]", "id")));
|
||||
assertThat(thrownException.getCause(), instanceOf(IllegalArgumentException.class));
|
||||
service.awaitTermination(TIMEOUT.getSeconds(), TimeUnit.SECONDS);
|
||||
|
||||
assertTrue(service.isTerminated());
|
||||
}
|
||||
|
||||
|
@ -361,7 +362,6 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
|||
createRequestExecutorServiceSettingsEmpty(),
|
||||
requestSender,
|
||||
Clock.systemUTC(),
|
||||
RequestExecutorService.DEFAULT_SLEEPER,
|
||||
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
|
||||
);
|
||||
|
||||
|
@ -375,36 +375,7 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
|||
});
|
||||
service.start();
|
||||
|
||||
assertTrue(service.isTerminated());
|
||||
}
|
||||
|
||||
public void testSleep_ThrowingInterruptedException_TerminatesService() throws Exception {
|
||||
@SuppressWarnings("unchecked")
|
||||
BlockingQueue<RejectableTask> queue = mock(LinkedBlockingQueue.class);
|
||||
var sleeper = mock(RequestExecutorService.Sleeper.class);
|
||||
doThrow(new InterruptedException("failed")).when(sleeper).sleep(any());
|
||||
|
||||
var service = new RequestExecutorService(
|
||||
threadPool,
|
||||
mockQueueCreator(queue),
|
||||
null,
|
||||
createRequestExecutorServiceSettingsEmpty(),
|
||||
mock(RetryingHttpSender.class),
|
||||
Clock.systemUTC(),
|
||||
sleeper,
|
||||
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
|
||||
);
|
||||
|
||||
Future<?> executorTermination = threadPool.generic().submit(() -> {
|
||||
try {
|
||||
service.start();
|
||||
} catch (Exception e) {
|
||||
fail(Strings.format("Failed to shutdown executor: %s", e));
|
||||
}
|
||||
});
|
||||
|
||||
executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS);
|
||||
|
||||
service.awaitTermination(TIMEOUT.getSeconds(), TimeUnit.SECONDS);
|
||||
assertTrue(service.isTerminated());
|
||||
}
|
||||
|
||||
|
@ -581,7 +552,6 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
|||
settings,
|
||||
requestSender,
|
||||
Clock.systemUTC(),
|
||||
RequestExecutorService.DEFAULT_SLEEPER,
|
||||
rateLimiterCreator
|
||||
);
|
||||
var requestManager = RequestManagerTests.createMock(requestSender);
|
||||
|
@ -614,7 +584,6 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
|||
settings,
|
||||
requestSender,
|
||||
Clock.systemUTC(),
|
||||
RequestExecutorService.DEFAULT_SLEEPER,
|
||||
rateLimiterCreator
|
||||
);
|
||||
var requestManager = RequestManagerTests.createMock(requestSender);
|
||||
|
@ -626,11 +595,15 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
|||
|
||||
doAnswer(invocation -> {
|
||||
service.shutdown();
|
||||
ActionListener<InferenceServiceResults> passedListener = invocation.getArgument(4);
|
||||
passedListener.onResponse(null);
|
||||
|
||||
return Void.TYPE;
|
||||
}).when(requestSender).send(any(), any(), any(), any(), any());
|
||||
|
||||
service.start();
|
||||
|
||||
listener.actionGet(TIMEOUT);
|
||||
verify(requestSender, times(1)).send(any(), any(), any(), any(), any());
|
||||
}
|
||||
|
||||
|
@ -648,7 +621,6 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
|||
settings,
|
||||
requestSender,
|
||||
clock,
|
||||
RequestExecutorService.DEFAULT_SLEEPER,
|
||||
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
|
||||
);
|
||||
var requestManager = RequestManagerTests.createMock(requestSender, "id1");
|
||||
|
@ -682,7 +654,6 @@ public class RequestExecutorServiceTests extends ESTestCase {
|
|||
settings,
|
||||
requestSender,
|
||||
Clock.systemUTC(),
|
||||
RequestExecutorService.DEFAULT_SLEEPER,
|
||||
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
|
||||
);
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue