mirror of
https://github.com/elastic/elasticsearch.git
synced 2025-06-28 17:34:17 -04:00
[ML] Inference duration and error metrics (#115876)
Add `es.inference.requests.time` metric around `infer` API. As recommended by OTel spec, errors are determined by the presence or absence of the `error.type` attribute in the metric. "error.type" will be the http status code (as a string) if it is available, otherwise it will be the name of the exception (e.g. NullPointerException). Additional notes: - ApmInferenceStats is merged into InferenceStats. Originally we planned to have multiple implementations, but now we're only using APM. - Request count is now always recorded, even when there are failures loading the endpoint configuration. - Added a hook in streaming for cancel messages, so we can close the metrics when a user cancels the stream.
This commit is contained in:
parent
38c7ddd409
commit
26870ef38d
11 changed files with 826 additions and 151 deletions
5
docs/changelog/115876.yaml
Normal file
5
docs/changelog/115876.yaml
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
pr: 115876
|
||||||
|
summary: Inference duration and error metrics
|
||||||
|
area: Machine Learning
|
||||||
|
type: enhancement
|
||||||
|
issues: []
|
|
@ -101,7 +101,6 @@ import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceE
|
||||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService;
|
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService;
|
||||||
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
|
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
|
||||||
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
|
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
|
||||||
import org.elasticsearch.xpack.inference.telemetry.ApmInferenceStats;
|
|
||||||
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
|
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -239,7 +238,7 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
|
||||||
shardBulkInferenceActionFilter.set(actionFilter);
|
shardBulkInferenceActionFilter.set(actionFilter);
|
||||||
|
|
||||||
var meterRegistry = services.telemetryProvider().getMeterRegistry();
|
var meterRegistry = services.telemetryProvider().getMeterRegistry();
|
||||||
var stats = new PluginComponentBinding<>(InferenceStats.class, ApmInferenceStats.create(meterRegistry));
|
var stats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry));
|
||||||
|
|
||||||
return List.of(modelRegistry, registry, httpClientManager, stats);
|
return List.of(modelRegistry, registry, httpClientManager, stats);
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,12 +7,15 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.action;
|
package org.elasticsearch.xpack.inference.action;
|
||||||
|
|
||||||
|
import org.apache.logging.log4j.LogManager;
|
||||||
|
import org.apache.logging.log4j.Logger;
|
||||||
import org.elasticsearch.ElasticsearchStatusException;
|
import org.elasticsearch.ElasticsearchStatusException;
|
||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
import org.elasticsearch.action.support.ActionFilters;
|
import org.elasticsearch.action.support.ActionFilters;
|
||||||
import org.elasticsearch.action.support.HandledTransportAction;
|
import org.elasticsearch.action.support.HandledTransportAction;
|
||||||
import org.elasticsearch.common.util.concurrent.EsExecutors;
|
import org.elasticsearch.common.util.concurrent.EsExecutors;
|
||||||
import org.elasticsearch.common.xcontent.ChunkedToXContent;
|
import org.elasticsearch.common.xcontent.ChunkedToXContent;
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
import org.elasticsearch.inference.InferenceService;
|
import org.elasticsearch.inference.InferenceService;
|
||||||
import org.elasticsearch.inference.InferenceServiceRegistry;
|
import org.elasticsearch.inference.InferenceServiceRegistry;
|
||||||
import org.elasticsearch.inference.InferenceServiceResults;
|
import org.elasticsearch.inference.InferenceServiceResults;
|
||||||
|
@ -25,20 +28,22 @@ import org.elasticsearch.tasks.Task;
|
||||||
import org.elasticsearch.transport.TransportService;
|
import org.elasticsearch.transport.TransportService;
|
||||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||||
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
|
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
|
||||||
|
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
|
||||||
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
|
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
|
||||||
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
|
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
|
||||||
|
import org.elasticsearch.xpack.inference.telemetry.InferenceTimer;
|
||||||
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import static org.elasticsearch.core.Strings.format;
|
import static org.elasticsearch.core.Strings.format;
|
||||||
|
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
|
||||||
|
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes;
|
||||||
|
|
||||||
public class TransportInferenceAction extends HandledTransportAction<InferenceAction.Request, InferenceAction.Response> {
|
public class TransportInferenceAction extends HandledTransportAction<InferenceAction.Request, InferenceAction.Response> {
|
||||||
|
private static final Logger log = LogManager.getLogger(TransportInferenceAction.class);
|
||||||
private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference";
|
private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference";
|
||||||
private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]";
|
private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]";
|
||||||
|
|
||||||
private static final Set<Class<? extends InferenceService>> supportsStreaming = Set.of();
|
|
||||||
|
|
||||||
private final ModelRegistry modelRegistry;
|
private final ModelRegistry modelRegistry;
|
||||||
private final InferenceServiceRegistry serviceRegistry;
|
private final InferenceServiceRegistry serviceRegistry;
|
||||||
private final InferenceStats inferenceStats;
|
private final InferenceStats inferenceStats;
|
||||||
|
@ -62,17 +67,22 @@ public class TransportInferenceAction extends HandledTransportAction<InferenceAc
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void doExecute(Task task, InferenceAction.Request request, ActionListener<InferenceAction.Response> listener) {
|
protected void doExecute(Task task, InferenceAction.Request request, ActionListener<InferenceAction.Response> listener) {
|
||||||
|
var timer = InferenceTimer.start();
|
||||||
|
|
||||||
ActionListener<UnparsedModel> getModelListener = listener.delegateFailureAndWrap((delegate, unparsedModel) -> {
|
var getModelListener = ActionListener.wrap((UnparsedModel unparsedModel) -> {
|
||||||
var service = serviceRegistry.getService(unparsedModel.service());
|
var service = serviceRegistry.getService(unparsedModel.service());
|
||||||
if (service.isEmpty()) {
|
if (service.isEmpty()) {
|
||||||
listener.onFailure(unknownServiceException(unparsedModel.service(), request.getInferenceEntityId()));
|
var e = unknownServiceException(unparsedModel.service(), request.getInferenceEntityId());
|
||||||
|
recordMetrics(unparsedModel, timer, e);
|
||||||
|
listener.onFailure(e);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) {
|
if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) {
|
||||||
// not the wildcard task type and not the model task type
|
// not the wildcard task type and not the model task type
|
||||||
listener.onFailure(incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType()));
|
var e = incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType());
|
||||||
|
recordMetrics(unparsedModel, timer, e);
|
||||||
|
listener.onFailure(e);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,20 +93,69 @@ public class TransportInferenceAction extends HandledTransportAction<InferenceAc
|
||||||
unparsedModel.settings(),
|
unparsedModel.settings(),
|
||||||
unparsedModel.secrets()
|
unparsedModel.secrets()
|
||||||
);
|
);
|
||||||
inferOnService(model, request, service.get(), delegate);
|
inferOnServiceWithMetrics(model, request, service.get(), timer, listener);
|
||||||
|
}, e -> {
|
||||||
|
try {
|
||||||
|
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(e));
|
||||||
|
} catch (Exception metricsException) {
|
||||||
|
log.atDebug().withThrowable(metricsException).log("Failed to record metrics when the model is missing, dropping metrics");
|
||||||
|
}
|
||||||
|
listener.onFailure(e);
|
||||||
});
|
});
|
||||||
|
|
||||||
modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener);
|
modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) {
|
||||||
|
try {
|
||||||
|
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t));
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.atDebug().withThrowable(e).log("Failed to record metrics with an unparsed model, dropping metrics");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void inferOnServiceWithMetrics(
|
||||||
|
Model model,
|
||||||
|
InferenceAction.Request request,
|
||||||
|
InferenceService service,
|
||||||
|
InferenceTimer timer,
|
||||||
|
ActionListener<InferenceAction.Response> listener
|
||||||
|
) {
|
||||||
|
inferenceStats.requestCount().incrementBy(1, modelAttributes(model));
|
||||||
|
inferOnService(model, request, service, ActionListener.wrap(inferenceResults -> {
|
||||||
|
if (request.isStreaming()) {
|
||||||
|
var taskProcessor = streamingTaskManager.<ChunkedToXContent>create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION);
|
||||||
|
inferenceResults.publisher().subscribe(taskProcessor);
|
||||||
|
|
||||||
|
var instrumentedStream = new PublisherWithMetrics(timer, model);
|
||||||
|
taskProcessor.subscribe(instrumentedStream);
|
||||||
|
|
||||||
|
listener.onResponse(new InferenceAction.Response(inferenceResults, instrumentedStream));
|
||||||
|
} else {
|
||||||
|
recordMetrics(model, timer, null);
|
||||||
|
listener.onResponse(new InferenceAction.Response(inferenceResults));
|
||||||
|
}
|
||||||
|
}, e -> {
|
||||||
|
recordMetrics(model, timer, e);
|
||||||
|
listener.onFailure(e);
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) {
|
||||||
|
try {
|
||||||
|
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t));
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.atDebug().withThrowable(e).log("Failed to record metrics with a parsed model, dropping metrics");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private void inferOnService(
|
private void inferOnService(
|
||||||
Model model,
|
Model model,
|
||||||
InferenceAction.Request request,
|
InferenceAction.Request request,
|
||||||
InferenceService service,
|
InferenceService service,
|
||||||
ActionListener<InferenceAction.Response> listener
|
ActionListener<InferenceServiceResults> listener
|
||||||
) {
|
) {
|
||||||
if (request.isStreaming() == false || service.canStream(request.getTaskType())) {
|
if (request.isStreaming() == false || service.canStream(request.getTaskType())) {
|
||||||
inferenceStats.incrementRequestCount(model);
|
|
||||||
service.infer(
|
service.infer(
|
||||||
model,
|
model,
|
||||||
request.getQuery(),
|
request.getQuery(),
|
||||||
|
@ -105,7 +164,7 @@ public class TransportInferenceAction extends HandledTransportAction<InferenceAc
|
||||||
request.getTaskSettings(),
|
request.getTaskSettings(),
|
||||||
request.getInputType(),
|
request.getInputType(),
|
||||||
request.getInferenceTimeout(),
|
request.getInferenceTimeout(),
|
||||||
createListener(request, listener)
|
listener
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
listener.onFailure(unsupportedStreamingTaskException(request, service));
|
listener.onFailure(unsupportedStreamingTaskException(request, service));
|
||||||
|
@ -133,20 +192,6 @@ public class TransportInferenceAction extends HandledTransportAction<InferenceAc
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private ActionListener<InferenceServiceResults> createListener(
|
|
||||||
InferenceAction.Request request,
|
|
||||||
ActionListener<InferenceAction.Response> listener
|
|
||||||
) {
|
|
||||||
if (request.isStreaming()) {
|
|
||||||
return listener.delegateFailureAndWrap((l, inferenceResults) -> {
|
|
||||||
var taskProcessor = streamingTaskManager.<ChunkedToXContent>create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION);
|
|
||||||
inferenceResults.publisher().subscribe(taskProcessor);
|
|
||||||
l.onResponse(new InferenceAction.Response(inferenceResults, taskProcessor));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
return listener.delegateFailureAndWrap((l, inferenceResults) -> l.onResponse(new InferenceAction.Response(inferenceResults)));
|
|
||||||
}
|
|
||||||
|
|
||||||
private static ElasticsearchStatusException unknownServiceException(String service, String inferenceId) {
|
private static ElasticsearchStatusException unknownServiceException(String service, String inferenceId) {
|
||||||
return new ElasticsearchStatusException("Unknown service [{}] for model [{}]. ", RestStatus.BAD_REQUEST, service, inferenceId);
|
return new ElasticsearchStatusException("Unknown service [{}] for model [{}]. ", RestStatus.BAD_REQUEST, service, inferenceId);
|
||||||
}
|
}
|
||||||
|
@ -160,4 +205,37 @@ public class TransportInferenceAction extends HandledTransportAction<InferenceAc
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private class PublisherWithMetrics extends DelegatingProcessor<ChunkedToXContent, ChunkedToXContent> {
|
||||||
|
private final InferenceTimer timer;
|
||||||
|
private final Model model;
|
||||||
|
|
||||||
|
private PublisherWithMetrics(InferenceTimer timer, Model model) {
|
||||||
|
this.timer = timer;
|
||||||
|
this.model = model;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void next(ChunkedToXContent item) {
|
||||||
|
downstream().onNext(item);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onError(Throwable throwable) {
|
||||||
|
recordMetrics(model, timer, throwable);
|
||||||
|
super.onError(throwable);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void onCancel() {
|
||||||
|
recordMetrics(model, timer, null);
|
||||||
|
super.onCancel();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onComplete() {
|
||||||
|
recordMetrics(model, timer, null);
|
||||||
|
super.onComplete();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,11 +61,14 @@ public abstract class DelegatingProcessor<T, R> implements Flow.Processor<T, R>
|
||||||
public void cancel() {
|
public void cancel() {
|
||||||
if (isClosed.compareAndSet(false, true) && upstream != null) {
|
if (isClosed.compareAndSet(false, true) && upstream != null) {
|
||||||
upstream.cancel();
|
upstream.cancel();
|
||||||
|
onCancel();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected void onCancel() {}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onSubscribe(Flow.Subscription subscription) {
|
public void onSubscribe(Flow.Subscription subscription) {
|
||||||
if (upstream != null) {
|
if (upstream != null) {
|
||||||
|
|
|
@ -1,49 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
|
||||||
* or more contributor license agreements. Licensed under the Elastic License
|
|
||||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
|
||||||
* 2.0.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.telemetry;
|
|
||||||
|
|
||||||
import org.elasticsearch.inference.Model;
|
|
||||||
import org.elasticsearch.telemetry.metric.LongCounter;
|
|
||||||
import org.elasticsearch.telemetry.metric.MeterRegistry;
|
|
||||||
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.Objects;
|
|
||||||
|
|
||||||
public class ApmInferenceStats implements InferenceStats {
|
|
||||||
private final LongCounter inferenceAPMRequestCounter;
|
|
||||||
|
|
||||||
public ApmInferenceStats(LongCounter inferenceAPMRequestCounter) {
|
|
||||||
this.inferenceAPMRequestCounter = Objects.requireNonNull(inferenceAPMRequestCounter);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void incrementRequestCount(Model model) {
|
|
||||||
var service = model.getConfigurations().getService();
|
|
||||||
var taskType = model.getTaskType();
|
|
||||||
var modelId = model.getServiceSettings().modelId();
|
|
||||||
|
|
||||||
var attributes = new HashMap<String, Object>(5);
|
|
||||||
attributes.put("service", service);
|
|
||||||
attributes.put("task_type", taskType.toString());
|
|
||||||
if (modelId != null) {
|
|
||||||
attributes.put("model_id", modelId);
|
|
||||||
}
|
|
||||||
|
|
||||||
inferenceAPMRequestCounter.incrementBy(1, attributes);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static ApmInferenceStats create(MeterRegistry meterRegistry) {
|
|
||||||
return new ApmInferenceStats(
|
|
||||||
meterRegistry.registerLongCounter(
|
|
||||||
"es.inference.requests.count.total",
|
|
||||||
"Inference API request counts for a particular service, task type, model ID",
|
|
||||||
"operations"
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -7,15 +7,87 @@
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.telemetry;
|
package org.elasticsearch.xpack.inference.telemetry;
|
||||||
|
|
||||||
|
import org.elasticsearch.ElasticsearchStatusException;
|
||||||
|
import org.elasticsearch.core.Nullable;
|
||||||
import org.elasticsearch.inference.Model;
|
import org.elasticsearch.inference.Model;
|
||||||
|
import org.elasticsearch.inference.UnparsedModel;
|
||||||
|
import org.elasticsearch.telemetry.metric.LongCounter;
|
||||||
|
import org.elasticsearch.telemetry.metric.LongHistogram;
|
||||||
|
import org.elasticsearch.telemetry.metric.MeterRegistry;
|
||||||
|
|
||||||
public interface InferenceStats {
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
/**
|
import static java.util.Map.entry;
|
||||||
* Increment the counter for a particular value in a thread safe manner.
|
import static java.util.stream.Stream.concat;
|
||||||
* @param model the model to increment request count for
|
|
||||||
*/
|
|
||||||
void incrementRequestCount(Model model);
|
|
||||||
|
|
||||||
InferenceStats NOOP = model -> {};
|
public record InferenceStats(LongCounter requestCount, LongHistogram inferenceDuration) {
|
||||||
|
|
||||||
|
public InferenceStats {
|
||||||
|
Objects.requireNonNull(requestCount);
|
||||||
|
Objects.requireNonNull(inferenceDuration);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static InferenceStats create(MeterRegistry meterRegistry) {
|
||||||
|
return new InferenceStats(
|
||||||
|
meterRegistry.registerLongCounter(
|
||||||
|
"es.inference.requests.count.total",
|
||||||
|
"Inference API request counts for a particular service, task type, model ID",
|
||||||
|
"operations"
|
||||||
|
),
|
||||||
|
meterRegistry.registerLongHistogram(
|
||||||
|
"es.inference.requests.time",
|
||||||
|
"Inference API request counts for a particular service, task type, model ID",
|
||||||
|
"ms"
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Map<String, Object> modelAttributes(Model model) {
|
||||||
|
return toMap(modelAttributeEntries(model));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Stream<Map.Entry<String, Object>> modelAttributeEntries(Model model) {
|
||||||
|
var stream = Stream.<Map.Entry<String, Object>>builder()
|
||||||
|
.add(entry("service", model.getConfigurations().getService()))
|
||||||
|
.add(entry("task_type", model.getTaskType().toString()));
|
||||||
|
if (model.getServiceSettings().modelId() != null) {
|
||||||
|
stream.add(entry("model_id", model.getServiceSettings().modelId()));
|
||||||
|
}
|
||||||
|
return stream.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Map<String, Object> toMap(Stream<Map.Entry<String, Object>> stream) {
|
||||||
|
return stream.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Map<String, Object> responseAttributes(Model model, @Nullable Throwable t) {
|
||||||
|
return toMap(concat(modelAttributeEntries(model), errorAttributes(t)));
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Map<String, Object> responseAttributes(UnparsedModel model, @Nullable Throwable t) {
|
||||||
|
var unknownModelAttributes = Stream.<Map.Entry<String, Object>>builder()
|
||||||
|
.add(entry("service", model.service()))
|
||||||
|
.add(entry("task_type", model.taskType().toString()))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
return toMap(concat(unknownModelAttributes, errorAttributes(t)));
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Map<String, Object> responseAttributes(@Nullable Throwable t) {
|
||||||
|
return toMap(errorAttributes(t));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Stream<Map.Entry<String, Object>> errorAttributes(@Nullable Throwable t) {
|
||||||
|
return switch (t) {
|
||||||
|
case null -> Stream.of(entry("status_code", 200));
|
||||||
|
case ElasticsearchStatusException ese -> Stream.<Map.Entry<String, Object>>builder()
|
||||||
|
.add(entry("status_code", ese.status().getStatus()))
|
||||||
|
.add(entry("error.type", String.valueOf(ese.status().getStatus())))
|
||||||
|
.build();
|
||||||
|
default -> Stream.of(entry("error.type", t.getClass().getSimpleName()));
|
||||||
|
};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,33 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.telemetry;
|
||||||
|
|
||||||
|
import java.time.Clock;
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.time.Instant;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
public record InferenceTimer(Instant startTime, Clock clock) {
|
||||||
|
|
||||||
|
public InferenceTimer {
|
||||||
|
Objects.requireNonNull(startTime);
|
||||||
|
Objects.requireNonNull(clock);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static InferenceTimer start() {
|
||||||
|
return start(Clock.systemUTC());
|
||||||
|
}
|
||||||
|
|
||||||
|
public static InferenceTimer start(Clock clock) {
|
||||||
|
return new InferenceTimer(clock.instant(), clock);
|
||||||
|
}
|
||||||
|
|
||||||
|
public long elapsedMillis() {
|
||||||
|
return Duration.between(startTime(), clock().instant()).toMillis();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,354 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.action;
|
||||||
|
|
||||||
|
import org.elasticsearch.ElasticsearchStatusException;
|
||||||
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
import org.elasticsearch.action.support.ActionFilters;
|
||||||
|
import org.elasticsearch.common.xcontent.ChunkedToXContent;
|
||||||
|
import org.elasticsearch.inference.InferenceService;
|
||||||
|
import org.elasticsearch.inference.InferenceServiceRegistry;
|
||||||
|
import org.elasticsearch.inference.InferenceServiceResults;
|
||||||
|
import org.elasticsearch.inference.Model;
|
||||||
|
import org.elasticsearch.inference.ModelConfigurations;
|
||||||
|
import org.elasticsearch.inference.TaskType;
|
||||||
|
import org.elasticsearch.inference.UnparsedModel;
|
||||||
|
import org.elasticsearch.rest.RestStatus;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.transport.TransportService;
|
||||||
|
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||||
|
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
|
||||||
|
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
|
||||||
|
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.mockito.ArgumentCaptor;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.concurrent.Flow;
|
||||||
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.is;
|
||||||
|
import static org.hamcrest.Matchers.isA;
|
||||||
|
import static org.hamcrest.Matchers.nullValue;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.ArgumentMatchers.anyBoolean;
|
||||||
|
import static org.mockito.ArgumentMatchers.anyLong;
|
||||||
|
import static org.mockito.ArgumentMatchers.assertArg;
|
||||||
|
import static org.mockito.ArgumentMatchers.same;
|
||||||
|
import static org.mockito.Mockito.doAnswer;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
public class TransportInferenceActionTests extends ESTestCase {
|
||||||
|
private static final String serviceId = "serviceId";
|
||||||
|
private static final TaskType taskType = TaskType.COMPLETION;
|
||||||
|
private static final String inferenceId = "inferenceEntityId";
|
||||||
|
private ModelRegistry modelRegistry;
|
||||||
|
private InferenceServiceRegistry serviceRegistry;
|
||||||
|
private InferenceStats inferenceStats;
|
||||||
|
private StreamingTaskManager streamingTaskManager;
|
||||||
|
private TransportInferenceAction action;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setUp() throws Exception {
|
||||||
|
super.setUp();
|
||||||
|
TransportService transportService = mock();
|
||||||
|
ActionFilters actionFilters = mock();
|
||||||
|
modelRegistry = mock();
|
||||||
|
serviceRegistry = mock();
|
||||||
|
inferenceStats = new InferenceStats(mock(), mock());
|
||||||
|
streamingTaskManager = mock();
|
||||||
|
action = new TransportInferenceAction(
|
||||||
|
transportService,
|
||||||
|
actionFilters,
|
||||||
|
modelRegistry,
|
||||||
|
serviceRegistry,
|
||||||
|
inferenceStats,
|
||||||
|
streamingTaskManager
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMetricsAfterModelRegistryError() {
|
||||||
|
var expectedException = new IllegalStateException("hello");
|
||||||
|
var expectedError = expectedException.getClass().getSimpleName();
|
||||||
|
|
||||||
|
doAnswer(ans -> {
|
||||||
|
ActionListener<?> listener = ans.getArgument(1);
|
||||||
|
listener.onFailure(expectedException);
|
||||||
|
return null;
|
||||||
|
}).when(modelRegistry).getModelWithSecrets(any(), any());
|
||||||
|
|
||||||
|
var listener = doExecute(taskType);
|
||||||
|
verify(listener).onFailure(same(expectedException));
|
||||||
|
|
||||||
|
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
|
||||||
|
assertThat(attributes.get("service"), nullValue());
|
||||||
|
assertThat(attributes.get("task_type"), nullValue());
|
||||||
|
assertThat(attributes.get("model_id"), nullValue());
|
||||||
|
assertThat(attributes.get("status_code"), nullValue());
|
||||||
|
assertThat(attributes.get("error.type"), is(expectedError));
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
private ActionListener<InferenceAction.Response> doExecute(TaskType taskType) {
|
||||||
|
return doExecute(taskType, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
private ActionListener<InferenceAction.Response> doExecute(TaskType taskType, boolean stream) {
|
||||||
|
InferenceAction.Request request = mock();
|
||||||
|
when(request.getInferenceEntityId()).thenReturn(inferenceId);
|
||||||
|
when(request.getTaskType()).thenReturn(taskType);
|
||||||
|
when(request.isStreaming()).thenReturn(stream);
|
||||||
|
ActionListener<InferenceAction.Response> listener = mock();
|
||||||
|
action.doExecute(mock(), request, listener);
|
||||||
|
return listener;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMetricsAfterMissingService() {
|
||||||
|
mockModelRegistry(taskType);
|
||||||
|
|
||||||
|
when(serviceRegistry.getService(any())).thenReturn(Optional.empty());
|
||||||
|
|
||||||
|
var listener = doExecute(taskType);
|
||||||
|
|
||||||
|
verify(listener).onFailure(assertArg(e -> {
|
||||||
|
assertThat(e, isA(ElasticsearchStatusException.class));
|
||||||
|
assertThat(e.getMessage(), is("Unknown service [" + serviceId + "] for model [" + inferenceId + "]. "));
|
||||||
|
assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST));
|
||||||
|
}));
|
||||||
|
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
|
||||||
|
assertThat(attributes.get("service"), is(serviceId));
|
||||||
|
assertThat(attributes.get("task_type"), is(taskType.toString()));
|
||||||
|
assertThat(attributes.get("model_id"), nullValue());
|
||||||
|
assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus()));
|
||||||
|
assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus())));
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
private void mockModelRegistry(TaskType expectedTaskType) {
|
||||||
|
var unparsedModel = new UnparsedModel(inferenceId, expectedTaskType, serviceId, Map.of(), Map.of());
|
||||||
|
doAnswer(ans -> {
|
||||||
|
ActionListener<UnparsedModel> listener = ans.getArgument(1);
|
||||||
|
listener.onResponse(unparsedModel);
|
||||||
|
return null;
|
||||||
|
}).when(modelRegistry).getModelWithSecrets(any(), any());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMetricsAfterUnknownTaskType() {
|
||||||
|
var modelTaskType = TaskType.RERANK;
|
||||||
|
var requestTaskType = TaskType.SPARSE_EMBEDDING;
|
||||||
|
mockModelRegistry(modelTaskType);
|
||||||
|
when(serviceRegistry.getService(any())).thenReturn(Optional.of(mock()));
|
||||||
|
|
||||||
|
var listener = doExecute(requestTaskType);
|
||||||
|
|
||||||
|
verify(listener).onFailure(assertArg(e -> {
|
||||||
|
assertThat(e, isA(ElasticsearchStatusException.class));
|
||||||
|
assertThat(
|
||||||
|
e.getMessage(),
|
||||||
|
is(
|
||||||
|
"Incompatible task_type, the requested type ["
|
||||||
|
+ requestTaskType
|
||||||
|
+ "] does not match the model type ["
|
||||||
|
+ modelTaskType
|
||||||
|
+ "]"
|
||||||
|
)
|
||||||
|
);
|
||||||
|
assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST));
|
||||||
|
}));
|
||||||
|
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
|
||||||
|
assertThat(attributes.get("service"), is(serviceId));
|
||||||
|
assertThat(attributes.get("task_type"), is(modelTaskType.toString()));
|
||||||
|
assertThat(attributes.get("model_id"), nullValue());
|
||||||
|
assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus()));
|
||||||
|
assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus())));
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMetricsAfterInferError() {
|
||||||
|
var expectedException = new IllegalStateException("hello");
|
||||||
|
var expectedError = expectedException.getClass().getSimpleName();
|
||||||
|
mockService(listener -> listener.onFailure(expectedException));
|
||||||
|
|
||||||
|
var listener = doExecute(taskType);
|
||||||
|
|
||||||
|
verify(listener).onFailure(same(expectedException));
|
||||||
|
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
|
||||||
|
assertThat(attributes.get("service"), is(serviceId));
|
||||||
|
assertThat(attributes.get("task_type"), is(taskType.toString()));
|
||||||
|
assertThat(attributes.get("model_id"), nullValue());
|
||||||
|
assertThat(attributes.get("status_code"), nullValue());
|
||||||
|
assertThat(attributes.get("error.type"), is(expectedError));
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMetricsAfterStreamUnsupported() {
|
||||||
|
var expectedStatus = RestStatus.METHOD_NOT_ALLOWED;
|
||||||
|
var expectedError = String.valueOf(expectedStatus.getStatus());
|
||||||
|
mockService(l -> {});
|
||||||
|
|
||||||
|
var listener = doExecute(taskType, true);
|
||||||
|
|
||||||
|
verify(listener).onFailure(assertArg(e -> {
|
||||||
|
assertThat(e, isA(ElasticsearchStatusException.class));
|
||||||
|
var ese = (ElasticsearchStatusException) e;
|
||||||
|
assertThat(ese.getMessage(), is("Streaming is not allowed for service [" + serviceId + "]."));
|
||||||
|
assertThat(ese.status(), is(expectedStatus));
|
||||||
|
}));
|
||||||
|
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
|
||||||
|
assertThat(attributes.get("service"), is(serviceId));
|
||||||
|
assertThat(attributes.get("task_type"), is(taskType.toString()));
|
||||||
|
assertThat(attributes.get("model_id"), nullValue());
|
||||||
|
assertThat(attributes.get("status_code"), is(expectedStatus.getStatus()));
|
||||||
|
assertThat(attributes.get("error.type"), is(expectedError));
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMetricsAfterInferSuccess() {
|
||||||
|
mockService(listener -> listener.onResponse(mock()));
|
||||||
|
|
||||||
|
var listener = doExecute(taskType);
|
||||||
|
|
||||||
|
verify(listener).onResponse(any());
|
||||||
|
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
|
||||||
|
assertThat(attributes.get("service"), is(serviceId));
|
||||||
|
assertThat(attributes.get("task_type"), is(taskType.toString()));
|
||||||
|
assertThat(attributes.get("model_id"), nullValue());
|
||||||
|
assertThat(attributes.get("status_code"), is(200));
|
||||||
|
assertThat(attributes.get("error.type"), nullValue());
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMetricsAfterStreamInferSuccess() {
|
||||||
|
mockStreamResponse(Flow.Subscriber::onComplete);
|
||||||
|
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
|
||||||
|
assertThat(attributes.get("service"), is(serviceId));
|
||||||
|
assertThat(attributes.get("task_type"), is(taskType.toString()));
|
||||||
|
assertThat(attributes.get("model_id"), nullValue());
|
||||||
|
assertThat(attributes.get("status_code"), is(200));
|
||||||
|
assertThat(attributes.get("error.type"), nullValue());
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMetricsAfterStreamInferFailure() {
|
||||||
|
var expectedException = new IllegalStateException("hello");
|
||||||
|
var expectedError = expectedException.getClass().getSimpleName();
|
||||||
|
mockStreamResponse(subscriber -> {
|
||||||
|
subscriber.subscribe(mock());
|
||||||
|
subscriber.onError(expectedException);
|
||||||
|
});
|
||||||
|
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
|
||||||
|
assertThat(attributes.get("service"), is(serviceId));
|
||||||
|
assertThat(attributes.get("task_type"), is(taskType.toString()));
|
||||||
|
assertThat(attributes.get("model_id"), nullValue());
|
||||||
|
assertThat(attributes.get("status_code"), nullValue());
|
||||||
|
assertThat(attributes.get("error.type"), is(expectedError));
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMetricsAfterStreamCancel() {
|
||||||
|
var response = mockStreamResponse(s -> s.onSubscribe(mock()));
|
||||||
|
response.subscribe(new Flow.Subscriber<>() {
|
||||||
|
@Override
|
||||||
|
public void onSubscribe(Flow.Subscription subscription) {
|
||||||
|
subscription.cancel();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onNext(ChunkedToXContent item) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onError(Throwable throwable) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onComplete() {
|
||||||
|
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
|
||||||
|
assertThat(attributes.get("service"), is(serviceId));
|
||||||
|
assertThat(attributes.get("task_type"), is(taskType.toString()));
|
||||||
|
assertThat(attributes.get("model_id"), nullValue());
|
||||||
|
assertThat(attributes.get("status_code"), is(200));
|
||||||
|
assertThat(attributes.get("error.type"), nullValue());
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
private Flow.Publisher<ChunkedToXContent> mockStreamResponse(Consumer<Flow.Processor<?, ?>> action) {
|
||||||
|
mockService(true, Set.of(), listener -> {
|
||||||
|
Flow.Processor<ChunkedToXContent, ChunkedToXContent> taskProcessor = mock();
|
||||||
|
doAnswer(innerAns -> {
|
||||||
|
action.accept(innerAns.getArgument(0));
|
||||||
|
return null;
|
||||||
|
}).when(taskProcessor).subscribe(any());
|
||||||
|
when(streamingTaskManager.<ChunkedToXContent>create(any(), any())).thenReturn(taskProcessor);
|
||||||
|
var inferenceServiceResults = mock(InferenceServiceResults.class);
|
||||||
|
when(inferenceServiceResults.publisher()).thenReturn(mock());
|
||||||
|
listener.onResponse(inferenceServiceResults);
|
||||||
|
});
|
||||||
|
|
||||||
|
var listener = doExecute(taskType, true);
|
||||||
|
var captor = ArgumentCaptor.forClass(InferenceAction.Response.class);
|
||||||
|
verify(listener).onResponse(captor.capture());
|
||||||
|
assertTrue(captor.getValue().isStreaming());
|
||||||
|
assertNotNull(captor.getValue().publisher());
|
||||||
|
return captor.getValue().publisher();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void mockService(Consumer<ActionListener<InferenceServiceResults>> listenerAction) {
|
||||||
|
mockService(false, Set.of(), listenerAction);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void mockService(
|
||||||
|
boolean stream,
|
||||||
|
Set<TaskType> supportedStreamingTasks,
|
||||||
|
Consumer<ActionListener<InferenceServiceResults>> listenerAction
|
||||||
|
) {
|
||||||
|
InferenceService service = mock();
|
||||||
|
Model model = mockModel();
|
||||||
|
when(service.parsePersistedConfigWithSecrets(any(), any(), any(), any())).thenReturn(model);
|
||||||
|
when(service.name()).thenReturn(serviceId);
|
||||||
|
|
||||||
|
when(service.canStream(any())).thenReturn(stream);
|
||||||
|
when(service.supportedStreamingTasks()).thenReturn(supportedStreamingTasks);
|
||||||
|
doAnswer(ans -> {
|
||||||
|
listenerAction.accept(ans.getArgument(7));
|
||||||
|
return null;
|
||||||
|
}).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());
|
||||||
|
mockModelAndServiceRegistry(service);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Model mockModel() {
|
||||||
|
Model model = mock();
|
||||||
|
ModelConfigurations modelConfigurations = mock();
|
||||||
|
when(modelConfigurations.getService()).thenReturn(serviceId);
|
||||||
|
when(model.getConfigurations()).thenReturn(modelConfigurations);
|
||||||
|
when(model.getTaskType()).thenReturn(taskType);
|
||||||
|
when(model.getServiceSettings()).thenReturn(mock());
|
||||||
|
return model;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void mockModelAndServiceRegistry(InferenceService service) {
|
||||||
|
var unparsedModel = new UnparsedModel(inferenceId, taskType, serviceId, Map.of(), Map.of());
|
||||||
|
doAnswer(ans -> {
|
||||||
|
ActionListener<UnparsedModel> listener = ans.getArgument(1);
|
||||||
|
listener.onResponse(unparsedModel);
|
||||||
|
return null;
|
||||||
|
}).when(modelRegistry).getModelWithSecrets(any(), any());
|
||||||
|
|
||||||
|
when(serviceRegistry.getService(any())).thenReturn(Optional.of(service));
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,69 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
|
||||||
* or more contributor license agreements. Licensed under the Elastic License
|
|
||||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
|
||||||
* 2.0.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.elasticsearch.xpack.inference.telemetry;
|
|
||||||
|
|
||||||
import org.elasticsearch.inference.Model;
|
|
||||||
import org.elasticsearch.inference.ModelConfigurations;
|
|
||||||
import org.elasticsearch.inference.ServiceSettings;
|
|
||||||
import org.elasticsearch.inference.TaskType;
|
|
||||||
import org.elasticsearch.telemetry.metric.LongCounter;
|
|
||||||
import org.elasticsearch.telemetry.metric.MeterRegistry;
|
|
||||||
import org.elasticsearch.test.ESTestCase;
|
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
import static org.mockito.ArgumentMatchers.eq;
|
|
||||||
import static org.mockito.Mockito.mock;
|
|
||||||
import static org.mockito.Mockito.verify;
|
|
||||||
import static org.mockito.Mockito.when;
|
|
||||||
|
|
||||||
public class ApmInferenceStatsTests extends ESTestCase {
|
|
||||||
|
|
||||||
public void testRecordWithModel() {
|
|
||||||
var longCounter = mock(LongCounter.class);
|
|
||||||
|
|
||||||
var stats = new ApmInferenceStats(longCounter);
|
|
||||||
|
|
||||||
stats.incrementRequestCount(model("service", TaskType.ANY, "modelId"));
|
|
||||||
|
|
||||||
verify(longCounter).incrementBy(
|
|
||||||
eq(1L),
|
|
||||||
eq(Map.of("service", "service", "task_type", TaskType.ANY.toString(), "model_id", "modelId"))
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void testRecordWithoutModel() {
|
|
||||||
var longCounter = mock(LongCounter.class);
|
|
||||||
|
|
||||||
var stats = new ApmInferenceStats(longCounter);
|
|
||||||
|
|
||||||
stats.incrementRequestCount(model("service", TaskType.ANY, null));
|
|
||||||
|
|
||||||
verify(longCounter).incrementBy(eq(1L), eq(Map.of("service", "service", "task_type", TaskType.ANY.toString())));
|
|
||||||
}
|
|
||||||
|
|
||||||
public void testCreation() {
|
|
||||||
assertNotNull(ApmInferenceStats.create(MeterRegistry.NOOP));
|
|
||||||
}
|
|
||||||
|
|
||||||
private Model model(String service, TaskType taskType, String modelId) {
|
|
||||||
var configuration = mock(ModelConfigurations.class);
|
|
||||||
when(configuration.getService()).thenReturn(service);
|
|
||||||
var settings = mock(ServiceSettings.class);
|
|
||||||
if (modelId != null) {
|
|
||||||
when(settings.modelId()).thenReturn(modelId);
|
|
||||||
}
|
|
||||||
|
|
||||||
var model = mock(Model.class);
|
|
||||||
when(model.getTaskType()).thenReturn(taskType);
|
|
||||||
when(model.getConfigurations()).thenReturn(configuration);
|
|
||||||
when(model.getServiceSettings()).thenReturn(settings);
|
|
||||||
|
|
||||||
return model;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,217 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.telemetry;
|
||||||
|
|
||||||
|
import org.elasticsearch.ElasticsearchStatusException;
|
||||||
|
import org.elasticsearch.inference.Model;
|
||||||
|
import org.elasticsearch.inference.ModelConfigurations;
|
||||||
|
import org.elasticsearch.inference.ServiceSettings;
|
||||||
|
import org.elasticsearch.inference.TaskType;
|
||||||
|
import org.elasticsearch.inference.UnparsedModel;
|
||||||
|
import org.elasticsearch.rest.RestStatus;
|
||||||
|
import org.elasticsearch.telemetry.metric.LongCounter;
|
||||||
|
import org.elasticsearch.telemetry.metric.LongHistogram;
|
||||||
|
import org.elasticsearch.telemetry.metric.MeterRegistry;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
|
||||||
|
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes;
|
||||||
|
import static org.hamcrest.Matchers.is;
|
||||||
|
import static org.hamcrest.Matchers.nullValue;
|
||||||
|
import static org.mockito.ArgumentMatchers.assertArg;
|
||||||
|
import static org.mockito.ArgumentMatchers.eq;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
public class InferenceStatsTests extends ESTestCase {
|
||||||
|
|
||||||
|
public void testRecordWithModel() {
|
||||||
|
var longCounter = mock(LongCounter.class);
|
||||||
|
var stats = new InferenceStats(longCounter, mock());
|
||||||
|
|
||||||
|
stats.requestCount().incrementBy(1, modelAttributes(model("service", TaskType.ANY, "modelId")));
|
||||||
|
|
||||||
|
verify(longCounter).incrementBy(
|
||||||
|
eq(1L),
|
||||||
|
eq(Map.of("service", "service", "task_type", TaskType.ANY.toString(), "model_id", "modelId"))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testRecordWithoutModel() {
|
||||||
|
var longCounter = mock(LongCounter.class);
|
||||||
|
var stats = new InferenceStats(longCounter, mock());
|
||||||
|
|
||||||
|
stats.requestCount().incrementBy(1, modelAttributes(model("service", TaskType.ANY, null)));
|
||||||
|
|
||||||
|
verify(longCounter).incrementBy(eq(1L), eq(Map.of("service", "service", "task_type", TaskType.ANY.toString())));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testCreation() {
|
||||||
|
assertNotNull(InferenceStats.create(MeterRegistry.NOOP));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testRecordDurationWithoutError() {
|
||||||
|
var expectedLong = randomLong();
|
||||||
|
var histogramCounter = mock(LongHistogram.class);
|
||||||
|
var stats = new InferenceStats(mock(), histogramCounter);
|
||||||
|
|
||||||
|
stats.inferenceDuration().record(expectedLong, responseAttributes(model("service", TaskType.ANY, "modelId"), null));
|
||||||
|
|
||||||
|
verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
|
||||||
|
assertThat(attributes.get("service"), is("service"));
|
||||||
|
assertThat(attributes.get("task_type"), is(TaskType.ANY.toString()));
|
||||||
|
assertThat(attributes.get("model_id"), is("modelId"));
|
||||||
|
assertThat(attributes.get("status_code"), is(200));
|
||||||
|
assertThat(attributes.get("error.type"), nullValue());
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* "If response status code was sent or received and status indicates an error according to HTTP span status definition,
|
||||||
|
* error.type SHOULD be set to the status code number (represented as a string)"
|
||||||
|
* - https://opentelemetry.io/docs/specs/semconv/http/http-metrics/
|
||||||
|
*/
|
||||||
|
public void testRecordDurationWithElasticsearchStatusException() {
|
||||||
|
var expectedLong = randomLong();
|
||||||
|
var histogramCounter = mock(LongHistogram.class);
|
||||||
|
var stats = new InferenceStats(mock(), histogramCounter);
|
||||||
|
var statusCode = RestStatus.BAD_REQUEST;
|
||||||
|
var exception = new ElasticsearchStatusException("hello", statusCode);
|
||||||
|
var expectedError = String.valueOf(statusCode.getStatus());
|
||||||
|
|
||||||
|
stats.inferenceDuration().record(expectedLong, responseAttributes(model("service", TaskType.ANY, "modelId"), exception));
|
||||||
|
|
||||||
|
verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
|
||||||
|
assertThat(attributes.get("service"), is("service"));
|
||||||
|
assertThat(attributes.get("task_type"), is(TaskType.ANY.toString()));
|
||||||
|
assertThat(attributes.get("model_id"), is("modelId"));
|
||||||
|
assertThat(attributes.get("status_code"), is(statusCode.getStatus()));
|
||||||
|
assertThat(attributes.get("error.type"), is(expectedError));
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* "If the request fails with an error before response status code was sent or received,
|
||||||
|
* error.type SHOULD be set to exception type"
|
||||||
|
* - https://opentelemetry.io/docs/specs/semconv/http/http-metrics/
|
||||||
|
*/
|
||||||
|
public void testRecordDurationWithOtherException() {
|
||||||
|
var expectedLong = randomLong();
|
||||||
|
var histogramCounter = mock(LongHistogram.class);
|
||||||
|
var stats = new InferenceStats(mock(), histogramCounter);
|
||||||
|
var exception = new IllegalStateException("ahh");
|
||||||
|
var expectedError = exception.getClass().getSimpleName();
|
||||||
|
|
||||||
|
stats.inferenceDuration().record(expectedLong, responseAttributes(model("service", TaskType.ANY, "modelId"), exception));
|
||||||
|
|
||||||
|
verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
|
||||||
|
assertThat(attributes.get("service"), is("service"));
|
||||||
|
assertThat(attributes.get("task_type"), is(TaskType.ANY.toString()));
|
||||||
|
assertThat(attributes.get("model_id"), is("modelId"));
|
||||||
|
assertThat(attributes.get("status_code"), nullValue());
|
||||||
|
assertThat(attributes.get("error.type"), is(expectedError));
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testRecordDurationWithUnparsedModelAndElasticsearchStatusException() {
|
||||||
|
var expectedLong = randomLong();
|
||||||
|
var histogramCounter = mock(LongHistogram.class);
|
||||||
|
var stats = new InferenceStats(mock(), histogramCounter);
|
||||||
|
var statusCode = RestStatus.BAD_REQUEST;
|
||||||
|
var exception = new ElasticsearchStatusException("hello", statusCode);
|
||||||
|
var expectedError = String.valueOf(statusCode.getStatus());
|
||||||
|
|
||||||
|
var unparsedModel = new UnparsedModel("inferenceEntityId", TaskType.ANY, "service", Map.of(), Map.of());
|
||||||
|
|
||||||
|
stats.inferenceDuration().record(expectedLong, responseAttributes(unparsedModel, exception));
|
||||||
|
|
||||||
|
verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
|
||||||
|
assertThat(attributes.get("service"), is("service"));
|
||||||
|
assertThat(attributes.get("task_type"), is(TaskType.ANY.toString()));
|
||||||
|
assertThat(attributes.get("model_id"), nullValue());
|
||||||
|
assertThat(attributes.get("status_code"), is(statusCode.getStatus()));
|
||||||
|
assertThat(attributes.get("error.type"), is(expectedError));
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testRecordDurationWithUnparsedModelAndOtherException() {
|
||||||
|
var expectedLong = randomLong();
|
||||||
|
var histogramCounter = mock(LongHistogram.class);
|
||||||
|
var stats = new InferenceStats(mock(), histogramCounter);
|
||||||
|
var exception = new IllegalStateException("ahh");
|
||||||
|
var expectedError = exception.getClass().getSimpleName();
|
||||||
|
|
||||||
|
var unparsedModel = new UnparsedModel("inferenceEntityId", TaskType.ANY, "service", Map.of(), Map.of());
|
||||||
|
|
||||||
|
stats.inferenceDuration().record(expectedLong, responseAttributes(unparsedModel, exception));
|
||||||
|
|
||||||
|
verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
|
||||||
|
assertThat(attributes.get("service"), is("service"));
|
||||||
|
assertThat(attributes.get("task_type"), is(TaskType.ANY.toString()));
|
||||||
|
assertThat(attributes.get("model_id"), nullValue());
|
||||||
|
assertThat(attributes.get("status_code"), nullValue());
|
||||||
|
assertThat(attributes.get("error.type"), is(expectedError));
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testRecordDurationWithUnknownModelAndElasticsearchStatusException() {
|
||||||
|
var expectedLong = randomLong();
|
||||||
|
var histogramCounter = mock(LongHistogram.class);
|
||||||
|
var stats = new InferenceStats(mock(), histogramCounter);
|
||||||
|
var statusCode = RestStatus.BAD_REQUEST;
|
||||||
|
var exception = new ElasticsearchStatusException("hello", statusCode);
|
||||||
|
var expectedError = String.valueOf(statusCode.getStatus());
|
||||||
|
|
||||||
|
stats.inferenceDuration().record(expectedLong, responseAttributes(exception));
|
||||||
|
|
||||||
|
verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
|
||||||
|
assertThat(attributes.get("service"), nullValue());
|
||||||
|
assertThat(attributes.get("task_type"), nullValue());
|
||||||
|
assertThat(attributes.get("model_id"), nullValue());
|
||||||
|
assertThat(attributes.get("status_code"), is(statusCode.getStatus()));
|
||||||
|
assertThat(attributes.get("error.type"), is(expectedError));
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testRecordDurationWithUnknownModelAndOtherException() {
|
||||||
|
var expectedLong = randomLong();
|
||||||
|
var histogramCounter = mock(LongHistogram.class);
|
||||||
|
var stats = new InferenceStats(mock(), histogramCounter);
|
||||||
|
var exception = new IllegalStateException("ahh");
|
||||||
|
var expectedError = exception.getClass().getSimpleName();
|
||||||
|
|
||||||
|
stats.inferenceDuration().record(expectedLong, responseAttributes(exception));
|
||||||
|
|
||||||
|
verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
|
||||||
|
assertThat(attributes.get("service"), nullValue());
|
||||||
|
assertThat(attributes.get("task_type"), nullValue());
|
||||||
|
assertThat(attributes.get("model_id"), nullValue());
|
||||||
|
assertThat(attributes.get("status_code"), nullValue());
|
||||||
|
assertThat(attributes.get("error.type"), is(expectedError));
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
private Model model(String service, TaskType taskType, String modelId) {
|
||||||
|
var configuration = mock(ModelConfigurations.class);
|
||||||
|
when(configuration.getService()).thenReturn(service);
|
||||||
|
var settings = mock(ServiceSettings.class);
|
||||||
|
if (modelId != null) {
|
||||||
|
when(settings.modelId()).thenReturn(modelId);
|
||||||
|
}
|
||||||
|
|
||||||
|
var model = mock(Model.class);
|
||||||
|
when(model.getTaskType()).thenReturn(taskType);
|
||||||
|
when(model.getConfigurations()).thenReturn(configuration);
|
||||||
|
when(model.getServiceSettings()).thenReturn(settings);
|
||||||
|
|
||||||
|
return model;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,32 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License
|
||||||
|
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||||
|
* 2.0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.inference.telemetry;
|
||||||
|
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
|
||||||
|
import java.time.Clock;
|
||||||
|
import java.time.Instant;
|
||||||
|
import java.time.temporal.ChronoUnit;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.is;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
public class InferenceTimerTests extends ESTestCase {
|
||||||
|
|
||||||
|
public void testElapsedMillis() {
|
||||||
|
var expectedDuration = randomLongBetween(10, 300);
|
||||||
|
|
||||||
|
var startTime = Instant.now();
|
||||||
|
var clock = mock(Clock.class);
|
||||||
|
when(clock.instant()).thenReturn(startTime).thenReturn(startTime.plus(expectedDuration, ChronoUnit.MILLIS));
|
||||||
|
var timer = InferenceTimer.start(clock);
|
||||||
|
|
||||||
|
assertThat(expectedDuration, is(timer.elapsedMillis()));
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue