[Enterprise Search] Attach ML Inference Pipeline - Pipeline re-use (#143979)

* added create ml inference pipeline parameters interface

* updated NLP_CONFIG_KEYS to use common constant as source to match server code

* attach existing ml inference pipeline

Added the ability to choose an existing ml inference pipeline and attach
it to the index. This will re-use the existing pipeline instead of
creating a new one.

* testing ml inference logic

* test parseMlInferenceParametersFromPipeline
This commit is contained in:
Rodney Norris 2022-10-27 12:34:28 -05:00 committed by GitHub
parent 61505e5edd
commit 756916db3f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 1009 additions and 78 deletions

View file

@ -17,6 +17,7 @@ import {
getMlModelTypesForModelConfig,
getSetProcessorForInferenceType,
SUPPORTED_PYTORCH_TASKS as LOCAL_SUPPORTED_PYTORCH_TASKS,
parseMlInferenceParametersFromPipeline,
} from '.';
const mockModel: MlTrainedModelConfig = {
@ -198,3 +199,45 @@ describe('generateMlInferencePipelineBody lib function', () => {
);
});
});
describe('parseMlInferenceParametersFromPipeline', () => {
it('returns pipeline parameters from ingest pipeline', () => {
expect(
parseMlInferenceParametersFromPipeline('unit-test', {
processors: [
{
inference: {
field_map: {
body: 'text_field',
},
model_id: 'test-model',
target_field: 'ml.inference.test',
},
},
],
})
).toEqual({
destination_field: 'test',
model_id: 'test-model',
pipeline_name: 'unit-test',
source_field: 'body',
});
});
it('return null if pipeline missing inference processor', () => {
expect(parseMlInferenceParametersFromPipeline('unit-test', { processors: [] })).toBeNull();
});
it('return null if pipeline missing field_map', () => {
expect(
parseMlInferenceParametersFromPipeline('unit-test', {
processors: [
{
inference: {
model_id: 'test-model',
target_field: 'test',
},
},
],
})
).toBeNull();
});
});

View file

@ -5,9 +5,13 @@
* 2.0.
*/
import { IngestSetProcessor, MlTrainedModelConfig } from '@elastic/elasticsearch/lib/api/types';
import {
IngestPipeline,
IngestSetProcessor,
MlTrainedModelConfig,
} from '@elastic/elasticsearch/lib/api/types';
import { MlInferencePipeline } from '../types/pipelines';
import { MlInferencePipeline, CreateMlInferencePipelineParameters } from '../types/pipelines';
// Getting an error importing this from @kbn/ml-plugin/common/constants/data_frame_analytics'
// So defining it locally for now with a test to make sure it matches.
@ -151,3 +155,25 @@ export const formatPipelineName = (rawName: string) =>
.trim()
.replace(/\s+/g, '_') // Convert whitespaces to underscores
.toLowerCase();
export const parseMlInferenceParametersFromPipeline = (
name: string,
pipeline: IngestPipeline
): CreateMlInferencePipelineParameters | null => {
const processor = pipeline?.processors?.find((proc) => proc.inference !== undefined);
if (!processor || processor?.inference === undefined) {
return null;
}
const { inference: inferenceProcessor } = processor;
const sourceFields = Object.keys(inferenceProcessor.field_map ?? {});
const sourceField = sourceFields.length === 1 ? sourceFields[0] : null;
if (!sourceField) {
return null;
}
return {
destination_field: inferenceProcessor.target_field.replace('ml.inference.', ''),
model_id: inferenceProcessor.model_id,
pipeline_name: name,
source_field: sourceField,
};
};

View file

@ -64,3 +64,10 @@ export interface DeleteMlInferencePipelineResponse {
deleted?: string;
updated?: string;
}
export interface CreateMlInferencePipelineParameters {
destination_field?: string;
model_id: string;
pipeline_name: string;
source_field: string;
}

View file

@ -0,0 +1,47 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { mockHttpValues } from '../../../__mocks__/kea_logic';
import {
attachMlInferencePipeline,
AttachMlInferencePipelineApiLogicArgs,
AttachMlInferencePipelineResponse,
} from './attach_ml_inference_pipeline';
describe('AttachMlInferencePipelineApiLogic', () => {
const { http } = mockHttpValues;
beforeEach(() => {
jest.clearAllMocks();
});
describe('createMlInferencePipeline', () => {
it('calls the api', async () => {
const response: Promise<AttachMlInferencePipelineResponse> = Promise.resolve({
addedToParentPipeline: true,
created: false,
id: 'unit-test',
});
http.post.mockReturnValue(response);
const args: AttachMlInferencePipelineApiLogicArgs = {
indexName: 'unit-test-index',
pipelineName: 'unit-test',
};
const result = await attachMlInferencePipeline(args);
expect(http.post).toHaveBeenCalledWith(
'/internal/enterprise_search/indices/unit-test-index/ml_inference/pipeline_processors/attach',
{
body: '{"pipeline_name":"unit-test"}',
}
);
expect(result).toEqual({
addedToParentPipeline: true,
created: false,
id: args.pipelineName,
});
});
});
});

View file

@ -0,0 +1,36 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { AttachMlInferencePipelineResponse } from '../../../../../common/types/pipelines';
import { createApiLogic } from '../../../shared/api_logic/create_api_logic';
import { HttpLogic } from '../../../shared/http';
export interface AttachMlInferencePipelineApiLogicArgs {
indexName: string;
pipelineName: string;
}
export type { AttachMlInferencePipelineResponse };
export const attachMlInferencePipeline = async (
args: AttachMlInferencePipelineApiLogicArgs
): Promise<AttachMlInferencePipelineResponse> => {
const route = `/internal/enterprise_search/indices/${args.indexName}/ml_inference/pipeline_processors/attach`;
const params = {
pipeline_name: args.pipelineName,
};
return await HttpLogic.values.http.post<AttachMlInferencePipelineResponse>(route, {
body: JSON.stringify(params),
});
};
export const AttachMlInferencePipelineApiLogic = createApiLogic(
['attach_ml_inference_pipeline_api_logic'],
attachMlInferencePipeline
);

View file

