[ML] Inference duration and error metrics (#115876)

Add `es.inference.requests.time` metric around `infer` API.

As recommended by OTel spec, errors are determined by the
presence or absence of the `error.type` attribute in the metric.
"error.type" will be the http status code (as a string) if it is
available, otherwise it will be the name of the exception (e.g.
NullPointerException).

Additional notes:
- ApmInferenceStats is merged into InferenceStats. Originally we planned
  to have multiple implementations, but now we're only using APM.
- Request count is now always recorded, even when there are failures
  loading the endpoint configuration.
- Added a hook in streaming for cancel messages, so we can close the
  metrics when a user cancels the stream.
This commit is contained in:
Pat Whelan 2024-11-05 09:21:30 -05:00 committed by GitHub
parent 38c7ddd409
commit 26870ef38d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 826 additions and 151 deletions

View file

@ -0,0 +1,5 @@
pr: 115876
summary: Inference duration and error metrics
area: Machine Learning
type: enhancement
issues: []

View file

@ -101,7 +101,6 @@ import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceE
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService; import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService;
import org.elasticsearch.xpack.inference.services.mistral.MistralService; import org.elasticsearch.xpack.inference.services.mistral.MistralService;
import org.elasticsearch.xpack.inference.services.openai.OpenAiService; import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
import org.elasticsearch.xpack.inference.telemetry.ApmInferenceStats;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
import java.util.ArrayList; import java.util.ArrayList;
@ -239,7 +238,7 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
shardBulkInferenceActionFilter.set(actionFilter); shardBulkInferenceActionFilter.set(actionFilter);
var meterRegistry = services.telemetryProvider().getMeterRegistry(); var meterRegistry = services.telemetryProvider().getMeterRegistry();
var stats = new PluginComponentBinding<>(InferenceStats.class, ApmInferenceStats.create(meterRegistry)); var stats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry));
return List.of(modelRegistry, registry, httpClientManager, stats); return List.of(modelRegistry, registry, httpClientManager, stats);
} }

View file

@ -7,12 +7,15 @@
package org.elasticsearch.xpack.inference.action; package org.elasticsearch.xpack.inference.action;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InferenceServiceResults;
@ -25,20 +28,22 @@ import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService; import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
import org.elasticsearch.xpack.inference.telemetry.InferenceTimer;
import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes;
public class TransportInferenceAction extends HandledTransportAction<InferenceAction.Request, InferenceAction.Response> { public class TransportInferenceAction extends HandledTransportAction<InferenceAction.Request, InferenceAction.Response> {
private static final Logger log = LogManager.getLogger(TransportInferenceAction.class);
private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference"; private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference";
private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]"; private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]";
private static final Set<Class<? extends InferenceService>> supportsStreaming = Set.of();
private final ModelRegistry modelRegistry; private final ModelRegistry modelRegistry;
private final InferenceServiceRegistry serviceRegistry; private final InferenceServiceRegistry serviceRegistry;
private final InferenceStats inferenceStats; private final InferenceStats inferenceStats;
@ -62,17 +67,22 @@ public class TransportInferenceAction extends HandledTransportAction<InferenceAc
@Override @Override
protected void doExecute(Task task, InferenceAction.Request request, ActionListener<InferenceAction.Response> listener) { protected void doExecute(Task task, InferenceAction.Request request, ActionListener<InferenceAction.Response> listener) {
var timer = InferenceTimer.start();
ActionListener<UnparsedModel> getModelListener = listener.delegateFailureAndWrap((delegate, unparsedModel) -> { var getModelListener = ActionListener.wrap((UnparsedModel unparsedModel) -> {
var service = serviceRegistry.getService(unparsedModel.service()); var service = serviceRegistry.getService(unparsedModel.service());
if (service.isEmpty()) { if (service.isEmpty()) {
listener.onFailure(unknownServiceException(unparsedModel.service(), request.getInferenceEntityId())); var e = unknownServiceException(unparsedModel.service(), request.getInferenceEntityId());
recordMetrics(unparsedModel, timer, e);
listener.onFailure(e);
return; return;
} }
if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) { if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) {
// not the wildcard task type and not the model task type // not the wildcard task type and not the model task type
listener.onFailure(incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType())); var e = incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType());
recordMetrics(unparsedModel, timer, e);
listener.onFailure(e);
return; return;
} }
@ -83,20 +93,69 @@ public class TransportInferenceAction extends HandledTransportAction<InferenceAc
unparsedModel.settings(), unparsedModel.settings(),
unparsedModel.secrets() unparsedModel.secrets()
); );
inferOnService(model, request, service.get(), delegate); inferOnServiceWithMetrics(model, request, service.get(), timer, listener);
}, e -> {
try {
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(e));
} catch (Exception metricsException) {
log.atDebug().withThrowable(metricsException).log("Failed to record metrics when the model is missing, dropping metrics");
}
listener.onFailure(e);
}); });
modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener); modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener);
} }
private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) {
try {
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t));
} catch (Exception e) {
log.atDebug().withThrowable(e).log("Failed to record metrics with an unparsed model, dropping metrics");
}
}
private void inferOnServiceWithMetrics(
Model model,
InferenceAction.Request request,
InferenceService service,
InferenceTimer timer,
ActionListener<InferenceAction.Response> listener
) {
inferenceStats.requestCount().incrementBy(1, modelAttributes(model));
inferOnService(model, request, service, ActionListener.wrap(inferenceResults -> {
if (request.isStreaming()) {
var taskProcessor = streamingTaskManager.<ChunkedToXContent>create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION);
inferenceResults.publisher().subscribe(taskProcessor);
var instrumentedStream = new PublisherWithMetrics(timer, model);
taskProcessor.subscribe(instrumentedStream);
listener.onResponse(new InferenceAction.Response(inferenceResults, instrumentedStream));
} else {
recordMetrics(model, timer, null);
listener.onResponse(new InferenceAction.Response(inferenceResults));
}
}, e -> {
recordMetrics(model, timer, e);
listener.onFailure(e);
}));
}
private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) {
try {
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t));
} catch (Exception e) {
log.atDebug().withThrowable(e).log("Failed to record metrics with a parsed model, dropping metrics");
}
}
private void inferOnService( private void inferOnService(
Model model, Model model,
InferenceAction.Request request, InferenceAction.Request request,
InferenceService service, InferenceService service,
ActionListener<InferenceAction.Response> listener ActionListener<InferenceServiceResults> listener
) { ) {
if (request.isStreaming() == false || service.canStream(request.getTaskType())) { if (request.isStreaming() == false || service.canStream(request.getTaskType())) {
inferenceStats.incrementRequestCount(model);
service.infer( service.infer(
model, model,
request.getQuery(), request.getQuery(),
@ -105,7 +164,7 @@ public class TransportInferenceAction extends HandledTransportAction<InferenceAc
request.getTaskSettings(), request.getTaskSettings(),
request.getInputType(), request.getInputType(),
request.getInferenceTimeout(), request.getInferenceTimeout(),
createListener(request, listener) listener
); );
} else { } else {
listener.onFailure(unsupportedStreamingTaskException(request, service)); listener.onFailure(unsupportedStreamingTaskException(request, service));
@ -133,20 +192,6 @@ public class TransportInferenceAction extends HandledTransportAction<InferenceAc
} }
} }
private ActionListener<InferenceServiceResults> createListener(
InferenceAction.Request request,
ActionListener<InferenceAction.Response> listener
) {
if (request.isStreaming()) {
return listener.delegateFailureAndWrap((l, inferenceResults) -> {
var taskProcessor = streamingTaskManager.<ChunkedToXContent>create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION);
inferenceResults.publisher().subscribe(taskProcessor);
l.onResponse(new InferenceAction.Response(inferenceResults, taskProcessor));
});
}
return listener.delegateFailureAndWrap((l, inferenceResults) -> l.onResponse(new InferenceAction.Response(inferenceResults)));
}
private static ElasticsearchStatusException unknownServiceException(String service, String inferenceId) { private static ElasticsearchStatusException unknownServiceException(String service, String inferenceId) {
return new ElasticsearchStatusException("Unknown service [{}] for model [{}]. ", RestStatus.BAD_REQUEST, service, inferenceId); return new ElasticsearchStatusException("Unknown service [{}] for model [{}]. ", RestStatus.BAD_REQUEST, service, inferenceId);
} }
@ -160,4 +205,37 @@ public class TransportInferenceAction extends HandledTransportAction<InferenceAc
); );
} }
private class PublisherWithMetrics extends DelegatingProcessor<ChunkedToXContent, ChunkedToXContent> {
private final InferenceTimer timer;
private final Model model;
private PublisherWithMetrics(InferenceTimer timer, Model model) {
this.timer = timer;
this.model = model;
}
@Override
protected void next(ChunkedToXContent item) {
downstream().onNext(item);
}
@Override
public void onError(Throwable throwable) {
recordMetrics(model, timer, throwable);
super.onError(throwable);
}
@Override
protected void onCancel() {
recordMetrics(model, timer, null);
super.onCancel();
}
@Override
public void onComplete() {
recordMetrics(model, timer, null);
super.onComplete();
}
}
} }

View file

@ -61,11 +61,14 @@ public abstract class DelegatingProcessor<T, R> implements Flow.Processor<T, R>
public void cancel() { public void cancel() {
if (isClosed.compareAndSet(false, true) && upstream != null) { if (isClosed.compareAndSet(false, true) && upstream != null) {
upstream.cancel(); upstream.cancel();
onCancel();
} }
} }
}; };
} }
protected void onCancel() {}
@Override @Override
public void onSubscribe(Flow.Subscription subscription) { public void onSubscribe(Flow.Subscription subscription) {
if (upstream != null) { if (upstream != null) {

View file

@ -1,49 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.telemetry;
import org.elasticsearch.inference.Model;
import org.elasticsearch.telemetry.metric.LongCounter;
import org.elasticsearch.telemetry.metric.MeterRegistry;
import java.util.HashMap;
import java.util.Objects;
public class ApmInferenceStats implements InferenceStats {
private final LongCounter inferenceAPMRequestCounter;
public ApmInferenceStats(LongCounter inferenceAPMRequestCounter) {
this.inferenceAPMRequestCounter = Objects.requireNonNull(inferenceAPMRequestCounter);
}
@Override
public void incrementRequestCount(Model model) {
var service = model.getConfigurations().getService();
var taskType = model.getTaskType();
var modelId = model.getServiceSettings().modelId();
var attributes = new HashMap<String, Object>(5);
attributes.put("service", service);
attributes.put("task_type", taskType.toString());
if (modelId != null) {
attributes.put("model_id", modelId);
}
inferenceAPMRequestCounter.incrementBy(1, attributes);
}
public static ApmInferenceStats create(MeterRegistry meterRegistry) {
return new ApmInferenceStats(
meterRegistry.registerLongCounter(
"es.inference.requests.count.total",
"Inference API request counts for a particular service, task type, model ID",
"operations"
)
);
}
}

View file

@ -7,15 +7,87 @@
package org.elasticsearch.xpack.inference.telemetry; package org.elasticsearch.xpack.inference.telemetry;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.Model; import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.telemetry.metric.LongCounter;
import org.elasticsearch.telemetry.metric.LongHistogram;
import org.elasticsearch.telemetry.metric.MeterRegistry;
public interface InferenceStats { import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/** import static java.util.Map.entry;
* Increment the counter for a particular value in a thread safe manner. import static java.util.stream.Stream.concat;
* @param model the model to increment request count for
*/
void incrementRequestCount(Model model);
InferenceStats NOOP = model -> {}; public record InferenceStats(LongCounter requestCount, LongHistogram inferenceDuration) {
public InferenceStats {
Objects.requireNonNull(requestCount);
Objects.requireNonNull(inferenceDuration);
}
public static InferenceStats create(MeterRegistry meterRegistry) {
return new InferenceStats(
meterRegistry.registerLongCounter(
"es.inference.requests.count.total",
"Inference API request counts for a particular service, task type, model ID",
"operations"
),
meterRegistry.registerLongHistogram(
"es.inference.requests.time",
"Inference API request counts for a particular service, task type, model ID",
"ms"
)
);
}
public static Map<String, Object> modelAttributes(Model model) {
return toMap(modelAttributeEntries(model));
}
private static Stream<Map.Entry<String, Object>> modelAttributeEntries(Model model) {
var stream = Stream.<Map.Entry<String, Object>>builder()
.add(entry("service", model.getConfigurations().getService()))
.add(entry("task_type", model.getTaskType().toString()));
if (model.getServiceSettings().modelId() != null) {
stream.add(entry("model_id", model.getServiceSettings().modelId()));
}
return stream.build();
}
private static Map<String, Object> toMap(Stream<Map.Entry<String, Object>> stream) {
return stream.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}
public static Map<String, Object> responseAttributes(Model model, @Nullable Throwable t) {
return toMap(concat(modelAttributeEntries(model), errorAttributes(t)));
}
public static Map<String, Object> responseAttributes(UnparsedModel model, @Nullable Throwable t) {
var unknownModelAttributes = Stream.<Map.Entry<String, Object>>builder()
.add(entry("service", model.service()))
.add(entry("task_type", model.taskType().toString()))
.build();
return toMap(concat(unknownModelAttributes, errorAttributes(t)));
}
public static Map<String, Object> responseAttributes(@Nullable Throwable t) {
return toMap(errorAttributes(t));
}
private static Stream<Map.Entry<String, Object>> errorAttributes(@Nullable Throwable t) {
return switch (t) {
case null -> Stream.of(entry("status_code", 200));
case ElasticsearchStatusException ese -> Stream.<Map.Entry<String, Object>>builder()
.add(entry("status_code", ese.status().getStatus()))
.add(entry("error.type", String.valueOf(ese.status().getStatus())))
.build();
default -> Stream.of(entry("error.type", t.getClass().getSimpleName()));
};
}
} }

View file

@ -0,0 +1,33 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.telemetry;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Objects;
public record InferenceTimer(Instant startTime, Clock clock) {
public InferenceTimer {
Objects.requireNonNull(startTime);
Objects.requireNonNull(clock);
}
public static InferenceTimer start() {
return start(Clock.systemUTC());
}
public static InferenceTimer start(Clock clock) {
return new InferenceTimer(clock.instant(), clock);
}
public long elapsedMillis() {
return Duration.between(startTime(), clock().instant()).toMillis();
}
}

View file

@ -0,0 +1,354 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.action;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.Flow;
import java.util.function.Consumer;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.isA;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.assertArg;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class TransportInferenceActionTests extends ESTestCase {
private static final String serviceId = "serviceId";
private static final TaskType taskType = TaskType.COMPLETION;
private static final String inferenceId = "inferenceEntityId";
private ModelRegistry modelRegistry;
private InferenceServiceRegistry serviceRegistry;
private InferenceStats inferenceStats;
private StreamingTaskManager streamingTaskManager;
private TransportInferenceAction action;
@Before
public void setUp() throws Exception {
super.setUp();
TransportService transportService = mock();
ActionFilters actionFilters = mock();
modelRegistry = mock();
serviceRegistry = mock();
inferenceStats = new InferenceStats(mock(), mock());
streamingTaskManager = mock();
action = new TransportInferenceAction(
transportService,
actionFilters,
modelRegistry,
serviceRegistry,
inferenceStats,
streamingTaskManager
);
}
public void testMetricsAfterModelRegistryError() {
var expectedException = new IllegalStateException("hello");
var expectedError = expectedException.getClass().getSimpleName();
doAnswer(ans -> {
ActionListener<?> listener = ans.getArgument(1);
listener.onFailure(expectedException);
return null;
}).when(modelRegistry).getModelWithSecrets(any(), any());
var listener = doExecute(taskType);
verify(listener).onFailure(same(expectedException));
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
assertThat(attributes.get("service"), nullValue());
assertThat(attributes.get("task_type"), nullValue());
assertThat(attributes.get("model_id"), nullValue());
assertThat(attributes.get("status_code"), nullValue());
assertThat(attributes.get("error.type"), is(expectedError));
}));
}
private ActionListener<InferenceAction.Response> doExecute(TaskType taskType) {
return doExecute(taskType, false);
}
private ActionListener<InferenceAction.Response> doExecute(TaskType taskType, boolean stream) {
InferenceAction.Request request = mock();
when(request.getInferenceEntityId()).thenReturn(inferenceId);
when(request.getTaskType()).thenReturn(taskType);
when(request.isStreaming()).thenReturn(stream);
ActionListener<InferenceAction.Response> listener = mock();
action.doExecute(mock(), request, listener);
return listener;
}
public void testMetricsAfterMissingService() {
mockModelRegistry(taskType);
when(serviceRegistry.getService(any())).thenReturn(Optional.empty());
var listener = doExecute(taskType);
verify(listener).onFailure(assertArg(e -> {
assertThat(e, isA(ElasticsearchStatusException.class));
assertThat(e.getMessage(), is("Unknown service [" + serviceId + "] for model [" + inferenceId + "]. "));
assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST));
}));
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
assertThat(attributes.get("service"), is(serviceId));
assertThat(attributes.get("task_type"), is(taskType.toString()));
assertThat(attributes.get("model_id"), nullValue());
assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus()));
assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus())));
}));
}
private void mockModelRegistry(TaskType expectedTaskType) {
var unparsedModel = new UnparsedModel(inferenceId, expectedTaskType, serviceId, Map.of(), Map.of());
doAnswer(ans -> {
ActionListener<UnparsedModel> listener = ans.getArgument(1);
listener.onResponse(unparsedModel);
return null;
}).when(modelRegistry).getModelWithSecrets(any(), any());
}
public void testMetricsAfterUnknownTaskType() {
var modelTaskType = TaskType.RERANK;
var requestTaskType = TaskType.SPARSE_EMBEDDING;
mockModelRegistry(modelTaskType);
when(serviceRegistry.getService(any())).thenReturn(Optional.of(mock()));
var listener = doExecute(requestTaskType);
verify(listener).onFailure(assertArg(e -> {
assertThat(e, isA(ElasticsearchStatusException.class));
assertThat(
e.getMessage(),
is(
"Incompatible task_type, the requested type ["
+ requestTaskType
+ "] does not match the model type ["
+ modelTaskType
+ "]"
)
);
assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST));
}));
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
assertThat(attributes.get("service"), is(serviceId));
assertThat(attributes.get("task_type"), is(modelTaskType.toString()));
assertThat(attributes.get("model_id"), nullValue());
assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus()));
assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus())));
}));
}
public void testMetricsAfterInferError() {
var expectedException = new IllegalStateException("hello");
var expectedError = expectedException.getClass().getSimpleName();
mockService(listener -> listener.onFailure(expectedException));
var listener = doExecute(taskType);
verify(listener).onFailure(same(expectedException));
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
assertThat(attributes.get("service"), is(serviceId));
assertThat(attributes.get("task_type"), is(taskType.toString()));
assertThat(attributes.get("model_id"), nullValue());
assertThat(attributes.get("status_code"), nullValue());
assertThat(attributes.get("error.type"), is(expectedError));
}));
}
public void testMetricsAfterStreamUnsupported() {
var expectedStatus = RestStatus.METHOD_NOT_ALLOWED;
var expectedError = String.valueOf(expectedStatus.getStatus());
mockService(l -> {});
var listener = doExecute(taskType, true);
verify(listener).onFailure(assertArg(e -> {
assertThat(e, isA(ElasticsearchStatusException.class));
var ese = (ElasticsearchStatusException) e;
assertThat(ese.getMessage(), is("Streaming is not allowed for service [" + serviceId + "]."));
assertThat(ese.status(), is(expectedStatus));
}));
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
assertThat(attributes.get("service"), is(serviceId));
assertThat(attributes.get("task_type"), is(taskType.toString()));
assertThat(attributes.get("model_id"), nullValue());
assertThat(attributes.get("status_code"), is(expectedStatus.getStatus()));
assertThat(attributes.get("error.type"), is(expectedError));
}));
}
public void testMetricsAfterInferSuccess() {
mockService(listener -> listener.onResponse(mock()));
var listener = doExecute(taskType);
verify(listener).onResponse(any());
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
assertThat(attributes.get("service"), is(serviceId));
assertThat(attributes.get("task_type"), is(taskType.toString()));
assertThat(attributes.get("model_id"), nullValue());
assertThat(attributes.get("status_code"), is(200));
assertThat(attributes.get("error.type"), nullValue());
}));
}
public void testMetricsAfterStreamInferSuccess() {
mockStreamResponse(Flow.Subscriber::onComplete);
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
assertThat(attributes.get("service"), is(serviceId));
assertThat(attributes.get("task_type"), is(taskType.toString()));
assertThat(attributes.get("model_id"), nullValue());
assertThat(attributes.get("status_code"), is(200));
assertThat(attributes.get("error.type"), nullValue());
}));
}
public void testMetricsAfterStreamInferFailure() {
var expectedException = new IllegalStateException("hello");
var expectedError = expectedException.getClass().getSimpleName();
mockStreamResponse(subscriber -> {
subscriber.subscribe(mock());
subscriber.onError(expectedException);
});
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
assertThat(attributes.get("service"), is(serviceId));
assertThat(attributes.get("task_type"), is(taskType.toString()));
assertThat(attributes.get("model_id"), nullValue());
assertThat(attributes.get("status_code"), nullValue());
assertThat(attributes.get("error.type"), is(expectedError));
}));
}
public void testMetricsAfterStreamCancel() {
var response = mockStreamResponse(s -> s.onSubscribe(mock()));
response.subscribe(new Flow.Subscriber<>() {
@Override
public void onSubscribe(Flow.Subscription subscription) {
subscription.cancel();
}
@Override
public void onNext(ChunkedToXContent item) {
}
@Override
public void onError(Throwable throwable) {
}
@Override
public void onComplete() {
}
});
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
assertThat(attributes.get("service"), is(serviceId));
assertThat(attributes.get("task_type"), is(taskType.toString()));
assertThat(attributes.get("model_id"), nullValue());
assertThat(attributes.get("status_code"), is(200));
assertThat(attributes.get("error.type"), nullValue());
}));
}
private Flow.Publisher<ChunkedToXContent> mockStreamResponse(Consumer<Flow.Processor<?, ?>> action) {
mockService(true, Set.of(), listener -> {
Flow.Processor<ChunkedToXContent, ChunkedToXContent> taskProcessor = mock();
doAnswer(innerAns -> {
action.accept(innerAns.getArgument(0));
return null;
}).when(taskProcessor).subscribe(any());
when(streamingTaskManager.<ChunkedToXContent>create(any(), any())).thenReturn(taskProcessor);
var inferenceServiceResults = mock(InferenceServiceResults.class);
when(inferenceServiceResults.publisher()).thenReturn(mock());
listener.onResponse(inferenceServiceResults);
});
var listener = doExecute(taskType, true);
var captor = ArgumentCaptor.forClass(InferenceAction.Response.class);
verify(listener).onResponse(captor.capture());
assertTrue(captor.getValue().isStreaming());
assertNotNull(captor.getValue().publisher());
return captor.getValue().publisher();
}
private void mockService(Consumer<ActionListener<InferenceServiceResults>> listenerAction) {
mockService(false, Set.of(), listenerAction);
}
private void mockService(
boolean stream,
Set<TaskType> supportedStreamingTasks,
Consumer<ActionListener<InferenceServiceResults>> listenerAction
) {
InferenceService service = mock();
Model model = mockModel();
when(service.parsePersistedConfigWithSecrets(any(), any(), any(), any())).thenReturn(model);
when(service.name()).thenReturn(serviceId);
when(service.canStream(any())).thenReturn(stream);
when(service.supportedStreamingTasks()).thenReturn(supportedStreamingTasks);
doAnswer(ans -> {
listenerAction.accept(ans.getArgument(7));
return null;
}).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());
mockModelAndServiceRegistry(service);
}
private Model mockModel() {
Model model = mock();
ModelConfigurations modelConfigurations = mock();
when(modelConfigurations.getService()).thenReturn(serviceId);
when(model.getConfigurations()).thenReturn(modelConfigurations);
when(model.getTaskType()).thenReturn(taskType);
when(model.getServiceSettings()).thenReturn(mock());
return model;
}
private void mockModelAndServiceRegistry(InferenceService service) {
var unparsedModel = new UnparsedModel(inferenceId, taskType, serviceId, Map.of(), Map.of());
doAnswer(ans -> {
ActionListener<UnparsedModel> listener = ans.getArgument(1);
listener.onResponse(unparsedModel);
return null;
}).when(modelRegistry).getModelWithSecrets(any(), any());
when(serviceRegistry.getService(any())).thenReturn(Optional.of(service));
}
}

View file

@ -1,69 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.telemetry;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.telemetry.metric.LongCounter;
import org.elasticsearch.telemetry.metric.MeterRegistry;
import org.elasticsearch.test.ESTestCase;
import java.util.Map;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class ApmInferenceStatsTests extends ESTestCase {
public void testRecordWithModel() {
var longCounter = mock(LongCounter.class);
var stats = new ApmInferenceStats(longCounter);
stats.incrementRequestCount(model("service", TaskType.ANY, "modelId"));
verify(longCounter).incrementBy(
eq(1L),
eq(Map.of("service", "service", "task_type", TaskType.ANY.toString(), "model_id", "modelId"))
);
}
public void testRecordWithoutModel() {
var longCounter = mock(LongCounter.class);
var stats = new ApmInferenceStats(longCounter);
stats.incrementRequestCount(model("service", TaskType.ANY, null));
verify(longCounter).incrementBy(eq(1L), eq(Map.of("service", "service", "task_type", TaskType.ANY.toString())));
}
public void testCreation() {
assertNotNull(ApmInferenceStats.create(MeterRegistry.NOOP));
}
private Model model(String service, TaskType taskType, String modelId) {
var configuration = mock(ModelConfigurations.class);
when(configuration.getService()).thenReturn(service);
var settings = mock(ServiceSettings.class);
if (modelId != null) {
when(settings.modelId()).thenReturn(modelId);
}
var model = mock(Model.class);
when(model.getTaskType()).thenReturn(taskType);
when(model.getConfigurations()).thenReturn(configuration);
when(model.getServiceSettings()).thenReturn(settings);
return model;
}
}

