mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 17:34:17 -04:00
[ML] Inference duration and error metrics (#115876)
Add `es.inference.requests.time` metric around `infer` API. As recommended by OTel spec, errors are determined by the presence or absence of the `error.type` attribute in the metric. "error.type" will be the http status code (as a string) if it is available, otherwise it will be the name of the exception (e.g. NullPointerException). Additional notes: - ApmInferenceStats is merged into InferenceStats. Originally we planned to have multiple implementations, but now we're only using APM. - Request count is now always recorded, even when there are failures loading the endpoint configuration. - Added a hook in streaming for cancel messages, so we can close the metrics when a user cancels the stream.
This commit is contained in:
parent
38c7ddd409
commit
26870ef38d
11 changed files with 826 additions and 151 deletions
5
docs/changelog/115876.yaml
Normal file
5
docs/changelog/115876.yaml
Normal file
|
@ -0,0 +1,5 @@
|
|||
pr: 115876
|
||||
summary: Inference duration and error metrics
|
||||
area: Machine Learning
|
||||
type: enhancement
|
||||
issues: []
|
|
@ -101,7 +101,6 @@ import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceE
|
|||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
|
||||
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
|
||||
import org.elasticsearch.xpack.inference.telemetry.ApmInferenceStats;
|
||||
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -239,7 +238,7 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
|
|||
shardBulkInferenceActionFilter.set(actionFilter);
|
||||
|
||||
var meterRegistry = services.telemetryProvider().getMeterRegistry();
|
||||
var stats = new PluginComponentBinding<>(InferenceStats.class, ApmInferenceStats.create(meterRegistry));
|
||||
var stats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry));
|
||||
|
||||
return List.of(modelRegistry, registry, httpClientManager, stats);
|
||||
}
|
||||
|
|
|
@ -7,12 +7,15 @@
|
|||
|
||||
package org.elasticsearch.xpack.inference.action;
|
||||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.support.ActionFilters;
|
||||
import org.elasticsearch.action.support.HandledTransportAction;
|
||||
import org.elasticsearch.common.util.concurrent.EsExecutors;
|
||||
import org.elasticsearch.common.xcontent.ChunkedToXContent;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.InferenceService;
|
||||
import org.elasticsearch.inference.InferenceServiceRegistry;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
|
@ -25,20 +28,22 @@ import org.elasticsearch.tasks.Task;
|
|||
import org.elasticsearch.transport.TransportService;
|
||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
|
||||
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
|
||||
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
|
||||
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
|
||||
import org.elasticsearch.xpack.inference.telemetry.InferenceTimer;
|
||||
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.elasticsearch.core.Strings.format;
|
||||
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
|
||||
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes;
|
||||
|
||||
public class TransportInferenceAction extends HandledTransportAction<InferenceAction.Request, InferenceAction.Response> {
|
||||
private static final Logger log = LogManager.getLogger(TransportInferenceAction.class);
|
||||
private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference";
|
||||
private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]";
|
||||
|
||||
private static final Set<Class<? extends InferenceService>> supportsStreaming = Set.of();
|
||||
|
||||
private final ModelRegistry modelRegistry;
|
||||
private final InferenceServiceRegistry serviceRegistry;
|
||||
private final InferenceStats inferenceStats;
|
||||
|
@ -62,17 +67,22 @@ public class TransportInferenceAction extends HandledTransportAction<InferenceAc
|
|||
|
||||
@Override
|
||||
protected void doExecute(Task task, InferenceAction.Request request, ActionListener<InferenceAction.Response> listener) {
|
||||
var timer = InferenceTimer.start();
|
||||
|
||||
ActionListener<UnparsedModel> getModelListener = listener.delegateFailureAndWrap((delegate, unparsedModel) -> {
|
||||
var getModelListener = ActionListener.wrap((UnparsedModel unparsedModel) -> {
|
||||
var service = serviceRegistry.getService(unparsedModel.service());
|
||||
if (service.isEmpty()) {
|
||||
listener.onFailure(unknownServiceException(unparsedModel.service(), request.getInferenceEntityId()));
|
||||
var e = unknownServiceException(unparsedModel.service(), request.getInferenceEntityId());
|
||||
recordMetrics(unparsedModel, timer, e);
|
||||
listener.onFailure(e);
|
||||
return;
|
||||
}
|
||||
|
||||
if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) {
|
||||
// not the wildcard task type and not the model task type
|
||||
listener.onFailure(incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType()));
|
||||
var e = incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType());
|
||||
recordMetrics(unparsedModel, timer, e);
|
||||
listener.onFailure(e);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -83,20 +93,69 @@ public class TransportInferenceAction extends HandledTransportAction<InferenceAc
|
|||
unparsedModel.settings(),
|
||||
unparsedModel.secrets()
|
||||
);
|
||||
inferOnService(model, request, service.get(), delegate);
|
||||
inferOnServiceWithMetrics(model, request, service.get(), timer, listener);
|
||||
}, e -> {
|
||||
try {
|
||||
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(e));
|
||||
} catch (Exception metricsException) {
|
||||
log.atDebug().withThrowable(metricsException).log("Failed to record metrics when the model is missing, dropping metrics");
|
||||
}
|
||||
listener.onFailure(e);
|
||||
});
|
||||
|
||||
modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener);
|
||||
}
|
||||
|
||||
private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) {
|
||||
try {
|
||||
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t));
|
||||
} catch (Exception e) {
|
||||
log.atDebug().withThrowable(e).log("Failed to record metrics with an unparsed model, dropping metrics");
|
||||
}
|
||||
}
|
||||
|
||||
private void inferOnServiceWithMetrics(
|
||||
Model model,
|
||||
InferenceAction.Request request,
|
||||
InferenceService service,
|
||||
InferenceTimer timer,
|
||||
ActionListener<InferenceAction.Response> listener
|
||||
) {
|
||||
inferenceStats.requestCount().incrementBy(1, modelAttributes(model));
|
||||
inferOnService(model, request, service, ActionListener.wrap(inferenceResults -> {
|
||||
if (request.isStreaming()) {
|
||||
var taskProcessor = streamingTaskManager.<ChunkedToXContent>create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION);
|
||||
inferenceResults.publisher().subscribe(taskProcessor);
|
||||
|
||||
var instrumentedStream = new PublisherWithMetrics(timer, model);
|
||||
taskProcessor.subscribe(instrumentedStream);
|
||||
|
||||
listener.onResponse(new InferenceAction.Response(inferenceResults, instrumentedStream));
|
||||
} else {
|
||||
recordMetrics(model, timer, null);
|
||||
listener.onResponse(new InferenceAction.Response(inferenceResults));
|
||||
}
|
||||
}, e -> {
|
||||
recordMetrics(model, timer, e);
|
||||
listener.onFailure(e);
|
||||
}));
|
||||
}
|
||||
|
||||
private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) {
|
||||
try {
|
||||
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t));
|
||||
} catch (Exception e) {
|
||||
log.atDebug().withThrowable(e).log("Failed to record metrics with a parsed model, dropping metrics");
|
||||
}
|
||||
}
|
||||
|
||||
private void inferOnService(
|
||||
Model model,
|
||||
InferenceAction.Request request,
|
||||
InferenceService service,
|
||||
ActionListener<InferenceAction.Response> listener
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
if (request.isStreaming() == false || service.canStream(request.getTaskType())) {
|
||||
inferenceStats.incrementRequestCount(model);
|
||||
service.infer(
|
||||
model,
|
||||
request.getQuery(),
|
||||
|
@ -105,7 +164,7 @@ public class TransportInferenceAction extends HandledTransportAction<InferenceAc
|
|||
request.getTaskSettings(),
|
||||
request.getInputType(),
|
||||
request.getInferenceTimeout(),
|
||||
createListener(request, listener)
|
||||
listener
|
||||
);
|
||||
} else {
|
||||
listener.onFailure(unsupportedStreamingTaskException(request, service));
|
||||
|
@ -133,20 +192,6 @@ public class TransportInferenceAction extends HandledTransportAction<InferenceAc
|
|||
}
|
||||
}
|
||||
|
||||
private ActionListener<InferenceServiceResults> createListener(
|
||||
InferenceAction.Request request,
|
||||
ActionListener<InferenceAction.Response> listener
|
||||
) {
|
||||
if (request.isStreaming()) {
|
||||
return listener.delegateFailureAndWrap((l, inferenceResults) -> {
|
||||
var taskProcessor = streamingTaskManager.<ChunkedToXContent>create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION);
|
||||
inferenceResults.publisher().subscribe(taskProcessor);
|
||||
l.onResponse(new InferenceAction.Response(inferenceResults, taskProcessor));
|
||||
});
|
||||
}
|
||||
return listener.delegateFailureAndWrap((l, inferenceResults) -> l.onResponse(new InferenceAction.Response(inferenceResults)));
|
||||
}
|
||||
|
||||
private static ElasticsearchStatusException unknownServiceException(String service, String inferenceId) {
|
||||
return new ElasticsearchStatusException("Unknown service [{}] for model [{}]. ", RestStatus.BAD_REQUEST, service, inferenceId);
|
||||
}
|
||||
|
@ -160,4 +205,37 @@ public class TransportInferenceAction extends HandledTransportAction<InferenceAc
|
|||
);
|
||||
}
|
||||
|
||||
private class PublisherWithMetrics extends DelegatingProcessor<ChunkedToXContent, ChunkedToXContent> {
|
||||
private final InferenceTimer timer;
|
||||
private final Model model;
|
||||
|
||||
private PublisherWithMetrics(InferenceTimer timer, Model model) {
|
||||
this.timer = timer;
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void next(ChunkedToXContent item) {
|
||||
downstream().onNext(item);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable throwable) {
|
||||
recordMetrics(model, timer, throwable);
|
||||
super.onError(throwable);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void onCancel() {
|
||||
recordMetrics(model, timer, null);
|
||||
super.onCancel();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete() {
|
||||
recordMetrics(model, timer, null);
|
||||
super.onComplete();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -61,11 +61,14 @@ public abstract class DelegatingProcessor<T, R> implements Flow.Processor<T, R>
|
|||
public void cancel() {
|
||||
if (isClosed.compareAndSet(false, true) && upstream != null) {
|
||||
upstream.cancel();
|
||||
onCancel();
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
protected void onCancel() {}
|
||||
|
||||
@Override
|
||||
public void onSubscribe(Flow.Subscription subscription) {
|
||||
if (upstream != null) {
|
||||
|
|
|
@ -1,49 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.telemetry;
|
||||
|
||||
import org.elasticsearch.inference.Model;
|
||||
import org.elasticsearch.telemetry.metric.LongCounter;
|
||||
import org.elasticsearch.telemetry.metric.MeterRegistry;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Objects;
|
||||
|
||||
public class ApmInferenceStats implements InferenceStats {
|
||||
private final LongCounter inferenceAPMRequestCounter;
|
||||
|
||||
public ApmInferenceStats(LongCounter inferenceAPMRequestCounter) {
|
||||
this.inferenceAPMRequestCounter = Objects.requireNonNull(inferenceAPMRequestCounter);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void incrementRequestCount(Model model) {
|
||||
var service = model.getConfigurations().getService();
|
||||
var taskType = model.getTaskType();
|
||||
var modelId = model.getServiceSettings().modelId();
|
||||
|
||||
var attributes = new HashMap<String, Object>(5);
|
||||
attributes.put("service", service);
|
||||
attributes.put("task_type", taskType.toString());
|
||||
if (modelId != null) {
|
||||
attributes.put("model_id", modelId);
|
||||
}
|
||||
|
||||
inferenceAPMRequestCounter.incrementBy(1, attributes);
|
||||
}
|
||||
|
||||
public static ApmInferenceStats create(MeterRegistry meterRegistry) {
|
||||
return new ApmInferenceStats(
|
||||
meterRegistry.registerLongCounter(
|
||||
"es.inference.requests.count.total",
|
||||
"Inference API request counts for a particular service, task type, model ID",
|
||||
"operations"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
|
@ -7,15 +7,87 @@
|
|||
|
||||
package org.elasticsearch.xpack.inference.telemetry;
|
||||
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.Model;
|
||||
import org.elasticsearch.inference.UnparsedModel;
|
||||
import org.elasticsearch.telemetry.metric.LongCounter;
|
||||
import org.elasticsearch.telemetry.metric.LongHistogram;
|
||||
import org.elasticsearch.telemetry.metric.MeterRegistry;
|
||||
|
||||
public interface InferenceStats {
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
/**
|
||||
* Increment the counter for a particular value in a thread safe manner.
|
||||
* @param model the model to increment request count for
|
||||
*/
|
||||
void incrementRequestCount(Model model);
|
||||
import static java.util.Map.entry;
|
||||
import static java.util.stream.Stream.concat;
|
||||
|
||||
InferenceStats NOOP = model -> {};
|
||||
public record InferenceStats(LongCounter requestCount, LongHistogram inferenceDuration) {
|
||||
|
||||
public InferenceStats {
|
||||
Objects.requireNonNull(requestCount);
|
||||
Objects.requireNonNull(inferenceDuration);
|
||||
}
|
||||
|
||||
public static InferenceStats create(MeterRegistry meterRegistry) {
|
||||
return new InferenceStats(
|
||||
meterRegistry.registerLongCounter(
|
||||
"es.inference.requests.count.total",
|
||||
"Inference API request counts for a particular service, task type, model ID",
|
||||
"operations"
|
||||
),
|
||||
meterRegistry.registerLongHistogram(
|
||||
"es.inference.requests.time",
|
||||
"Inference API request counts for a particular service, task type, model ID",
|
||||
"ms"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
public static Map<String, Object> modelAttributes(Model model) {
|
||||
return toMap(modelAttributeEntries(model));
|
||||
}
|
||||
|
||||
private static Stream<Map.Entry<String, Object>> modelAttributeEntries(Model model) {
|
||||
var stream = Stream.<Map.Entry<String, Object>>builder()
|
||||
.add(entry("service", model.getConfigurations().getService()))
|
||||
.add(entry("task_type", model.getTaskType().toString()));
|
||||
if (model.getServiceSettings().modelId() != null) {
|
||||
stream.add(entry("model_id", model.getServiceSettings().modelId()));
|
||||
}
|
||||
return stream.build();
|
||||
}
|
||||
|
||||
private static Map<String, Object> toMap(Stream<Map.Entry<String, Object>> stream) {
|
||||
return stream.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
|
||||
}
|
||||
|
||||
public static Map<String, Object> responseAttributes(Model model, @Nullable Throwable t) {
|
||||
return toMap(concat(modelAttributeEntries(model), errorAttributes(t)));
|
||||
}
|
||||
|
||||
public static Map<String, Object> responseAttributes(UnparsedModel model, @Nullable Throwable t) {
|
||||
var unknownModelAttributes = Stream.<Map.Entry<String, Object>>builder()
|
||||
.add(entry("service", model.service()))
|
||||
.add(entry("task_type", model.taskType().toString()))
|
||||
.build();
|
||||
|
||||
return toMap(concat(unknownModelAttributes, errorAttributes(t)));
|
||||
}
|
||||
|
||||
public static Map<String, Object> responseAttributes(@Nullable Throwable t) {
|
||||
return toMap(errorAttributes(t));
|
||||
}
|
||||
|
||||
private static Stream<Map.Entry<String, Object>> errorAttributes(@Nullable Throwable t) {
|
||||
return switch (t) {
|
||||
case null -> Stream.of(entry("status_code", 200));
|
||||
case ElasticsearchStatusException ese -> Stream.<Map.Entry<String, Object>>builder()
|
||||
.add(entry("status_code", ese.status().getStatus()))
|
||||
.add(entry("error.type", String.valueOf(ese.status().getStatus())))
|
||||
.build();
|
||||
default -> Stream.of(entry("error.type", t.getClass().getSimpleName()));
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.telemetry;
|
||||
|
||||
import java.time.Clock;
|
||||
import java.time.Duration;
|
||||
import java.time.Instant;
|
||||
import java.util.Objects;
|
||||
|
||||
public record InferenceTimer(Instant startTime, Clock clock) {
|
||||
|
||||
public InferenceTimer {
|
||||
Objects.requireNonNull(startTime);
|
||||
Objects.requireNonNull(clock);
|
||||
}
|
||||
|
||||
public static InferenceTimer start() {
|
||||
return start(Clock.systemUTC());
|
||||
}
|
||||
|
||||
public static InferenceTimer start(Clock clock) {
|
||||
return new InferenceTimer(clock.instant(), clock);
|
||||
}
|
||||
|
||||
public long elapsedMillis() {
|
||||
return Duration.between(startTime(), clock().instant()).toMillis();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,354 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.action;
|
||||
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.support.ActionFilters;
|
||||
import org.elasticsearch.common.xcontent.ChunkedToXContent;
|
||||
import org.elasticsearch.inference.InferenceService;
|
||||
import org.elasticsearch.inference.InferenceServiceRegistry;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.inference.Model;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.inference.UnparsedModel;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.transport.TransportService;
|
||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
|
||||
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
|
||||
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
|
||||
import org.junit.Before;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.Flow;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.isA;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyBoolean;
|
||||
import static org.mockito.ArgumentMatchers.anyLong;
|
||||
import static org.mockito.ArgumentMatchers.assertArg;
|
||||
import static org.mockito.ArgumentMatchers.same;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class TransportInferenceActionTests extends ESTestCase {
|
||||
private static final String serviceId = "serviceId";
|
||||
private static final TaskType taskType = TaskType.COMPLETION;
|
||||
private static final String inferenceId = "inferenceEntityId";
|
||||
private ModelRegistry modelRegistry;
|
||||
private InferenceServiceRegistry serviceRegistry;
|
||||
private InferenceStats inferenceStats;
|
||||
private StreamingTaskManager streamingTaskManager;
|
||||
private TransportInferenceAction action;
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
super.setUp();
|
||||
TransportService transportService = mock();
|
||||
ActionFilters actionFilters = mock();
|
||||
modelRegistry = mock();
|
||||
serviceRegistry = mock();
|
||||
inferenceStats = new InferenceStats(mock(), mock());
|
||||
streamingTaskManager = mock();
|
||||
action = new TransportInferenceAction(
|
||||
transportService,
|
||||
actionFilters,
|
||||
modelRegistry,
|
||||
serviceRegistry,
|
||||
inferenceStats,
|
||||
streamingTaskManager
|
||||
);
|
||||
}
|
||||
|
||||
public void testMetricsAfterModelRegistryError() {
|
||||
var expectedException = new IllegalStateException("hello");
|
||||
var expectedError = expectedException.getClass().getSimpleName();
|
||||
|
||||
doAnswer(ans -> {
|
||||
ActionListener<?> listener = ans.getArgument(1);
|
||||
listener.onFailure(expectedException);
|
||||
return null;
|
||||
}).when(modelRegistry).getModelWithSecrets(any(), any());
|
||||
|
||||
var listener = doExecute(taskType);
|
||||
verify(listener).onFailure(same(expectedException));
|
||||
|
||||
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
|
||||
assertThat(attributes.get("service"), nullValue());
|
||||
assertThat(attributes.get("task_type"), nullValue());
|
||||
assertThat(attributes.get("model_id"), nullValue());
|
||||
assertThat(attributes.get("status_code"), nullValue());
|
||||
assertThat(attributes.get("error.type"), is(expectedError));
|
||||
}));
|
||||
}
|
||||
|
||||
private ActionListener<InferenceAction.Response> doExecute(TaskType taskType) {
|
||||
return doExecute(taskType, false);
|
||||
}
|
||||
|
||||
private ActionListener<InferenceAction.Response> doExecute(TaskType taskType, boolean stream) {
|
||||
InferenceAction.Request request = mock();
|
||||
when(request.getInferenceEntityId()).thenReturn(inferenceId);
|
||||
when(request.getTaskType()).thenReturn(taskType);
|
||||
when(request.isStreaming()).thenReturn(stream);
|
||||
ActionListener<InferenceAction.Response> listener = mock();
|
||||
action.doExecute(mock(), request, listener);
|
||||
return listener;
|
||||
}
|
||||
|
||||
public void testMetricsAfterMissingService() {
|
||||
mockModelRegistry(taskType);
|
||||
|
||||
when(serviceRegistry.getService(any())).thenReturn(Optional.empty());
|
||||
|
||||
var listener = doExecute(taskType);
|
||||
|
||||
verify(listener).onFailure(assertArg(e -> {
|
||||
assertThat(e, isA(ElasticsearchStatusException.class));
|
||||
assertThat(e.getMessage(), is("Unknown service [" + serviceId + "] for model [" + inferenceId + "]. "));
|
||||
assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST));
|
||||
}));
|
||||
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
|
||||
assertThat(attributes.get("service"), is(serviceId));
|
||||
assertThat(attributes.get("task_type"), is(taskType.toString()));
|
||||
assertThat(attributes.get("model_id"), nullValue());
|
||||
assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus()));
|
||||
assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus())));
|
||||
}));
|
||||
}
|
||||
|
||||
private void mockModelRegistry(TaskType expectedTaskType) {
|
||||
var unparsedModel = new UnparsedModel(inferenceId, expectedTaskType, serviceId, Map.of(), Map.of());
|
||||
doAnswer(ans -> {
|
||||
ActionListener<UnparsedModel> listener = ans.getArgument(1);
|
||||
listener.onResponse(unparsedModel);
|
||||
return null;
|
||||
}).when(modelRegistry).getModelWithSecrets(any(), any());
|
||||
}
|
||||
|
||||
public void testMetricsAfterUnknownTaskType() {
|
||||
var modelTaskType = TaskType.RERANK;
|
||||
var requestTaskType = TaskType.SPARSE_EMBEDDING;
|
||||
mockModelRegistry(modelTaskType);
|
||||
when(serviceRegistry.getService(any())).thenReturn(Optional.of(mock()));
|
||||
|
||||
var listener = doExecute(requestTaskType);
|
||||
|
||||
verify(listener).onFailure(assertArg(e -> {
|
||||
assertThat(e, isA(ElasticsearchStatusException.class));
|
||||
assertThat(
|
||||
e.getMessage(),
|
||||
is(
|
||||
"Incompatible task_type, the requested type ["
|
||||
+ requestTaskType
|
||||
+ "] does not match the model type ["
|
||||
+ modelTaskType
|
||||
+ "]"
|
||||
)
|
||||
);
|
||||
assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST));
|
||||
}));
|
||||
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
|
||||
assertThat(attributes.get("service"), is(serviceId));
|
||||
assertThat(attributes.get("task_type"), is(modelTaskType.toString()));
|
||||
assertThat(attributes.get("model_id"), nullValue());
|
||||
assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus()));
|
||||
assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus())));
|
||||
}));
|
||||
}
|
||||
|
||||
public void testMetricsAfterInferError() {
|
||||
var expectedException = new IllegalStateException("hello");
|
||||
var expectedError = expectedException.getClass().getSimpleName();
|
||||
mockService(listener -> listener.onFailure(expectedException));
|
||||
|
||||
var listener = doExecute(taskType);
|
||||
|
||||
verify(listener).onFailure(same(expectedException));
|
||||
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
|
||||
assertThat(attributes.get("service"), is(serviceId));
|
||||
assertThat(attributes.get("task_type"), is(taskType.toString()));
|
||||
assertThat(attributes.get("model_id"), nullValue());
|
||||
assertThat(attributes.get("status_code"), nullValue());
|
||||
assertThat(attributes.get("error.type"), is(expectedError));
|
||||
}));
|
||||
}
|
||||
|
||||
public void testMetricsAfterStreamUnsupported() {
|
||||
var expectedStatus = RestStatus.METHOD_NOT_ALLOWED;
|
||||
var expectedError = String.valueOf(expectedStatus.getStatus());
|
||||
mockService(l -> {});
|
||||
|
||||
var listener = doExecute(taskType, true);
|
||||
|
||||
verify(listener).onFailure(assertArg(e -> {
|
||||
assertThat(e, isA(ElasticsearchStatusException.class));
|
||||
var ese = (ElasticsearchStatusException) e;
|
||||
assertThat(ese.getMessage(), is("Streaming is not allowed for service [" + serviceId + "]."));
|
||||
assertThat(ese.status(), is(expectedStatus));
|
||||
}));
|
||||
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
|
||||
assertThat(attributes.get("service"), is(serviceId));
|
||||
assertThat(attributes.get("task_type"), is(taskType.toString()));
|
||||
assertThat(attributes.get("model_id"), nullValue());
|
||||
assertThat(attributes.get("status_code"), is(expectedStatus.getStatus()));
|
||||
assertThat(attributes.get("error.type"), is(expectedError));
|
||||
}));
|
||||
}
|
||||
|
||||
public void testMetricsAfterInferSuccess() {
|
||||
mockService(listener -> listener.onResponse(mock()));
|
||||
|
||||
var listener = doExecute(taskType);
|
||||
|
||||
verify(listener).onResponse(any());
|
||||
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
|
||||
assertThat(attributes.get("service"), is(serviceId));
|
||||
assertThat(attributes.get("task_type"), is(taskType.toString()));
|
||||
assertThat(attributes.get("model_id"), nullValue());
|
||||
assertThat(attributes.get("status_code"), is(200));
|
||||
assertThat(attributes.get("error.type"), nullValue());
|
||||
}));
|
||||
}
|
||||
|
||||
public void testMetricsAfterStreamInferSuccess() {
|
||||
mockStreamResponse(Flow.Subscriber::onComplete);
|
||||
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
|
||||
assertThat(attributes.get("service"), is(serviceId));
|
||||
assertThat(attributes.get("task_type"), is(taskType.toString()));
|
||||
assertThat(attributes.get("model_id"), nullValue());
|
||||
assertThat(attributes.get("status_code"), is(200));
|
||||
assertThat(attributes.get("error.type"), nullValue());
|
||||
}));
|
||||
}
|
||||
|
||||
public void testMetricsAfterStreamInferFailure() {
|
||||
var expectedException = new IllegalStateException("hello");
|
||||
var expectedError = expectedException.getClass().getSimpleName();
|
||||
mockStreamResponse(subscriber -> {
|
||||
subscriber.subscribe(mock());
|
||||
subscriber.onError(expectedException);
|
||||
});
|
||||
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
|
||||
assertThat(attributes.get("service"), is(serviceId));
|
||||
assertThat(attributes.get("task_type"), is(taskType.toString()));
|
||||
assertThat(attributes.get("model_id"), nullValue());
|
||||
assertThat(attributes.get("status_code"), nullValue());
|
||||
assertThat(attributes.get("error.type"), is(expectedError));
|
||||
}));
|
||||
}
|
||||
|
||||
public void testMetricsAfterStreamCancel() {
|
||||
var response = mockStreamResponse(s -> s.onSubscribe(mock()));
|
||||
response.subscribe(new Flow.Subscriber<>() {
|
||||
@Override
|
||||
public void onSubscribe(Flow.Subscription subscription) {
|
||||
subscription.cancel();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onNext(ChunkedToXContent item) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable throwable) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete() {
|
||||
|
||||
}
|
||||
});
|
||||
|
||||
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
|
||||
assertThat(attributes.get("service"), is(serviceId));
|
||||
assertThat(attributes.get("task_type"), is(taskType.toString()));
|
||||
assertThat(attributes.get("model_id"), nullValue());
|
||||
assertThat(attributes.get("status_code"), is(200));
|
||||
assertThat(attributes.get("error.type"), nullValue());
|
||||
}));
|
||||
}
|
||||
|
||||
private Flow.Publisher<ChunkedToXContent> mockStreamResponse(Consumer<Flow.Processor<?, ?>> action) {
|
||||
mockService(true, Set.of(), listener -> {
|
||||
Flow.Processor<ChunkedToXContent, ChunkedToXContent> taskProcessor = mock();
|
||||
doAnswer(innerAns -> {
|
||||
action.accept(innerAns.getArgument(0));
|
||||
return null;
|
||||
}).when(taskProcessor).subscribe(any());
|
||||
when(streamingTaskManager.<ChunkedToXContent>create(any(), any())).thenReturn(taskProcessor);
|
||||
var inferenceServiceResults = mock(InferenceServiceResults.class);
|
||||
when(inferenceServiceResults.publisher()).thenReturn(mock());
|
||||
listener.onResponse(inferenceServiceResults);
|
||||
});
|
||||
|
||||
var listener = doExecute(taskType, true);
|
||||
var captor = ArgumentCaptor.forClass(InferenceAction.Response.class);
|
||||
verify(listener).onResponse(captor.capture());
|
||||
assertTrue(captor.getValue().isStreaming());
|
||||
assertNotNull(captor.getValue().publisher());
|
||||
return captor.getValue().publisher();
|
||||
}
|
||||
|
||||
private void mockService(Consumer<ActionListener<InferenceServiceResults>> listenerAction) {
|
||||
mockService(false, Set.of(), listenerAction);
|
||||
}
|
||||
|
||||
private void mockService(
|
||||
boolean stream,
|
||||
Set<TaskType> supportedStreamingTasks,
|
||||
Consumer<ActionListener<InferenceServiceResults>> listenerAction
|
||||
) {
|
||||
InferenceService service = mock();
|
||||
Model model = mockModel();
|
||||
when(service.parsePersistedConfigWithSecrets(any(), any(), any(), any())).thenReturn(model);
|
||||
when(service.name()).thenReturn(serviceId);
|
||||
|
||||
when(service.canStream(any())).thenReturn(stream);
|
||||
when(service.supportedStreamingTasks()).thenReturn(supportedStreamingTasks);
|
||||
doAnswer(ans -> {
|
||||
listenerAction.accept(ans.getArgument(7));
|
||||
return null;
|
||||
}).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());
|
||||
mockModelAndServiceRegistry(service);
|
||||
}
|
||||
|
||||
private Model mockModel() {
|
||||
Model model = mock();
|
||||
ModelConfigurations modelConfigurations = mock();
|
||||
when(modelConfigurations.getService()).thenReturn(serviceId);
|
||||
when(model.getConfigurations()).thenReturn(modelConfigurations);
|
||||
when(model.getTaskType()).thenReturn(taskType);
|
||||
when(model.getServiceSettings()).thenReturn(mock());
|
||||
return model;
|
||||
}
|
||||
|
||||
private void mockModelAndServiceRegistry(InferenceService service) {
|
||||
var unparsedModel = new UnparsedModel(inferenceId, taskType, serviceId, Map.of(), Map.of());
|
||||
doAnswer(ans -> {
|
||||
ActionListener<UnparsedModel> listener = ans.getArgument(1);
|
||||
listener.onResponse(unparsedModel);
|
||||
return null;
|
||||
}).when(modelRegistry).getModelWithSecrets(any(), any());
|
||||
|
||||
when(serviceRegistry.getService(any())).thenReturn(Optional.of(service));
|
||||
}
|
||||
}
|
|
@ -1,69 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.telemetry;
|
||||
|
||||
import org.elasticsearch.inference.Model;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ServiceSettings;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.telemetry.metric.LongCounter;
|
||||
import org.elasticsearch.telemetry.metric.MeterRegistry;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class ApmInferenceStatsTests extends ESTestCase {
|
||||
|
||||
public void testRecordWithModel() {
|
||||
var longCounter = mock(LongCounter.class);
|
||||
|
||||
var stats = new ApmInferenceStats(longCounter);
|
||||
|
||||
stats.incrementRequestCount(model("service", TaskType.ANY, "modelId"));
|
||||
|
||||
verify(longCounter).incrementBy(
|
||||
eq(1L),
|
||||
eq(Map.of("service", "service", "task_type", TaskType.ANY.toString(), "model_id", "modelId"))
|
||||
);
|
||||
}
|
||||
|
||||
public void testRecordWithoutModel() {
|
||||
var longCounter = mock(LongCounter.class);
|
||||
|
||||
var stats = new ApmInferenceStats(longCounter);
|
||||
|
||||
stats.incrementRequestCount(model("service", TaskType.ANY, null));
|
||||
|
||||
verify(longCounter).incrementBy(eq(1L), eq(Map.of("service", "service", "task_type", TaskType.ANY.toString())));
|
||||
}
|
||||
|
||||
public void testCreation() {
|
||||
assertNotNull(ApmInferenceStats.create(MeterRegistry.NOOP));
|
||||
}
|
||||
|
||||
private Model model(String service, TaskType taskType, String modelId) {
|
||||
var configuration = mock(ModelConfigurations.class);
|
||||
when(configuration.getService()).thenReturn(service);
|
||||
var settings = mock(ServiceSettings.class);
|
||||
if (modelId != null) {
|
||||
when(settings.modelId()).thenReturn(modelId);
|
||||
}
|
||||
|
||||
var model = mock(Model.class);
|
||||
when(model.getTaskType()).thenReturn(taskType);
|
||||
when(model.getConfigurations()).thenReturn(configuration);
|
||||
when(model.getServiceSettings()).thenReturn(settings);
|
||||
|
||||
return model;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,217 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.telemetry;
|
||||
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.inference.Model;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ServiceSettings;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.inference.UnparsedModel;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.telemetry.metric.LongCounter;
|
||||
import org.elasticsearch.telemetry.metric.LongHistogram;
|
||||
import org.elasticsearch.telemetry.metric.MeterRegistry;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
|
||||
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.nullValue;
|
||||
import static org.mockito.ArgumentMatchers.assertArg;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class InferenceStatsTests extends ESTestCase {
|
||||
|
||||
public void testRecordWithModel() {
|
||||
var longCounter = mock(LongCounter.class);
|
||||
var stats = new InferenceStats(longCounter, mock());
|
||||
|
||||
stats.requestCount().incrementBy(1, modelAttributes(model("service", TaskType.ANY, "modelId")));
|
||||
|
||||
verify(longCounter).incrementBy(
|
||||
eq(1L),
|
||||
eq(Map.of("service", "service", "task_type", TaskType.ANY.toString(), "model_id", "modelId"))
|
||||
);
|
||||
}
|
||||
|
||||
public void testRecordWithoutModel() {
|
||||
var longCounter = mock(LongCounter.class);
|
||||
var stats = new InferenceStats(longCounter, mock());
|
||||
|
||||
stats.requestCount().incrementBy(1, modelAttributes(model("service", TaskType.ANY, null)));
|
||||
|
||||
verify(longCounter).incrementBy(eq(1L), eq(Map.of("service", "service", "task_type", TaskType.ANY.toString())));
|
||||
}
|
||||
|
||||
public void testCreation() {
|
||||
assertNotNull(InferenceStats.create(MeterRegistry.NOOP));
|
||||
}
|
||||
|
||||
public void testRecordDurationWithoutError() {
|
||||
var expectedLong = randomLong();
|
||||
var histogramCounter = mock(LongHistogram.class);
|
||||
var stats = new InferenceStats(mock(), histogramCounter);
|
||||
|
||||
stats.inferenceDuration().record(expectedLong, responseAttributes(model("service", TaskType.ANY, "modelId"), null));
|
||||
|
||||
verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
|
||||
assertThat(attributes.get("service"), is("service"));
|
||||
assertThat(attributes.get("task_type"), is(TaskType.ANY.toString()));
|
||||
assertThat(attributes.get("model_id"), is("modelId"));
|
||||
assertThat(attributes.get("status_code"), is(200));
|
||||
assertThat(attributes.get("error.type"), nullValue());
|
||||
}));
|
||||
}
|
||||
|
||||
/**
|
||||
* "If response status code was sent or received and status indicates an error according to HTTP span status definition,
|
||||
* error.type SHOULD be set to the status code number (represented as a string)"
|
||||
* - https://opentelemetry.io/docs/specs/semconv/http/http-metrics/
|
||||
*/
|
||||
public void testRecordDurationWithElasticsearchStatusException() {
|
||||
var expectedLong = randomLong();
|
||||
var histogramCounter = mock(LongHistogram.class);
|
||||
var stats = new InferenceStats(mock(), histogramCounter);
|
||||
var statusCode = RestStatus.BAD_REQUEST;
|
||||
var exception = new ElasticsearchStatusException("hello", statusCode);
|
||||
var expectedError = String.valueOf(statusCode.getStatus());
|
||||
|
||||
stats.inferenceDuration().record(expectedLong, responseAttributes(model("service", TaskType.ANY, "modelId"), exception));
|
||||
|
||||
verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
|
||||
assertThat(attributes.get("service"), is("service"));
|
||||
assertThat(attributes.get("task_type"), is(TaskType.ANY.toString()));
|
||||
assertThat(attributes.get("model_id"), is("modelId"));
|
||||
assertThat(attributes.get("status_code"), is(statusCode.getStatus()));
|
||||
assertThat(attributes.get("error.type"), is(expectedError));
|
||||
}));
|
||||
}
|
||||
|
||||
/**
|
||||
* "If the request fails with an error before response status code was sent or received,
|
||||
* error.type SHOULD be set to exception type"
|
||||
* - https://opentelemetry.io/docs/specs/semconv/http/http-metrics/
|
||||
*/
|
||||
public void testRecordDurationWithOtherException() {
|
||||
var expectedLong = randomLong();
|
||||
var histogramCounter = mock(LongHistogram.class);
|
||||
var stats = new InferenceStats(mock(), histogramCounter);
|
||||
var exception = new IllegalStateException("ahh");
|
||||
var expectedError = exception.getClass().getSimpleName();
|
||||
|
||||
stats.inferenceDuration().record(expectedLong, responseAttributes(model("service", TaskType.ANY, "modelId"), exception));
|
||||
|
||||
verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
|
||||
assertThat(attributes.get("service"), is("service"));
|
||||
assertThat(attributes.get("task_type"), is(TaskType.ANY.toString()));
|
||||
assertThat(attributes.get("model_id"), is("modelId"));
|
||||
assertThat(attributes.get("status_code"), nullValue());
|
||||
assertThat(attributes.get("error.type"), is(expectedError));
|
||||
}));
|
||||
}
|
||||
|
||||
public void testRecordDurationWithUnparsedModelAndElasticsearchStatusException() {
|
||||
var expectedLong = randomLong();
|
||||
var histogramCounter = mock(LongHistogram.class);
|
||||
var stats = new InferenceStats(mock(), histogramCounter);
|
||||
var statusCode = RestStatus.BAD_REQUEST;
|
||||
var exception = new ElasticsearchStatusException("hello", statusCode);
|
||||
var expectedError = String.valueOf(statusCode.getStatus());
|
||||
|
||||
var unparsedModel = new UnparsedModel("inferenceEntityId", TaskType.ANY, "service", Map.of(), Map.of());
|
||||
|
||||
stats.inferenceDuration().record(expectedLong, responseAttributes(unparsedModel, exception));
|
||||
|
||||
verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
|
||||
assertThat(attributes.get("service"), is("service"));
|
||||
assertThat(attributes.get("task_type"), is(TaskType.ANY.toString()));
|
||||
assertThat(attributes.get("model_id"), nullValue());
|
||||
assertThat(attributes.get("status_code"), is(statusCode.getStatus()));
|
||||
assertThat(attributes.get("error.type"), is(expectedError));
|
||||
}));
|
||||
}
|
||||
|
||||
public void testRecordDurationWithUnparsedModelAndOtherException() {
|
||||
var expectedLong = randomLong();
|
||||
var histogramCounter = mock(LongHistogram.class);
|
||||
var stats = new InferenceStats(mock(), histogramCounter);
|
||||
var exception = new IllegalStateException("ahh");
|
||||
var expectedError = exception.getClass().getSimpleName();
|
||||
|
||||
var unparsedModel = new UnparsedModel("inferenceEntityId", TaskType.ANY, "service", Map.of(), Map.of());
|
||||
|
||||
stats.inferenceDuration().record(expectedLong, responseAttributes(unparsedModel, exception));
|
||||
|
||||
verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
|
||||
assertThat(attributes.get("service"), is("service"));
|
||||
assertThat(attributes.get("task_type"), is(TaskType.ANY.toString()));
|
||||
assertThat(attributes.get("model_id"), nullValue());
|
||||
assertThat(attributes.get("status_code"), nullValue());
|
||||
assertThat(attributes.get("error.type"), is(expectedError));
|
||||
}));
|
||||
}
|
||||
|
||||
public void testRecordDurationWithUnknownModelAndElasticsearchStatusException() {
|
||||
var expectedLong = randomLong();
|
||||
var histogramCounter = mock(LongHistogram.class);
|
||||
var stats = new InferenceStats(mock(), histogramCounter);
|
||||
var statusCode = RestStatus.BAD_REQUEST;
|
||||
var exception = new ElasticsearchStatusException("hello", statusCode);
|
||||
var expectedError = String.valueOf(statusCode.getStatus());
|
||||
|
||||
stats.inferenceDuration().record(expectedLong, responseAttributes(exception));
|
||||
|
||||
verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
|
||||
assertThat(attributes.get("service"), nullValue());
|
||||
assertThat(attributes.get("task_type"), nullValue());
|
||||
assertThat(attributes.get("model_id"), nullValue());
|
||||
assertThat(attributes.get("status_code"), is(statusCode.getStatus()));
|
||||
assertThat(attributes.get("error.type"), is(expectedError));
|
||||
}));
|
||||
}
|
||||
|
||||
public void testRecordDurationWithUnknownModelAndOtherException() {
|
||||
var expectedLong = randomLong();
|
||||
var histogramCounter = mock(LongHistogram.class);
|
||||
var stats = new InferenceStats(mock(), histogramCounter);
|
||||
var exception = new IllegalStateException("ahh");
|
||||
var expectedError = exception.getClass().getSimpleName();
|
||||
|
||||
stats.inferenceDuration().record(expectedLong, responseAttributes(exception));
|
||||
|
||||
verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
|
||||
assertThat(attributes.get("service"), nullValue());
|
||||
assertThat(attributes.get("task_type"), nullValue());
|
||||
assertThat(attributes.get("model_id"), nullValue());
|
||||
assertThat(attributes.get("status_code"), nullValue());
|
||||
assertThat(attributes.get("error.type"), is(expectedError));
|
||||
}));
|
||||
}
|
||||
|
||||
private Model model(String service, TaskType taskType, String modelId) {
|
||||
var configuration = mock(ModelConfigurations.class);
|
||||
when(configuration.getService()).thenReturn(service);
|
||||
var settings = mock(ServiceSettings.class);
|
||||
if (modelId != null) {
|
||||
when(settings.modelId()).thenReturn(modelId);
|
||||
}
|
||||
|
||||
var model = mock(Model.class);
|
||||
when(model.getTaskType()).thenReturn(taskType);
|
||||
when(model.getConfigurations()).thenReturn(configuration);
|
||||
when(model.getServiceSettings()).thenReturn(settings);
|
||||
|
||||
return model;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.telemetry;
|
||||
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.time.Clock;
|
||||
import java.time.Instant;
|
||||
import java.time.temporal.ChronoUnit;
|
||||
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class InferenceTimerTests extends ESTestCase {
|
||||
|
||||
public void testElapsedMillis() {
|
||||
var expectedDuration = randomLongBetween(10, 300);
|
||||
|
||||
var startTime = Instant.now();
|
||||
var clock = mock(Clock.class);
|
||||
when(clock.instant()).thenReturn(startTime).thenReturn(startTime.plus(expectedDuration, ChronoUnit.MILLIS));
|
||||
var timer = InferenceTimer.start(clock);
|
||||
|
||||
assertThat(expectedDuration, is(timer.elapsedMillis()));
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue