[ML] Trained Models: Count tests run against trained models (#212927)

Part of: https://github.com/elastic/kibana/issues/200725
This PR adds UI Counters for tests (success and failed) run against
trained models.
This commit is contained in:
Robert Jaszczurek 2025-03-10 13:42:30 +01:00 committed by GitHub
parent d6afbe9675
commit 6a184b4b4b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 166 additions and 32 deletions

View file

@ -16,6 +16,7 @@ import { ES_FIELD_TYPES } from '@kbn/field-types';
import type { MLHttpFetchError } from '@kbn/ml-error-utils';
import type { trainedModelsApiProvider } from '../../../services/ml_api_service/trained_models';
import { getInferenceInfoComponent } from './inference_info';
import type { ITelemetryClient } from '../../../services/telemetry/types';
export type InferenceType =
| SupportedPytorchTasksType
@ -84,7 +85,8 @@ export abstract class InferenceBase<TInferResponse> {
protected readonly trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
protected readonly model: estypes.MlTrainedModelConfig,
protected readonly inputType: INPUT_TYPE,
protected readonly deploymentId: string
protected readonly deploymentId: string,
private readonly telemetryClient: ITelemetryClient
) {
this.modelInputField = model.input?.field_names[0] ?? DEFAULT_INPUT_FIELD;
this.inputField$.next(this.modelInputField);
@ -317,9 +319,14 @@ export abstract class InferenceBase<TInferResponse> {
this.inferenceResult$.next([processedResponse]);
this.setFinished();
this.trackModelTested('success');
return [processedResponse];
} catch (error) {
this.setFinishedWithErrors(error);
this.trackModelTested('failure');
throw error;
}
}
@ -336,9 +343,15 @@ export abstract class InferenceBase<TInferResponse> {
const processedResponse = docs.map((d) => processResponse(this.getDocFromResponse(d)));
this.inferenceResult$.next(processedResponse);
this.setFinished();
this.trackModelTested('success');
return processedResponse;
} catch (error) {
this.setFinishedWithErrors(error);
this.trackModelTested('failure');
throw error;
}
}
@ -391,4 +404,13 @@ export abstract class InferenceBase<TInferResponse> {
}
return doc;
}
private trackModelTested(result: 'success' | 'failure') {
this.telemetryClient.trackTrainedModelsModelTested({
model_id: this.model.model_id,
model_type: this.model.model_type,
task_type: this.inferenceType,
result,
});
}
}

View file