@ -4,6 +4,7 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { CreateMlInferencePipelineParameters } from '../../../../../common/types/pipelines';
import { createApiLogic } from '../../../shared/api_logic/create_api_logic';
import { HttpLogic } from '../../../shared/http';
@ -23,7 +24,7 @@ export const createMlInferencePipeline = async (
args: CreateMlInferencePipelineApiLogicArgs
): Promise<CreateMlInferencePipelineResponse> => {
const route = `/internal/enterprise_search/indices/${args.indexName}/ml_inference/pipeline_processors`;
const params = {
const params: CreateMlInferencePipelineParameters = {
destination_field: args.destinationField,
model_id: args.modelId,
pipeline_name: args.pipelineName,

View file

@ -9,10 +9,18 @@ import { InferencePipeline } from '../../../../../common/types/pipelines';
import { createApiLogic } from '../../../shared/api_logic/create_api_logic';
import { HttpLogic } from '../../../shared/http';
export const fetchMlInferencePipelineProcessors = async ({ indexName }: { indexName: string }) => {
export interface FetchMlInferencePipelineProcessorsApiLogicArgs {
indexName: string;
}
export type FetchMlInferencePipelineProcessorsResponse = InferencePipeline[];
export const fetchMlInferencePipelineProcessors = async ({
indexName,
}: FetchMlInferencePipelineProcessorsApiLogicArgs) => {
const route = `/internal/enterprise_search/indices/${indexName}/ml_inference/pipeline_processors`;
return await HttpLogic.values.http.get<InferencePipeline[]>(route);
return await HttpLogic.values.http.get<FetchMlInferencePipelineProcessorsResponse>(route);
};
export const FetchMlInferencePipelineProcessorsApiLogic = createApiLogic(

View file

@ -0,0 +1,24 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { MlInferencePipeline } from '../../../../../common/types/pipelines';
import { createApiLogic } from '../../../shared/api_logic/create_api_logic';
import { HttpLogic } from '../../../shared/http';
export type FetchMlInferencePipelinesArgs = undefined;
export type FetchMlInferencePipelinesResponse = Record<string, MlInferencePipeline | undefined>;
export const fetchMlInferencePipelines = async () => {
const route = '/internal/enterprise_search/pipelines/ml_inference';
return await HttpLogic.values.http.get<FetchMlInferencePipelinesResponse>(route);
};
export const FetchMlInferencePipelinesApiLogic = createApiLogic(
['fetch_ml_inference_pipelines_api_logic'],
fetchMlInferencePipelines
);

View file

@ -92,7 +92,7 @@ const AddProcessorContent: React.FC<AddMLInferencePipelineModalProps> = ({ onClo
</EuiModalBody>
);
}
if (supportedMLModels === undefined || supportedMLModels?.length === 0) {
if (supportedMLModels.length === 0) {
return <NoModelsPanel />;
}
return (
@ -188,8 +188,10 @@ const ModalFooter: React.FC<AddMLInferencePipelineModalProps & { ingestionMethod
onClose,
}) => {
const { addInferencePipelineModal: modal, isPipelineDataValid } = useValues(MLInferenceLogic);
const { createPipeline, setAddInferencePipelineStep } = useActions(MLInferenceLogic);
const { attachPipeline, createPipeline, setAddInferencePipelineStep } =
useActions(MLInferenceLogic);
const attachExistingPipeline = Boolean(modal.configuration.existingPipeline);
let nextStep: AddInferencePipelineSteps | undefined;
let previousStep: AddInferencePipelineSteps | undefined;
switch (modal.step) {
@ -239,6 +241,21 @@ const ModalFooter: React.FC<AddMLInferencePipelineModalProps & { ingestionMethod
>
{CONTINUE_BUTTON_LABEL}
</EuiButton>
) : attachExistingPipeline ? (
<EuiButton
color="primary"
data-telemetry-id={`entSearchContent-${ingestionMethod}-pipelines-addMlInference-attach`}
disabled={!isPipelineDataValid}
fill
onClick={attachPipeline}
>
{i18n.translate(
'xpack.enterpriseSearch.content.indices.transforms.addInferencePipelineModal.footer.attach',
{
defaultMessage: 'Attach',
}
)}
</EuiButton>
) : (
<EuiButton
color="success"

View file

@ -30,10 +30,25 @@ import { docLinks } from '../../../../../shared/doc_links';
import { IndexViewLogic } from '../../index_view_logic';
import { MLInferenceLogic } from './ml_inference_logic';
import { EMPTY_PIPELINE_CONFIGURATION, MLInferenceLogic } from './ml_inference_logic';
import { MlModelSelectOption } from './model_select_option';
import { PipelineSelectOption } from './pipeline_select_option';
const MODEL_SELECT_PLACEHOLDER_VALUE = 'model_placeholder$$';
const PIPELINE_SELECT_PLACEHOLDER_VALUE = 'pipeline_placeholder$$';
const CHOOSE_EXISTING_LABEL = i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.existingPipeline.chooseLabel',
{ defaultMessage: 'Choose' }
);
const CHOOSE_NEW_LABEL = i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.existingPipeline.newLabel',
{ defaultMessage: 'New' }
);
const CHOOSE_PIPELINE_LABEL = i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.existingPipeline.existingLabel',
{ defaultMessage: 'Existing' }
);
const NoSourceFieldsError: React.FC = () => (
<FormattedMessage
@ -56,14 +71,15 @@ export const ConfigurePipeline: React.FC = () => {
const {
addInferencePipelineModal: { configuration },
formErrors,
existingInferencePipelines,
supportedMLModels,
sourceFields,
} = useValues(MLInferenceLogic);
const { setInferencePipelineConfiguration } = useActions(MLInferenceLogic);
const { selectExistingPipeline, setInferencePipelineConfiguration } =
useActions(MLInferenceLogic);
const { ingestionMethod } = useValues(IndexViewLogic);
const { destinationField, modelID, pipelineName, sourceField } = configuration;
const models = supportedMLModels ?? [];
const nameError = formErrors.pipelineName !== undefined && pipelineName.length > 0;
const emptySourceFields = (sourceFields?.length ?? 0) === 0;
@ -76,12 +92,30 @@ export const ConfigurePipeline: React.FC = () => {
),
value: MODEL_SELECT_PLACEHOLDER_VALUE,
},
...models.map((model) => ({
...supportedMLModels.map((model) => ({
dropdownDisplay: <MlModelSelectOption model={model} />,
inputDisplay: model.model_id,
value: model.model_id,
})),
];
const pipelineOptions: Array<EuiSuperSelectOption<string>> = [
{
disabled: true,
inputDisplay: i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.existingPipeline.placeholder',
{ defaultMessage: 'Select one' }
),
value: PIPELINE_SELECT_PLACEHOLDER_VALUE,
},
...(existingInferencePipelines?.map((pipeline) => ({
disabled: pipeline.disabled,
dropdownDisplay: <PipelineSelectOption pipeline={pipeline} />,
inputDisplay: pipeline.pipelineName,
value: pipeline.pipelineName,
})) ?? []),
];
const inputsDisabled = configuration.existingPipeline !== false;
return (
<>
@ -106,45 +140,107 @@ export const ConfigurePipeline: React.FC = () => {
</EuiText>
<EuiSpacer />
<EuiForm component="form">
<EuiFormRow
fullWidth
label={i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.nameLabel',
{
defaultMessage: 'Name',
}
)}
helpText={
!nameError &&
i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.name.helpText',
{
defaultMessage:
'Pipeline names are unique within a deployment and can only contain letters, numbers, underscores, and hyphens. The pipeline name will be automatically prefixed with "ml-inference-".',
}
)
}
error={nameError && formErrors.pipelineName}
isInvalid={nameError}
>
<EuiFieldText
data-telemetry-id={`entSearchContent-${ingestionMethod}-pipelines-configureInferencePipeline-uniqueName`}
fullWidth
placeholder={i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.namePlaceholder',
{
defaultMessage: 'Enter a unique name for this pipeline',
}
<EuiFlexGroup>
<EuiFlexItem grow={false}>
<EuiFormRow
label={i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.chooseExistingLabel',
{ defaultMessage: 'New or existing' }
)}
>
<EuiSelect
options={[
{
disabled: true,
text: CHOOSE_EXISTING_LABEL,
value: '',
},
{
text: CHOOSE_NEW_LABEL,
value: 'false',
},
{
disabled:
!existingInferencePipelines || existingInferencePipelines.length === 0,
text: CHOOSE_PIPELINE_LABEL,
value: 'true',
},
]}
onChange={(e) =>
setInferencePipelineConfiguration({
...EMPTY_PIPELINE_CONFIGURATION,
existingPipeline: e.target.value === 'true',
})
}
/>
</EuiFormRow>
</EuiFlexItem>
<EuiFlexItem>
{configuration.existingPipeline === true ? (
<EuiFormRow
fullWidth
label={i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.existingPipelineLabel',
{
defaultMessage: 'Select an existing inference pipeline',
}
)}
>
<EuiSuperSelect
fullWidth
hasDividers
data-telemetry-id={`entSearchContent-${ingestionMethod}-pipelines-configureInferencePipeline-selectExistingPipeline`}
valueOfSelected={
pipelineName.length > 0 ? pipelineName : PIPELINE_SELECT_PLACEHOLDER_VALUE
}
options={pipelineOptions}
onChange={(value) => selectExistingPipeline(value)}
/>
</EuiFormRow>
) : (
<EuiFormRow
fullWidth
label={i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.nameLabel',
{
defaultMessage: 'Name',
}
)}
helpText={
!nameError &&
i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.name.helpText',
{
defaultMessage:
'Pipeline names are unique within a deployment and can only contain letters, numbers, underscores, and hyphens. The pipeline name will be automatically prefixed with "ml-inference-".',
}
)
}
error={nameError && formErrors.pipelineName}
isInvalid={nameError}
>
<EuiFieldText
data-telemetry-id={`entSearchContent-${ingestionMethod}-pipelines-configureInferencePipeline-uniqueName`}
disabled={inputsDisabled}
fullWidth
placeholder={i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.namePlaceholder',
{
defaultMessage: 'Enter a unique name for this pipeline',
}
)}
value={pipelineName}
onChange={(e) =>
setInferencePipelineConfiguration({
...configuration,
pipelineName: e.target.value,
})
}
/>
</EuiFormRow>
)}
value={pipelineName}
onChange={(e) =>
setInferencePipelineConfiguration({
...configuration,
pipelineName: e.target.value,
})
}
/>
</EuiFormRow>
</EuiFlexItem>
</EuiFlexGroup>
<EuiSpacer />
<EuiFormRow
label={i18n.translate(
@ -159,6 +255,7 @@ export const ConfigurePipeline: React.FC = () => {
data-telemetry-id={`entSearchContent-${ingestionMethod}-pipelines-configureInferencePipeline-selectTrainedModel`}
fullWidth
hasDividers
disabled={inputsDisabled}
itemLayoutAlign="top"
onChange={(value) =>
setInferencePipelineConfiguration({
@ -185,6 +282,7 @@ export const ConfigurePipeline: React.FC = () => {
>
<EuiSelect
data-telemetry-id={`entSearchContent-${ingestionMethod}-pipelines-configureInferencePipeline-selectSchemaField`}
disabled={inputsDisabled}
value={sourceField}
options={[
{
@ -235,6 +333,7 @@ export const ConfigurePipeline: React.FC = () => {
>
<EuiFieldText
data-telemetry-id={`entSearchContent-${ingestionMethod}-pipelines-configureInferencePipeline-destionationField`}
disabled={inputsDisabled}
placeholder="custom_field_name"
value={destinationField}
onChange={(e) =>

View file

@ -7,20 +7,27 @@
import { LogicMounter } from '../../../../../__mocks__/kea_logic';
import { HttpError, Status } from '../../../../../../../common/types/api';
import { HttpResponse } from '@kbn/core/public';
import { ErrorResponse, HttpError, Status } from '../../../../../../../common/types/api';
import { TrainedModelState } from '../../../../../../../common/types/pipelines';
import { MappingsApiLogic } from '../../../../api/mappings/mappings_logic';
import { CreateMlInferencePipelineApiLogic } from '../../../../api/ml_models/create_ml_inference_pipeline';
import { MLModelsApiLogic } from '../../../../api/ml_models/ml_models_logic';
import { AttachMlInferencePipelineApiLogic } from '../../../../api/pipelines/attach_ml_inference_pipeline';
import { CreateMlInferencePipelineApiLogic } from '../../../../api/pipelines/create_ml_inference_pipeline';
import { FetchMlInferencePipelineProcessorsApiLogic } from '../../../../api/pipelines/fetch_ml_inference_pipeline_processors';
import { FetchMlInferencePipelinesApiLogic } from '../../../../api/pipelines/fetch_ml_inference_pipelines';
import { SimulateMlInterfacePipelineApiLogic } from '../../../../api/pipelines/simulate_ml_inference_pipeline_processors';
import {
MLInferenceLogic,
EMPTY_PIPELINE_CONFIGURATION,
AddInferencePipelineSteps,
MLInferenceProcessorsValues,
} from './ml_inference_logic';
const DEFAULT_VALUES = {
const DEFAULT_VALUES: MLInferenceProcessorsValues = {
addInferencePipelineModal: {
configuration: {
...EMPTY_PIPELINE_CONFIGURATION,
@ -46,6 +53,7 @@ const DEFAULT_VALUES = {
step: AddInferencePipelineSteps.Configuration,
},
createErrors: [],
existingInferencePipelines: [],
formErrors: {
modelID: 'Field is required.',
pipelineName: 'Field is required.',
@ -57,6 +65,8 @@ const DEFAULT_VALUES = {
mappingData: undefined,
mappingStatus: 0,
mlInferencePipeline: undefined,
mlInferencePipelineProcessors: undefined,
mlInferencePipelinesData: undefined,
mlModelsData: undefined,
mlModelsStatus: 0,
simulatePipelineData: undefined,
@ -64,7 +74,7 @@ const DEFAULT_VALUES = {
simulatePipelineResult: undefined,
simulatePipelineStatus: 0,
sourceFields: undefined,
supportedMLModels: undefined,
supportedMLModels: [],
};
describe('MlInferenceLogic', () => {
@ -77,13 +87,25 @@ describe('MlInferenceLogic', () => {
const { mount: mountCreateMlInferencePipelineApiLogic } = new LogicMounter(
CreateMlInferencePipelineApiLogic
);
const { mount: mountAttachMlInferencePipelineApiLogic } = new LogicMounter(
AttachMlInferencePipelineApiLogic
);
const { mount: mountFetchMlInferencePipelineProcessorsApiLogic } = new LogicMounter(
FetchMlInferencePipelineProcessorsApiLogic
);
const { mount: mountFetchMlInferencePipelinesApiLogic } = new LogicMounter(
FetchMlInferencePipelinesApiLogic
);
beforeEach(() => {
jest.clearAllMocks();
mountMappingApiLogic();
mountMLModelsApiLogic();
mountFetchMlInferencePipelineProcessorsApiLogic();
mountFetchMlInferencePipelinesApiLogic();
mountSimulateMlInterfacePipelineApiLogic();
mountCreateMlInferencePipelineApiLogic();
mountAttachMlInferencePipelineApiLogic();
mount();
});
@ -110,6 +132,70 @@ describe('MlInferenceLogic', () => {
});
});
});
describe('attachApiError', () => {
it('updates create errors', () => {
MLInferenceLogic.actions.attachApiError({
body: {
error: '',
message: 'this is an error',
statusCode: 500,
},
} as HttpResponse<ErrorResponse>);
expect(MLInferenceLogic.values.createErrors).toEqual(['this is an error']);
});
});
describe('createApiError', () => {
it('updates create errors', () => {
MLInferenceLogic.actions.createApiError({
body: {
error: '',
message: 'this is an error',
statusCode: 500,
},
} as HttpResponse<ErrorResponse>);
expect(MLInferenceLogic.values.createErrors).toEqual(['this is an error']);
});
});
describe('makeAttachPipelineRequest', () => {
it('clears existing errors', () => {
MLInferenceLogic.actions.attachApiError({
body: {
error: '',
message: 'this is an error',
statusCode: 500,
},
} as HttpResponse<ErrorResponse>);
expect(MLInferenceLogic.values.createErrors).not.toHaveLength(0);
MLInferenceLogic.actions.makeAttachPipelineRequest({
indexName: 'test',
pipelineName: 'unit-test',
});
expect(MLInferenceLogic.values.createErrors).toHaveLength(0);
});
});
describe('makeCreatePipelineRequest', () => {
it('clears existing errors', () => {
MLInferenceLogic.actions.createApiError({
body: {
error: '',
message: 'this is an error',
statusCode: 500,
},
} as HttpResponse<ErrorResponse>);
expect(MLInferenceLogic.values.createErrors).not.toHaveLength(0);
MLInferenceLogic.actions.makeCreatePipelineRequest({
indexName: 'test',
pipelineName: 'unit-test',
modelId: 'test-model',
sourceField: 'body',
});
expect(MLInferenceLogic.values.createErrors).toHaveLength(0);
});
});
});
describe('selectors', () => {
@ -162,6 +248,220 @@ describe('MlInferenceLogic', () => {
expect(MLInferenceLogic.values.simulatePipelineResult).toEqual(simulateResponse);
});
});
describe('existingInferencePipelines', () => {
beforeEach(() => {
MappingsApiLogic.actions.apiSuccess({
mappings: {
properties: {
body: {
type: 'text',
},
},
},
});
});
it('returns empty list when there is not existing pipelines available', () => {
expect(MLInferenceLogic.values.existingInferencePipelines).toEqual([]);
});
it('returns existing pipeline option', () => {
FetchMlInferencePipelinesApiLogic.actions.apiSuccess({
'unit-test': {
processors: [
{
inference: {
field_map: {
body: 'text_field',
},
model_id: 'test-model',
target_field: 'ml.inference.test-field',
},
},
],
version: 1,
},
});
expect(MLInferenceLogic.values.existingInferencePipelines).toEqual([
{
destinationField: 'test-field',
disabled: false,
pipelineName: 'unit-test',
modelType: '',
modelId: 'test-model',
sourceField: 'body',
},
]);
});
it('returns disabled pipeline option if missing source field', () => {
FetchMlInferencePipelinesApiLogic.actions.apiSuccess({
'unit-test': {
processors: [
{
inference: {
field_map: {
body_content: 'text_field',
},
model_id: 'test-model',
target_field: 'ml.inference.test-field',
},
},
],
version: 1,
},
});
expect(MLInferenceLogic.values.existingInferencePipelines).toEqual([
{
destinationField: 'test-field',
disabled: true,
disabledReason: expect.any(String),
pipelineName: 'unit-test',
modelType: '',
modelId: 'test-model',
sourceField: 'body_content',
},
]);
});
it('returns disabled pipeline option if model is redacted', () => {
FetchMlInferencePipelinesApiLogic.actions.apiSuccess({
'unit-test': {
processors: [
{
inference: {
field_map: {
body: 'text_field',
},
model_id: '',
target_field: 'ml.inference.test-field',
},
},
],
version: 1,
},
});
expect(MLInferenceLogic.values.existingInferencePipelines).toEqual([
{
destinationField: 'test-field',
disabled: true,
disabledReason: expect.any(String),
pipelineName: 'unit-test',
modelType: '',
modelId: '',
sourceField: 'body',
},
]);
});
it('returns disabled pipeline option if pipeline already attached', () => {
FetchMlInferencePipelineProcessorsApiLogic.actions.apiSuccess([
{
modelId: 'test-model',
modelState: TrainedModelState.Started,
pipelineName: 'unit-test',
pipelineReferences: ['test@ml-inference'],
types: ['ner', 'pytorch'],
},
]);
FetchMlInferencePipelinesApiLogic.actions.apiSuccess({
'unit-test': {
processors: [
{
inference: {
field_map: {
body: 'text_field',
},
model_id: 'test-model',
target_field: 'ml.inference.test-field',
},
},
],
version: 1,
},
});
expect(MLInferenceLogic.values.existingInferencePipelines).toEqual([
{
destinationField: 'test-field',
disabled: true,
disabledReason: expect.any(String),
pipelineName: 'unit-test',
modelType: '',
modelId: 'test-model',
sourceField: 'body',
},
]);
});
});
describe('mlInferencePipeline', () => {
it('returns undefined when configuration is invalid', () => {
MLInferenceLogic.actions.setInferencePipelineConfiguration({
destinationField: '',
modelID: '',
pipelineName: 'unit-test',
sourceField: '',
});
expect(MLInferenceLogic.values.mlInferencePipeline).toBeUndefined();
});
it('generates inference pipeline', () => {
MLModelsApiLogic.actions.apiSuccess([
{
inference_config: {
text_classification: {
classification_labels: ['one', 'two'],
tokenization: {
bert: {},
},
},
},
input: {
field_names: ['text_field'],
},
model_id: 'test-model',
model_type: 'pytorch',
tags: [],
version: '1.0.0',
},
]);
MLInferenceLogic.actions.setInferencePipelineConfiguration({
destinationField: '',
modelID: 'test-model',
pipelineName: 'unit-test',
sourceField: 'body',
});
expect(MLInferenceLogic.values.mlInferencePipeline).not.toBeUndefined();
});
it('returns undefined when existing pipeline not yet selected', () => {
MLInferenceLogic.actions.setInferencePipelineConfiguration({
existingPipeline: true,
destinationField: '',
modelID: '',
pipelineName: '',
sourceField: '',
});
expect(MLInferenceLogic.values.mlInferencePipeline).toBeUndefined();
});
it('return existing pipeline when selected', () => {
const existingPipeline = {
description: 'this is a test',
processors: [],
version: 1,
};
FetchMlInferencePipelinesApiLogic.actions.apiSuccess({
'unit-test': existingPipeline,
});
MLInferenceLogic.actions.setInferencePipelineConfiguration({
existingPipeline: true,
destinationField: '',
modelID: '',
pipelineName: 'unit-test',
sourceField: '',
});
expect(MLInferenceLogic.values.mlInferencePipeline).not.toBeUndefined();
expect(MLInferenceLogic.values.mlInferencePipeline).toEqual(existingPipeline);
});
});
});
describe('listeners', () => {

View file

@ -15,6 +15,8 @@ import { TrainedModelConfigResponse } from '@kbn/ml-plugin/common/types/trained_
import {
formatPipelineName,
generateMlInferencePipelineBody,
getMlModelTypesForModelConfig,
parseMlInferenceParametersFromPipeline,
} from '../../../../../../../common/ml_inference_pipeline';
import { Status } from '../../../../../../../common/types/api';
import { MlInferencePipeline } from '../../../../../../../common/types/pipelines';
@ -30,16 +32,30 @@ import {
GetMappingsResponse,
MappingsApiLogic,
} from '../../../../api/mappings/mappings_logic';
import {
CreateMlInferencePipelineApiLogic,
CreateMlInferencePipelineApiLogicArgs,
CreateMlInferencePipelineResponse,
} from '../../../../api/ml_models/create_ml_inference_pipeline';
import {
GetMlModelsArgs,
GetMlModelsResponse,
MLModelsApiLogic,
} from '../../../../api/ml_models/ml_models_logic';
import {
AttachMlInferencePipelineApiLogic,
AttachMlInferencePipelineApiLogicArgs,
AttachMlInferencePipelineResponse,
} from '../../../../api/pipelines/attach_ml_inference_pipeline';
import {
CreateMlInferencePipelineApiLogic,
CreateMlInferencePipelineApiLogicArgs,
CreateMlInferencePipelineResponse,
} from '../../../../api/pipelines/create_ml_inference_pipeline';
import {
FetchMlInferencePipelineProcessorsApiLogic,
FetchMlInferencePipelineProcessorsResponse,
} from '../../../../api/pipelines/fetch_ml_inference_pipeline_processors';
import {
FetchMlInferencePipelinesApiLogic,
FetchMlInferencePipelinesArgs,
FetchMlInferencePipelinesResponse,
} from '../../../../api/pipelines/fetch_ml_inference_pipelines';
import {
SimulateMlInterfacePipelineApiLogic,
SimulateMlInterfacePipelineArgs,
@ -47,11 +63,20 @@ import {
} from '../../../../api/pipelines/simulate_ml_inference_pipeline_processors';
import { isConnectorIndex } from '../../../../utils/indices';
import { isSupportedMLModel, sortSourceFields } from '../../../shared/ml_inference/utils';
import {
getMLType,
isSupportedMLModel,
sortSourceFields,
} from '../../../shared/ml_inference/utils';
import { AddInferencePipelineFormErrors, InferencePipelineConfiguration } from './types';
import { validateInferencePipelineConfiguration } from './utils';
import {
validateInferencePipelineConfiguration,
EXISTING_PIPELINE_DISABLED_MODEL_REDACTED,
EXISTING_PIPELINE_DISABLED_MISSING_SOURCE_FIELD,
EXISTING_PIPELINE_DISABLED_PIPELINE_EXISTS,
} from './utils';
export const EMPTY_PIPELINE_CONFIGURATION: InferencePipelineConfiguration = {
destinationField: '',
@ -69,7 +94,26 @@ export enum AddInferencePipelineSteps {
const API_REQUEST_COMPLETE_STATUSES = [Status.SUCCESS, Status.ERROR];
const DEFAULT_CONNECTOR_FIELDS = ['body', 'title', 'id', 'type', 'url'];
export interface MLInferencePipelineOption {
destinationField: string;
disabled: boolean;
disabledReason?: string;
modelId: string;
modelType: string;
pipelineName: string;
sourceField: string;
}
interface MLInferenceProcessorsActions {
attachApiError: Actions<
AttachMlInferencePipelineApiLogicArgs,
AttachMlInferencePipelineResponse
>['apiError'];
attachApiSuccess: Actions<
AttachMlInferencePipelineApiLogicArgs,
AttachMlInferencePipelineResponse
>['apiSuccess'];
attachPipeline: () => void;
createApiError: Actions<
CreateMlInferencePipelineApiLogicArgs,
CreateMlInferencePipelineResponse
@ -79,18 +123,29 @@ interface MLInferenceProcessorsActions {
CreateMlInferencePipelineResponse
>['apiSuccess'];
createPipeline: () => void;
makeAttachPipelineRequest: Actions<
AttachMlInferencePipelineApiLogicArgs,
AttachMlInferencePipelineResponse
>['makeRequest'];
makeCreatePipelineRequest: Actions<
CreateMlInferencePipelineApiLogicArgs,
CreateMlInferencePipelineResponse
>['makeRequest'];
makeMLModelsRequest: Actions<GetMlModelsArgs, GetMlModelsResponse>['makeRequest'];
makeMappingRequest: Actions<GetMappingsArgs, GetMappingsResponse>['makeRequest'];
makeMlInferencePipelinesRequest: Actions<
FetchMlInferencePipelinesArgs,
FetchMlInferencePipelinesResponse
>['makeRequest'];
makeSimulatePipelineRequest: Actions<
SimulateMlInterfacePipelineArgs,
SimulateMlInterfacePipelineResponse
>['makeRequest'];
mappingsApiError: Actions<GetMappingsArgs, GetMappingsResponse>['apiError'];
mlModelsApiError: Actions<GetMlModelsArgs, GetMlModelsResponse>['apiError'];
selectExistingPipeline: (pipelineName: string) => {
pipelineName: string;
};
setAddInferencePipelineStep: (step: AddInferencePipelineSteps) => {
step: AddInferencePipelineSteps;
};
@ -120,21 +175,24 @@ export interface AddInferencePipelineModal {
step: AddInferencePipelineSteps;
}
interface MLInferenceProcessorsValues {
export interface MLInferenceProcessorsValues {
addInferencePipelineModal: AddInferencePipelineModal;
createErrors: string[];
existingInferencePipelines: MLInferencePipelineOption[];
formErrors: AddInferencePipelineFormErrors;
index: FetchIndexApiResponse;
index: FetchIndexApiResponse | undefined;
isLoading: boolean;
isPipelineDataValid: boolean;
mappingData: typeof MappingsApiLogic.values.data;
mappingStatus: Status;
mlInferencePipeline?: MlInferencePipeline;
mlModelsData: TrainedModelConfigResponse[];
mlInferencePipeline: MlInferencePipeline | undefined;
mlInferencePipelineProcessors: FetchMlInferencePipelineProcessorsResponse | undefined;
mlInferencePipelinesData: FetchMlInferencePipelinesResponse | undefined;
mlModelsData: TrainedModelConfigResponse[] | undefined;
mlModelsStatus: Status;
simulatePipelineData: typeof SimulateMlInterfacePipelineApiLogic.values.data;
simulatePipelineErrors: string[];
simulatePipelineResult: IngestSimulateResponse;
simulatePipelineResult: IngestSimulateResponse | undefined;
simulatePipelineStatus: Status;
sourceFields: string[] | undefined;
supportedMLModels: TrainedModelConfigResponse[];
@ -144,8 +202,10 @@ export const MLInferenceLogic = kea<
MakeLogicType<MLInferenceProcessorsValues, MLInferenceProcessorsActions>
>({
actions: {
attachPipeline: true,
clearFormErrors: true,
createPipeline: true,
selectExistingPipeline: (pipelineName: string) => ({ pipelineName }),
setAddInferencePipelineStep: (step: AddInferencePipelineSteps) => ({ step }),
setFormErrors: (inputErrors: AddInferencePipelineFormErrors) => ({ inputErrors }),
setIndexName: (indexName: string) => ({ indexName }),
@ -160,6 +220,8 @@ export const MLInferenceLogic = kea<
},
connect: {
actions: [
FetchMlInferencePipelinesApiLogic,
['makeRequest as makeMlInferencePipelinesRequest'],
MappingsApiLogic,
['makeRequest as makeMappingRequest', 'apiError as mappingsApiError'],
MLModelsApiLogic,
@ -176,20 +238,43 @@ export const MLInferenceLogic = kea<
'apiSuccess as createApiSuccess',
'makeRequest as makeCreatePipelineRequest',
],
AttachMlInferencePipelineApiLogic,
[
'apiError as attachApiError',
'apiSuccess as attachApiSuccess',
'makeRequest as makeAttachPipelineRequest',
],
],
values: [
FetchIndexApiLogic,
['data as index'],
FetchMlInferencePipelinesApiLogic,
['data as mlInferencePipelinesData'],
MappingsApiLogic,
['data as mappingData', 'status as mappingStatus'],
MLModelsApiLogic,
['data as mlModelsData', 'status as mlModelsStatus'],
SimulateMlInterfacePipelineApiLogic,
['data as simulatePipelineData', 'status as simulatePipelineStatus'],
FetchMlInferencePipelineProcessorsApiLogic,
['data as mlInferencePipelineProcessors'],
],
},
events: {},
listeners: ({ values, actions }) => ({
attachPipeline: () => {
const {
addInferencePipelineModal: {
configuration: { pipelineName },
indexName,
},
} = values;
actions.makeAttachPipelineRequest({
indexName,
pipelineName,
});
},
createPipeline: () => {
const {
addInferencePipelineModal: { configuration, indexName },
@ -206,7 +291,21 @@ export const MLInferenceLogic = kea<
sourceField: configuration.sourceField,
});
},
selectExistingPipeline: ({ pipelineName }) => {
const pipeline = values.mlInferencePipelinesData?.[pipelineName];
if (!pipeline) return;
const params = parseMlInferenceParametersFromPipeline(pipelineName, pipeline);
if (params === null) return;
actions.setInferencePipelineConfiguration({
destinationField: params.destination_field ?? '',
existingPipeline: true,
modelID: params.model_id,
pipelineName,
sourceField: params.source_field,
});
},
setIndexName: ({ indexName }) => {
actions.makeMlInferencePipelinesRequest(undefined);
actions.makeMLModelsRequest(undefined);
actions.makeMappingRequest({ indexName });
},
@ -264,7 +363,9 @@ export const MLInferenceLogic = kea<
createErrors: [
[],
{
attachApiError: (_, error) => getErrorsFromHttpResponse(error),
createApiError: (_, error) => getErrorsFromHttpResponse(error),
makeAttachPipelineRequest: () => [],
makeCreatePipelineRequest: () => [],
},
],
@ -297,12 +398,24 @@ export const MLInferenceLogic = kea<
selectors.isPipelineDataValid,
selectors.addInferencePipelineModal,
selectors.mlModelsData,
selectors.mlInferencePipelinesData,
],
(
isPipelineDataValid: boolean,
{ configuration }: AddInferencePipelineModal,
models: MLInferenceProcessorsValues['mlModelsData']
isPipelineDataValid: MLInferenceProcessorsValues['isPipelineDataValid'],
{ configuration }: MLInferenceProcessorsValues['addInferencePipelineModal'],
models: MLInferenceProcessorsValues['mlModelsData'],
mlInferencePipelinesData: MLInferenceProcessorsValues['mlInferencePipelinesData']
) => {
if (configuration.existingPipeline) {
if (configuration.pipelineName.length === 0) {
return undefined;
}
const pipeline = mlInferencePipelinesData?.[configuration.pipelineName];
if (!pipeline) {
return undefined;
}
return pipeline as MlInferencePipeline;
}
if (!isPipelineDataValid) return undefined;
const model = models?.find((mlModel) => mlModel.model_id === configuration.modelID);
if (!model) return undefined;
@ -350,7 +463,69 @@ export const MLInferenceLogic = kea<
supportedMLModels: [
() => [selectors.mlModelsData],
(mlModelsData: TrainedModelConfigResponse[] | undefined) => {
return mlModelsData?.filter(isSupportedMLModel);
return mlModelsData?.filter(isSupportedMLModel) ?? [];
},
],
existingInferencePipelines: [
() => [
selectors.mlInferencePipelinesData,
selectors.sourceFields,
selectors.supportedMLModels,
selectors.mlInferencePipelineProcessors,
],
(
mlInferencePipelinesData: MLInferenceProcessorsValues['mlInferencePipelinesData'],
sourceFields: MLInferenceProcessorsValues['sourceFields'],
supportedMLModels: MLInferenceProcessorsValues['supportedMLModels'],
mlInferencePipelineProcessors: MLInferenceProcessorsValues['mlInferencePipelineProcessors']
) => {
if (!mlInferencePipelinesData) {
return [];
}
const indexProcessorNames =
mlInferencePipelineProcessors?.map((processor) => processor.pipelineName) ?? [];
const existingPipelines: MLInferencePipelineOption[] = Object.entries(
mlInferencePipelinesData
)
.map(([pipelineName, pipeline]): MLInferencePipelineOption | undefined => {
if (!pipeline) return undefined;
const pipelineParams = parseMlInferenceParametersFromPipeline(pipelineName, pipeline);
if (!pipelineParams) return undefined;
const {
destination_field: destinationField,
model_id: modelId,
source_field: sourceField,
} = pipelineParams;
let disabled: boolean = false;
let disabledReason: string | undefined;
if (!(sourceFields?.includes(sourceField) ?? false)) {
disabled = true;
disabledReason = EXISTING_PIPELINE_DISABLED_MISSING_SOURCE_FIELD;
} else if (indexProcessorNames.includes(pipelineName)) {
disabled = true;
disabledReason = EXISTING_PIPELINE_DISABLED_PIPELINE_EXISTS;
} else if (pipelineParams.model_id.length === 0) {
disabled = true;
disabledReason = EXISTING_PIPELINE_DISABLED_MODEL_REDACTED;
}
const mlModel = supportedMLModels.find((model) => model.model_id === modelId);
const modelType = mlModel ? getMLType(getMlModelTypesForModelConfig(mlModel)) : '';
return {
destinationField: destinationField ?? '',
disabled,
disabledReason,
modelId,
modelType,
pipelineName,
sourceField,
};
})
.filter((p): p is MLInferencePipelineOption => p !== undefined);
return existingPipelines;
},
],
}),

View file

@ -0,0 +1,96 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import React from 'react';
import { EuiBadge, EuiFlexGroup, EuiFlexItem, EuiIcon, EuiTextColor, EuiTitle } from '@elastic/eui';
import { i18n } from '@kbn/i18n';
import { MLInferencePipelineOption } from './ml_inference_logic';
import { EXISTING_PIPELINE_DISABLED_MISSING_SOURCE_FIELD } from './utils';
export interface PipelineSelectOptionProps {
pipeline: MLInferencePipelineOption;
}
const REDACTED_MODE_ID_DISPLAY = i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.existingPipeline.redactedModel',
{
defaultMessage: 'Trained model not available in this space',
}
);
export const PipelineSelectOption: React.FC<PipelineSelectOptionProps> = ({ pipeline }) => {
const modelIdDisplay = pipeline.modelId.length > 0 ? pipeline.modelId : REDACTED_MODE_ID_DISPLAY;
return (
<EuiFlexGroup direction="column" gutterSize="xs">
{pipeline.disabled && (
<EuiFlexItem>
<EuiFlexGroup>
<EuiFlexItem grow={false}>
<EuiIcon type="alert" color="warning" />
</EuiFlexItem>
<EuiFlexItem>
<EuiTextColor color="default">
{pipeline.disabledReason ?? EXISTING_PIPELINE_DISABLED_MISSING_SOURCE_FIELD}
</EuiTextColor>
</EuiFlexItem>
</EuiFlexGroup>
</EuiFlexItem>
)}
<EuiFlexItem>
<EuiTitle size="xs">
<h4>{pipeline.pipelineName}</h4>
</EuiTitle>
</EuiFlexItem>
<EuiFlexItem>
<EuiFlexGroup gutterSize="s" alignItems="center" justifyContent="flexEnd">
<EuiFlexItem>
{pipeline.disabled ? (
modelIdDisplay
) : (
<EuiTextColor color="subdued">{modelIdDisplay}</EuiTextColor>
)}
</EuiFlexItem>
{pipeline.modelType.length > 0 && (
<EuiFlexItem grow={false}>
<span>
<EuiBadge color="hollow">{pipeline.modelType}</EuiBadge>
</span>
</EuiFlexItem>
)}
</EuiFlexGroup>
</EuiFlexItem>
<EuiFlexItem>
<EuiFlexGroup>
<EuiFlexItem>
<strong>
{i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.existingPipeline.sourceField',
{ defaultMessage: 'Source field' }
)}
</strong>
</EuiFlexItem>
<EuiFlexItem>{pipeline.sourceField}</EuiFlexItem>
</EuiFlexGroup>
</EuiFlexItem>
<EuiFlexItem>
<EuiFlexGroup>
<EuiFlexItem>
<strong>
{i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.existingPipeline.destinationField',
{ defaultMessage: 'Destination field' }
)}
</strong>
</EuiFlexItem>
<EuiFlexItem>{pipeline.destinationField}</EuiFlexItem>
</EuiFlexGroup>
</EuiFlexItem>
</EuiFlexGroup>
);
};

View file

@ -7,6 +7,7 @@
export interface InferencePipelineConfiguration {
destinationField: string;
existingPipeline?: boolean;
modelID: string;
pipelineName: string;
sourceField: string;

View file

@ -31,6 +31,12 @@ export const validateInferencePipelineConfiguration = (
config: InferencePipelineConfiguration
): AddInferencePipelineFormErrors => {
const errors: AddInferencePipelineFormErrors = {};
if (config.existingPipeline === true) {
if (config.pipelineName.length === 0) {
errors.pipelineName = FIELD_REQUIRED_ERROR;
}
return errors;
}
if (config.pipelineName.trim().length === 0) {
errors.pipelineName = FIELD_REQUIRED_ERROR;
} else if (!isValidPipelineName(config.pipelineName)) {
@ -45,3 +51,27 @@ export const validateInferencePipelineConfiguration = (
return errors;
};
export const EXISTING_PIPELINE_DISABLED_MISSING_SOURCE_FIELD = i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.existingPipeline.disabledSourceFieldDescription',
{
defaultMessage:
'This pipeline cannot be selected because the source field does not exist on this index.',
}
);
export const EXISTING_PIPELINE_DISABLED_PIPELINE_EXISTS = i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.existingPipeline.disabledPipelineExistsDescription',
{
defaultMessage: 'This pipeline cannot be selected because it is already attached.',
}
);
// TODO: removed when we support attaching pipelines with unavailable models
export const EXISTING_PIPELINE_DISABLED_MODEL_REDACTED = i18n.translate(
'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.existingPipeline.disabledModelRedactedDescription',
{
defaultMessage:
'This pipeline cannot be selected because it uses a trained model not available in this Kibana space.',
}
);

View file

@ -47,12 +47,21 @@ import {
FetchIndexApiParams,
FetchIndexApiResponse,
} from '../../../api/index/fetch_index_api_logic';
import { CreateMlInferencePipelineApiLogic } from '../../../api/ml_models/create_ml_inference_pipeline';
import {
DeleteMlInferencePipelineApiLogic,
DeleteMlInferencePipelineApiLogicArgs,
DeleteMlInferencePipelineResponse,
} from '../../../api/ml_models/delete_ml_inference_pipeline';
import {
AttachMlInferencePipelineApiLogic,
AttachMlInferencePipelineApiLogicArgs,
AttachMlInferencePipelineResponse,
} from '../../../api/pipelines/attach_ml_inference_pipeline';
import {
CreateMlInferencePipelineApiLogic,
CreateMlInferencePipelineApiLogicArgs,
CreateMlInferencePipelineResponse,
} from '../../../api/pipelines/create_ml_inference_pipeline';
import { FetchMlInferencePipelineProcessorsApiLogic } from '../../../api/pipelines/fetch_ml_inference_pipeline_processors';
import { isApiIndex, isConnectorIndex, isCrawlerIndex } from '../../../utils/indices';
@ -60,6 +69,10 @@ type PipelinesActions = Pick<
Actions<PostPipelineArgs, PostPipelineResponse>,
'apiError' | 'apiSuccess' | 'makeRequest'
> & {
attachMlInferencePipelineSuccess: Actions<
AttachMlInferencePipelineApiLogicArgs,
AttachMlInferencePipelineResponse
>['apiSuccess'];
closeAddMlInferencePipelineModal: () => void;
closeModal: () => void;
createCustomPipeline: Actions<
@ -74,6 +87,10 @@ type PipelinesActions = Pick<
CreateCustomPipelineApiLogicArgs,
CreateCustomPipelineApiLogicResponse
>['apiSuccess'];
createMlInferencePipelineSuccess: Actions<
CreateMlInferencePipelineApiLogicArgs,
CreateMlInferencePipelineResponse
>['apiSuccess'];
deleteMlPipeline: Actions<
DeleteMlInferencePipelineApiLogicArgs,
DeleteMlInferencePipelineResponse
@ -153,6 +170,8 @@ export const PipelinesLogic = kea<MakeLogicType<PipelinesValues, PipelinesAction
'makeRequest as fetchMlInferenceProcessors',
'apiError as fetchMlInferenceProcessorsApiError',
],
AttachMlInferencePipelineApiLogic,
['apiSuccess as attachMlInferencePipelineSuccess'],
CreateMlInferencePipelineApiLogic,
['apiSuccess as createMlInferencePipelineSuccess'],
DeleteMlInferencePipelineApiLogic,
@ -201,6 +220,12 @@ export const PipelinesLogic = kea<MakeLogicType<PipelinesValues, PipelinesAction
})
);
},
attachMlInferencePipelineSuccess: () => {
// Re-fetch processors to ensure we display newly added ml processor
actions.fetchMlInferenceProcessors({ indexName: values.index.name });
// Needed to ensure correct JSON is available in the JSON configurations tab
actions.fetchCustomPipeline({ indexName: values.index.name });
},
closeModal: () =>
actions.setPipelineState(
isConnectorIndex(values.index) || isCrawlerIndex(values.index)
@ -287,6 +312,7 @@ export const PipelinesLogic = kea<MakeLogicType<PipelinesValues, PipelinesAction
showAddMlInferencePipelineModal: [
false,
{
attachMlInferencePipelineSuccess: () => false,
closeAddMlInferencePipelineModal: () => false,
createMlInferencePipelineSuccess: () => false,
openAddMlInferencePipelineModal: () => true,

View file

@ -8,14 +8,9 @@
import { i18n } from '@kbn/i18n';
import { TrainedModelConfigResponse } from '@kbn/ml-plugin/common/types/trained_models';
export const NLP_CONFIG_KEYS = [
'fill_mask',
'ner',
'text_classification',
'text_embedding',
'question_answering',
'zero_shot_classification',
];
import { SUPPORTED_PYTORCH_TASKS } from '../../../../../../common/ml_inference_pipeline';
export const NLP_CONFIG_KEYS: string[] = Object.values(SUPPORTED_PYTORCH_TASKS);
export const RECOMMENDED_FIELDS = ['body', 'body_content', 'title'];
export const NLP_DISPLAY_TITLES: Record<string, string | undefined> = {