[Enterprise Search] Add ML Inference Pipeline - Review Step (#142133)

* refactor: ml inference config live validation

migration validation from when you click create to be a selector that
updates as the form values are updated. This is to make it easier to
enable/disable the continue button as we introduce steps to the ml
inference modal.

* ml inference modal: introduce steps

Added footer, steps and placeholder components for the test and review
steps of the add ml inference pipeline modal.

* refactor: abstract ml inference body gen to common

Abstracted the code to generate the ml inference pipeline body from the
server to common utility functions. We will need these functions in the
frontend to display the JSON for review. And we want to use the same
code on the frontend and backend.

* add ml inference pipeline review

Added review component for the ml inference pipeline using a selector in
the kea logic to generate the pipeline. In the future this may need to
be a part of the state so it can be modified, but for now a selector
seemed to fit better when it's read-only.

Co-authored-by: Kibana Machine <42973632+kibanamachine@users.noreply.github.com>
This commit is contained in:
Rodney Norris 2022-09-30 08:32:23 -05:00 committed by GitHub
parent 85463eae67
commit 6157f0be86
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 445 additions and 168 deletions

View file

@ -0,0 +1,52 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { MlTrainedModelConfig } from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import { BUILT_IN_MODEL_TAG } from '@kbn/ml-plugin/common/constants/data_frame_analytics';
import { getMlModelTypesForModelConfig, BUILT_IN_MODEL_TAG as LOCAL_BUILT_IN_MODEL_TAG } from '.';
describe('getMlModelTypesForModelConfig lib function', () => {
const mockModel: MlTrainedModelConfig = {
inference_config: {
ner: {},
},
input: {
field_names: [],
},
model_id: 'test_id',
model_type: 'pytorch',
tags: ['test_tag'],
};
const builtInMockModel: MlTrainedModelConfig = {
inference_config: {
text_classification: {},
},
input: {
field_names: [],
},
model_id: 'test_id',
model_type: 'lang_ident',
tags: [BUILT_IN_MODEL_TAG],
};
it('should return the model type and inference config type', () => {
const expected = ['pytorch', 'ner'];
const response = getMlModelTypesForModelConfig(mockModel);
expect(response.sort()).toEqual(expected.sort());
});
it('should include the built in type', () => {
const expected = ['lang_ident', 'text_classification', BUILT_IN_MODEL_TAG];
const response = getMlModelTypesForModelConfig(builtInMockModel);
expect(response.sort()).toEqual(expected.sort());
});
it('local BUILT_IN_MODEL_TAG matches ml plugin', () => {
expect(LOCAL_BUILT_IN_MODEL_TAG).toEqual(BUILT_IN_MODEL_TAG);
});
});

View file

@ -0,0 +1,96 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { MlTrainedModelConfig } from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import { MlInferencePipeline } from '../types/pipelines';
// Getting an error importing this from @kbn/ml-plugin/common/constants/data_frame_analytics'
// So defining it locally for now with a test to make sure it matches.
export const BUILT_IN_MODEL_TAG = 'prepackaged';
export interface MlInferencePipelineParams {
description?: string;
destinationField: string;
model: MlTrainedModelConfig;
pipelineName: string;
sourceField: string;
}
/**
* Generates the pipeline body for a machine learning inference pipeline
* @param pipelineConfiguration machine learning inference pipeline configuration parameters
* @returns pipeline body
*/
export const generateMlInferencePipelineBody = ({
description,
destinationField,
model,
pipelineName,
sourceField,
}: MlInferencePipelineParams): MlInferencePipeline => {
// if model returned no input field, insert a placeholder
const modelInputField =
model.input?.field_names?.length > 0 ? model.input.field_names[0] : 'MODEL_INPUT_FIELD';
return {
description: description ?? '',
processors: [
{
remove: {
field: `ml.inference.${destinationField}`,
ignore_missing: true,
},
},
{
inference: {
field_map: {
[sourceField]: modelInputField,
},
model_id: model.model_id,
target_field: `ml.inference.${destinationField}`,
},
},
{
append: {
field: '_source._ingest.processors',
value: [
{
model_version: model.version,
pipeline: pipelineName,
processed_timestamp: '{{{ _ingest.timestamp }}}',
types: getMlModelTypesForModelConfig(model),
},
],
},
},
],
version: 1,
};
};
/**
* Parses model types list from the given configuration of a trained machine learning model
* @param trainedModel configuration for a trained machine learning model
* @returns list of model types
*/
export const getMlModelTypesForModelConfig = (trainedModel: MlTrainedModelConfig): string[] => {
if (!trainedModel) return [];
const isBuiltIn = trainedModel.tags?.includes(BUILT_IN_MODEL_TAG);
return [
trainedModel.model_type,
...Object.keys(trainedModel.inference_config || {}),
...(isBuiltIn ? [BUILT_IN_MODEL_TAG] : []),
].filter((type): type is string => type !== undefined);
};
export const formatPipelineName = (rawName: string) =>
rawName
.trim()
.replace(/\s+/g, '_') // Convert whitespaces to underscores
.toLowerCase();

View file

@ -5,6 +5,8 @@
* 2.0.
*/
import { IngestPipeline } from '@elastic/elasticsearch/lib/api/types';
export interface InferencePipeline {
modelState: TrainedModelState;
modelStateReason?: string;
@ -19,3 +21,7 @@ export enum TrainedModelState {
Started = 'started',
Failed = 'failed',
}
export interface MlInferencePipeline extends IngestPipeline {
version?: number;
}

View file

@ -15,6 +15,8 @@ import {
EuiCallOut,
EuiFlexGroup,
EuiFlexItem,
EuiStepsHorizontal,
EuiStepsHorizontalProps,
EuiModal,
EuiModalBody,
EuiModalFooter,
@ -26,11 +28,18 @@ import {
import { i18n } from '@kbn/i18n';
import {
BACK_BUTTON_LABEL,
CANCEL_BUTTON_LABEL,
CONTINUE_BUTTON_LABEL,
} from '../../../../../shared/constants';
import { IndexNameLogic } from '../../index_name_logic';
import { ConfigurePipeline } from './configure_pipeline';
import { MLInferenceLogic, AddInferencePipelineModal } from './ml_inference_logic';
import { AddInferencePipelineSteps, MLInferenceLogic } from './ml_inference_logic';
import { NoModelsPanel } from './no_models';
import { ReviewPipeline } from './review_pipeline';
import { TestPipeline } from './test_pipeline';
interface AddMLInferencePipelineModalProps {
onClose: () => void;
@ -64,15 +73,13 @@ export const AddMLInferencePipelineModal: React.FC<AddMLInferencePipelineModalPr
);
};
const isProcessorConfigurationInvalid = ({ configuration }: AddInferencePipelineModal): boolean => {
const { pipelineName, modelID, sourceField } = configuration;
return pipelineName.trim().length === 0 || modelID.length === 0 || sourceField.length === 0;
};
const AddProcessorContent: React.FC<AddMLInferencePipelineModalProps> = ({ onClose }) => {
const { addInferencePipelineModal, createErrors, supportedMLModels, isLoading } =
useValues(MLInferenceLogic);
const { createPipeline } = useActions(MLInferenceLogic);
const {
createErrors,
supportedMLModels,
isLoading,
addInferencePipelineModal: { step },
} = useValues(MLInferenceLogic);
if (isLoading) {
return (
<EuiModalBody>
@ -103,37 +110,126 @@ const AddProcessorContent: React.FC<AddMLInferencePipelineModalProps> = ({ onClo
<EuiSpacer />
</>
)}
<ConfigurePipeline />
<ModalSteps />
{step === AddInferencePipelineSteps.Configuration && <ConfigurePipeline />}
{step === AddInferencePipelineSteps.Test && <TestPipeline />}
{step === AddInferencePipelineSteps.Review && <ReviewPipeline />}
</EuiModalBody>
<EuiModalFooter>
<EuiFlexGroup>
<EuiFlexItem />
<EuiFlexItem grow={false}>
<EuiButtonEmpty onClick={onClose}>
{i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.footer.cancel',
{
defaultMessage: 'Cancel',
}
)}
</EuiButtonEmpty>
</EuiFlexItem>
<EuiFlexItem grow={false}>
<EuiButton
color="success"
disabled={isProcessorConfigurationInvalid(addInferencePipelineModal)}
onClick={createPipeline}
<ModalFooter onClose={onClose} />
</>
);
};
const ModalSteps: React.FC = () => {
const {
addInferencePipelineModal: { step },
isPipelineDataValid,
} = useValues(MLInferenceLogic);
const { setAddInferencePipelineStep } = useActions(MLInferenceLogic);
const navSteps: EuiStepsHorizontalProps['steps'] = [
{
onClick: () => setAddInferencePipelineStep(AddInferencePipelineSteps.Configuration),
status: isPipelineDataValid ? 'complete' : 'disabled',
title: i18n.translate(
'xpack.enterpriseSearch.content.indices.transforms.addInferencePipelineModal.steps.configure.title',
{
defaultMessage: 'Configure',
}
),
},
{
onClick: () => setAddInferencePipelineStep(AddInferencePipelineSteps.Test),
status: isPipelineDataValid ? 'incomplete' : 'disabled',
title: i18n.translate(
'xpack.enterpriseSearch.content.indices.transforms.addInferencePipelineModal.steps.test.title',
{
defaultMessage: 'Test',
}
),
},
{
onClick: () => setAddInferencePipelineStep(AddInferencePipelineSteps.Review),
status: isPipelineDataValid ? 'incomplete' : 'disabled',
title: i18n.translate(
'xpack.enterpriseSearch.content.indices.transforms.addInferencePipelineModal.steps.review.title',
{
defaultMessage: 'Review',
}
),
},
];
switch (step) {
case AddInferencePipelineSteps.Configuration:
navSteps[0].status = isPipelineDataValid ? 'complete' : 'current';
break;
case AddInferencePipelineSteps.Test:
navSteps[1].status = 'current';
break;
case AddInferencePipelineSteps.Review:
navSteps[2].status = 'current';
break;
}
return <EuiStepsHorizontal steps={navSteps} />;
};
const ModalFooter: React.FC<AddMLInferencePipelineModalProps> = ({ onClose }) => {
const { addInferencePipelineModal: modal, isPipelineDataValid } = useValues(MLInferenceLogic);
const { createPipeline, setAddInferencePipelineStep } = useActions(MLInferenceLogic);
let nextStep: AddInferencePipelineSteps | undefined;
let previousStep: AddInferencePipelineSteps | undefined;
switch (modal.step) {
case AddInferencePipelineSteps.Test:
nextStep = AddInferencePipelineSteps.Review;
previousStep = AddInferencePipelineSteps.Configuration;
break;
case AddInferencePipelineSteps.Review:
previousStep = AddInferencePipelineSteps.Test;
break;
case AddInferencePipelineSteps.Configuration:
nextStep = AddInferencePipelineSteps.Test;
break;
}
return (
<EuiModalFooter>
<EuiFlexGroup>
<EuiFlexItem grow={false}>
{previousStep !== undefined ? (
<EuiButtonEmpty
flush="both"
iconType="arrowLeft"
onClick={() => setAddInferencePipelineStep(previousStep as AddInferencePipelineSteps)}
>
{BACK_BUTTON_LABEL}
</EuiButtonEmpty>
) : null}
</EuiFlexItem>
<EuiFlexItem />
<EuiFlexItem grow={false}>
<EuiButtonEmpty onClick={onClose}>{CANCEL_BUTTON_LABEL}</EuiButtonEmpty>
</EuiFlexItem>
<EuiFlexItem grow={false}>
{nextStep !== undefined ? (
<EuiButton
iconType="arrowRight"
iconSide="right"
onClick={() => setAddInferencePipelineStep(nextStep as AddInferencePipelineSteps)}
disabled={!isPipelineDataValid}
>
{CONTINUE_BUTTON_LABEL}
</EuiButton>
) : (
<EuiButton color="success" disabled={!isPipelineDataValid} onClick={createPipeline}>
{i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.footer.create',
'xpack.enterpriseSearch.content.indices.transforms.addInferencePipelineModal.footer.create',
{
defaultMessage: 'Create',
}
)}
</EuiButton>
</EuiFlexItem>
</EuiFlexGroup>
</EuiModalFooter>
</>
)}
</EuiFlexItem>
</EuiFlexGroup>
</EuiModalFooter>
);
};

View file

@ -36,6 +36,7 @@ export const ConfigurePipeline: React.FC = () => {
const { destinationField, modelID, pipelineName, sourceField } = configuration;
const models = supportedMLModels ?? [];
const nameError = formErrors.pipelineName !== undefined && pipelineName.length > 0;
return (
<>
@ -73,7 +74,7 @@ export const ConfigurePipeline: React.FC = () => {
}
)}
helpText={
formErrors.pipelineName === undefined &&
!nameError &&
i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.name.helpText',
{
@ -82,8 +83,8 @@ export const ConfigurePipeline: React.FC = () => {
}
)
}
error={formErrors.pipelineName}
isInvalid={formErrors.pipelineName !== undefined}
error={nameError && formErrors.pipelineName}
isInvalid={nameError}
>
<EuiFieldText
fullWidth

View file

@ -11,7 +11,12 @@ import { IndicesGetMappingIndexMappingRecord } from '@elastic/elasticsearch/lib/
import { TrainedModelConfigResponse } from '@kbn/ml-plugin/common/types/trained_models';
import {
formatPipelineName,
generateMlInferencePipelineBody,
} from '../../../../../../../common/ml_inference_pipeline';
import { HttpError, Status } from '../../../../../../../common/types/api';
import { MlInferencePipeline } from '../../../../../../../common/types/pipelines';
import { generateEncodedPath } from '../../../../../shared/encode_path_params';
import { getErrorsFromHttpResponse } from '../../../../../shared/flash_messages/handle_api_errors';
@ -37,10 +42,15 @@ export const EMPTY_PIPELINE_CONFIGURATION: InferencePipelineConfiguration = {
sourceField: '',
};
export enum AddInferencePipelineSteps {
Configuration,
Test,
Review,
}
const API_REQUEST_COMPLETE_STATUSES = [Status.SUCCESS, Status.ERROR];
interface MLInferenceProcessorsActions {
clearFormErrors: () => void;
createApiError: (error: HttpError) => HttpError;
createApiSuccess: typeof CreateMlInferencePipelineApiLogic.actions.apiSuccess;
createPipeline: () => void;
@ -49,10 +59,10 @@ interface MLInferenceProcessorsActions {
makeMappingRequest: typeof MappingsApiLogic.actions.makeRequest;
mappingsApiError(error: HttpError): HttpError;
mlModelsApiError(error: HttpError): HttpError;
setCreateErrors(errors: string[]): { errors: string[] };
setFormErrors: (inputErrors: AddInferencePipelineFormErrors) => {
inputErrors: AddInferencePipelineFormErrors;
setAddInferencePipelineStep: (step: AddInferencePipelineSteps) => {
step: AddInferencePipelineSteps;
};
setCreateErrors(errors: string[]): { errors: string[] };
setIndexName: (indexName: string) => { indexName: string };
setInferencePipelineConfiguration: (configuration: InferencePipelineConfiguration) => {
configuration: InferencePipelineConfiguration;
@ -62,6 +72,7 @@ interface MLInferenceProcessorsActions {
export interface AddInferencePipelineModal {
configuration: InferencePipelineConfiguration;
indexName: string;
step: AddInferencePipelineSteps;
}
interface MLInferenceProcessorsValues {
@ -69,8 +80,10 @@ interface MLInferenceProcessorsValues {
createErrors: string[];
formErrors: AddInferencePipelineFormErrors;
isLoading: boolean;
isPipelineDataValid: boolean;
mappingData: typeof MappingsApiLogic.values.data;
mappingStatus: Status;
mlInferencePipeline?: MlInferencePipeline;
mlModelsData: typeof MLModelsApiLogic.values.data;
mlModelsStatus: typeof MLModelsApiLogic.values.apiStatus;
sourceFields: string[] | undefined;
@ -83,6 +96,7 @@ export const MLInferenceLogic = kea<
actions: {
clearFormErrors: true,
createPipeline: true,
setAddInferencePipelineStep: (step: AddInferencePipelineSteps) => ({ step }),
setCreateErrors: (errors: string[]) => ({ errors }),
setFormErrors: (inputErrors: AddInferencePipelineFormErrors) => ({ inputErrors }),
setIndexName: (indexName: string) => ({ indexName }),
@ -124,12 +138,6 @@ export const MLInferenceLogic = kea<
const {
addInferencePipelineModal: { configuration, indexName },
} = values;
const validationErrors = validateInferencePipelineConfiguration(configuration);
if (validationErrors !== undefined) {
actions.setFormErrors(validationErrors);
return;
}
actions.clearFormErrors();
actions.makeCreatePipelineRequest({
indexName,
@ -155,8 +163,10 @@ export const MLInferenceLogic = kea<
...EMPTY_PIPELINE_CONFIGURATION,
},
indexName: '',
step: AddInferencePipelineSteps.Configuration,
},
{
setAddInferencePipelineStep: (modal, { step }) => ({ ...modal, step }),
setIndexName: (modal, { indexName }) => ({ ...modal, indexName }),
setInferencePipelineConfiguration: (modal, { configuration }) => ({
...modal,
@ -171,21 +181,47 @@ export const MLInferenceLogic = kea<
setCreateErrors: (_, { errors }) => errors,
},
],
formErrors: [
{},
{
clearFormErrors: () => ({}),
setFormErrors: (_, { inputErrors }) => inputErrors,
},
],
},
selectors: ({ selectors }) => ({
formErrors: [
() => [selectors.addInferencePipelineModal],
(modal: AddInferencePipelineModal) =>
validateInferencePipelineConfiguration(modal.configuration),
],
isLoading: [
() => [selectors.mlModelsStatus, selectors.mappingStatus],
(mlModelsStatus, mappingStatus) =>
!API_REQUEST_COMPLETE_STATUSES.includes(mlModelsStatus) ||
!API_REQUEST_COMPLETE_STATUSES.includes(mappingStatus),
],
isPipelineDataValid: [
() => [selectors.formErrors],
(errors: AddInferencePipelineFormErrors) => Object.keys(errors).length === 0,
],
mlInferencePipeline: [
() => [
selectors.isPipelineDataValid,
selectors.addInferencePipelineModal,
selectors.mlModelsData,
],
(
isPipelineDataValid: boolean,
{ configuration }: AddInferencePipelineModal,
models: MLInferenceProcessorsValues['mlModelsData']
) => {
if (!isPipelineDataValid) return undefined;
const model = models?.find((mlModel) => mlModel.model_id === configuration.modelID);
if (!model) return undefined;
return generateMlInferencePipelineBody({
destinationField:
configuration.destinationField || formatPipelineName(configuration.pipelineName),
model,
pipelineName: configuration.pipelineName,
sourceField: configuration.sourceField,
});
},
],
sourceFields: [
() => [selectors.mappingStatus, selectors.mappingData],
(status: Status, mapping: IndicesGetMappingIndexMappingRecord) => {

View file

@ -0,0 +1,49 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import React from 'react';
import { useValues } from 'kea';
import { EuiCodeBlock, EuiSpacer, EuiText } from '@elastic/eui';
import { i18n } from '@kbn/i18n';
import { MLInferenceLogic } from './ml_inference_logic';
export const ReviewPipeline: React.FC = () => {
const { mlInferencePipeline } = useValues(MLInferenceLogic);
return (
<>
<EuiText>
<h4>
{i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.review.title',
{
defaultMessage: 'Pipeline configuration',
}
)}
</h4>
</EuiText>
<EuiCodeBlock language="json" isCopyable overflowHeight={300}>
{JSON.stringify(mlInferencePipeline ?? {}, null, 2)}
</EuiCodeBlock>
<EuiSpacer />
<EuiText>
<p>
{i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.review.description',
{
defaultMessage:
"This pipeline will be created and injected as a processor into your default pipeline for this index. You'll be able to use this new pipeline independently as well.",
}
)}
</p>
</EuiText>
</>
);
};

View file

@ -0,0 +1,12 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import React from 'react';
export const TestPipeline: React.FC = () => {
return <div>Test Pipeline</div>;
};

View file

@ -14,5 +14,7 @@ export interface InferencePipelineConfiguration {
export interface AddInferencePipelineFormErrors {
destinationField?: string;
modelID?: string;
pipelineName?: string;
sourceField?: string;
}

View file

@ -41,17 +41,34 @@ export const isValidPipelineName = (input: string): boolean => {
return input.length > 0 && VALID_PIPELINE_NAME_REGEX.test(input);
};
const INVALID_PIPELINE_NAME_ERROR = i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.invalidPipelineName',
{
defaultMessage: 'Name must only contain letters, numbers, underscores, and hyphens.',
}
);
const FIELD_REQUIRED_ERROR = i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.emptyValueError',
{
defaultMessage: 'Field is required.',
}
);
export const validateInferencePipelineConfiguration = (
config: InferencePipelineConfiguration
): AddInferencePipelineFormErrors | undefined => {
): AddInferencePipelineFormErrors => {
const errors: AddInferencePipelineFormErrors = {};
if (!isValidPipelineName(config.pipelineName)) {
errors.pipelineName = i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.invalidPipelineName',
{
defaultMessage: 'Name must only contain letters, numbers, underscores, and hyphens.',
}
);
return errors;
if (config.pipelineName.trim().length === 0) {
errors.pipelineName = FIELD_REQUIRED_ERROR;
} else if (!isValidPipelineName(config.pipelineName)) {
errors.pipelineName = INVALID_PIPELINE_NAME_ERROR;
}
if (config.modelID.trim().length === 0) {
errors.modelID = FIELD_REQUIRED_ERROR;
}
if (config.sourceField.trim().length === 0) {
errors.sourceField = FIELD_REQUIRED_ERROR;
}
return errors;
};

View file

@ -45,6 +45,10 @@ export const CONTINUE_BUTTON_LABEL = i18n.translate(
{ defaultMessage: 'Continue' }
);
export const BACK_BUTTON_LABEL = i18n.translate('xpack.enterpriseSearch.actions.backButtonLabel', {
defaultMessage: 'Back',
});
export const CLOSE_BUTTON_LABEL = i18n.translate(
'xpack.enterpriseSearch.actions.closeButtonLabel',
{ defaultMessage: 'Close' }

View file

@ -5,15 +5,12 @@
* 2.0.
*/
import { MlTrainedModelConfig } from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import { ElasticsearchClient } from '@kbn/core/server';
import { BUILT_IN_MODEL_TAG } from '@kbn/ml-plugin/common/constants/data_frame_analytics';
import { InferencePipeline, TrainedModelState } from '../../../common/types/pipelines';
import {
fetchAndAddTrainedModelData,
getMlModelTypesForModelConfig,
getMlModelConfigsForModelIds,
fetchMlInferencePipelineProcessorNames,
fetchMlInferencePipelineProcessors,
@ -320,43 +317,6 @@ describe('fetchPipelineProcessorInferenceData lib function', () => {
});
});
describe('getMlModelTypesForModelConfig lib function', () => {
const mockModel: MlTrainedModelConfig = {
inference_config: {
ner: {},
},
input: {
field_names: [],
},
model_id: 'test_id',
model_type: 'pytorch',
tags: ['test_tag'],
};
const builtInMockModel: MlTrainedModelConfig = {
inference_config: {
text_classification: {},
},
input: {
field_names: [],
},
model_id: 'test_id',
model_type: 'lang_ident',
tags: [BUILT_IN_MODEL_TAG],
};
it('should return the model type and inference config type', () => {
const expected = ['pytorch', 'ner'];
const response = getMlModelTypesForModelConfig(mockModel);
expect(response.sort()).toEqual(expected.sort());
});
it('should include the built in type', () => {
const expected = ['lang_ident', 'text_classification', BUILT_IN_MODEL_TAG];
const response = getMlModelTypesForModelConfig(builtInMockModel);
expect(response.sort()).toEqual(expected.sort());
});
});
describe('getMlModelConfigsForModelIds lib function', () => {
const mockClient = {
ml: {

View file

@ -5,10 +5,9 @@
* 2.0.
*/
import { MlTrainedModelConfig } from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import { ElasticsearchClient } from '@kbn/core/server';
import { BUILT_IN_MODEL_TAG } from '@kbn/ml-plugin/common/constants/data_frame_analytics';
import { getMlModelTypesForModelConfig } from '../../../common/ml_inference_pipeline';
import { InferencePipeline, TrainedModelState } from '../../../common/types/pipelines';
import { getInferencePipelineNameFromIndexName } from '../../utils/ml_inference_pipeline_utils';
@ -70,18 +69,6 @@ export const fetchPipelineProcessorInferenceData = async (
);
};
export const getMlModelTypesForModelConfig = (trainedModel: MlTrainedModelConfig): string[] => {
if (!trainedModel) return [];
const isBuiltIn = trainedModel.tags?.includes(BUILT_IN_MODEL_TAG);
return [
trainedModel.model_type,
...Object.keys(trainedModel.inference_config || {}),
...(isBuiltIn ? [BUILT_IN_MODEL_TAG] : []),
].filter((type): type is string => type !== undefined);
};
export const getMlModelConfigsForModelIds = async (
client: ElasticsearchClient,
trainedModelNames: string[]

View file

@ -5,20 +5,16 @@
* 2.0.
*/
import { IngestPipeline } from '@elastic/elasticsearch/lib/api/types';
import { ElasticsearchClient } from '@kbn/core/server';
import { generateMlInferencePipelineBody } from '../../../common/ml_inference_pipeline';
import { MlInferencePipeline } from '../../../common/types/pipelines';
import { getInferencePipelineNameFromIndexName } from '../../utils/ml_inference_pipeline_utils';
import { getMlModelTypesForModelConfig } from '../indices/fetch_ml_inference_pipeline_processors';
export interface CreatedPipelines {
created: string[];
}
export interface MlInferencePipeline extends IngestPipeline {
version?: number;
}
/**
* Used to create index-specific Ingest Pipelines to be used in conjunction with Enterprise Search
* ingestion mechanisms. Three pipelines are created:
@ -237,43 +233,10 @@ export const formatMlPipelineBody = async (
// this will raise a 404 if model doesn't exist
const models = await esClient.ml.getTrainedModels({ model_id: modelId });
const model = models.trained_model_configs[0];
// if model returned no input field, insert a placeholder
const modelInputField =
model.input?.field_names?.length > 0 ? model.input.field_names[0] : 'MODEL_INPUT_FIELD';
const modelTypes = getMlModelTypesForModelConfig(model);
const modelVersion = model.version;
return {
description: '',
processors: [
{
remove: {
field: `ml.inference.${destinationField}`,
ignore_missing: true,
},
},
{
inference: {
field_map: {
[sourceField]: modelInputField,
},
model_id: modelId,
target_field: `ml.inference.${destinationField}`,
},
},
{
append: {
field: '_source._ingest.processors',
value: [
{
model_version: modelVersion,
pipeline: pipelineName,
processed_timestamp: '{{{ _ingest.timestamp }}}',
types: modelTypes,
},
],
},
},
],
version: 1,
};
return generateMlInferencePipelineBody({
destinationField,
model,
pipelineName,
sourceField,
});
};

View file

@ -8,6 +8,7 @@
import { IngestGetPipelineResponse, IngestPipeline } from '@elastic/elasticsearch/lib/api/types';
import { ElasticsearchClient } from '@kbn/core/server';
import { formatPipelineName } from '../../common/ml_inference_pipeline';
import { ErrorCode } from '../../common/types/error_codes';
import { formatMlPipelineBody } from '../lib/pipelines/create_pipeline_definitions';
@ -15,7 +16,6 @@ import { formatMlPipelineBody } from '../lib/pipelines/create_pipeline_definitio
import {
getInferencePipelineNameFromIndexName,
getPrefixedInferencePipelineProcessorName,
formatPipelineName,
} from './ml_inference_pipeline_utils';
/**

View file

@ -5,6 +5,8 @@
* 2.0.
*/
import { formatPipelineName } from '../../common/ml_inference_pipeline';
export const getInferencePipelineNameFromIndexName = (indexName: string) =>
`${indexName}@ml-inference`;
@ -12,9 +14,3 @@ export const getPrefixedInferencePipelineProcessorName = (pipelineName: string)
pipelineName.startsWith('ml-inference-')
? formatPipelineName(pipelineName)
: `ml-inference-${formatPipelineName(pipelineName)}`;
export const formatPipelineName = (rawName: string) =>
rawName
.trim()
.replace(/\s+/g, '_') // Convert whitespaces to underscores
.toLowerCase();