@ -13,6 +13,7 @@ import { InferenceBase, INPUT_TYPE } from '../inference_base';
import type { InferResponse } from '../inference_base';
import { getGeneralInputComponent } from '../text_input';
import { getNerOutputComponent } from './ner_output';
import type { ITelemetryClient } from '../../../../services/telemetry/types';
export type FormattedNerResponse = Array<{
value: string;
@ -37,9 +38,10 @@ export class NerInference extends InferenceBase<NerResponse> {
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE,
deploymentId: string
deploymentId: string,
telemetryClient: ITelemetryClient
) {
super(trainedModelsApi, model, inputType, deploymentId);
super(trainedModelsApi, model, inputType, deploymentId, telemetryClient);
this.initialize();
}

View file

@ -15,6 +15,7 @@ import type { InferResponse, INPUT_TYPE } from '../inference_base';
import { getQuestionAnsweringInput } from './question_answering_input';
import { getQuestionAnsweringOutputComponent } from './question_answering_output';
import type { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models';
import type { ITelemetryClient } from '../../../../services/telemetry/types';
export interface RawQuestionAnsweringResponse {
inference_results: Array<{
@ -63,9 +64,10 @@ export class QuestionAnsweringInference extends InferenceBase<QuestionAnsweringR
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE,
deploymentId: string
deploymentId: string,
telemetryClient: ITelemetryClient
) {
super(trainedModelsApi, model, inputType, deploymentId);
super(trainedModelsApi, model, inputType, deploymentId, telemetryClient);
this.initialize(
[this.questionText$.pipe(map((questionText) => questionText !== ''))],

View file

@ -15,6 +15,7 @@ import { processResponse, processInferenceResult } from './common';
import { getGeneralInputComponent } from '../text_input';
import { getFillMaskOutputComponent } from './fill_mask_output';
import type { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models';
import type { ITelemetryClient } from '../../../../services/telemetry/types';
const DEFAULT_MASK_TOKEN = '[MASK]';
@ -36,9 +37,10 @@ export class FillMaskInference extends InferenceBase<TextClassificationResponse>
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE,
deploymentId: string
deploymentId: string,
telemetryClient: ITelemetryClient
) {
super(trainedModelsApi, model, inputType, deploymentId);
super(trainedModelsApi, model, inputType, deploymentId, telemetryClient);
const maskToken = model.inference_config?.[this.inferenceType]?.mask_token;
if (maskToken) {
this.maskToken = maskToken;

View file

@ -14,6 +14,7 @@ import { getGeneralInputComponent } from '../text_input';
import { getLangIdentOutputComponent } from './lang_ident_output';
import type { TextClassificationResponse, RawTextClassificationResponse } from './common';
import type { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models';
import type { ITelemetryClient } from '../../../../services/telemetry/types';
export class LangIdentInference extends InferenceBase<TextClassificationResponse> {
protected inferenceType: InferenceType = 'classification';
@ -32,9 +33,10 @@ export class LangIdentInference extends InferenceBase<TextClassificationResponse
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE,
deploymentId: string
deploymentId: string,
telemetryClient: ITelemetryClient
) {
super(trainedModelsApi, model, inputType, deploymentId);
super(trainedModelsApi, model, inputType, deploymentId, telemetryClient);
this.initialize();
}

View file

@ -14,6 +14,7 @@ import type { TextClassificationResponse, RawTextClassificationResponse } from '
import { getGeneralInputComponent } from '../text_input';
import { getTextClassificationOutputComponent } from './text_classification_output';
import type { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models';
import type { ITelemetryClient } from '../../../../services/telemetry/types';
export class TextClassificationInference extends InferenceBase<TextClassificationResponse> {
protected inferenceType = SUPPORTED_PYTORCH_TASKS.TEXT_CLASSIFICATION;
@ -31,9 +32,10 @@ export class TextClassificationInference extends InferenceBase<TextClassificatio
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE,
deploymentId: string
deploymentId: string,
telemetryClient: ITelemetryClient
) {
super(trainedModelsApi, model, inputType, deploymentId);
super(trainedModelsApi, model, inputType, deploymentId, telemetryClient);
this.initialize();
}

View file

@ -18,6 +18,7 @@ import type { TextClassificationResponse, RawTextClassificationResponse } from '
import { getZeroShotClassificationInput } from './zero_shot_classification_input';
import { getTextClassificationOutputComponent } from './text_classification_output';
import type { ITelemetryClient } from '../../../../services/telemetry/types';
export class ZeroShotClassificationInference extends InferenceBase<TextClassificationResponse> {
protected inferenceType = SUPPORTED_PYTORCH_TASKS.ZERO_SHOT_CLASSIFICATION;
@ -39,9 +40,10 @@ export class ZeroShotClassificationInference extends InferenceBase<TextClassific
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE,
deploymentId: string
deploymentId: string,
telemetryClient: ITelemetryClient
) {
super(trainedModelsApi, model, inputType, deploymentId);
super(trainedModelsApi, model, inputType, deploymentId, telemetryClient);
this.initialize(
[this.labelsText$.pipe(map((labelsText) => labelsText !== ''))],

View file

@ -13,6 +13,7 @@ import type { InferResponse } from '../inference_base';
import { getGeneralInputComponent } from '../text_input';
import { getTextEmbeddingOutputComponent } from './text_embedding_output';
import type { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models';
import type { ITelemetryClient } from '../../../../services/telemetry/types';
export interface RawTextEmbeddingResponse {
inference_results: Array<{ predicted_value: number[] }>;
@ -43,9 +44,10 @@ export class TextEmbeddingInference extends InferenceBase<TextEmbeddingResponse>
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE,
deploymentId: string
deploymentId: string,
telemetryClient: ITelemetryClient
) {
super(trainedModelsApi, model, inputType, deploymentId);
super(trainedModelsApi, model, inputType, deploymentId, telemetryClient);
this.initialize();
}

View file

@ -15,6 +15,7 @@ import type { INPUT_TYPE } from '../inference_base';
import { InferenceBase, type InferResponse } from '../inference_base';
import { getTextExpansionOutputComponent } from './text_expansion_output';
import { getTextExpansionInput } from './text_expansion_input';
import type { ITelemetryClient } from '../../../../services/telemetry/types';
export interface TextExpansionPair {
token: string;
@ -53,9 +54,10 @@ export class TextExpansionInference extends InferenceBase<TextExpansionResponse>
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE,
deploymentId: string
deploymentId: string,
telemetryClient: ITelemetryClient
) {
super(trainedModelsApi, model, inputType, deploymentId);
super(trainedModelsApi, model, inputType, deploymentId, telemetryClient);
this.initialize(
[this.queryText$.pipe(map((questionText) => questionText !== ''))],

View file

@ -35,6 +35,7 @@ import {
isMlIngestInferenceProcessor,
isMlInferencePipelineInferenceConfig,
} from '../create_pipeline_for_model/get_inference_properties_from_pipeline_config';
import { useMlTelemetryClient } from '../../contexts/ml/ml_telemetry_context';
interface Props {
model: estypes.MlTrainedModelConfig;
@ -54,6 +55,7 @@ export const SelectedModel: FC<Props> = ({
setCurrentContext,
}) => {
const { trainedModels } = useMlApi();
const { telemetryClient } = useMlTelemetryClient();
const inferrer = useMemo<InferrerType | undefined>(() => {
const taskType = Object.keys(model.inference_config ?? {})[0];
@ -65,14 +67,21 @@ export const SelectedModel: FC<Props> = ({
if (model.model_type === TRAINED_MODEL_TYPE.PYTORCH) {
switch (taskType) {
case SUPPORTED_PYTORCH_TASKS.NER:
tempInferrer = new NerInference(trainedModels, model, inputType, deploymentId);
tempInferrer = new NerInference(
trainedModels,
model,
inputType,
deploymentId,
telemetryClient
);
break;
case SUPPORTED_PYTORCH_TASKS.TEXT_CLASSIFICATION:
tempInferrer = new TextClassificationInference(
trainedModels,
model,
inputType,
deploymentId
deploymentId,
telemetryClient
);
break;
case SUPPORTED_PYTORCH_TASKS.ZERO_SHOT_CLASSIFICATION:
@ -80,7 +89,8 @@ export const SelectedModel: FC<Props> = ({
trainedModels,
model,
inputType,
deploymentId
deploymentId,
telemetryClient
);
if (pipelineConfigValues) {
const { labels, multi_label: multiLabel } = pipelineConfigValues;
@ -91,30 +101,55 @@ export const SelectedModel: FC<Props> = ({
}
break;
case SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING:
tempInferrer = new TextEmbeddingInference(trainedModels, model, inputType, deploymentId);
tempInferrer = new TextEmbeddingInference(
trainedModels,
model,
inputType,
deploymentId,
telemetryClient
);
break;
case SUPPORTED_PYTORCH_TASKS.FILL_MASK:
tempInferrer = new FillMaskInference(trainedModels, model, inputType, deploymentId);
tempInferrer = new FillMaskInference(
trainedModels,
model,
inputType,
deploymentId,
telemetryClient
);
break;
case SUPPORTED_PYTORCH_TASKS.QUESTION_ANSWERING:
tempInferrer = new QuestionAnsweringInference(
trainedModels,
model,
inputType,
deploymentId
deploymentId,
telemetryClient
);
if (pipelineConfigValues?.question) {
tempInferrer.setQuestionText(pipelineConfigValues.question);
}
break;
case SUPPORTED_PYTORCH_TASKS.TEXT_EXPANSION:
tempInferrer = new TextExpansionInference(trainedModels, model, inputType, deploymentId);
tempInferrer = new TextExpansionInference(
trainedModels,
model,
inputType,
deploymentId,
telemetryClient
);
break;
default:
break;
}
} else if (model.model_type === TRAINED_MODEL_TYPE.LANG_IDENT) {
tempInferrer = new LangIdentInference(trainedModels, model, inputType, deploymentId);
tempInferrer = new LangIdentInference(
trainedModels,
model,
inputType,
deploymentId,
telemetryClient
);
}
if (tempInferrer) {
if (pipelineConfigValues) {

View file

@ -89,7 +89,7 @@ describe('TrainedModelsService', () => {
mockTelemetryService = {
trackTrainedModelsDeploymentCreated: jest.fn(),
};
} as unknown as jest.Mocked<ITelemetryClient>;
mockTrainedModelsApiService = {
getTrainedModelsList: jest.fn(),

View file

@ -6,6 +6,7 @@
*/
import type { SchemaObject } from '@elastic/ebt';
import type { TrainedModelsModelTestedEbtProps } from './types';
import {
TrainedModelsTelemetryEventTypes,
type TrainedModelsDeploymentEbtProps,
@ -66,11 +67,47 @@ const trainedModelsDeploymentSchema: SchemaObject<TrainedModelsDeploymentEbtProp
},
};
const trainedModelsModelTestedSchema: SchemaObject<TrainedModelsModelTestedEbtProps>['properties'] =
{
model_id: {
type: 'keyword',
_meta: {
description: 'The ID of the trained model',
},
},
model_type: {
type: 'keyword',
_meta: {
description: 'The type of the trained model',
optional: true,
},
},
task_type: {
type: 'keyword',
_meta: {
description: 'The type of the task',
optional: true,
},
},
result: {
type: 'keyword',
_meta: {
description: 'The result of the task',
},
},
};
const trainedModelsDeploymentCreatedEventType: TrainedModelsTelemetryEvent = {
eventType: TrainedModelsTelemetryEventTypes.DEPLOYMENT_CREATED,
schema: trainedModelsDeploymentSchema,
};
const trainedModelsModelTestedEventType: TrainedModelsTelemetryEvent = {
eventType: TrainedModelsTelemetryEventTypes.MODEL_TESTED,
schema: trainedModelsModelTestedSchema,
};
export const trainedModelsEbtEvents = {
trainedModelsDeploymentCreatedEventType,
trainedModelsModelTestedEventType,
};

View file

@ -6,7 +6,11 @@
*/
import type { AnalyticsServiceSetup } from '@kbn/core-analytics-browser';
import type { ITelemetryClient, TrainedModelsDeploymentEbtProps } from './types';
import type {
ITelemetryClient,
TrainedModelsDeploymentEbtProps,
TrainedModelsModelTestedEbtProps,
} from './types';
import { TrainedModelsTelemetryEventTypes } from './types';
export class TelemetryClient implements ITelemetryClient {
@ -15,4 +19,8 @@ export class TelemetryClient implements ITelemetryClient {
public trackTrainedModelsDeploymentCreated = (eventProps: TrainedModelsDeploymentEbtProps) => {
this.analytics.reportEvent(TrainedModelsTelemetryEventTypes.DEPLOYMENT_CREATED, eventProps);
};
public trackTrainedModelsModelTested = (eventProps: TrainedModelsModelTestedEbtProps) => {
this.analytics.reportEvent(TrainedModelsTelemetryEventTypes.MODEL_TESTED, eventProps);
};
}

View file

@ -23,6 +23,7 @@ export class TelemetryService {
this.analytics = analytics;
analytics.registerEventType(trainedModelsEbtEvents.trainedModelsDeploymentCreatedEventType);
analytics.registerEventType(trainedModelsEbtEvents.trainedModelsModelTestedEventType);
}
public start(): ITelemetryClient {

View file

@ -6,6 +6,7 @@
*/
import type { RootSchema } from '@kbn/core/public';
import type { TrainedModelType } from '@kbn/ml-trained-models-utils';
export interface TrainedModelsDeploymentEbtProps {
model_id: string;
@ -18,15 +19,29 @@ export interface TrainedModelsDeploymentEbtProps {
vcpu_usage: 'low' | 'medium' | 'high';
}
export enum TrainedModelsTelemetryEventTypes {
DEPLOYMENT_CREATED = 'Trained Models Deployment Created',
export interface TrainedModelsModelTestedEbtProps {
model_id: string;
model_type?: TrainedModelType;
task_type?: string;
result: 'success' | 'failure';
}
export interface TrainedModelsTelemetryEvent {
eventType: TrainedModelsTelemetryEventTypes.DEPLOYMENT_CREATED;
schema: RootSchema<TrainedModelsDeploymentEbtProps>;
export enum TrainedModelsTelemetryEventTypes {
DEPLOYMENT_CREATED = 'Trained Models Deployment Created',
MODEL_TESTED = 'Trained Model Tested',
}
export type TrainedModelsTelemetryEvent =
| {
eventType: TrainedModelsTelemetryEventTypes.DEPLOYMENT_CREATED;
schema: RootSchema<TrainedModelsDeploymentEbtProps>;
}
| {
eventType: TrainedModelsTelemetryEventTypes.MODEL_TESTED;
schema: RootSchema<TrainedModelsModelTestedEbtProps>;
};
export interface ITelemetryClient {
trackTrainedModelsDeploymentCreated: (eventProps: TrainedModelsDeploymentEbtProps) => void;
trackTrainedModelsModelTested: (eventProps: TrainedModelsModelTestedEbtProps) => void;
}