[ML] Refactor inference request executor to leverage scheduled execution (#126858)

* Using threadpool schedule and fixing tests

* Update docs/changelog/126858.yaml

* Clean up

* change log
This commit is contained in:
Jonathan Buttner 2025-04-16 14:14:02 -04:00 committed by GitHub
parent e42c118ec6
commit 7a0f63c1a0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 57 additions and 74 deletions

View file

@ -0,0 +1,6 @@
pr: 126858
summary: Leverage threadpool schedule for inference api to avoid long running thread
area: Machine Learning
type: bug
issues:
- 126853

View file

@ -57,15 +57,6 @@ import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_P
*/
public class RequestExecutorService implements RequestExecutor {
/**
* Provides dependency injection mainly for testing
*/
interface Sleeper {
void sleep(TimeValue sleepTime) throws InterruptedException;
}
// default for tests
static final Sleeper DEFAULT_SLEEPER = sleepTime -> sleepTime.timeUnit().sleep(sleepTime.duration());
// default for tests
static final AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> DEFAULT_QUEUE_CREATOR =
new AdjustableCapacityBlockingQueue.QueueCreator<>() {
@ -118,7 +109,6 @@ public class RequestExecutorService implements RequestExecutor {
private final Clock clock;
private final AtomicBoolean shutdown = new AtomicBoolean(false);
private final AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> queueCreator;
private final Sleeper sleeper;
private final RateLimiterCreator rateLimiterCreator;
private final AtomicReference<Scheduler.Cancellable> cancellableCleanupTask = new AtomicReference<>();
private final AtomicBoolean started = new AtomicBoolean(false);
@ -129,16 +119,7 @@ public class RequestExecutorService implements RequestExecutor {
RequestExecutorServiceSettings settings,
RequestSender requestSender
) {
this(
threadPool,
DEFAULT_QUEUE_CREATOR,
startupLatch,
settings,
requestSender,
Clock.systemUTC(),
DEFAULT_SLEEPER,
DEFAULT_RATE_LIMIT_CREATOR
);
this(threadPool, DEFAULT_QUEUE_CREATOR, startupLatch, settings, requestSender, Clock.systemUTC(), DEFAULT_RATE_LIMIT_CREATOR);
}
public RequestExecutorService(
@ -148,7 +129,6 @@ public class RequestExecutorService implements RequestExecutor {
RequestExecutorServiceSettings settings,
RequestSender requestSender,
Clock clock,
Sleeper sleeper,
RateLimiterCreator rateLimiterCreator
) {
this.threadPool = Objects.requireNonNull(threadPool);
@ -157,7 +137,6 @@ public class RequestExecutorService implements RequestExecutor {
this.requestSender = Objects.requireNonNull(requestSender);
this.settings = Objects.requireNonNull(settings);
this.clock = Objects.requireNonNull(clock);
this.sleeper = Objects.requireNonNull(sleeper);
this.rateLimiterCreator = Objects.requireNonNull(rateLimiterCreator);
}
@ -213,15 +192,10 @@ public class RequestExecutorService implements RequestExecutor {
startCleanupTask();
signalStartInitiated();
while (isShutdown() == false) {
handleTasks();
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} finally {
shutdown();
notifyRequestsOfShutdown();
terminationLatch.countDown();
handleTasks();
} catch (Exception e) {
logger.warn("Failed to start request executor", e);
cleanup();
}
}
@ -256,13 +230,44 @@ public class RequestExecutorService implements RequestExecutor {
}
}
private void handleTasks() throws InterruptedException {
var timeToWait = settings.getTaskPollFrequency();
for (var endpoint : rateLimitGroupings.values()) {
timeToWait = TimeValue.min(endpoint.executeEnqueuedTask(), timeToWait);
private void scheduleNextHandleTasks(TimeValue timeToWait) {
if (shutdown.get()) {
logger.debug("Shutdown requested while scheduling next handle task call, cleaning up");
cleanup();
return;
}
sleeper.sleep(timeToWait);
threadPool.schedule(this::handleTasks, timeToWait, threadPool.executor(UTILITY_THREAD_POOL_NAME));
}
private void cleanup() {
try {
shutdown();
notifyRequestsOfShutdown();
terminationLatch.countDown();
} catch (Exception e) {
logger.warn("Encountered an error while cleaning up", e);
}
}
private void handleTasks() {
try {
if (shutdown.get()) {
logger.debug("Shutdown requested while handling tasks, cleaning up");
cleanup();
return;
}
var timeToWait = settings.getTaskPollFrequency();
for (var endpoint : rateLimitGroupings.values()) {
timeToWait = TimeValue.min(endpoint.executeEnqueuedTask(), timeToWait);
}
scheduleNextHandleTasks(timeToWait);
} catch (Exception e) {
logger.warn("Encountered an error while handling tasks", e);
cleanup();
}
}
private void notifyRequestsOfShutdown() {

View file

@ -50,6 +50,7 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER;
import static org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceRequestTests.randomElasticInferenceServiceRequestMetadata;
import static org.elasticsearch.xpack.inference.services.openai.OpenAiUtils.ORGANIZATION_HEADER;
@ -90,7 +91,7 @@ public class HttpRequestSenderTests extends ESTestCase {
}
public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception {
var senderFactory = createSenderFactory(clientManager, threadRef);
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
try (var sender = createSender(senderFactory)) {
sender.start();

View file

@ -51,7 +51,6 @@ import static org.hamcrest.Matchers.is;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@ -206,7 +205,7 @@ public class RequestExecutorServiceTests extends ESTestCase {
assertFalse(thrownException.isExecutorShutdown());
}
public void testTaskThrowsError_CallsOnFailure() {
public void testTaskThrowsError_CallsOnFailure() throws InterruptedException {
var requestSender = mock(RetryingHttpSender.class);
var service = createRequestExecutorService(null, requestSender);
@ -229,6 +228,8 @@ public class RequestExecutorServiceTests extends ESTestCase {
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
assertThat(thrownException.getMessage(), is(format("Failed to send request from inference entity id [%s]", "id")));
assertThat(thrownException.getCause(), instanceOf(IllegalArgumentException.class));
service.awaitTermination(TIMEOUT.getSeconds(), TimeUnit.SECONDS);
assertTrue(service.isTerminated());
}
@ -361,7 +362,6 @@ public class RequestExecutorServiceTests extends ESTestCase {
createRequestExecutorServiceSettingsEmpty(),
requestSender,
Clock.systemUTC(),
RequestExecutorService.DEFAULT_SLEEPER,
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
);
@ -375,36 +375,7 @@ public class RequestExecutorServiceTests extends ESTestCase {
});
service.start();
assertTrue(service.isTerminated());
}
public void testSleep_ThrowingInterruptedException_TerminatesService() throws Exception {
@SuppressWarnings("unchecked")
BlockingQueue<RejectableTask> queue = mock(LinkedBlockingQueue.class);
var sleeper = mock(RequestExecutorService.Sleeper.class);
doThrow(new InterruptedException("failed")).when(sleeper).sleep(any());
var service = new RequestExecutorService(
threadPool,
mockQueueCreator(queue),
null,
createRequestExecutorServiceSettingsEmpty(),
mock(RetryingHttpSender.class),
Clock.systemUTC(),
sleeper,
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
);
Future<?> executorTermination = threadPool.generic().submit(() -> {
try {
service.start();
} catch (Exception e) {
fail(Strings.format("Failed to shutdown executor: %s", e));
}
});
executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS);
service.awaitTermination(TIMEOUT.getSeconds(), TimeUnit.SECONDS);
assertTrue(service.isTerminated());
}
@ -581,7 +552,6 @@ public class RequestExecutorServiceTests extends ESTestCase {
settings,
requestSender,
Clock.systemUTC(),
RequestExecutorService.DEFAULT_SLEEPER,
rateLimiterCreator
);
var requestManager = RequestManagerTests.createMock(requestSender);
@ -614,7 +584,6 @@ public class RequestExecutorServiceTests extends ESTestCase {
settings,
requestSender,
Clock.systemUTC(),
RequestExecutorService.DEFAULT_SLEEPER,
rateLimiterCreator
);
var requestManager = RequestManagerTests.createMock(requestSender);
@ -626,11 +595,15 @@ public class RequestExecutorServiceTests extends ESTestCase {
doAnswer(invocation -> {
service.shutdown();
ActionListener<InferenceServiceResults> passedListener = invocation.getArgument(4);
passedListener.onResponse(null);
return Void.TYPE;
}).when(requestSender).send(any(), any(), any(), any(), any());
service.start();
listener.actionGet(TIMEOUT);
verify(requestSender, times(1)).send(any(), any(), any(), any(), any());
}
@ -648,7 +621,6 @@ public class RequestExecutorServiceTests extends ESTestCase {
settings,
requestSender,
clock,
RequestExecutorService.DEFAULT_SLEEPER,
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
);
var requestManager = RequestManagerTests.createMock(requestSender, "id1");
@ -682,7 +654,6 @@ public class RequestExecutorServiceTests extends ESTestCase {
settings,
requestSender,
Clock.systemUTC(),
RequestExecutorService.DEFAULT_SLEEPER,
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
);