View file

@ -0,0 +1,217 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.telemetry;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.telemetry.metric.LongCounter;
import org.elasticsearch.telemetry.metric.LongHistogram;
import org.elasticsearch.telemetry.metric.MeterRegistry;
import org.elasticsearch.test.ESTestCase;
import java.util.Map;
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.ArgumentMatchers.assertArg;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class InferenceStatsTests extends ESTestCase {
public void testRecordWithModel() {
var longCounter = mock(LongCounter.class);
var stats = new InferenceStats(longCounter, mock());
stats.requestCount().incrementBy(1, modelAttributes(model("service", TaskType.ANY, "modelId")));
verify(longCounter).incrementBy(
eq(1L),
eq(Map.of("service", "service", "task_type", TaskType.ANY.toString(), "model_id", "modelId"))
);
}
public void testRecordWithoutModel() {
var longCounter = mock(LongCounter.class);
var stats = new InferenceStats(longCounter, mock());
stats.requestCount().incrementBy(1, modelAttributes(model("service", TaskType.ANY, null)));
verify(longCounter).incrementBy(eq(1L), eq(Map.of("service", "service", "task_type", TaskType.ANY.toString())));
}
public void testCreation() {
assertNotNull(InferenceStats.create(MeterRegistry.NOOP));
}
public void testRecordDurationWithoutError() {
var expectedLong = randomLong();
var histogramCounter = mock(LongHistogram.class);
var stats = new InferenceStats(mock(), histogramCounter);
stats.inferenceDuration().record(expectedLong, responseAttributes(model("service", TaskType.ANY, "modelId"), null));
verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
assertThat(attributes.get("service"), is("service"));
assertThat(attributes.get("task_type"), is(TaskType.ANY.toString()));
assertThat(attributes.get("model_id"), is("modelId"));
assertThat(attributes.get("status_code"), is(200));
assertThat(attributes.get("error.type"), nullValue());
}));
}
/**
* "If response status code was sent or received and status indicates an error according to HTTP span status definition,
* error.type SHOULD be set to the status code number (represented as a string)"
* - https://opentelemetry.io/docs/specs/semconv/http/http-metrics/
*/
public void testRecordDurationWithElasticsearchStatusException() {
var expectedLong = randomLong();
var histogramCounter = mock(LongHistogram.class);
var stats = new InferenceStats(mock(), histogramCounter);
var statusCode = RestStatus.BAD_REQUEST;
var exception = new ElasticsearchStatusException("hello", statusCode);
var expectedError = String.valueOf(statusCode.getStatus());
stats.inferenceDuration().record(expectedLong, responseAttributes(model("service", TaskType.ANY, "modelId"), exception));
verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
assertThat(attributes.get("service"), is("service"));
assertThat(attributes.get("task_type"), is(TaskType.ANY.toString()));
assertThat(attributes.get("model_id"), is("modelId"));
assertThat(attributes.get("status_code"), is(statusCode.getStatus()));
assertThat(attributes.get("error.type"), is(expectedError));
}));
}
/**
* "If the request fails with an error before response status code was sent or received,
* error.type SHOULD be set to exception type"
* - https://opentelemetry.io/docs/specs/semconv/http/http-metrics/
*/
public void testRecordDurationWithOtherException() {
var expectedLong = randomLong();
var histogramCounter = mock(LongHistogram.class);
var stats = new InferenceStats(mock(), histogramCounter);
var exception = new IllegalStateException("ahh");
var expectedError = exception.getClass().getSimpleName();
stats.inferenceDuration().record(expectedLong, responseAttributes(model("service", TaskType.ANY, "modelId"), exception));
verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
assertThat(attributes.get("service"), is("service"));
assertThat(attributes.get("task_type"), is(TaskType.ANY.toString()));
assertThat(attributes.get("model_id"), is("modelId"));
assertThat(attributes.get("status_code"), nullValue());
assertThat(attributes.get("error.type"), is(expectedError));
}));
}
public void testRecordDurationWithUnparsedModelAndElasticsearchStatusException() {
var expectedLong = randomLong();
var histogramCounter = mock(LongHistogram.class);
var stats = new InferenceStats(mock(), histogramCounter);
var statusCode = RestStatus.BAD_REQUEST;
var exception = new ElasticsearchStatusException("hello", statusCode);
var expectedError = String.valueOf(statusCode.getStatus());
var unparsedModel = new UnparsedModel("inferenceEntityId", TaskType.ANY, "service", Map.of(), Map.of());
stats.inferenceDuration().record(expectedLong, responseAttributes(unparsedModel, exception));
verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
assertThat(attributes.get("service"), is("service"));
assertThat(attributes.get("task_type"), is(TaskType.ANY.toString()));
assertThat(attributes.get("model_id"), nullValue());
assertThat(attributes.get("status_code"), is(statusCode.getStatus()));
assertThat(attributes.get("error.type"), is(expectedError));
}));
}
public void testRecordDurationWithUnparsedModelAndOtherException() {
var expectedLong = randomLong();
var histogramCounter = mock(LongHistogram.class);
var stats = new InferenceStats(mock(), histogramCounter);
var exception = new IllegalStateException("ahh");
var expectedError = exception.getClass().getSimpleName();
var unparsedModel = new UnparsedModel("inferenceEntityId", TaskType.ANY, "service", Map.of(), Map.of());
stats.inferenceDuration().record(expectedLong, responseAttributes(unparsedModel, exception));
verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
assertThat(attributes.get("service"), is("service"));
assertThat(attributes.get("task_type"), is(TaskType.ANY.toString()));
assertThat(attributes.get("model_id"), nullValue());
assertThat(attributes.get("status_code"), nullValue());
assertThat(attributes.get("error.type"), is(expectedError));
}));
}
public void testRecordDurationWithUnknownModelAndElasticsearchStatusException() {
var expectedLong = randomLong();
var histogramCounter = mock(LongHistogram.class);
var stats = new InferenceStats(mock(), histogramCounter);
var statusCode = RestStatus.BAD_REQUEST;
var exception = new ElasticsearchStatusException("hello", statusCode);
var expectedError = String.valueOf(statusCode.getStatus());
stats.inferenceDuration().record(expectedLong, responseAttributes(exception));
verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
assertThat(attributes.get("service"), nullValue());
assertThat(attributes.get("task_type"), nullValue());
assertThat(attributes.get("model_id"), nullValue());
assertThat(attributes.get("status_code"), is(statusCode.getStatus()));
assertThat(attributes.get("error.type"), is(expectedError));
}));
}
public void testRecordDurationWithUnknownModelAndOtherException() {
var expectedLong = randomLong();
var histogramCounter = mock(LongHistogram.class);
var stats = new InferenceStats(mock(), histogramCounter);
var exception = new IllegalStateException("ahh");
var expectedError = exception.getClass().getSimpleName();
stats.inferenceDuration().record(expectedLong, responseAttributes(exception));
verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
assertThat(attributes.get("service"), nullValue());
assertThat(attributes.get("task_type"), nullValue());
assertThat(attributes.get("model_id"), nullValue());
assertThat(attributes.get("status_code"), nullValue());
assertThat(attributes.get("error.type"), is(expectedError));
}));
}
private Model model(String service, TaskType taskType, String modelId) {
var configuration = mock(ModelConfigurations.class);
when(configuration.getService()).thenReturn(service);
var settings = mock(ServiceSettings.class);
if (modelId != null) {
when(settings.modelId()).thenReturn(modelId);
}
var model = mock(Model.class);
when(model.getTaskType()).thenReturn(taskType);
when(model.getConfigurations()).thenReturn(configuration);
when(model.getServiceSettings()).thenReturn(settings);
return model;
}
}

View file

@ -0,0 +1,32 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.telemetry;
import org.elasticsearch.test.ESTestCase;
import java.time.Clock;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class InferenceTimerTests extends ESTestCase {
public void testElapsedMillis() {
var expectedDuration = randomLongBetween(10, 300);
var startTime = Instant.now();
var clock = mock(Clock.class);
when(clock.instant()).thenReturn(startTime).thenReturn(startTime.plus(expectedDuration, ChronoUnit.MILLIS));
var timer = InferenceTimer.start(clock);
assertThat(expectedDuration, is(timer.elapsedMillis()));
}
}