[Index management] Make create Inference endpoint from flyout async task (#184615)

Currently, when a new inference endpoint is created from inference
flyout, the flyout stays open until the endpoint is created. This may
take long time when Elasticsearch models - `.elser_model_2` and
`.multilingual-e5-small` is to be downloaded, deployed and inference
endpoint is to be created.

In this PR, When a new inference endpoint is saved, inference flyout is
closed and the new inference endpoint is created by a callback function
in the component.
 
**Screen Recording**


8eabba1a-108a-4bf2-813a-66ceb291467c

**Testing instructions**

**update Elasticsearch to latest (only to test save mappings)**

Since ES changes for the semantic_text has been merged to main, this can
be tested against running ES from source or from latest snapshot

- Update local branch with latest Elasticsearch changes from main
- Run the elasticsearch: ./gradlew :run -Drun.license_type=trial
- Manual test in UI

**Frontend**

- enable` xpack.index_management.dev.enableSemanticText` to true in
`config/kibana.dev.yml`
- Add a new field with type - Semantic_text
- Click on drop down menu below `Select an inference endpoint`
- Click Add inference Endpoint
- Type new inference endpoint name and click Save endpoint
- Save endpoint button should close the flyout 
- A new success notification toasts is shown with text "1 model is being
deployed on your ml_node."
- Add new field
- Click Save mappings
- should show a modal with model deployment status 
- After new endpoint is created, refresh button should hide the modal
and save mappings should update mappings

---------

Co-authored-by: Kibana Machine <42973632+kibanamachine@users.noreply.github.com>
This commit is contained in:
Saarika Bhasi 2024-06-17 17:12:13 -04:00 committed by GitHub
parent 51e84f49ba
commit c091dd89ff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 601 additions and 309 deletions

View file

@ -78,7 +78,7 @@ export interface SaveMappingOnClick {
taskType: InferenceTaskType, taskType: InferenceTaskType,
modelConfig: ModelConfig modelConfig: ModelConfig
) => void; ) => void;
isCreateInferenceApiLoading: boolean; isCreateInferenceApiLoading?: boolean;
} }
export interface DocumentationProps { export interface DocumentationProps {
elserv2documentationUrl?: string; elserv2documentationUrl?: string;

View file

@ -54,7 +54,7 @@ export interface ElasticsearchService {
export enum Service { export enum Service {
cohere = 'cohere', cohere = 'cohere',
elser = 'elser', elser = 'elser',
huggingFace = 'huggingFace', huggingFace = 'hugging_face',
openai = 'openai', openai = 'openai',
elasticsearch = 'elasticsearch', elasticsearch = 'elasticsearch',
} }

View file

@ -209,6 +209,9 @@ const registerHttpRequestMockHelpers = (
const setCreateIndexResponse = (response?: HttpResponse, error?: ResponseError) => const setCreateIndexResponse = (response?: HttpResponse, error?: ResponseError) =>
mockResponse('PUT', `${INTERNAL_API_BASE_PATH}/indices/create`, response, error); mockResponse('PUT', `${INTERNAL_API_BASE_PATH}/indices/create`, response, error);
const setInferenceModels = (response?: HttpResponse, error?: ResponseError) =>
mockResponse('GET', `${API_BASE_PATH}/inference/all`, response, error);
return { return {
setLoadTemplatesResponse, setLoadTemplatesResponse,
setLoadIndicesResponse, setLoadIndicesResponse,
@ -238,6 +241,7 @@ const registerHttpRequestMockHelpers = (
setGetFieldsFromIndices, setGetFieldsFromIndices,
setGetPrivilegesResponse, setGetPrivilegesResponse,
setCreateEnrichPolicy, setCreateEnrichPolicy,
setInferenceModels,
}; };
}; };

View file

@ -13,7 +13,7 @@ import {
} from '@kbn/test-jest-helpers'; } from '@kbn/test-jest-helpers';
import { HttpSetup } from '@kbn/core/public'; import { HttpSetup } from '@kbn/core/public';
import { act } from 'react-dom/test-utils'; import { act } from 'react-dom/test-utils';
import { keys } from '@elastic/eui';
import { IndexDetailsTabId } from '../../../common/constants'; import { IndexDetailsTabId } from '../../../common/constants';
import { IndexDetailsPage } from '../../../public/application/sections/home/index_list/details_page'; import { IndexDetailsPage } from '../../../public/application/sections/home/index_list/details_page';
import { WithAppDependencies } from '../helpers'; import { WithAppDependencies } from '../helpers';
@ -56,6 +56,12 @@ export interface IndexDetailsPageTestBed extends TestBed {
setSearchBarValue: (searchValue: string) => Promise<void>; setSearchBarValue: (searchValue: string) => Promise<void>;
findSearchResult: () => string; findSearchResult: () => string;
isSemanticTextBannerVisible: () => boolean; isSemanticTextBannerVisible: () => boolean;
selectSemanticTextField: (name: string, type: string) => Promise<void>;
isReferenceFieldVisible: () => void;
selectInferenceIdButtonExists: () => void;
openSelectInferencePopover: () => void;
expectDefaultInferenceModelToExists: () => void;
expectCustomInferenceModelToExists: (customInference: string) => Promise<void>;
}; };
settings: { settings: {
getCodeBlockContent: () => string; getCodeBlockContent: () => string;
@ -228,7 +234,7 @@ export const setup = async ({
component.update(); component.update();
}, },
selectFilterFieldType: async (fieldType: string) => { selectFilterFieldType: async (fieldType: string) => {
expect(testBed.exists('indexDetailsMappingsSelectFilter-text')).toBe(true); expect(testBed.exists(fieldType)).toBe(true);
await act(async () => { await act(async () => {
find(fieldType).simulate('click'); find(fieldType).simulate('click');
}); });
@ -287,6 +293,7 @@ export const setup = async ({
await act(async () => { await act(async () => {
expect(exists('createFieldForm.addButton')).toBe(true); expect(exists('createFieldForm.addButton')).toBe(true);
expect(find('createFieldForm.addButton').props().disabled).toBeFalsy();
find('createFieldForm.addButton').simulate('click'); find('createFieldForm.addButton').simulate('click');
}); });
@ -294,6 +301,41 @@ export const setup = async ({
} }
} }
}, },
selectSemanticTextField: async (name: string, type: string) => {
expect(exists('comboBoxSearchInput')).toBe(true);
const { form } = testBed;
form.setInputValue('nameParameterInput', name);
form.setInputValue('comboBoxSearchInput', type);
await act(async () => {
find('comboBoxSearchInput').simulate('keydown', { key: keys.ENTER });
});
// select semantic_text field
await act(async () => {
expect(exists('fieldTypesOptions-semantic_text')).toBe(true);
find('fieldTypesOptions-semantic_text').simulate('click');
expect(exists('fieldTypesOptions-semantic_text')).toBe(false);
});
},
isReferenceFieldVisible: async () => {
expect(exists('referenceField.select')).toBe(true);
},
selectInferenceIdButtonExists: async () => {
expect(exists('selectInferenceId')).toBe(true);
expect(exists('inferenceIdButton')).toBe(true);
find('inferenceIdButton').simulate('click');
},
openSelectInferencePopover: async () => {
expect(exists('addInferenceEndpointButton')).toBe(true);
expect(exists('manageInferenceEndpointButton')).toBe(true);
},
expectDefaultInferenceModelToExists: async () => {
expect(exists('default-inference_elser_model_2')).toBe(true);
expect(exists('default-inference_e5')).toBe(true);
},
expectCustomInferenceModelToExists: async (customInference: string) => {
expect(exists(customInference)).toBe(true);
},
}; };
const settings = { const settings = {

View file

@ -18,6 +18,7 @@ import {
breadcrumbService, breadcrumbService,
IndexManagementBreadcrumb, IndexManagementBreadcrumb,
} from '../../../public/application/services/breadcrumbs'; } from '../../../public/application/services/breadcrumbs';
import { documentationService } from '../../../public/application/services/documentation';
import { humanizeTimeStamp } from '../../../public/application/sections/home/data_stream_list/humanize_time_stamp'; import { humanizeTimeStamp } from '../../../public/application/sections/home/data_stream_list/humanize_time_stamp';
import { createDataStreamPayload } from '../home/data_streams_tab.helpers'; import { createDataStreamPayload } from '../home/data_streams_tab.helpers';
import { import {
@ -58,6 +59,7 @@ describe('<IndexDetailsPage />', () => {
let httpSetup: ReturnType<typeof setupEnvironment>['httpSetup']; let httpSetup: ReturnType<typeof setupEnvironment>['httpSetup'];
let httpRequestsMockHelpers: ReturnType<typeof setupEnvironment>['httpRequestsMockHelpers']; let httpRequestsMockHelpers: ReturnType<typeof setupEnvironment>['httpRequestsMockHelpers'];
jest.spyOn(breadcrumbService, 'setBreadcrumbs'); jest.spyOn(breadcrumbService, 'setBreadcrumbs');
jest.spyOn(documentationService, 'setup');
beforeEach(async () => { beforeEach(async () => {
const mockEnvironment = setupEnvironment(); const mockEnvironment = setupEnvironment();
@ -571,14 +573,9 @@ describe('<IndexDetailsPage />', () => {
}, },
}; };
beforeEach(async () => { beforeEach(async () => {
httpRequestsMockHelpers.setUpdateIndexMappingsResponse(testIndexName, {
acknowledged: true,
});
await act(async () => { await act(async () => {
testBed = await setup({ httpSetup }); testBed = await setup({ httpSetup });
}); });
testBed.component.update(); testBed.component.update();
await testBed.actions.clickIndexDetailsTab(IndexDetailsSection.Mappings); await testBed.actions.clickIndexDetailsTab(IndexDetailsSection.Mappings);
await testBed.actions.mappings.clickAddFieldButton(); await testBed.actions.mappings.clickAddFieldButton();
@ -634,43 +631,6 @@ describe('<IndexDetailsPage />', () => {
); );
}); });
it('can add a semantic_text field and can save mappings', async () => {
const mockIndexMappingResponseForSemanticText: any = {
...testIndexMappings.mappings,
properties: {
...testIndexMappings.mappings.properties,
sem: {
type: 'semantic_text',
inference_id: 'my-elser',
},
},
};
httpRequestsMockHelpers.setLoadIndexMappingResponse(testIndexName, {
mappings: mockIndexMappingResponseForSemanticText,
});
await testBed.actions.mappings.addNewMappingFieldNameAndType([
{ name: 'sem', type: 'semantic_text' },
]);
await testBed.actions.mappings.clickSaveMappingsButton();
// add field button is available again
expect(testBed.exists('indexDetailsMappingsAddField')).toBe(true);
expect(testBed.find('semField-datatype').props()['data-type-value']).toBe('semantic_text');
expect(httpSetup.get).toHaveBeenCalledTimes(5);
expect(httpSetup.get).toHaveBeenLastCalledWith(
`${API_BASE_PATH}/mapping/${testIndexName}`,
requestOptions
);
// refresh mappings and page re-renders
expect(testBed.exists('indexDetailsMappingsAddField')).toBe(true);
expect(testBed.actions.mappings.isSearchBarDisabled()).toBe(false);
const treeViewContent = testBed.actions.mappings.getTreeViewContent('semField');
expect(treeViewContent).toContain('sem');
await testBed.actions.mappings.clickToggleViewButton();
const jsonContent = testBed.actions.mappings.getCodeBlockContent();
expect(jsonContent).toEqual(
JSON.stringify({ mappings: mockIndexMappingResponseForSemanticText }, null, 2)
);
});
it('there is a callout with error message when save mappings fail', async () => { it('there is a callout with error message when save mappings fail', async () => {
const error = { const error = {
statusCode: 400, statusCode: 400,
@ -685,6 +645,116 @@ describe('<IndexDetailsPage />', () => {
await testBed.actions.mappings.clickSaveMappingsButton(); await testBed.actions.mappings.clickSaveMappingsButton();
expect(testBed.actions.mappings.isSaveMappingsErrorDisplayed()).toBe(true); expect(testBed.actions.mappings.isSaveMappingsErrorDisplayed()).toBe(true);
}); });
describe('Add Semantic text field', () => {
const customInferenceModel = 'my-elser-model';
beforeEach(async () => {
httpRequestsMockHelpers.setInferenceModels({
data: [
{
model_id: customInferenceModel,
task_type: 'sparse_embedding',
service: 'elser',
service_settings: {
num_allocations: 1,
num_threads: 1,
model_id: '.elser_model_2',
},
task_settings: {},
},
],
});
await act(async () => {
testBed = await setup({
httpSetup,
dependencies: {
config: { enableSemanticText: true },
docLinks: {
links: {
ml: '',
enterpriseSearch: '',
},
},
plugins: {
ml: {
mlApi: {
trainedModels: {
getTrainedModels: jest.fn().mockResolvedValue([
{
model_id: '.elser_model_2',
model_type: 'pytorch',
model_package: {
packaged_model_id: customInferenceModel,
model_repository: 'https://ml-models.elastic.co',
minimum_version: '11.0.0',
size: 438123914,
sha256: '',
metadata: {},
tags: [],
vocabulary_file: 'elser_model_2.vocab.json',
},
description: 'Elastic Learned Sparse EncodeR v2',
tags: ['elastic'],
},
]),
getTrainedModelStats: jest.fn().mockResolvedValue({
count: 1,
trained_model_stats: [
{
model_id: '.elser_model_2',
deployment_stats: {
deployment_id: customInferenceModel,
model_id: '.elser_model_2',
threads_per_allocation: 1,
number_of_allocations: 1,
queue_capacity: 1024,
state: 'started',
},
},
{
model_id: '.elser_model_2',
deployment_stats: {
deployment_id: '.elser_model_2',
model_id: '.elser_model_2',
threads_per_allocation: 1,
number_of_allocations: 1,
queue_capacity: 1024,
state: 'started',
},
},
],
}),
},
},
},
},
},
});
});
testBed.component.update();
await testBed.actions.clickIndexDetailsTab(IndexDetailsSection.Mappings);
await testBed.actions.mappings.clickAddFieldButton();
});
it('can select semantic_text field', async () => {
await testBed.actions.mappings.selectSemanticTextField(
'semantic_text_name',
'Semantic text'
);
testBed.actions.mappings.isReferenceFieldVisible();
testBed.actions.mappings.selectInferenceIdButtonExists();
testBed.actions.mappings.openSelectInferencePopover();
testBed.actions.mappings.expectDefaultInferenceModelToExists();
testBed.actions.mappings.expectCustomInferenceModelToExists(
`custom-inference_${customInferenceModel}`
);
// can cancel new field
expect(testBed.exists('cancelButton')).toBe(true);
testBed.find('cancelButton').simulate('click');
});
});
}); });
describe('error loading mappings', () => { describe('error loading mappings', () => {

View file

@ -7,13 +7,13 @@
import { registerTestBed } from '@kbn/test-jest-helpers'; import { registerTestBed } from '@kbn/test-jest-helpers';
import { act } from 'react-dom/test-utils'; import { act } from 'react-dom/test-utils';
import { SelectInferenceId } from './select_inference_id'; import { SelectInferenceId } from '../../../public/application/components/mappings_editor/components/document_fields/field_parameters/select_inference_id';
const onChangeMock = jest.fn(); const onChangeMock = jest.fn();
const setValueMock = jest.fn(); const setValueMock = jest.fn();
const setNewInferenceEndpointMock = jest.fn(); const setNewInferenceEndpointMock = jest.fn();
jest.mock('../../../../../app_context', () => ({ jest.mock('../../../public/application/app_context', () => ({
useAppContext: jest.fn().mockReturnValue({ useAppContext: jest.fn().mockReturnValue({
core: { application: {} }, core: { application: {} },
docLinks: {}, docLinks: {},
@ -30,6 +30,17 @@ jest.mock('../../../../../app_context', () => ({
}), }),
})); }));
jest.mock(
'../../../public/application/components/component_templates/component_templates_context',
() => ({
useComponentTemplatesContext: jest.fn().mockReturnValue({
toasts: {
addError: jest.fn(),
addSuccess: jest.fn(),
},
}),
})
);
describe('SelectInferenceId', () => { describe('SelectInferenceId', () => {
let exists: any; let exists: any;
let find: any; let find: any;

View file

@ -32,7 +32,11 @@ export const ReferenceFieldSelects = ({
Object.keys(data.mappings.properties).forEach((key) => { Object.keys(data.mappings.properties).forEach((key) => {
const field = data.mappings.properties[key]; const field = data.mappings.properties[key];
if (field.type === 'text') { if (field.type === 'text') {
referenceFieldOptions.push({ value: key, inputDisplay: key }); referenceFieldOptions.push({
value: key,
inputDisplay: key,
'data-test-subj': `select-reference-field-${key}`,
});
} }
}); });
} }
@ -47,14 +51,15 @@ export const ReferenceFieldSelects = ({
return subscription.unsubscribe; return subscription.unsubscribe;
}, [subscribe, onChange]); }, [subscribe, onChange]);
return ( return (
<Form form={form}> <Form form={form} data-test-subj="referenceField">
<UseField path="main" config={fieldConfigReferenceField}> <UseField path="main" config={fieldConfigReferenceField}>
{(field) => ( {(field) => (
<SuperSelectField <SuperSelectField
field={field} field={field}
euiFieldProps={{ options: referenceFieldOptions }} euiFieldProps={{
options: referenceFieldOptions,
}}
data-test-subj={dataTestSubj} data-test-subj={dataTestSubj}
/> />
)} )}

View file

@ -35,28 +35,40 @@ import {
ModelConfig, ModelConfig,
Service, Service,
} from '@kbn/inference_integration_flyout/types'; } from '@kbn/inference_integration_flyout/types';
import { FormattedMessage } from '@kbn/i18n-react';
import { InferenceFlyoutWrapper } from '@kbn/inference_integration_flyout/components/inference_flyout_wrapper'; import { InferenceFlyoutWrapper } from '@kbn/inference_integration_flyout/components/inference_flyout_wrapper';
import { TrainedModelConfigResponse } from '@kbn/ml-plugin/common/types/trained_models'; import { TrainedModelConfigResponse } from '@kbn/ml-plugin/common/types/trained_models';
import { extractErrorProperties } from '@kbn/ml-error-utils';
import { getFieldConfig } from '../../../lib'; import { getFieldConfig } from '../../../lib';
import { useAppContext } from '../../../../../app_context'; import { useAppContext } from '../../../../../app_context';
import { Form, UseField, useForm } from '../../../shared_imports'; import { Form, UseField, useForm } from '../../../shared_imports';
import { useLoadInferenceModels } from '../../../../../services/api'; import { useLoadInferenceModels } from '../../../../../services/api';
import { getTrainedModelStats } from '../../../../../../hooks/use_details_page_mappings_model_management'; import { getTrainedModelStats } from '../../../../../../hooks/use_details_page_mappings_model_management';
import { InferenceToModelIdMap } from '../fields'; import { InferenceToModelIdMap } from '../fields';
import { useMLModelNotificationToasts } from '../../../../../../hooks/use_ml_model_status_toasts';
import {
CustomInferenceEndpointConfig,
DefaultInferenceModels,
DeploymentState,
} from '../../../types';
const inferenceServiceTypeElasticsearchModelMap: Record<string, ElasticsearchModelDefaultOptions> = const inferenceServiceTypeElasticsearchModelMap: Record<string, ElasticsearchModelDefaultOptions> =
{ {
elser: ElasticsearchModelDefaultOptions.elser, elser: ElasticsearchModelDefaultOptions.elser,
elasticsearch: ElasticsearchModelDefaultOptions.e5, elasticsearch: ElasticsearchModelDefaultOptions.e5,
}; };
const uncheckSelectedModelOption = (options: EuiSelectableOption[]) => {
const checkedOption = options.find(({ checked }) => checked === 'on');
if (checkedOption) {
checkedOption.checked = undefined;
}
};
interface Props { interface Props {
onChange(value: string): void; onChange(value: string): void;
'data-test-subj'?: string; 'data-test-subj'?: string;
setValue: (value: string) => void; setValue: (value: string) => void;
setNewInferenceEndpoint: (newInferenceEndpoint: InferenceToModelIdMap) => void; setNewInferenceEndpoint: (
newInferenceEndpoint: InferenceToModelIdMap,
customInferenceEndpointConfig: CustomInferenceEndpointConfig
) => void;
} }
export const SelectInferenceId = ({ export const SelectInferenceId = ({
onChange, onChange,
@ -76,16 +88,14 @@ export const SelectInferenceId = ({
}); });
}, [ml]); }, [ml]);
const { form } = useForm({ defaultValue: { main: 'elser_model_2' } }); const { form } = useForm({ defaultValue: { main: DefaultInferenceModels.elser_model_2 } });
const { subscribe } = form; const { subscribe } = form;
const [isInferenceFlyoutVisible, setIsInferenceFlyoutVisible] = useState<boolean>(false); const [isInferenceFlyoutVisible, setIsInferenceFlyoutVisible] = useState<boolean>(false);
const [inferenceAddError, setInferenceAddError] = useState<string | undefined>(undefined);
const [availableTrainedModels, setAvailableTrainedModels] = useState< const [availableTrainedModels, setAvailableTrainedModels] = useState<
TrainedModelConfigResponse[] TrainedModelConfigResponse[]
>([]); >([]);
const onFlyoutClose = useCallback(() => { const onFlyoutClose = useCallback(() => {
setInferenceAddError(undefined);
setIsInferenceFlyoutVisible(!isInferenceFlyoutVisible); setIsInferenceFlyoutVisible(!isInferenceFlyoutVisible);
}, [isInferenceFlyoutVisible]); }, [isInferenceFlyoutVisible]);
useEffect(() => { useEffect(() => {
@ -111,16 +121,27 @@ export const SelectInferenceId = ({
const fieldConfigModelId = getFieldConfig('inference_id'); const fieldConfigModelId = getFieldConfig('inference_id');
const defaultInferenceIds: EuiSelectableOption[] = useMemo(() => { const defaultInferenceIds: EuiSelectableOption[] = useMemo(() => {
return [{ checked: 'on', label: 'elser_model_2' }, { label: 'e5' }]; return [
{
checked: 'on',
label: 'elser_model_2',
'data-test-subj': 'default-inference_elser_model_2',
},
{
label: 'e5',
'data-test-subj': 'default-inference_e5',
},
];
}, []); }, []);
const { isLoading, data: models, resendRequest } = useLoadInferenceModels(); const { isLoading, data: models } = useLoadInferenceModels();
const [options, setOptions] = useState<EuiSelectableOption[]>([...defaultInferenceIds]); const [options, setOptions] = useState<EuiSelectableOption[]>([...defaultInferenceIds]);
const inferenceIdOptionsFromModels = useMemo(() => { const inferenceIdOptionsFromModels = useMemo(() => {
const inferenceIdOptions = const inferenceIdOptions =
models?.map((model: InferenceAPIConfigResponse) => ({ models?.map((model: InferenceAPIConfigResponse) => ({
label: model.model_id, label: model.model_id,
'data-test-subj': `custom-inference_${model.model_id}`,
})) || []; })) || [];
return inferenceIdOptions; return inferenceIdOptions;
@ -136,40 +157,48 @@ export const SelectInferenceId = ({
}; };
setOptions(Object.values(mergedOptions)); setOptions(Object.values(mergedOptions));
}, [inferenceIdOptionsFromModels, defaultInferenceIds]); }, [inferenceIdOptionsFromModels, defaultInferenceIds]);
const [isCreateInferenceApiLoading, setIsCreateInferenceApiLoading] = useState(false);
const { showErrorToasts } = useMLModelNotificationToasts();
const onSaveInferenceCallback = useCallback( const onSaveInferenceCallback = useCallback(
async (inferenceId: string, taskType: InferenceTaskType, modelConfig: ModelConfig) => { async (inferenceId: string, taskType: InferenceTaskType, modelConfig: ModelConfig) => {
setIsCreateInferenceApiLoading(true); setIsInferenceFlyoutVisible(!isInferenceFlyoutVisible);
try { try {
await ml?.mlApi?.inferenceModels?.createInferenceEndpoint( const isDeployable =
inferenceId, modelConfig.service === Service.elser || modelConfig.service === Service.elasticsearch;
taskType,
modelConfig const newOption: EuiSelectableOption[] = [
); {
setIsInferenceFlyoutVisible(!isInferenceFlyoutVisible); label: inferenceId,
setIsCreateInferenceApiLoading(false); checked: 'on',
setInferenceAddError(undefined); 'data-test-subj': `custom-inference_${inferenceId}`,
},
];
// uncheck selected endpoint id
uncheckSelectedModelOption(options);
setOptions([...options, ...newOption]);
const trainedModelStats = await ml?.mlApi?.trainedModels.getTrainedModelStats(); const trainedModelStats = await ml?.mlApi?.trainedModels.getTrainedModelStats();
const defaultEndpointId = const defaultEndpointId =
inferenceServiceTypeElasticsearchModelMap[modelConfig.service] || ''; inferenceServiceTypeElasticsearchModelMap[modelConfig.service] || '';
const newModelId: InferenceToModelIdMap = {}; const newModelId: InferenceToModelIdMap = {};
newModelId[inferenceId] = { newModelId[inferenceId] = {
trainedModelId: defaultEndpointId, trainedModelId: defaultEndpointId,
isDeployable: isDeployable,
modelConfig.service === Service.elser || modelConfig.service === Service.elasticsearch, isDeployed:
isDeployed: getTrainedModelStats(trainedModelStats)[defaultEndpointId] === 'deployed', getTrainedModelStats(trainedModelStats)[defaultEndpointId] === DeploymentState.DEPLOYED,
defaultInferenceEndpoint: false,
}; };
resendRequest(); const customInferenceEndpointConfig: CustomInferenceEndpointConfig = {
setNewInferenceEndpoint(newModelId); taskType,
modelConfig,
};
setNewInferenceEndpoint(newModelId, customInferenceEndpointConfig);
} catch (error) { } catch (error) {
const errorObj = extractErrorProperties(error); showErrorToasts(error);
setInferenceAddError(errorObj.message);
setIsCreateInferenceApiLoading(false);
} }
}, },
[isInferenceFlyoutVisible, resendRequest, ml, setNewInferenceEndpoint] [isInferenceFlyoutVisible, ml, setNewInferenceEndpoint, options, showErrorToasts]
); );
useEffect(() => { useEffect(() => {
const subscription = subscribe((updateData) => { const subscription = subscribe((updateData) => {
@ -182,7 +211,7 @@ export const SelectInferenceId = ({
}, [subscribe, onChange]); }, [subscribe, onChange]);
const selectedOptionLabel = options.find((option) => option.checked)?.label; const selectedOptionLabel = options.find((option) => option.checked)?.label;
useEffect(() => { useEffect(() => {
setValue(selectedOptionLabel ?? 'elser_model_2'); setValue(selectedOptionLabel ?? DefaultInferenceModels.elser_model_2);
}, [selectedOptionLabel, setValue]); }, [selectedOptionLabel, setValue]);
const [isInferencePopoverVisible, setIsInferencePopoverVisible] = useState<boolean>(false); const [isInferencePopoverVisible, setIsInferencePopoverVisible] = useState<boolean>(false);
const [inferenceEndpointError, setInferenceEndpointError] = useState<string | undefined>( const [inferenceEndpointError, setInferenceEndpointError] = useState<string | undefined>(
@ -304,7 +333,7 @@ export const SelectInferenceId = ({
data-test-subj={dataTestSubj} data-test-subj={dataTestSubj}
searchable searchable
isLoading={isLoading} isLoading={isLoading}
singleSelection singleSelection="always"
searchProps={{ searchProps={{
compressed: true, compressed: true,
placeholder: i18n.translate( placeholder: i18n.translate(
@ -340,32 +369,6 @@ export const SelectInferenceId = ({
<InferenceFlyoutWrapper <InferenceFlyoutWrapper
elserv2documentationUrl={docLinks.links.ml.nlpElser} elserv2documentationUrl={docLinks.links.ml.nlpElser}
e5documentationUrl={docLinks.links.ml.nlpE5} e5documentationUrl={docLinks.links.ml.nlpE5}
errorCallout={
inferenceAddError && (
<EuiFlexItem grow={false}>
<EuiCallOut
color="danger"
data-test-subj="addInferenceError"
iconType="error"
title={i18n.translate(
'xpack.idxMgmt.mappingsEditor.parameters.inferenceId.errorTitle',
{
defaultMessage: 'Error adding inference endpoint',
}
)}
>
<EuiText>
<FormattedMessage
id="xpack.idxMgmt.mappingsEditor.parameters.inferenceId.errorDescription"
defaultMessage="Error adding inference endpoint: {errorMessage}"
values={{ errorMessage: inferenceAddError }}
/>
</EuiText>
</EuiCallOut>
<EuiSpacer />
</EuiFlexItem>
)
}
onInferenceEndpointChange={onInferenceEndpointChange} onInferenceEndpointChange={onInferenceEndpointChange}
inferenceEndpointError={inferenceEndpointError} inferenceEndpointError={inferenceEndpointError}
trainedModels={trainedModels} trainedModels={trainedModels}
@ -374,7 +377,6 @@ export const SelectInferenceId = ({
isInferenceFlyoutVisible={isInferenceFlyoutVisible} isInferenceFlyoutVisible={isInferenceFlyoutVisible}
supportedNlpModels={docLinks.links.enterpriseSearch.supportedNlpModels} supportedNlpModels={docLinks.links.enterpriseSearch.supportedNlpModels}
nlpImportModel={docLinks.links.ml.nlpImportModel} nlpImportModel={docLinks.links.ml.nlpImportModel}
isCreateInferenceApiLoading={isCreateInferenceApiLoading}
setInferenceEndpointError={setInferenceEndpointError} setInferenceEndpointError={setInferenceEndpointError}
/> />
)} )}

View file

@ -14,14 +14,20 @@ import {
EuiSpacer, EuiSpacer,
} from '@elastic/eui'; } from '@elastic/eui';
import { i18n } from '@kbn/i18n'; import { i18n } from '@kbn/i18n';
import { ElasticsearchModelDefaultOptions } from '@kbn/inference_integration_flyout/types';
import { MlPluginStart } from '@kbn/ml-plugin/public'; import { MlPluginStart } from '@kbn/ml-plugin/public';
import classNames from 'classnames'; import classNames from 'classnames';
import React, { useCallback, useEffect } from 'react'; import React, { useCallback, useEffect, useState } from 'react';
import { EUI_SIZE, TYPE_DEFINITION } from '../../../../constants'; import { EUI_SIZE, TYPE_DEFINITION } from '../../../../constants';
import { fieldSerializer } from '../../../../lib'; import { fieldSerializer } from '../../../../lib';
import { useDispatch, useMappingsState } from '../../../../mappings_state_context'; import { useDispatch, useMappingsState } from '../../../../mappings_state_context';
import { Form, FormDataProvider, UseField, useForm, useFormData } from '../../../../shared_imports'; import { Form, FormDataProvider, UseField, useForm, useFormData } from '../../../../shared_imports';
import { Field, MainType, NormalizedFields } from '../../../../types'; import {
CustomInferenceEndpointConfig,
Field,
MainType,
NormalizedFields,
} from '../../../../types';
import { NameParameter, SubTypeParameter, TypeParameter } from '../../field_parameters'; import { NameParameter, SubTypeParameter, TypeParameter } from '../../field_parameters';
import { ReferenceFieldSelects } from '../../field_parameters/reference_field_selects'; import { ReferenceFieldSelects } from '../../field_parameters/reference_field_selects';
import { SelectInferenceId } from '../../field_parameters/select_inference_id'; import { SelectInferenceId } from '../../field_parameters/select_inference_id';
@ -32,10 +38,9 @@ import { useSemanticText } from './semantic_text/use_semantic_text';
const formWrapper = (props: any) => <form {...props} />; const formWrapper = (props: any) => <form {...props} />;
export interface InferenceToModelIdMap { export interface InferenceToModelIdMap {
[key: string]: { [key: string]: {
trainedModelId?: string; trainedModelId: ElasticsearchModelDefaultOptions | string;
isDeployed: boolean; isDeployed: boolean;
isDeployable: boolean; isDeployable: boolean;
defaultInferenceEndpoint: boolean;
}; };
} }
@ -88,7 +93,9 @@ export const CreateField = React.memo(function CreateFieldComponent({
return subscription.unsubscribe; return subscription.unsubscribe;
}, [dispatch, subscribe]); }, [dispatch, subscribe]);
const [customInferenceEndpointConfig, setCustomInferenceEndpointConfig] = useState<
CustomInferenceEndpointConfig | undefined
>(undefined);
const cancel = () => { const cancel = () => {
if (isAddingFields && onCancelAddingNewFields) { if (isAddingFields && onCancelAddingNewFields) {
onCancelAddingNewFields(); onCancelAddingNewFields();
@ -125,7 +132,7 @@ export const CreateField = React.memo(function CreateFieldComponent({
form.reset(); form.reset();
if (data.type === 'semantic_text' && !clickOutside) { if (data.type === 'semantic_text' && !clickOutside) {
handleSemanticText(data); handleSemanticText(data, customInferenceEndpointConfig);
} else { } else {
dispatch({ type: 'field.add', value: data }); dispatch({ type: 'field.add', value: data });
} }
@ -283,7 +290,10 @@ export const CreateField = React.memo(function CreateFieldComponent({
}} }}
</FormDataProvider> </FormDataProvider>
{/* Field inference_id for semantic_text field type */} {/* Field inference_id for semantic_text field type */}
<InferenceIdCombo setValue={setInferenceValue} /> <InferenceIdCombo
setValue={setInferenceValue}
setCustomInferenceEndpointConfig={setCustomInferenceEndpointConfig}
/>
{renderFormActions()} {renderFormActions()}
</div> </div>
</div> </div>
@ -311,16 +321,20 @@ function ReferenceFieldCombo({ indexName }: { indexName?: string }) {
interface InferenceProps { interface InferenceProps {
setValue: (value: string) => void; setValue: (value: string) => void;
setCustomInferenceEndpointConfig: (config: CustomInferenceEndpointConfig) => void;
} }
function InferenceIdCombo({ setValue }: InferenceProps) { function InferenceIdCombo({ setValue, setCustomInferenceEndpointConfig }: InferenceProps) {
const { inferenceToModelIdMap } = useMappingsState(); const { inferenceToModelIdMap } = useMappingsState();
const dispatch = useDispatch(); const dispatch = useDispatch();
const [{ type }] = useFormData({ watch: 'type' }); const [{ type }] = useFormData({ watch: 'type' });
// update new inferenceEndpoint // update new inferenceEndpoint
const setNewInferenceEndpoint = useCallback( const setNewInferenceEndpoint = useCallback(
(newInferenceEndpoint: InferenceToModelIdMap) => { (
newInferenceEndpoint: InferenceToModelIdMap,
customInferenceEndpointConfig: CustomInferenceEndpointConfig
) => {
dispatch({ dispatch({
type: 'inferenceToModelIdMap.update', type: 'inferenceToModelIdMap.update',
value: { value: {
@ -330,8 +344,9 @@ function InferenceIdCombo({ setValue }: InferenceProps) {
}, },
}, },
}); });
setCustomInferenceEndpointConfig(customInferenceEndpointConfig);
}, },
[dispatch, inferenceToModelIdMap] [dispatch, inferenceToModelIdMap, setCustomInferenceEndpointConfig]
); );
if (type === undefined || type[0]?.value !== 'semantic_text') { if (type === undefined || type[0]?.value !== 'semantic_text') {

View file

@ -6,7 +6,7 @@
*/ */
import { renderHook } from '@testing-library/react-hooks'; import { renderHook } from '@testing-library/react-hooks';
import { Field } from '../../../../../types'; import { CustomInferenceEndpointConfig, Field } from '../../../../../types';
import { useSemanticText } from './use_semantic_text'; import { useSemanticText } from './use_semantic_text';
import { act } from 'react-dom/test-utils'; import { act } from 'react-dom/test-utils';
@ -15,22 +15,54 @@ const mlMock: any = {
inferenceModels: { inferenceModels: {
createInferenceEndpoint: jest.fn().mockResolvedValue({}), createInferenceEndpoint: jest.fn().mockResolvedValue({}),
}, },
trainedModels: {
startModelAllocation: jest.fn().mockResolvedValue({}),
getTrainedModels: jest.fn().mockResolvedValue([
{
fully_defined: true,
},
]),
},
}, },
}; };
const mockFieldData = { const mockField: Record<string, Field> = {
name: 'name', elser_model_2: {
type: 'semantic_text', name: 'name',
inferenceId: 'elser_model_2', type: 'semantic_text',
} as Field; inferenceId: 'elser_model_2',
},
e5: {
name: 'name',
type: 'semantic_text',
inferenceId: 'e5',
},
openai: {
name: 'name',
type: 'semantic_text',
inferenceId: 'openai',
},
my_elser_endpoint: {
name: 'name',
type: 'semantic_text',
inferenceId: 'my_elser_endpoint',
},
};
const mockConfig: Record<string, CustomInferenceEndpointConfig> = {
openai: {
taskType: 'text_embedding',
modelConfig: {
service: 'openai',
service_settings: {
api_key: 'test',
model_id: 'text-embedding-ada-002',
},
},
},
elser: {
taskType: 'sparse_embedding',
modelConfig: {
service: 'elser',
service_settings: {
num_allocations: 1,
num_threads: 1,
},
},
},
};
const mockDispatch = jest.fn(); const mockDispatch = jest.fn();
@ -38,13 +70,21 @@ jest.mock('../../../../../mappings_state_context', () => ({
useMappingsState: jest.fn().mockReturnValue({ useMappingsState: jest.fn().mockReturnValue({
inferenceToModelIdMap: { inferenceToModelIdMap: {
e5: { e5: {
defaultInferenceEndpoint: false,
isDeployed: false, isDeployed: false,
isDeployable: true, isDeployable: true,
trainedModelId: '.multilingual-e5-small', trainedModelId: '.multilingual-e5-small',
}, },
elser_model_2: { elser_model_2: {
defaultInferenceEndpoint: true, isDeployed: false,
isDeployable: true,
trainedModelId: '.elser_model_2',
},
openai: {
isDeployed: false,
isDeployable: false,
trainedModelId: '',
},
my_elser_endpoint: {
isDeployed: false, isDeployed: false,
isDeployable: true, isDeployable: true,
trainedModelId: '.elser_model_2', trainedModelId: '.elser_model_2',
@ -63,24 +103,108 @@ jest.mock('../../../../../../component_templates/component_templates_context', (
}), }),
})); }));
jest.mock('../../../../../../../services/api', () => ({
getInferenceModels: jest.fn().mockResolvedValue({
data: [
{
model_id: 'e5',
task_type: 'text_embedding',
service: 'elasticsearch',
service_settings: {
num_allocations: 1,
num_threads: 1,
model_id: '.multilingual-e5-small',
},
task_settings: {},
},
],
}),
}));
describe('useSemanticText', () => { describe('useSemanticText', () => {
let form: any; let mockForm: any;
beforeEach(() => { beforeEach(() => {
jest.clearAllMocks(); jest.clearAllMocks();
form = { mockForm = {
getFields: jest.fn().mockReturnValue({ form: {
referenceField: { value: 'title' }, getFields: jest.fn().mockReturnValue({
name: { value: 'sem' }, referenceField: { value: 'title' },
type: { value: [{ value: 'semantic_text' }] }, name: { value: 'sem' },
inferenceId: { value: 'e5' }, type: { value: [{ value: 'semantic_text' }] },
}), inferenceId: { value: 'e5' },
}),
},
thirdPartyModel: {
getFields: jest.fn().mockReturnValue({
referenceField: { value: 'title' },
name: { value: 'semantic_text_openai_endpoint' },
type: { value: [{ value: 'semantic_text' }] },
inferenceId: { value: 'openai' },
}),
},
elasticModelEndpointCreatedfromFlyout: {
getFields: jest.fn().mockReturnValue({
referenceField: { value: 'title' },
name: { value: 'semantic_text_elserServiceType_endpoint' },
type: { value: [{ value: 'semantic_text' }] },
inferenceId: { value: 'my_elser_endpoint' },
}),
},
}; };
}); });
it('should handle semantic text with third party model correctly', async () => {
const { result } = renderHook(() =>
useSemanticText({
form: mockForm.thirdPartyModel,
setErrorsInTrainedModelDeployment: jest.fn(),
ml: mlMock,
})
);
await act(async () => {
result.current.setInferenceValue('openai');
result.current.handleSemanticText(mockField.openai, mockConfig.openai);
});
expect(mockDispatch).toHaveBeenCalledWith({
type: 'field.addSemanticText',
value: mockField.openai,
});
expect(mlMock.mlApi.inferenceModels.createInferenceEndpoint).toHaveBeenCalledWith(
'openai',
'text_embedding',
mockConfig.openai.modelConfig
);
});
it('should handle semantic text with inference endpoint created from flyout correctly', async () => {
const { result } = renderHook(() =>
useSemanticText({
form: mockForm.elasticModelEndpointCreatedfromFlyout,
setErrorsInTrainedModelDeployment: jest.fn(),
ml: mlMock,
})
);
await act(async () => {
result.current.setInferenceValue('my_elser_endpoint');
result.current.handleSemanticText(mockField.my_elser_endpoint, mockConfig.elser);
});
expect(mockDispatch).toHaveBeenCalledWith({
type: 'field.addSemanticText',
value: mockField.my_elser_endpoint,
});
expect(mlMock.mlApi.inferenceModels.createInferenceEndpoint).toHaveBeenCalledWith(
'my_elser_endpoint',
'sparse_embedding',
mockConfig.elser.modelConfig
);
});
it('should populate the values from the form', () => { it('should populate the values from the form', () => {
const { result } = renderHook(() => const { result } = renderHook(() =>
useSemanticText({ form, setErrorsInTrainedModelDeployment: jest.fn(), ml: mlMock }) useSemanticText({
form: mockForm.form,
setErrorsInTrainedModelDeployment: jest.fn(),
ml: mlMock,
})
); );
expect(result.current.referenceFieldComboValue).toBe('title'); expect(result.current.referenceFieldComboValue).toBe('title');
@ -91,23 +215,26 @@ describe('useSemanticText', () => {
it('should handle semantic text correctly', async () => { it('should handle semantic text correctly', async () => {
const { result } = renderHook(() => const { result } = renderHook(() =>
useSemanticText({ form, setErrorsInTrainedModelDeployment: jest.fn(), ml: mlMock }) useSemanticText({
form: mockForm.form,
setErrorsInTrainedModelDeployment: jest.fn(),
ml: mlMock,
})
); );
await act(async () => { await act(async () => {
result.current.handleSemanticText(mockFieldData); result.current.handleSemanticText(mockField.elser_model_2);
}); });
expect(mlMock.mlApi.trainedModels.startModelAllocation).toHaveBeenCalledWith('.elser_model_2');
expect(mockDispatch).toHaveBeenCalledWith({ expect(mockDispatch).toHaveBeenCalledWith({
type: 'field.addSemanticText', type: 'field.addSemanticText',
value: mockFieldData, value: mockField.elser_model_2,
}); });
expect(mlMock.mlApi.inferenceModels.createInferenceEndpoint).toHaveBeenCalledWith( expect(mlMock.mlApi.inferenceModels.createInferenceEndpoint).toHaveBeenCalledWith(
'elser_model_2', 'elser_model_2',
'text_embedding', 'sparse_embedding',
{ {
service: 'elasticsearch', service: 'elser',
service_settings: { service_settings: {
num_allocations: 1, num_allocations: 1,
num_threads: 1, num_threads: 1,
@ -116,68 +243,42 @@ describe('useSemanticText', () => {
} }
); );
}); });
it('does not call create inference endpoint api, if default endpoint already exists', async () => {
it('should invoke the download api if the model does not exist', async () => {
const mlMockWithModelNotDownloaded: any = {
mlApi: {
inferenceModels: {
createInferenceEndpoint: jest.fn(),
},
trainedModels: {
startModelAllocation: jest.fn(),
getTrainedModels: jest.fn().mockResolvedValue([
{
fully_defined: false,
},
]),
installElasticTrainedModelConfig: jest.fn().mockResolvedValue({}),
},
},
};
const { result } = renderHook(() => const { result } = renderHook(() =>
useSemanticText({ useSemanticText({
form, form: mockForm.form,
setErrorsInTrainedModelDeployment: jest.fn(), setErrorsInTrainedModelDeployment: jest.fn(),
ml: mlMockWithModelNotDownloaded, ml: mlMock,
}) })
); );
await act(async () => { await act(async () => {
result.current.handleSemanticText(mockFieldData); result.current.setInferenceValue('e5');
result.current.handleSemanticText(mockField.e5);
}); });
expect( expect(mockDispatch).toHaveBeenCalledWith({
mlMockWithModelNotDownloaded.mlApi.trainedModels.installElasticTrainedModelConfig type: 'field.addSemanticText',
).toHaveBeenCalledWith('.elser_model_2'); value: mockField.e5,
expect(
mlMockWithModelNotDownloaded.mlApi.trainedModels.startModelAllocation
).toHaveBeenCalledWith('.elser_model_2');
expect(
mlMockWithModelNotDownloaded.mlApi.inferenceModels.createInferenceEndpoint
).toHaveBeenCalledWith('elser_model_2', 'text_embedding', {
service: 'elasticsearch',
service_settings: {
num_allocations: 1,
num_threads: 1,
model_id: '.elser_model_2',
},
}); });
expect(mlMock.mlApi.inferenceModels.createInferenceEndpoint).not.toBeCalled();
}); });
it('handles errors correctly', async () => { it('handles errors correctly', async () => {
const mockError = new Error('Test error'); const mockError = new Error('Test error');
mlMock.mlApi?.trainedModels.startModelAllocation.mockImplementationOnce(() => { mlMock.mlApi?.inferenceModels.createInferenceEndpoint.mockImplementationOnce(() => {
throw mockError; throw mockError;
}); });
const setErrorsInTrainedModelDeployment = jest.fn(); const setErrorsInTrainedModelDeployment = jest.fn();
const { result } = renderHook(() => const { result } = renderHook(() =>
useSemanticText({ form, setErrorsInTrainedModelDeployment, ml: mlMock }) useSemanticText({ form: mockForm.form, setErrorsInTrainedModelDeployment, ml: mlMock })
); );
await act(async () => { await act(async () => {
result.current.handleSemanticText(mockFieldData); result.current.handleSemanticText(mockField.elser_model_2);
}); });
expect(setErrorsInTrainedModelDeployment).toHaveBeenCalledWith(expect.any(Function)); expect(setErrorsInTrainedModelDeployment).toHaveBeenCalledWith(expect.any(Function));

View file

@ -7,30 +7,39 @@
import { i18n } from '@kbn/i18n'; import { i18n } from '@kbn/i18n';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { MlPluginStart, TrainedModelConfigResponse } from '@kbn/ml-plugin/public'; import { MlPluginStart } from '@kbn/ml-plugin/public';
import React, { useEffect, useState } from 'react'; import React, { useEffect, useState } from 'react';
import { useComponentTemplatesContext } from '../../../../../../component_templates/component_templates_context'; import { ElasticsearchModelDefaultOptions } from '@kbn/inference_integration_flyout/types';
import { InferenceTaskType } from '@elastic/elasticsearch/lib/api/types';
import { InferenceAPIConfigResponse } from '@kbn/ml-trained-models-utils';
import { useDispatch, useMappingsState } from '../../../../../mappings_state_context'; import { useDispatch, useMappingsState } from '../../../../../mappings_state_context';
import { FormHook } from '../../../../../shared_imports'; import { FormHook } from '../../../../../shared_imports';
import { Field } from '../../../../../types'; import { CustomInferenceEndpointConfig, DefaultInferenceModels, Field } from '../../../../../types';
import { useMLModelNotificationToasts } from '../../../../../../../../hooks/use_ml_model_status_toasts';
import { getInferenceModels } from '../../../../../../../services/api';
interface UseSemanticTextProps { interface UseSemanticTextProps {
form: FormHook<Field, Field>; form: FormHook<Field, Field>;
ml?: MlPluginStart; ml?: MlPluginStart;
setErrorsInTrainedModelDeployment: React.Dispatch<React.SetStateAction<string[]>> | undefined; setErrorsInTrainedModelDeployment: React.Dispatch<React.SetStateAction<string[]>> | undefined;
} }
interface DefaultInferenceEndpointConfig {
taskType: InferenceTaskType;
service: string;
}
export function useSemanticText(props: UseSemanticTextProps) { export function useSemanticText(props: UseSemanticTextProps) {
const { form, setErrorsInTrainedModelDeployment, ml } = props; const { form, setErrorsInTrainedModelDeployment, ml } = props;
const { inferenceToModelIdMap } = useMappingsState(); const { inferenceToModelIdMap } = useMappingsState();
const { toasts } = useComponentTemplatesContext();
const dispatch = useDispatch(); const dispatch = useDispatch();
const [referenceFieldComboValue, setReferenceFieldComboValue] = useState<string>(); const [referenceFieldComboValue, setReferenceFieldComboValue] = useState<string>();
const [nameValue, setNameValue] = useState<string>(); const [nameValue, setNameValue] = useState<string>();
const [inferenceIdComboValue, setInferenceIdComboValue] = useState<string>(); const [inferenceIdComboValue, setInferenceIdComboValue] = useState<string>();
const [semanticFieldType, setSemanticTextFieldType] = useState<string>(); const [semanticFieldType, setSemanticTextFieldType] = useState<string>();
const [inferenceValue, setInferenceValue] = useState<string>('elser_model_2'); const [inferenceValue, setInferenceValue] = useState<string>(
DefaultInferenceModels.elser_model_2
);
const { showSuccessToasts, showErrorToasts } = useMLModelNotificationToasts();
const useFieldEffect = ( const useFieldEffect = (
semanticTextform: FormHook, semanticTextform: FormHook,
@ -65,113 +74,92 @@ export function useSemanticText(props: UseSemanticTextProps) {
} }
}, [form, inferenceId, inferenceToModelIdMap]); }, [form, inferenceId, inferenceToModelIdMap]);
const isModelDownloaded = useCallback( const createInferenceEndpoint = useCallback(
async (modelId: string) => { async (
try { trainedModelId: ElasticsearchModelDefaultOptions | string,
const response: TrainedModelConfigResponse[] | undefined = data: Field,
await ml?.mlApi?.trainedModels.getTrainedModels(modelId, { customInferenceEndpointConfig?: CustomInferenceEndpointConfig
include: 'definition_status', ) => {
}); if (data.inferenceId === undefined) {
return !!response?.[0]?.fully_defined; throw new Error(
} catch (error) { i18n.translate('xpack.idxMgmt.mappingsEditor.createField.undefinedInferenceIdError', {
if (error.body.statusCode !== 404) { defaultMessage: 'InferenceId is undefined while creating the inference endpoint.',
throw error; })
} );
} }
return false; const defaultInferenceEndpointConfig: DefaultInferenceEndpointConfig = {
}, service:
[ml?.mlApi?.trainedModels] trainedModelId === ElasticsearchModelDefaultOptions.elser ? 'elser' : 'elasticsearch',
); taskType:
trainedModelId === ElasticsearchModelDefaultOptions.elser
const createInferenceEndpoint = ( ? 'sparse_embedding'
trainedModelId: string, : 'text_embedding',
defaultInferenceEndpoint: boolean,
data: Field
) => {
if (data.inferenceId === undefined) {
throw new Error(
i18n.translate('xpack.idxMgmt.mappingsEditor.createField.undefinedInferenceIdError', {
defaultMessage: 'InferenceId is undefined while creating the inference endpoint.',
})
);
}
if (trainedModelId && defaultInferenceEndpoint) {
const modelConfig = {
service: 'elasticsearch',
service_settings: {
num_allocations: 1,
num_threads: 1,
model_id: trainedModelId,
},
}; };
ml?.mlApi?.inferenceModels?.createInferenceEndpoint( const modelConfig = customInferenceEndpointConfig
data.inferenceId, ? customInferenceEndpointConfig.modelConfig
'text_embedding', : {
modelConfig service: defaultInferenceEndpointConfig.service,
); service_settings: {
} num_allocations: 1,
}; num_threads: 1,
model_id: trainedModelId,
},
};
const taskType: InferenceTaskType =
customInferenceEndpointConfig?.taskType ?? defaultInferenceEndpointConfig.taskType;
try {
await ml?.mlApi?.inferenceModels?.createInferenceEndpoint(
data.inferenceId,
taskType,
modelConfig
);
} catch (error) {
throw error;
}
},
[ml?.mlApi?.inferenceModels]
);
const handleSemanticText = async (data: Field) => { const handleSemanticText = async (
data: Field,
customInferenceEndpointConfig?: CustomInferenceEndpointConfig
) => {
data.inferenceId = inferenceValue; data.inferenceId = inferenceValue;
if (data.inferenceId === undefined) { if (data.inferenceId === undefined) {
return; return;
} }
const inferenceData = inferenceToModelIdMap?.[data.inferenceId]; const inferenceData = inferenceToModelIdMap?.[data.inferenceId];
if (!inferenceData) { if (!inferenceData) {
return; return;
} }
const { trainedModelId, defaultInferenceEndpoint, isDeployed, isDeployable } = inferenceData; const { trainedModelId } = inferenceData;
if (isDeployable && trainedModelId) {
try {
const modelDownloaded: boolean = await isModelDownloaded(trainedModelId);
if (isDeployed) {
createInferenceEndpoint(trainedModelId, defaultInferenceEndpoint, data);
} else if (modelDownloaded) {
ml?.mlApi?.trainedModels
.startModelAllocation(trainedModelId)
.then(() => createInferenceEndpoint(trainedModelId, defaultInferenceEndpoint, data));
} else {
ml?.mlApi?.trainedModels
.installElasticTrainedModelConfig(trainedModelId)
.then(() => ml?.mlApi?.trainedModels.startModelAllocation(trainedModelId))
.then(() => createInferenceEndpoint(trainedModelId, defaultInferenceEndpoint, data));
}
toasts?.addSuccess({
title: i18n.translate(
'xpack.idxMgmt.mappingsEditor.createField.modelDeploymentStartedNotification',
{
defaultMessage: 'Model deployment started',
}
),
text: i18n.translate(
'xpack.idxMgmt.mappingsEditor.createField.modelDeploymentNotification',
{
defaultMessage: '1 model is being deployed on your ml_node.',
}
),
});
} catch (error) {
setErrorsInTrainedModelDeployment?.((prevItems) => [...prevItems, trainedModelId]);
toasts?.addError(error.body && error.body.message ? new Error(error.body.message) : error, {
title: i18n.translate(
'xpack.idxMgmt.mappingsEditor.createField.modelDeploymentErrorTitle',
{
defaultMessage: 'Model deployment failed',
}
),
});
}
}
dispatch({ type: 'field.addSemanticText', value: data }); dispatch({ type: 'field.addSemanticText', value: data });
try {
// if model exists already, do not create inference endpoint
const inferenceModels = await getInferenceModels();
const inferenceModel: InferenceAPIConfigResponse[] = inferenceModels.data.some(
(e: InferenceAPIConfigResponse) => e.model_id === inferenceValue
);
if (inferenceModel) {
return;
}
if (trainedModelId) {
// show toasts only if it's elastic models
showSuccessToasts();
}
await createInferenceEndpoint(trainedModelId, data, customInferenceEndpointConfig);
} catch (error) {
// trainedModelId is empty string when it's a third party model
if (trainedModelId) {
setErrorsInTrainedModelDeployment?.((prevItems) => [...prevItems, trainedModelId]);
}
showErrorToasts(error);
}
}; };
return { return {

View file

@ -26,6 +26,7 @@ export const FIELD_TYPES_OPTIONS = Object.entries(MAIN_DATA_TYPE_DEFINITION).map
([dataType, { label }]) => ({ ([dataType, { label }]) => ({
value: dataType, value: dataType,
label, label,
'data-test-subj': `fieldTypesOptions-${dataType}`,
}) })
) as ComboBoxOption[]; ) as ComboBoxOption[];

View file

@ -7,6 +7,8 @@
import { ReactNode } from 'react'; import { ReactNode } from 'react';
import { InferenceTaskType } from '@elastic/elasticsearch/lib/api/types';
import { ModelConfig } from '@kbn/inference_integration_flyout';
import { GenericObject } from './mappings_editor'; import { GenericObject } from './mappings_editor';
import { PARAMETERS_DEFINITION } from '../constants'; import { PARAMETERS_DEFINITION } from '../constants';
@ -246,3 +248,16 @@ export interface NormalizedRuntimeField {
export interface NormalizedRuntimeFields { export interface NormalizedRuntimeFields {
[id: string]: NormalizedRuntimeField; [id: string]: NormalizedRuntimeField;
} }
export enum DefaultInferenceModels {
elser_model_2 = 'elser_model_2',
e5 = 'e5',
}
export enum DeploymentState {
'DEPLOYED' = 'deployed',
'NOT_DEPLOYED' = 'not_deployed',
}
export interface CustomInferenceEndpointConfig {
taskType: InferenceTaskType;
modelConfig: ModelConfig;
}

View file

@ -84,9 +84,9 @@ export function TrainedModelsDeploymentModal({
onCancel={closeModal} onCancel={closeModal}
onConfirm={refreshModal} onConfirm={refreshModal}
cancelButtonText={i18n.translate( cancelButtonText={i18n.translate(
'xpack.idxMgmt.indexDetails.trainedModelsDeploymentModal.cancelButtonLabel', 'xpack.idxMgmt.indexDetails.trainedModelsDeploymentModal.closeButtonLabel',
{ {
defaultMessage: 'Cancel', defaultMessage: 'Close',
} }
)} )}
confirmButtonText={i18n.translate( confirmButtonText={i18n.translate(

View file

@ -104,13 +104,11 @@ const inferenceToModelIdMap = {
trainedModelId: '.elser_model_2', trainedModelId: '.elser_model_2',
isDeployed: true, isDeployed: true,
isDeployable: true, isDeployable: true,
defaultInferenceEndpoint: false,
}, },
e5: { e5: {
trainedModelId: '.multilingual-e5-small', trainedModelId: '.multilingual-e5-small',
isDeployed: true, isDeployed: true,
isDeployable: true, isDeployable: true,
defaultInferenceEndpoint: false,
}, },
} as InferenceToModelIdMap; } as InferenceToModelIdMap;
@ -127,13 +125,11 @@ describe('useDetailsPageMappingsModelManagement', () => {
value: { value: {
inferenceToModelIdMap: { inferenceToModelIdMap: {
e5: { e5: {
defaultInferenceEndpoint: false,
isDeployed: false, isDeployed: false,
isDeployable: true, isDeployable: true,
trainedModelId: '.multilingual-e5-small', trainedModelId: '.multilingual-e5-small',
}, },
elser_model_2: { elser_model_2: {
defaultInferenceEndpoint: true,
isDeployed: true, isDeployed: true,
isDeployable: true, isDeployable: true,
trainedModelId: '.elser_model_2', trainedModelId: '.elser_model_2',

View file

@ -13,14 +13,18 @@ import { useAppContext } from '../application/app_context';
import { InferenceToModelIdMap } from '../application/components/mappings_editor/components/document_fields/fields'; import { InferenceToModelIdMap } from '../application/components/mappings_editor/components/document_fields/fields';
import { deNormalize } from '../application/components/mappings_editor/lib'; import { deNormalize } from '../application/components/mappings_editor/lib';
import { useDispatch } from '../application/components/mappings_editor/mappings_state_context'; import { useDispatch } from '../application/components/mappings_editor/mappings_state_context';
import { NormalizedFields } from '../application/components/mappings_editor/types'; import {
DefaultInferenceModels,
DeploymentState,
NormalizedFields,
} from '../application/components/mappings_editor/types';
import { getInferenceModels } from '../application/services/api'; import { getInferenceModels } from '../application/services/api';
interface InferenceModel { interface InferenceModel {
data: InferenceAPIConfigResponse[]; data: InferenceAPIConfigResponse[];
} }
type DeploymentStatusType = Record<string, 'deployed' | 'not_deployed'>; type DeploymentStatusType = Record<string, DeploymentState>;
const getCustomInferenceIdMap = ( const getCustomInferenceIdMap = (
deploymentStatsByModelId: DeploymentStatusType, deploymentStatsByModelId: DeploymentStatusType,
@ -39,7 +43,6 @@ const getCustomInferenceIdMap = (
trainedModelId, trainedModelId,
isDeployable: model.service === Service.elser || model.service === Service.elasticsearch, isDeployable: model.service === Service.elser || model.service === Service.elasticsearch,
isDeployed: deploymentStatsByModelId[trainedModelId] === 'deployed', isDeployed: deploymentStatsByModelId[trainedModelId] === 'deployed',
defaultInferenceEndpoint: false,
}; };
return inferenceMap; return inferenceMap;
}, {}); }, {});
@ -50,7 +53,9 @@ export const getTrainedModelStats = (modelStats?: InferenceStatsResponse): Deplo
modelStats?.trained_model_stats.reduce<DeploymentStatusType>((acc, modelStat) => { modelStats?.trained_model_stats.reduce<DeploymentStatusType>((acc, modelStat) => {
if (modelStat.model_id) { if (modelStat.model_id) {
acc[modelStat.model_id] = acc[modelStat.model_id] =
modelStat?.deployment_stats?.state === 'started' ? 'deployed' : 'not_deployed'; modelStat?.deployment_stats?.state === 'started'
? DeploymentState.DEPLOYED
: DeploymentState.NOT_DEPLOYED;
} }
return acc; return acc;
}, {}) || {} }, {}) || {}
@ -59,17 +64,18 @@ export const getTrainedModelStats = (modelStats?: InferenceStatsResponse): Deplo
const getDefaultInferenceIds = (deploymentStatsByModelId: DeploymentStatusType) => { const getDefaultInferenceIds = (deploymentStatsByModelId: DeploymentStatusType) => {
return { return {
elser_model_2: { [DefaultInferenceModels.elser_model_2]: {
trainedModelId: '.elser_model_2', trainedModelId: ElasticsearchModelDefaultOptions.elser,
isDeployable: true, isDeployable: true,
isDeployed: deploymentStatsByModelId['.elser_model_2'] === 'deployed', isDeployed:
defaultInferenceEndpoint: true, deploymentStatsByModelId[ElasticsearchModelDefaultOptions.elser] ===
DeploymentState.DEPLOYED,
}, },
e5: { [DefaultInferenceModels.e5]: {
trainedModelId: '.multilingual-e5-small', trainedModelId: ElasticsearchModelDefaultOptions.e5,
isDeployable: true, isDeployable: true,
isDeployed: deploymentStatsByModelId['.multilingual-e5-small'] === 'deployed', isDeployed:
defaultInferenceEndpoint: true, deploymentStatsByModelId[ElasticsearchModelDefaultOptions.e5] === DeploymentState.DEPLOYED,
}, },
}; };
}; };

View file

@ -0,0 +1,36 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { i18n } from '@kbn/i18n';
import { ErrorType, extractErrorProperties, MLRequestFailure } from '@kbn/ml-error-utils';
import { useComponentTemplatesContext } from '../application/components/component_templates/component_templates_context';
export function useMLModelNotificationToasts() {
const { toasts } = useComponentTemplatesContext();
const showSuccessToasts = () => {
return toasts.addSuccess({
title: i18n.translate(
'xpack.idxMgmt.mappingsEditor.createField.modelDeploymentStartedNotification',
{
defaultMessage: 'Model deployment started',
}
),
text: i18n.translate('xpack.idxMgmt.mappingsEditor.createField.modelDeploymentNotification', {
defaultMessage: '1 model is being deployed on your ml_node.',
}),
});
};
const showErrorToasts = (error: ErrorType) => {
const errorObj = extractErrorProperties(error);
return toasts.addError(new MLRequestFailure(errorObj, error), {
title: i18n.translate('xpack.idxMgmt.mappingsEditor.createField.modelDeploymentErrorTitle', {
defaultMessage: 'Model deployment failed',
}),
});
};
return { showSuccessToasts, showErrorToasts };
}

View file

@ -49,9 +49,9 @@
"@kbn/utility-types", "@kbn/utility-types",
"@kbn/inference_integration_flyout", "@kbn/inference_integration_flyout",
"@kbn/ml-plugin", "@kbn/ml-plugin",
"@kbn/ml-error-utils",
"@kbn/react-kibana-context-render", "@kbn/react-kibana-context-render",
"@kbn/react-kibana-mount" "@kbn/react-kibana-mount",
"@kbn/ml-error-utils",
], ],
"exclude": ["target/**/*"] "exclude": ["target/**/*"]
} }