mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 17:59:23 -04:00
[Index management] Make create Inference endpoint from flyout async task (#184615)
Currently, when a new inference endpoint is created from inference
flyout, the flyout stays open until the endpoint is created. This may
take long time when Elasticsearch models - `.elser_model_2` and
`.multilingual-e5-small` is to be downloaded, deployed and inference
endpoint is to be created.
In this PR, When a new inference endpoint is saved, inference flyout is
closed and the new inference endpoint is created by a callback function
in the component.
**Screen Recording**
8eabba1a
-108a-4bf2-813a-66ceb291467c
**Testing instructions**
**update Elasticsearch to latest (only to test save mappings)**
Since ES changes for the semantic_text has been merged to main, this can
be tested against running ES from source or from latest snapshot
- Update local branch with latest Elasticsearch changes from main
- Run the elasticsearch: ./gradlew :run -Drun.license_type=trial
- Manual test in UI
**Frontend**
- enable` xpack.index_management.dev.enableSemanticText` to true in
`config/kibana.dev.yml`
- Add a new field with type - Semantic_text
- Click on drop down menu below `Select an inference endpoint`
- Click Add inference Endpoint
- Type new inference endpoint name and click Save endpoint
- Save endpoint button should close the flyout
- A new success notification toasts is shown with text "1 model is being
deployed on your ml_node."
- Add new field
- Click Save mappings
- should show a modal with model deployment status
- After new endpoint is created, refresh button should hide the modal
and save mappings should update mappings
---------
Co-authored-by: Kibana Machine <42973632+kibanamachine@users.noreply.github.com>
This commit is contained in:
parent
51e84f49ba
commit
c091dd89ff
18 changed files with 601 additions and 309 deletions
|
@ -78,7 +78,7 @@ export interface SaveMappingOnClick {
|
|||
taskType: InferenceTaskType,
|
||||
modelConfig: ModelConfig
|
||||
) => void;
|
||||
isCreateInferenceApiLoading: boolean;
|
||||
isCreateInferenceApiLoading?: boolean;
|
||||
}
|
||||
export interface DocumentationProps {
|
||||
elserv2documentationUrl?: string;
|
||||
|
|
|
@ -54,7 +54,7 @@ export interface ElasticsearchService {
|
|||
export enum Service {
|
||||
cohere = 'cohere',
|
||||
elser = 'elser',
|
||||
huggingFace = 'huggingFace',
|
||||
huggingFace = 'hugging_face',
|
||||
openai = 'openai',
|
||||
elasticsearch = 'elasticsearch',
|
||||
}
|
||||
|
|
|
@ -209,6 +209,9 @@ const registerHttpRequestMockHelpers = (
|
|||
const setCreateIndexResponse = (response?: HttpResponse, error?: ResponseError) =>
|
||||
mockResponse('PUT', `${INTERNAL_API_BASE_PATH}/indices/create`, response, error);
|
||||
|
||||
const setInferenceModels = (response?: HttpResponse, error?: ResponseError) =>
|
||||
mockResponse('GET', `${API_BASE_PATH}/inference/all`, response, error);
|
||||
|
||||
return {
|
||||
setLoadTemplatesResponse,
|
||||
setLoadIndicesResponse,
|
||||
|
@ -238,6 +241,7 @@ const registerHttpRequestMockHelpers = (
|
|||
setGetFieldsFromIndices,
|
||||
setGetPrivilegesResponse,
|
||||
setCreateEnrichPolicy,
|
||||
setInferenceModels,
|
||||
};
|
||||
};
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ import {
|
|||
} from '@kbn/test-jest-helpers';
|
||||
import { HttpSetup } from '@kbn/core/public';
|
||||
import { act } from 'react-dom/test-utils';
|
||||
|
||||
import { keys } from '@elastic/eui';
|
||||
import { IndexDetailsTabId } from '../../../common/constants';
|
||||
import { IndexDetailsPage } from '../../../public/application/sections/home/index_list/details_page';
|
||||
import { WithAppDependencies } from '../helpers';
|
||||
|
@ -56,6 +56,12 @@ export interface IndexDetailsPageTestBed extends TestBed {
|
|||
setSearchBarValue: (searchValue: string) => Promise<void>;
|
||||
findSearchResult: () => string;
|
||||
isSemanticTextBannerVisible: () => boolean;
|
||||
selectSemanticTextField: (name: string, type: string) => Promise<void>;
|
||||
isReferenceFieldVisible: () => void;
|
||||
selectInferenceIdButtonExists: () => void;
|
||||
openSelectInferencePopover: () => void;
|
||||
expectDefaultInferenceModelToExists: () => void;
|
||||
expectCustomInferenceModelToExists: (customInference: string) => Promise<void>;
|
||||
};
|
||||
settings: {
|
||||
getCodeBlockContent: () => string;
|
||||
|
@ -228,7 +234,7 @@ export const setup = async ({
|
|||
component.update();
|
||||
},
|
||||
selectFilterFieldType: async (fieldType: string) => {
|
||||
expect(testBed.exists('indexDetailsMappingsSelectFilter-text')).toBe(true);
|
||||
expect(testBed.exists(fieldType)).toBe(true);
|
||||
await act(async () => {
|
||||
find(fieldType).simulate('click');
|
||||
});
|
||||
|
@ -287,6 +293,7 @@ export const setup = async ({
|
|||
|
||||
await act(async () => {
|
||||
expect(exists('createFieldForm.addButton')).toBe(true);
|
||||
expect(find('createFieldForm.addButton').props().disabled).toBeFalsy();
|
||||
find('createFieldForm.addButton').simulate('click');
|
||||
});
|
||||
|
||||
|
@ -294,6 +301,41 @@ export const setup = async ({
|
|||
}
|
||||
}
|
||||
},
|
||||
selectSemanticTextField: async (name: string, type: string) => {
|
||||
expect(exists('comboBoxSearchInput')).toBe(true);
|
||||
|
||||
const { form } = testBed;
|
||||
form.setInputValue('nameParameterInput', name);
|
||||
form.setInputValue('comboBoxSearchInput', type);
|
||||
await act(async () => {
|
||||
find('comboBoxSearchInput').simulate('keydown', { key: keys.ENTER });
|
||||
});
|
||||
// select semantic_text field
|
||||
await act(async () => {
|
||||
expect(exists('fieldTypesOptions-semantic_text')).toBe(true);
|
||||
find('fieldTypesOptions-semantic_text').simulate('click');
|
||||
expect(exists('fieldTypesOptions-semantic_text')).toBe(false);
|
||||
});
|
||||
},
|
||||
isReferenceFieldVisible: async () => {
|
||||
expect(exists('referenceField.select')).toBe(true);
|
||||
},
|
||||
selectInferenceIdButtonExists: async () => {
|
||||
expect(exists('selectInferenceId')).toBe(true);
|
||||
expect(exists('inferenceIdButton')).toBe(true);
|
||||
find('inferenceIdButton').simulate('click');
|
||||
},
|
||||
openSelectInferencePopover: async () => {
|
||||
expect(exists('addInferenceEndpointButton')).toBe(true);
|
||||
expect(exists('manageInferenceEndpointButton')).toBe(true);
|
||||
},
|
||||
expectDefaultInferenceModelToExists: async () => {
|
||||
expect(exists('default-inference_elser_model_2')).toBe(true);
|
||||
expect(exists('default-inference_e5')).toBe(true);
|
||||
},
|
||||
expectCustomInferenceModelToExists: async (customInference: string) => {
|
||||
expect(exists(customInference)).toBe(true);
|
||||
},
|
||||
};
|
||||
|
||||
const settings = {
|
||||
|
|
|
@ -18,6 +18,7 @@ import {
|
|||
breadcrumbService,
|
||||
IndexManagementBreadcrumb,
|
||||
} from '../../../public/application/services/breadcrumbs';
|
||||
import { documentationService } from '../../../public/application/services/documentation';
|
||||
import { humanizeTimeStamp } from '../../../public/application/sections/home/data_stream_list/humanize_time_stamp';
|
||||
import { createDataStreamPayload } from '../home/data_streams_tab.helpers';
|
||||
import {
|
||||
|
@ -58,6 +59,7 @@ describe('<IndexDetailsPage />', () => {
|
|||
let httpSetup: ReturnType<typeof setupEnvironment>['httpSetup'];
|
||||
let httpRequestsMockHelpers: ReturnType<typeof setupEnvironment>['httpRequestsMockHelpers'];
|
||||
jest.spyOn(breadcrumbService, 'setBreadcrumbs');
|
||||
jest.spyOn(documentationService, 'setup');
|
||||
|
||||
beforeEach(async () => {
|
||||
const mockEnvironment = setupEnvironment();
|
||||
|
@ -571,14 +573,9 @@ describe('<IndexDetailsPage />', () => {
|
|||
},
|
||||
};
|
||||
beforeEach(async () => {
|
||||
httpRequestsMockHelpers.setUpdateIndexMappingsResponse(testIndexName, {
|
||||
acknowledged: true,
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
testBed = await setup({ httpSetup });
|
||||
});
|
||||
|
||||
testBed.component.update();
|
||||
await testBed.actions.clickIndexDetailsTab(IndexDetailsSection.Mappings);
|
||||
await testBed.actions.mappings.clickAddFieldButton();
|
||||
|
@ -634,43 +631,6 @@ describe('<IndexDetailsPage />', () => {
|
|||
);
|
||||
});
|
||||
|
||||
it('can add a semantic_text field and can save mappings', async () => {
|
||||
const mockIndexMappingResponseForSemanticText: any = {
|
||||
...testIndexMappings.mappings,
|
||||
properties: {
|
||||
...testIndexMappings.mappings.properties,
|
||||
sem: {
|
||||
type: 'semantic_text',
|
||||
inference_id: 'my-elser',
|
||||
},
|
||||
},
|
||||
};
|
||||
httpRequestsMockHelpers.setLoadIndexMappingResponse(testIndexName, {
|
||||
mappings: mockIndexMappingResponseForSemanticText,
|
||||
});
|
||||
await testBed.actions.mappings.addNewMappingFieldNameAndType([
|
||||
{ name: 'sem', type: 'semantic_text' },
|
||||
]);
|
||||
await testBed.actions.mappings.clickSaveMappingsButton();
|
||||
// add field button is available again
|
||||
expect(testBed.exists('indexDetailsMappingsAddField')).toBe(true);
|
||||
expect(testBed.find('semField-datatype').props()['data-type-value']).toBe('semantic_text');
|
||||
expect(httpSetup.get).toHaveBeenCalledTimes(5);
|
||||
expect(httpSetup.get).toHaveBeenLastCalledWith(
|
||||
`${API_BASE_PATH}/mapping/${testIndexName}`,
|
||||
requestOptions
|
||||
);
|
||||
// refresh mappings and page re-renders
|
||||
expect(testBed.exists('indexDetailsMappingsAddField')).toBe(true);
|
||||
expect(testBed.actions.mappings.isSearchBarDisabled()).toBe(false);
|
||||
const treeViewContent = testBed.actions.mappings.getTreeViewContent('semField');
|
||||
expect(treeViewContent).toContain('sem');
|
||||
await testBed.actions.mappings.clickToggleViewButton();
|
||||
const jsonContent = testBed.actions.mappings.getCodeBlockContent();
|
||||
expect(jsonContent).toEqual(
|
||||
JSON.stringify({ mappings: mockIndexMappingResponseForSemanticText }, null, 2)
|
||||
);
|
||||
});
|
||||
it('there is a callout with error message when save mappings fail', async () => {
|
||||
const error = {
|
||||
statusCode: 400,
|
||||
|
@ -685,6 +645,116 @@ describe('<IndexDetailsPage />', () => {
|
|||
await testBed.actions.mappings.clickSaveMappingsButton();
|
||||
expect(testBed.actions.mappings.isSaveMappingsErrorDisplayed()).toBe(true);
|
||||
});
|
||||
describe('Add Semantic text field', () => {
|
||||
const customInferenceModel = 'my-elser-model';
|
||||
beforeEach(async () => {
|
||||
httpRequestsMockHelpers.setInferenceModels({
|
||||
data: [
|
||||
{
|
||||
model_id: customInferenceModel,
|
||||
task_type: 'sparse_embedding',
|
||||
service: 'elser',
|
||||
service_settings: {
|
||||
num_allocations: 1,
|
||||
num_threads: 1,
|
||||
model_id: '.elser_model_2',
|
||||
},
|
||||
task_settings: {},
|
||||
},
|
||||
],
|
||||
});
|
||||
await act(async () => {
|
||||
testBed = await setup({
|
||||
httpSetup,
|
||||
dependencies: {
|
||||
config: { enableSemanticText: true },
|
||||
docLinks: {
|
||||
links: {
|
||||
ml: '',
|
||||
enterpriseSearch: '',
|
||||
},
|
||||
},
|
||||
plugins: {
|
||||
ml: {
|
||||
mlApi: {
|
||||
trainedModels: {
|
||||
getTrainedModels: jest.fn().mockResolvedValue([
|
||||
{
|
||||
model_id: '.elser_model_2',
|
||||
model_type: 'pytorch',
|
||||
model_package: {
|
||||
packaged_model_id: customInferenceModel,
|
||||
model_repository: 'https://ml-models.elastic.co',
|
||||
minimum_version: '11.0.0',
|
||||
size: 438123914,
|
||||
sha256: '',
|
||||
metadata: {},
|
||||
tags: [],
|
||||
vocabulary_file: 'elser_model_2.vocab.json',
|
||||
},
|
||||
description: 'Elastic Learned Sparse EncodeR v2',
|
||||
tags: ['elastic'],
|
||||
},
|
||||
]),
|
||||
getTrainedModelStats: jest.fn().mockResolvedValue({
|
||||
count: 1,
|
||||
trained_model_stats: [
|
||||
{
|
||||
model_id: '.elser_model_2',
|
||||
|
||||
deployment_stats: {
|
||||
deployment_id: customInferenceModel,
|
||||
model_id: '.elser_model_2',
|
||||
threads_per_allocation: 1,
|
||||
number_of_allocations: 1,
|
||||
queue_capacity: 1024,
|
||||
state: 'started',
|
||||
},
|
||||
},
|
||||
{
|
||||
model_id: '.elser_model_2',
|
||||
|
||||
deployment_stats: {
|
||||
deployment_id: '.elser_model_2',
|
||||
model_id: '.elser_model_2',
|
||||
threads_per_allocation: 1,
|
||||
number_of_allocations: 1,
|
||||
queue_capacity: 1024,
|
||||
state: 'started',
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
});
|
||||
testBed.component.update();
|
||||
await testBed.actions.clickIndexDetailsTab(IndexDetailsSection.Mappings);
|
||||
await testBed.actions.mappings.clickAddFieldButton();
|
||||
});
|
||||
it('can select semantic_text field', async () => {
|
||||
await testBed.actions.mappings.selectSemanticTextField(
|
||||
'semantic_text_name',
|
||||
'Semantic text'
|
||||
);
|
||||
|
||||
testBed.actions.mappings.isReferenceFieldVisible();
|
||||
testBed.actions.mappings.selectInferenceIdButtonExists();
|
||||
testBed.actions.mappings.openSelectInferencePopover();
|
||||
testBed.actions.mappings.expectDefaultInferenceModelToExists();
|
||||
testBed.actions.mappings.expectCustomInferenceModelToExists(
|
||||
`custom-inference_${customInferenceModel}`
|
||||
);
|
||||
|
||||
// can cancel new field
|
||||
expect(testBed.exists('cancelButton')).toBe(true);
|
||||
testBed.find('cancelButton').simulate('click');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('error loading mappings', () => {
|
||||
|
|
|
@ -7,13 +7,13 @@
|
|||
|
||||
import { registerTestBed } from '@kbn/test-jest-helpers';
|
||||
import { act } from 'react-dom/test-utils';
|
||||
import { SelectInferenceId } from './select_inference_id';
|
||||
import { SelectInferenceId } from '../../../public/application/components/mappings_editor/components/document_fields/field_parameters/select_inference_id';
|
||||
|
||||
const onChangeMock = jest.fn();
|
||||
const setValueMock = jest.fn();
|
||||
const setNewInferenceEndpointMock = jest.fn();
|
||||
|
||||
jest.mock('../../../../../app_context', () => ({
|
||||
jest.mock('../../../public/application/app_context', () => ({
|
||||
useAppContext: jest.fn().mockReturnValue({
|
||||
core: { application: {} },
|
||||
docLinks: {},
|
||||
|
@ -30,6 +30,17 @@ jest.mock('../../../../../app_context', () => ({
|
|||
}),
|
||||
}));
|
||||
|
||||
jest.mock(
|
||||
'../../../public/application/components/component_templates/component_templates_context',
|
||||
() => ({
|
||||
useComponentTemplatesContext: jest.fn().mockReturnValue({
|
||||
toasts: {
|
||||
addError: jest.fn(),
|
||||
addSuccess: jest.fn(),
|
||||
},
|
||||
}),
|
||||
})
|
||||
);
|
||||
describe('SelectInferenceId', () => {
|
||||
let exists: any;
|
||||
let find: any;
|
|
@ -32,7 +32,11 @@ export const ReferenceFieldSelects = ({
|
|||
Object.keys(data.mappings.properties).forEach((key) => {
|
||||
const field = data.mappings.properties[key];
|
||||
if (field.type === 'text') {
|
||||
referenceFieldOptions.push({ value: key, inputDisplay: key });
|
||||
referenceFieldOptions.push({
|
||||
value: key,
|
||||
inputDisplay: key,
|
||||
'data-test-subj': `select-reference-field-${key}`,
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -47,14 +51,15 @@ export const ReferenceFieldSelects = ({
|
|||
|
||||
return subscription.unsubscribe;
|
||||
}, [subscribe, onChange]);
|
||||
|
||||
return (
|
||||
<Form form={form}>
|
||||
<Form form={form} data-test-subj="referenceField">
|
||||
<UseField path="main" config={fieldConfigReferenceField}>
|
||||
{(field) => (
|
||||
<SuperSelectField
|
||||
field={field}
|
||||
euiFieldProps={{ options: referenceFieldOptions }}
|
||||
euiFieldProps={{
|
||||
options: referenceFieldOptions,
|
||||
}}
|
||||
data-test-subj={dataTestSubj}
|
||||
/>
|
||||
)}
|
||||
|
|
|
@ -35,28 +35,40 @@ import {
|
|||
ModelConfig,
|
||||
Service,
|
||||
} from '@kbn/inference_integration_flyout/types';
|
||||
import { FormattedMessage } from '@kbn/i18n-react';
|
||||
import { InferenceFlyoutWrapper } from '@kbn/inference_integration_flyout/components/inference_flyout_wrapper';
|
||||
import { TrainedModelConfigResponse } from '@kbn/ml-plugin/common/types/trained_models';
|
||||
import { extractErrorProperties } from '@kbn/ml-error-utils';
|
||||
import { getFieldConfig } from '../../../lib';
|
||||
import { useAppContext } from '../../../../../app_context';
|
||||
import { Form, UseField, useForm } from '../../../shared_imports';
|
||||
import { useLoadInferenceModels } from '../../../../../services/api';
|
||||
import { getTrainedModelStats } from '../../../../../../hooks/use_details_page_mappings_model_management';
|
||||
import { InferenceToModelIdMap } from '../fields';
|
||||
import { useMLModelNotificationToasts } from '../../../../../../hooks/use_ml_model_status_toasts';
|
||||
import {
|
||||
CustomInferenceEndpointConfig,
|
||||
DefaultInferenceModels,
|
||||
DeploymentState,
|
||||
} from '../../../types';
|
||||
|
||||
const inferenceServiceTypeElasticsearchModelMap: Record<string, ElasticsearchModelDefaultOptions> =
|
||||
{
|
||||
elser: ElasticsearchModelDefaultOptions.elser,
|
||||
elasticsearch: ElasticsearchModelDefaultOptions.e5,
|
||||
};
|
||||
|
||||
const uncheckSelectedModelOption = (options: EuiSelectableOption[]) => {
|
||||
const checkedOption = options.find(({ checked }) => checked === 'on');
|
||||
if (checkedOption) {
|
||||
checkedOption.checked = undefined;
|
||||
}
|
||||
};
|
||||
interface Props {
|
||||
onChange(value: string): void;
|
||||
'data-test-subj'?: string;
|
||||
setValue: (value: string) => void;
|
||||
setNewInferenceEndpoint: (newInferenceEndpoint: InferenceToModelIdMap) => void;
|
||||
setNewInferenceEndpoint: (
|
||||
newInferenceEndpoint: InferenceToModelIdMap,
|
||||
customInferenceEndpointConfig: CustomInferenceEndpointConfig
|
||||
) => void;
|
||||
}
|
||||
export const SelectInferenceId = ({
|
||||
onChange,
|
||||
|
@ -76,16 +88,14 @@ export const SelectInferenceId = ({
|
|||
});
|
||||
}, [ml]);
|
||||
|
||||
const { form } = useForm({ defaultValue: { main: 'elser_model_2' } });
|
||||
const { form } = useForm({ defaultValue: { main: DefaultInferenceModels.elser_model_2 } });
|
||||
const { subscribe } = form;
|
||||
|
||||
const [isInferenceFlyoutVisible, setIsInferenceFlyoutVisible] = useState<boolean>(false);
|
||||
const [inferenceAddError, setInferenceAddError] = useState<string | undefined>(undefined);
|
||||
const [availableTrainedModels, setAvailableTrainedModels] = useState<
|
||||
TrainedModelConfigResponse[]
|
||||
>([]);
|
||||
const onFlyoutClose = useCallback(() => {
|
||||
setInferenceAddError(undefined);
|
||||
setIsInferenceFlyoutVisible(!isInferenceFlyoutVisible);
|
||||
}, [isInferenceFlyoutVisible]);
|
||||
useEffect(() => {
|
||||
|
@ -111,16 +121,27 @@ export const SelectInferenceId = ({
|
|||
|
||||
const fieldConfigModelId = getFieldConfig('inference_id');
|
||||
const defaultInferenceIds: EuiSelectableOption[] = useMemo(() => {
|
||||
return [{ checked: 'on', label: 'elser_model_2' }, { label: 'e5' }];
|
||||
return [
|
||||
{
|
||||
checked: 'on',
|
||||
label: 'elser_model_2',
|
||||
'data-test-subj': 'default-inference_elser_model_2',
|
||||
},
|
||||
{
|
||||
label: 'e5',
|
||||
'data-test-subj': 'default-inference_e5',
|
||||
},
|
||||
];
|
||||
}, []);
|
||||
|
||||
const { isLoading, data: models, resendRequest } = useLoadInferenceModels();
|
||||
const { isLoading, data: models } = useLoadInferenceModels();
|
||||
|
||||
const [options, setOptions] = useState<EuiSelectableOption[]>([...defaultInferenceIds]);
|
||||
const inferenceIdOptionsFromModels = useMemo(() => {
|
||||
const inferenceIdOptions =
|
||||
models?.map((model: InferenceAPIConfigResponse) => ({
|
||||
label: model.model_id,
|
||||
'data-test-subj': `custom-inference_${model.model_id}`,
|
||||
})) || [];
|
||||
|
||||
return inferenceIdOptions;
|
||||
|
@ -136,40 +157,48 @@ export const SelectInferenceId = ({
|
|||
};
|
||||
setOptions(Object.values(mergedOptions));
|
||||
}, [inferenceIdOptionsFromModels, defaultInferenceIds]);
|
||||
const [isCreateInferenceApiLoading, setIsCreateInferenceApiLoading] = useState(false);
|
||||
|
||||
const { showErrorToasts } = useMLModelNotificationToasts();
|
||||
|
||||
const onSaveInferenceCallback = useCallback(
|
||||
async (inferenceId: string, taskType: InferenceTaskType, modelConfig: ModelConfig) => {
|
||||
setIsCreateInferenceApiLoading(true);
|
||||
setIsInferenceFlyoutVisible(!isInferenceFlyoutVisible);
|
||||
try {
|
||||
await ml?.mlApi?.inferenceModels?.createInferenceEndpoint(
|
||||
inferenceId,
|
||||
taskType,
|
||||
modelConfig
|
||||
);
|
||||
setIsInferenceFlyoutVisible(!isInferenceFlyoutVisible);
|
||||
setIsCreateInferenceApiLoading(false);
|
||||
setInferenceAddError(undefined);
|
||||
const isDeployable =
|
||||
modelConfig.service === Service.elser || modelConfig.service === Service.elasticsearch;
|
||||
|
||||
const newOption: EuiSelectableOption[] = [
|
||||
{
|
||||
label: inferenceId,
|
||||
checked: 'on',
|
||||
'data-test-subj': `custom-inference_${inferenceId}`,
|
||||
},
|
||||
];
|
||||
// uncheck selected endpoint id
|
||||
uncheckSelectedModelOption(options);
|
||||
|
||||
setOptions([...options, ...newOption]);
|
||||
|
||||
const trainedModelStats = await ml?.mlApi?.trainedModels.getTrainedModelStats();
|
||||
const defaultEndpointId =
|
||||
inferenceServiceTypeElasticsearchModelMap[modelConfig.service] || '';
|
||||
const newModelId: InferenceToModelIdMap = {};
|
||||
newModelId[inferenceId] = {
|
||||
trainedModelId: defaultEndpointId,
|
||||
isDeployable:
|
||||
modelConfig.service === Service.elser || modelConfig.service === Service.elasticsearch,
|
||||
isDeployed: getTrainedModelStats(trainedModelStats)[defaultEndpointId] === 'deployed',
|
||||
defaultInferenceEndpoint: false,
|
||||
isDeployable,
|
||||
isDeployed:
|
||||
getTrainedModelStats(trainedModelStats)[defaultEndpointId] === DeploymentState.DEPLOYED,
|
||||
};
|
||||
resendRequest();
|
||||
setNewInferenceEndpoint(newModelId);
|
||||
const customInferenceEndpointConfig: CustomInferenceEndpointConfig = {
|
||||
taskType,
|
||||
modelConfig,
|
||||
};
|
||||
setNewInferenceEndpoint(newModelId, customInferenceEndpointConfig);
|
||||
} catch (error) {
|
||||
const errorObj = extractErrorProperties(error);
|
||||
setInferenceAddError(errorObj.message);
|
||||
setIsCreateInferenceApiLoading(false);
|
||||
showErrorToasts(error);
|
||||
}
|
||||
},
|
||||
[isInferenceFlyoutVisible, resendRequest, ml, setNewInferenceEndpoint]
|
||||
[isInferenceFlyoutVisible, ml, setNewInferenceEndpoint, options, showErrorToasts]
|
||||
);
|
||||
useEffect(() => {
|
||||
const subscription = subscribe((updateData) => {
|
||||
|
@ -182,7 +211,7 @@ export const SelectInferenceId = ({
|
|||
}, [subscribe, onChange]);
|
||||
const selectedOptionLabel = options.find((option) => option.checked)?.label;
|
||||
useEffect(() => {
|
||||
setValue(selectedOptionLabel ?? 'elser_model_2');
|
||||
setValue(selectedOptionLabel ?? DefaultInferenceModels.elser_model_2);
|
||||
}, [selectedOptionLabel, setValue]);
|
||||
const [isInferencePopoverVisible, setIsInferencePopoverVisible] = useState<boolean>(false);
|
||||
const [inferenceEndpointError, setInferenceEndpointError] = useState<string | undefined>(
|
||||
|
@ -304,7 +333,7 @@ export const SelectInferenceId = ({
|
|||
data-test-subj={dataTestSubj}
|
||||
searchable
|
||||
isLoading={isLoading}
|
||||
singleSelection
|
||||
singleSelection="always"
|
||||
searchProps={{
|
||||
compressed: true,
|
||||
placeholder: i18n.translate(
|
||||
|
@ -340,32 +369,6 @@ export const SelectInferenceId = ({
|
|||
<InferenceFlyoutWrapper
|
||||
elserv2documentationUrl={docLinks.links.ml.nlpElser}
|
||||
e5documentationUrl={docLinks.links.ml.nlpE5}
|
||||
errorCallout={
|
||||
inferenceAddError && (
|
||||
<EuiFlexItem grow={false}>
|
||||
<EuiCallOut
|
||||
color="danger"
|
||||
data-test-subj="addInferenceError"
|
||||
iconType="error"
|
||||
title={i18n.translate(
|
||||
'xpack.idxMgmt.mappingsEditor.parameters.inferenceId.errorTitle',
|
||||
{
|
||||
defaultMessage: 'Error adding inference endpoint',
|
||||
}
|
||||
)}
|
||||
>
|
||||
<EuiText>
|
||||
<FormattedMessage
|
||||
id="xpack.idxMgmt.mappingsEditor.parameters.inferenceId.errorDescription"
|
||||
defaultMessage="Error adding inference endpoint: {errorMessage}"
|
||||
values={{ errorMessage: inferenceAddError }}
|
||||
/>
|
||||
</EuiText>
|
||||
</EuiCallOut>
|
||||
<EuiSpacer />
|
||||
</EuiFlexItem>
|
||||
)
|
||||
}
|
||||
onInferenceEndpointChange={onInferenceEndpointChange}
|
||||
inferenceEndpointError={inferenceEndpointError}
|
||||
trainedModels={trainedModels}
|
||||
|
@ -374,7 +377,6 @@ export const SelectInferenceId = ({
|
|||
isInferenceFlyoutVisible={isInferenceFlyoutVisible}
|
||||
supportedNlpModels={docLinks.links.enterpriseSearch.supportedNlpModels}
|
||||
nlpImportModel={docLinks.links.ml.nlpImportModel}
|
||||
isCreateInferenceApiLoading={isCreateInferenceApiLoading}
|
||||
setInferenceEndpointError={setInferenceEndpointError}
|
||||
/>
|
||||
)}
|
||||
|
|
|
@ -14,14 +14,20 @@ import {
|
|||
EuiSpacer,
|
||||
} from '@elastic/eui';
|
||||
import { i18n } from '@kbn/i18n';
|
||||
import { ElasticsearchModelDefaultOptions } from '@kbn/inference_integration_flyout/types';
|
||||
import { MlPluginStart } from '@kbn/ml-plugin/public';
|
||||
import classNames from 'classnames';
|
||||
import React, { useCallback, useEffect } from 'react';
|
||||
import React, { useCallback, useEffect, useState } from 'react';
|
||||
import { EUI_SIZE, TYPE_DEFINITION } from '../../../../constants';
|
||||
import { fieldSerializer } from '../../../../lib';
|
||||
import { useDispatch, useMappingsState } from '../../../../mappings_state_context';
|
||||
import { Form, FormDataProvider, UseField, useForm, useFormData } from '../../../../shared_imports';
|
||||
import { Field, MainType, NormalizedFields } from '../../../../types';
|
||||
import {
|
||||
CustomInferenceEndpointConfig,
|
||||
Field,
|
||||
MainType,
|
||||
NormalizedFields,
|
||||
} from '../../../../types';
|
||||
import { NameParameter, SubTypeParameter, TypeParameter } from '../../field_parameters';
|
||||
import { ReferenceFieldSelects } from '../../field_parameters/reference_field_selects';
|
||||
import { SelectInferenceId } from '../../field_parameters/select_inference_id';
|
||||
|
@ -32,10 +38,9 @@ import { useSemanticText } from './semantic_text/use_semantic_text';
|
|||
const formWrapper = (props: any) => <form {...props} />;
|
||||
export interface InferenceToModelIdMap {
|
||||
[key: string]: {
|
||||
trainedModelId?: string;
|
||||
trainedModelId: ElasticsearchModelDefaultOptions | string;
|
||||
isDeployed: boolean;
|
||||
isDeployable: boolean;
|
||||
defaultInferenceEndpoint: boolean;
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -88,7 +93,9 @@ export const CreateField = React.memo(function CreateFieldComponent({
|
|||
|
||||
return subscription.unsubscribe;
|
||||
}, [dispatch, subscribe]);
|
||||
|
||||
const [customInferenceEndpointConfig, setCustomInferenceEndpointConfig] = useState<
|
||||
CustomInferenceEndpointConfig | undefined
|
||||
>(undefined);
|
||||
const cancel = () => {
|
||||
if (isAddingFields && onCancelAddingNewFields) {
|
||||
onCancelAddingNewFields();
|
||||
|
@ -125,7 +132,7 @@ export const CreateField = React.memo(function CreateFieldComponent({
|
|||
form.reset();
|
||||
|
||||
if (data.type === 'semantic_text' && !clickOutside) {
|
||||
handleSemanticText(data);
|
||||
handleSemanticText(data, customInferenceEndpointConfig);
|
||||
} else {
|
||||
dispatch({ type: 'field.add', value: data });
|
||||
}
|
||||
|
@ -283,7 +290,10 @@ export const CreateField = React.memo(function CreateFieldComponent({
|
|||
}}
|
||||
</FormDataProvider>
|
||||
{/* Field inference_id for semantic_text field type */}
|
||||
<InferenceIdCombo setValue={setInferenceValue} />
|
||||
<InferenceIdCombo
|
||||
setValue={setInferenceValue}
|
||||
setCustomInferenceEndpointConfig={setCustomInferenceEndpointConfig}
|
||||
/>
|
||||
{renderFormActions()}
|
||||
</div>
|
||||
</div>
|
||||
|
@ -311,16 +321,20 @@ function ReferenceFieldCombo({ indexName }: { indexName?: string }) {
|
|||
|
||||
interface InferenceProps {
|
||||
setValue: (value: string) => void;
|
||||
setCustomInferenceEndpointConfig: (config: CustomInferenceEndpointConfig) => void;
|
||||
}
|
||||
|
||||
function InferenceIdCombo({ setValue }: InferenceProps) {
|
||||
function InferenceIdCombo({ setValue, setCustomInferenceEndpointConfig }: InferenceProps) {
|
||||
const { inferenceToModelIdMap } = useMappingsState();
|
||||
const dispatch = useDispatch();
|
||||
const [{ type }] = useFormData({ watch: 'type' });
|
||||
|
||||
// update new inferenceEndpoint
|
||||
const setNewInferenceEndpoint = useCallback(
|
||||
(newInferenceEndpoint: InferenceToModelIdMap) => {
|
||||
(
|
||||
newInferenceEndpoint: InferenceToModelIdMap,
|
||||
customInferenceEndpointConfig: CustomInferenceEndpointConfig
|
||||
) => {
|
||||
dispatch({
|
||||
type: 'inferenceToModelIdMap.update',
|
||||
value: {
|
||||
|
@ -330,8 +344,9 @@ function InferenceIdCombo({ setValue }: InferenceProps) {
|
|||
},
|
||||
},
|
||||
});
|
||||
setCustomInferenceEndpointConfig(customInferenceEndpointConfig);
|
||||
},
|
||||
[dispatch, inferenceToModelIdMap]
|
||||
[dispatch, inferenceToModelIdMap, setCustomInferenceEndpointConfig]
|
||||
);
|
||||
|
||||
if (type === undefined || type[0]?.value !== 'semantic_text') {
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
*/
|
||||
|
||||
import { renderHook } from '@testing-library/react-hooks';
|
||||
import { Field } from '../../../../../types';
|
||||
import { CustomInferenceEndpointConfig, Field } from '../../../../../types';
|
||||
import { useSemanticText } from './use_semantic_text';
|
||||
import { act } from 'react-dom/test-utils';
|
||||
|
||||
|
@ -15,22 +15,54 @@ const mlMock: any = {
|
|||
inferenceModels: {
|
||||
createInferenceEndpoint: jest.fn().mockResolvedValue({}),
|
||||
},
|
||||
trainedModels: {
|
||||
startModelAllocation: jest.fn().mockResolvedValue({}),
|
||||
getTrainedModels: jest.fn().mockResolvedValue([
|
||||
{
|
||||
fully_defined: true,
|
||||
},
|
||||
]),
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const mockFieldData = {
|
||||
name: 'name',
|
||||
type: 'semantic_text',
|
||||
inferenceId: 'elser_model_2',
|
||||
} as Field;
|
||||
const mockField: Record<string, Field> = {
|
||||
elser_model_2: {
|
||||
name: 'name',
|
||||
type: 'semantic_text',
|
||||
inferenceId: 'elser_model_2',
|
||||
},
|
||||
e5: {
|
||||
name: 'name',
|
||||
type: 'semantic_text',
|
||||
inferenceId: 'e5',
|
||||
},
|
||||
openai: {
|
||||
name: 'name',
|
||||
type: 'semantic_text',
|
||||
inferenceId: 'openai',
|
||||
},
|
||||
my_elser_endpoint: {
|
||||
name: 'name',
|
||||
type: 'semantic_text',
|
||||
inferenceId: 'my_elser_endpoint',
|
||||
},
|
||||
};
|
||||
|
||||
const mockConfig: Record<string, CustomInferenceEndpointConfig> = {
|
||||
openai: {
|
||||
taskType: 'text_embedding',
|
||||
modelConfig: {
|
||||
service: 'openai',
|
||||
service_settings: {
|
||||
api_key: 'test',
|
||||
model_id: 'text-embedding-ada-002',
|
||||
},
|
||||
},
|
||||
},
|
||||
elser: {
|
||||
taskType: 'sparse_embedding',
|
||||
modelConfig: {
|
||||
service: 'elser',
|
||||
service_settings: {
|
||||
num_allocations: 1,
|
||||
num_threads: 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const mockDispatch = jest.fn();
|
||||
|
||||
|
@ -38,13 +70,21 @@ jest.mock('../../../../../mappings_state_context', () => ({
|
|||
useMappingsState: jest.fn().mockReturnValue({
|
||||
inferenceToModelIdMap: {
|
||||
e5: {
|
||||
defaultInferenceEndpoint: false,
|
||||
isDeployed: false,
|
||||
isDeployable: true,
|
||||
trainedModelId: '.multilingual-e5-small',
|
||||
},
|
||||
elser_model_2: {
|
||||
defaultInferenceEndpoint: true,
|
||||
isDeployed: false,
|
||||
isDeployable: true,
|
||||
trainedModelId: '.elser_model_2',
|
||||
},
|
||||
openai: {
|
||||
isDeployed: false,
|
||||
isDeployable: false,
|
||||
trainedModelId: '',
|
||||
},
|
||||
my_elser_endpoint: {
|
||||
isDeployed: false,
|
||||
isDeployable: true,
|
||||
trainedModelId: '.elser_model_2',
|
||||
|
@ -63,24 +103,108 @@ jest.mock('../../../../../../component_templates/component_templates_context', (
|
|||
}),
|
||||
}));
|
||||
|
||||
jest.mock('../../../../../../../services/api', () => ({
|
||||
getInferenceModels: jest.fn().mockResolvedValue({
|
||||
data: [
|
||||
{
|
||||
model_id: 'e5',
|
||||
task_type: 'text_embedding',
|
||||
service: 'elasticsearch',
|
||||
service_settings: {
|
||||
num_allocations: 1,
|
||||
num_threads: 1,
|
||||
model_id: '.multilingual-e5-small',
|
||||
},
|
||||
task_settings: {},
|
||||
},
|
||||
],
|
||||
}),
|
||||
}));
|
||||
|
||||
describe('useSemanticText', () => {
|
||||
let form: any;
|
||||
let mockForm: any;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
form = {
|
||||
getFields: jest.fn().mockReturnValue({
|
||||
referenceField: { value: 'title' },
|
||||
name: { value: 'sem' },
|
||||
type: { value: [{ value: 'semantic_text' }] },
|
||||
inferenceId: { value: 'e5' },
|
||||
}),
|
||||
mockForm = {
|
||||
form: {
|
||||
getFields: jest.fn().mockReturnValue({
|
||||
referenceField: { value: 'title' },
|
||||
name: { value: 'sem' },
|
||||
type: { value: [{ value: 'semantic_text' }] },
|
||||
inferenceId: { value: 'e5' },
|
||||
}),
|
||||
},
|
||||
thirdPartyModel: {
|
||||
getFields: jest.fn().mockReturnValue({
|
||||
referenceField: { value: 'title' },
|
||||
name: { value: 'semantic_text_openai_endpoint' },
|
||||
type: { value: [{ value: 'semantic_text' }] },
|
||||
inferenceId: { value: 'openai' },
|
||||
}),
|
||||
},
|
||||
elasticModelEndpointCreatedfromFlyout: {
|
||||
getFields: jest.fn().mockReturnValue({
|
||||
referenceField: { value: 'title' },
|
||||
name: { value: 'semantic_text_elserServiceType_endpoint' },
|
||||
type: { value: [{ value: 'semantic_text' }] },
|
||||
inferenceId: { value: 'my_elser_endpoint' },
|
||||
}),
|
||||
},
|
||||
};
|
||||
});
|
||||
it('should handle semantic text with third party model correctly', async () => {
|
||||
const { result } = renderHook(() =>
|
||||
useSemanticText({
|
||||
form: mockForm.thirdPartyModel,
|
||||
setErrorsInTrainedModelDeployment: jest.fn(),
|
||||
ml: mlMock,
|
||||
})
|
||||
);
|
||||
await act(async () => {
|
||||
result.current.setInferenceValue('openai');
|
||||
result.current.handleSemanticText(mockField.openai, mockConfig.openai);
|
||||
});
|
||||
expect(mockDispatch).toHaveBeenCalledWith({
|
||||
type: 'field.addSemanticText',
|
||||
value: mockField.openai,
|
||||
});
|
||||
expect(mlMock.mlApi.inferenceModels.createInferenceEndpoint).toHaveBeenCalledWith(
|
||||
'openai',
|
||||
'text_embedding',
|
||||
mockConfig.openai.modelConfig
|
||||
);
|
||||
});
|
||||
it('should handle semantic text with inference endpoint created from flyout correctly', async () => {
|
||||
const { result } = renderHook(() =>
|
||||
useSemanticText({
|
||||
form: mockForm.elasticModelEndpointCreatedfromFlyout,
|
||||
setErrorsInTrainedModelDeployment: jest.fn(),
|
||||
ml: mlMock,
|
||||
})
|
||||
);
|
||||
await act(async () => {
|
||||
result.current.setInferenceValue('my_elser_endpoint');
|
||||
result.current.handleSemanticText(mockField.my_elser_endpoint, mockConfig.elser);
|
||||
});
|
||||
|
||||
expect(mockDispatch).toHaveBeenCalledWith({
|
||||
type: 'field.addSemanticText',
|
||||
value: mockField.my_elser_endpoint,
|
||||
});
|
||||
expect(mlMock.mlApi.inferenceModels.createInferenceEndpoint).toHaveBeenCalledWith(
|
||||
'my_elser_endpoint',
|
||||
'sparse_embedding',
|
||||
mockConfig.elser.modelConfig
|
||||
);
|
||||
});
|
||||
it('should populate the values from the form', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useSemanticText({ form, setErrorsInTrainedModelDeployment: jest.fn(), ml: mlMock })
|
||||
useSemanticText({
|
||||
form: mockForm.form,
|
||||
setErrorsInTrainedModelDeployment: jest.fn(),
|
||||
ml: mlMock,
|
||||
})
|
||||
);
|
||||
|
||||
expect(result.current.referenceFieldComboValue).toBe('title');
|
||||
|
@ -91,23 +215,26 @@ describe('useSemanticText', () => {
|
|||
|
||||
it('should handle semantic text correctly', async () => {
|
||||
const { result } = renderHook(() =>
|
||||
useSemanticText({ form, setErrorsInTrainedModelDeployment: jest.fn(), ml: mlMock })
|
||||
useSemanticText({
|
||||
form: mockForm.form,
|
||||
setErrorsInTrainedModelDeployment: jest.fn(),
|
||||
ml: mlMock,
|
||||
})
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
result.current.handleSemanticText(mockFieldData);
|
||||
result.current.handleSemanticText(mockField.elser_model_2);
|
||||
});
|
||||
|
||||
expect(mlMock.mlApi.trainedModels.startModelAllocation).toHaveBeenCalledWith('.elser_model_2');
|
||||
expect(mockDispatch).toHaveBeenCalledWith({
|
||||
type: 'field.addSemanticText',
|
||||
value: mockFieldData,
|
||||
value: mockField.elser_model_2,
|
||||
});
|
||||
expect(mlMock.mlApi.inferenceModels.createInferenceEndpoint).toHaveBeenCalledWith(
|
||||
'elser_model_2',
|
||||
'text_embedding',
|
||||
'sparse_embedding',
|
||||
{
|
||||
service: 'elasticsearch',
|
||||
service: 'elser',
|
||||
service_settings: {
|
||||
num_allocations: 1,
|
||||
num_threads: 1,
|
||||
|
@ -116,68 +243,42 @@ describe('useSemanticText', () => {
|
|||
}
|
||||
);
|
||||
});
|
||||
|
||||
it('should invoke the download api if the model does not exist', async () => {
|
||||
const mlMockWithModelNotDownloaded: any = {
|
||||
mlApi: {
|
||||
inferenceModels: {
|
||||
createInferenceEndpoint: jest.fn(),
|
||||
},
|
||||
trainedModels: {
|
||||
startModelAllocation: jest.fn(),
|
||||
getTrainedModels: jest.fn().mockResolvedValue([
|
||||
{
|
||||
fully_defined: false,
|
||||
},
|
||||
]),
|
||||
installElasticTrainedModelConfig: jest.fn().mockResolvedValue({}),
|
||||
},
|
||||
},
|
||||
};
|
||||
it('does not call create inference endpoint api, if default endpoint already exists', async () => {
|
||||
const { result } = renderHook(() =>
|
||||
useSemanticText({
|
||||
form,
|
||||
form: mockForm.form,
|
||||
setErrorsInTrainedModelDeployment: jest.fn(),
|
||||
ml: mlMockWithModelNotDownloaded,
|
||||
ml: mlMock,
|
||||
})
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
result.current.handleSemanticText(mockFieldData);
|
||||
result.current.setInferenceValue('e5');
|
||||
result.current.handleSemanticText(mockField.e5);
|
||||
});
|
||||
|
||||
expect(
|
||||
mlMockWithModelNotDownloaded.mlApi.trainedModels.installElasticTrainedModelConfig
|
||||
).toHaveBeenCalledWith('.elser_model_2');
|
||||
expect(
|
||||
mlMockWithModelNotDownloaded.mlApi.trainedModels.startModelAllocation
|
||||
).toHaveBeenCalledWith('.elser_model_2');
|
||||
expect(
|
||||
mlMockWithModelNotDownloaded.mlApi.inferenceModels.createInferenceEndpoint
|
||||
).toHaveBeenCalledWith('elser_model_2', 'text_embedding', {
|
||||
service: 'elasticsearch',
|
||||
service_settings: {
|
||||
num_allocations: 1,
|
||||
num_threads: 1,
|
||||
model_id: '.elser_model_2',
|
||||
},
|
||||
expect(mockDispatch).toHaveBeenCalledWith({
|
||||
type: 'field.addSemanticText',
|
||||
value: mockField.e5,
|
||||
});
|
||||
|
||||
expect(mlMock.mlApi.inferenceModels.createInferenceEndpoint).not.toBeCalled();
|
||||
});
|
||||
|
||||
it('handles errors correctly', async () => {
|
||||
const mockError = new Error('Test error');
|
||||
mlMock.mlApi?.trainedModels.startModelAllocation.mockImplementationOnce(() => {
|
||||
mlMock.mlApi?.inferenceModels.createInferenceEndpoint.mockImplementationOnce(() => {
|
||||
throw mockError;
|
||||
});
|
||||
|
||||
const setErrorsInTrainedModelDeployment = jest.fn();
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useSemanticText({ form, setErrorsInTrainedModelDeployment, ml: mlMock })
|
||||
useSemanticText({ form: mockForm.form, setErrorsInTrainedModelDeployment, ml: mlMock })
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
result.current.handleSemanticText(mockFieldData);
|
||||
result.current.handleSemanticText(mockField.elser_model_2);
|
||||
});
|
||||
|
||||
expect(setErrorsInTrainedModelDeployment).toHaveBeenCalledWith(expect.any(Function));
|
||||
|
|
|
@ -7,30 +7,39 @@
|
|||
|
||||
import { i18n } from '@kbn/i18n';
|
||||
import { useCallback } from 'react';
|
||||
import { MlPluginStart, TrainedModelConfigResponse } from '@kbn/ml-plugin/public';
|
||||
import { MlPluginStart } from '@kbn/ml-plugin/public';
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { useComponentTemplatesContext } from '../../../../../../component_templates/component_templates_context';
|
||||
import { ElasticsearchModelDefaultOptions } from '@kbn/inference_integration_flyout/types';
|
||||
import { InferenceTaskType } from '@elastic/elasticsearch/lib/api/types';
|
||||
import { InferenceAPIConfigResponse } from '@kbn/ml-trained-models-utils';
|
||||
import { useDispatch, useMappingsState } from '../../../../../mappings_state_context';
|
||||
import { FormHook } from '../../../../../shared_imports';
|
||||
import { Field } from '../../../../../types';
|
||||
import { CustomInferenceEndpointConfig, DefaultInferenceModels, Field } from '../../../../../types';
|
||||
import { useMLModelNotificationToasts } from '../../../../../../../../hooks/use_ml_model_status_toasts';
|
||||
|
||||
import { getInferenceModels } from '../../../../../../../services/api';
|
||||
interface UseSemanticTextProps {
|
||||
form: FormHook<Field, Field>;
|
||||
ml?: MlPluginStart;
|
||||
setErrorsInTrainedModelDeployment: React.Dispatch<React.SetStateAction<string[]>> | undefined;
|
||||
}
|
||||
interface DefaultInferenceEndpointConfig {
|
||||
taskType: InferenceTaskType;
|
||||
service: string;
|
||||
}
|
||||
|
||||
export function useSemanticText(props: UseSemanticTextProps) {
|
||||
const { form, setErrorsInTrainedModelDeployment, ml } = props;
|
||||
const { inferenceToModelIdMap } = useMappingsState();
|
||||
const { toasts } = useComponentTemplatesContext();
|
||||
const dispatch = useDispatch();
|
||||
|
||||
const [referenceFieldComboValue, setReferenceFieldComboValue] = useState<string>();
|
||||
const [nameValue, setNameValue] = useState<string>();
|
||||
const [inferenceIdComboValue, setInferenceIdComboValue] = useState<string>();
|
||||
const [semanticFieldType, setSemanticTextFieldType] = useState<string>();
|
||||
const [inferenceValue, setInferenceValue] = useState<string>('elser_model_2');
|
||||
const [inferenceValue, setInferenceValue] = useState<string>(
|
||||
DefaultInferenceModels.elser_model_2
|
||||
);
|
||||
const { showSuccessToasts, showErrorToasts } = useMLModelNotificationToasts();
|
||||
|
||||
const useFieldEffect = (
|
||||
semanticTextform: FormHook,
|
||||
|
@ -65,113 +74,92 @@ export function useSemanticText(props: UseSemanticTextProps) {
|
|||
}
|
||||
}, [form, inferenceId, inferenceToModelIdMap]);
|
||||
|
||||
const isModelDownloaded = useCallback(
|
||||
async (modelId: string) => {
|
||||
try {
|
||||
const response: TrainedModelConfigResponse[] | undefined =
|
||||
await ml?.mlApi?.trainedModels.getTrainedModels(modelId, {
|
||||
include: 'definition_status',
|
||||
});
|
||||
return !!response?.[0]?.fully_defined;
|
||||
} catch (error) {
|
||||
if (error.body.statusCode !== 404) {
|
||||
throw error;
|
||||
}
|
||||
const createInferenceEndpoint = useCallback(
|
||||
async (
|
||||
trainedModelId: ElasticsearchModelDefaultOptions | string,
|
||||
data: Field,
|
||||
customInferenceEndpointConfig?: CustomInferenceEndpointConfig
|
||||
) => {
|
||||
if (data.inferenceId === undefined) {
|
||||
throw new Error(
|
||||
i18n.translate('xpack.idxMgmt.mappingsEditor.createField.undefinedInferenceIdError', {
|
||||
defaultMessage: 'InferenceId is undefined while creating the inference endpoint.',
|
||||
})
|
||||
);
|
||||
}
|
||||
return false;
|
||||
},
|
||||
[ml?.mlApi?.trainedModels]
|
||||
);
|
||||
|
||||
const createInferenceEndpoint = (
|
||||
trainedModelId: string,
|
||||
defaultInferenceEndpoint: boolean,
|
||||
data: Field
|
||||
) => {
|
||||
if (data.inferenceId === undefined) {
|
||||
throw new Error(
|
||||
i18n.translate('xpack.idxMgmt.mappingsEditor.createField.undefinedInferenceIdError', {
|
||||
defaultMessage: 'InferenceId is undefined while creating the inference endpoint.',
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
if (trainedModelId && defaultInferenceEndpoint) {
|
||||
const modelConfig = {
|
||||
service: 'elasticsearch',
|
||||
service_settings: {
|
||||
num_allocations: 1,
|
||||
num_threads: 1,
|
||||
model_id: trainedModelId,
|
||||
},
|
||||
const defaultInferenceEndpointConfig: DefaultInferenceEndpointConfig = {
|
||||
service:
|
||||
trainedModelId === ElasticsearchModelDefaultOptions.elser ? 'elser' : 'elasticsearch',
|
||||
taskType:
|
||||
trainedModelId === ElasticsearchModelDefaultOptions.elser
|
||||
? 'sparse_embedding'
|
||||
: 'text_embedding',
|
||||
};
|
||||
|
||||
ml?.mlApi?.inferenceModels?.createInferenceEndpoint(
|
||||
data.inferenceId,
|
||||
'text_embedding',
|
||||
modelConfig
|
||||
);
|
||||
}
|
||||
};
|
||||
const modelConfig = customInferenceEndpointConfig
|
||||
? customInferenceEndpointConfig.modelConfig
|
||||
: {
|
||||
service: defaultInferenceEndpointConfig.service,
|
||||
service_settings: {
|
||||
num_allocations: 1,
|
||||
num_threads: 1,
|
||||
model_id: trainedModelId,
|
||||
},
|
||||
};
|
||||
const taskType: InferenceTaskType =
|
||||
customInferenceEndpointConfig?.taskType ?? defaultInferenceEndpointConfig.taskType;
|
||||
try {
|
||||
await ml?.mlApi?.inferenceModels?.createInferenceEndpoint(
|
||||
data.inferenceId,
|
||||
taskType,
|
||||
modelConfig
|
||||
);
|
||||
} catch (error) {
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
[ml?.mlApi?.inferenceModels]
|
||||
);
|
||||
|
||||
const handleSemanticText = async (data: Field) => {
|
||||
const handleSemanticText = async (
|
||||
data: Field,
|
||||
customInferenceEndpointConfig?: CustomInferenceEndpointConfig
|
||||
) => {
|
||||
data.inferenceId = inferenceValue;
|
||||
if (data.inferenceId === undefined) {
|
||||
return;
|
||||
}
|
||||
|
||||
const inferenceData = inferenceToModelIdMap?.[data.inferenceId];
|
||||
|
||||
if (!inferenceData) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { trainedModelId, defaultInferenceEndpoint, isDeployed, isDeployable } = inferenceData;
|
||||
|
||||
if (isDeployable && trainedModelId) {
|
||||
try {
|
||||
const modelDownloaded: boolean = await isModelDownloaded(trainedModelId);
|
||||
|
||||
if (isDeployed) {
|
||||
createInferenceEndpoint(trainedModelId, defaultInferenceEndpoint, data);
|
||||
} else if (modelDownloaded) {
|
||||
ml?.mlApi?.trainedModels
|
||||
.startModelAllocation(trainedModelId)
|
||||
.then(() => createInferenceEndpoint(trainedModelId, defaultInferenceEndpoint, data));
|
||||
} else {
|
||||
ml?.mlApi?.trainedModels
|
||||
.installElasticTrainedModelConfig(trainedModelId)
|
||||
.then(() => ml?.mlApi?.trainedModels.startModelAllocation(trainedModelId))
|
||||
.then(() => createInferenceEndpoint(trainedModelId, defaultInferenceEndpoint, data));
|
||||
}
|
||||
toasts?.addSuccess({
|
||||
title: i18n.translate(
|
||||
'xpack.idxMgmt.mappingsEditor.createField.modelDeploymentStartedNotification',
|
||||
{
|
||||
defaultMessage: 'Model deployment started',
|
||||
}
|
||||
),
|
||||
text: i18n.translate(
|
||||
'xpack.idxMgmt.mappingsEditor.createField.modelDeploymentNotification',
|
||||
{
|
||||
defaultMessage: '1 model is being deployed on your ml_node.',
|
||||
}
|
||||
),
|
||||
});
|
||||
} catch (error) {
|
||||
setErrorsInTrainedModelDeployment?.((prevItems) => [...prevItems, trainedModelId]);
|
||||
toasts?.addError(error.body && error.body.message ? new Error(error.body.message) : error, {
|
||||
title: i18n.translate(
|
||||
'xpack.idxMgmt.mappingsEditor.createField.modelDeploymentErrorTitle',
|
||||
{
|
||||
defaultMessage: 'Model deployment failed',
|
||||
}
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const { trainedModelId } = inferenceData;
|
||||
dispatch({ type: 'field.addSemanticText', value: data });
|
||||
|
||||
try {
|
||||
// if model exists already, do not create inference endpoint
|
||||
const inferenceModels = await getInferenceModels();
|
||||
const inferenceModel: InferenceAPIConfigResponse[] = inferenceModels.data.some(
|
||||
(e: InferenceAPIConfigResponse) => e.model_id === inferenceValue
|
||||
);
|
||||
if (inferenceModel) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (trainedModelId) {
|
||||
// show toasts only if it's elastic models
|
||||
showSuccessToasts();
|
||||
}
|
||||
|
||||
await createInferenceEndpoint(trainedModelId, data, customInferenceEndpointConfig);
|
||||
} catch (error) {
|
||||
// trainedModelId is empty string when it's a third party model
|
||||
if (trainedModelId) {
|
||||
setErrorsInTrainedModelDeployment?.((prevItems) => [...prevItems, trainedModelId]);
|
||||
}
|
||||
showErrorToasts(error);
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
|
|
|
@ -26,6 +26,7 @@ export const FIELD_TYPES_OPTIONS = Object.entries(MAIN_DATA_TYPE_DEFINITION).map
|
|||
([dataType, { label }]) => ({
|
||||
value: dataType,
|
||||
label,
|
||||
'data-test-subj': `fieldTypesOptions-${dataType}`,
|
||||
})
|
||||
) as ComboBoxOption[];
|
||||
|
||||
|
|
|
@ -7,6 +7,8 @@
|
|||
|
||||
import { ReactNode } from 'react';
|
||||
|
||||
import { InferenceTaskType } from '@elastic/elasticsearch/lib/api/types';
|
||||
import { ModelConfig } from '@kbn/inference_integration_flyout';
|
||||
import { GenericObject } from './mappings_editor';
|
||||
|
||||
import { PARAMETERS_DEFINITION } from '../constants';
|
||||
|
@ -246,3 +248,16 @@ export interface NormalizedRuntimeField {
|
|||
export interface NormalizedRuntimeFields {
|
||||
[id: string]: NormalizedRuntimeField;
|
||||
}
|
||||
export enum DefaultInferenceModels {
|
||||
elser_model_2 = 'elser_model_2',
|
||||
e5 = 'e5',
|
||||
}
|
||||
|
||||
export enum DeploymentState {
|
||||
'DEPLOYED' = 'deployed',
|
||||
'NOT_DEPLOYED' = 'not_deployed',
|
||||
}
|
||||
export interface CustomInferenceEndpointConfig {
|
||||
taskType: InferenceTaskType;
|
||||
modelConfig: ModelConfig;
|
||||
}
|
||||
|
|
|
@ -84,9 +84,9 @@ export function TrainedModelsDeploymentModal({
|
|||
onCancel={closeModal}
|
||||
onConfirm={refreshModal}
|
||||
cancelButtonText={i18n.translate(
|
||||
'xpack.idxMgmt.indexDetails.trainedModelsDeploymentModal.cancelButtonLabel',
|
||||
'xpack.idxMgmt.indexDetails.trainedModelsDeploymentModal.closeButtonLabel',
|
||||
{
|
||||
defaultMessage: 'Cancel',
|
||||
defaultMessage: 'Close',
|
||||
}
|
||||
)}
|
||||
confirmButtonText={i18n.translate(
|
||||
|
|
|
@ -104,13 +104,11 @@ const inferenceToModelIdMap = {
|
|||
trainedModelId: '.elser_model_2',
|
||||
isDeployed: true,
|
||||
isDeployable: true,
|
||||
defaultInferenceEndpoint: false,
|
||||
},
|
||||
e5: {
|
||||
trainedModelId: '.multilingual-e5-small',
|
||||
isDeployed: true,
|
||||
isDeployable: true,
|
||||
defaultInferenceEndpoint: false,
|
||||
},
|
||||
} as InferenceToModelIdMap;
|
||||
|
||||
|
@ -127,13 +125,11 @@ describe('useDetailsPageMappingsModelManagement', () => {
|
|||
value: {
|
||||
inferenceToModelIdMap: {
|
||||
e5: {
|
||||
defaultInferenceEndpoint: false,
|
||||
isDeployed: false,
|
||||
isDeployable: true,
|
||||
trainedModelId: '.multilingual-e5-small',
|
||||
},
|
||||
elser_model_2: {
|
||||
defaultInferenceEndpoint: true,
|
||||
isDeployed: true,
|
||||
isDeployable: true,
|
||||
trainedModelId: '.elser_model_2',
|
||||
|
|
|
@ -13,14 +13,18 @@ import { useAppContext } from '../application/app_context';
|
|||
import { InferenceToModelIdMap } from '../application/components/mappings_editor/components/document_fields/fields';
|
||||
import { deNormalize } from '../application/components/mappings_editor/lib';
|
||||
import { useDispatch } from '../application/components/mappings_editor/mappings_state_context';
|
||||
import { NormalizedFields } from '../application/components/mappings_editor/types';
|
||||
import {
|
||||
DefaultInferenceModels,
|
||||
DeploymentState,
|
||||
NormalizedFields,
|
||||
} from '../application/components/mappings_editor/types';
|
||||
import { getInferenceModels } from '../application/services/api';
|
||||
|
||||
interface InferenceModel {
|
||||
data: InferenceAPIConfigResponse[];
|
||||
}
|
||||
|
||||
type DeploymentStatusType = Record<string, 'deployed' | 'not_deployed'>;
|
||||
type DeploymentStatusType = Record<string, DeploymentState>;
|
||||
|
||||
const getCustomInferenceIdMap = (
|
||||
deploymentStatsByModelId: DeploymentStatusType,
|
||||
|
@ -39,7 +43,6 @@ const getCustomInferenceIdMap = (
|
|||
trainedModelId,
|
||||
isDeployable: model.service === Service.elser || model.service === Service.elasticsearch,
|
||||
isDeployed: deploymentStatsByModelId[trainedModelId] === 'deployed',
|
||||
defaultInferenceEndpoint: false,
|
||||
};
|
||||
return inferenceMap;
|
||||
}, {});
|
||||
|
@ -50,7 +53,9 @@ export const getTrainedModelStats = (modelStats?: InferenceStatsResponse): Deplo
|
|||
modelStats?.trained_model_stats.reduce<DeploymentStatusType>((acc, modelStat) => {
|
||||
if (modelStat.model_id) {
|
||||
acc[modelStat.model_id] =
|
||||
modelStat?.deployment_stats?.state === 'started' ? 'deployed' : 'not_deployed';
|
||||
modelStat?.deployment_stats?.state === 'started'
|
||||
? DeploymentState.DEPLOYED
|
||||
: DeploymentState.NOT_DEPLOYED;
|
||||
}
|
||||
return acc;
|
||||
}, {}) || {}
|
||||
|
@ -59,17 +64,18 @@ export const getTrainedModelStats = (modelStats?: InferenceStatsResponse): Deplo
|
|||
|
||||
const getDefaultInferenceIds = (deploymentStatsByModelId: DeploymentStatusType) => {
|
||||
return {
|
||||
elser_model_2: {
|
||||
trainedModelId: '.elser_model_2',
|
||||
[DefaultInferenceModels.elser_model_2]: {
|
||||
trainedModelId: ElasticsearchModelDefaultOptions.elser,
|
||||
isDeployable: true,
|
||||
isDeployed: deploymentStatsByModelId['.elser_model_2'] === 'deployed',
|
||||
defaultInferenceEndpoint: true,
|
||||
isDeployed:
|
||||
deploymentStatsByModelId[ElasticsearchModelDefaultOptions.elser] ===
|
||||
DeploymentState.DEPLOYED,
|
||||
},
|
||||
e5: {
|
||||
trainedModelId: '.multilingual-e5-small',
|
||||
[DefaultInferenceModels.e5]: {
|
||||
trainedModelId: ElasticsearchModelDefaultOptions.e5,
|
||||
isDeployable: true,
|
||||
isDeployed: deploymentStatsByModelId['.multilingual-e5-small'] === 'deployed',
|
||||
defaultInferenceEndpoint: true,
|
||||
isDeployed:
|
||||
deploymentStatsByModelId[ElasticsearchModelDefaultOptions.e5] === DeploymentState.DEPLOYED,
|
||||
},
|
||||
};
|
||||
};
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { i18n } from '@kbn/i18n';
|
||||
import { ErrorType, extractErrorProperties, MLRequestFailure } from '@kbn/ml-error-utils';
|
||||
import { useComponentTemplatesContext } from '../application/components/component_templates/component_templates_context';
|
||||
|
||||
export function useMLModelNotificationToasts() {
|
||||
const { toasts } = useComponentTemplatesContext();
|
||||
const showSuccessToasts = () => {
|
||||
return toasts.addSuccess({
|
||||
title: i18n.translate(
|
||||
'xpack.idxMgmt.mappingsEditor.createField.modelDeploymentStartedNotification',
|
||||
{
|
||||
defaultMessage: 'Model deployment started',
|
||||
}
|
||||
),
|
||||
text: i18n.translate('xpack.idxMgmt.mappingsEditor.createField.modelDeploymentNotification', {
|
||||
defaultMessage: '1 model is being deployed on your ml_node.',
|
||||
}),
|
||||
});
|
||||
};
|
||||
const showErrorToasts = (error: ErrorType) => {
|
||||
const errorObj = extractErrorProperties(error);
|
||||
return toasts.addError(new MLRequestFailure(errorObj, error), {
|
||||
title: i18n.translate('xpack.idxMgmt.mappingsEditor.createField.modelDeploymentErrorTitle', {
|
||||
defaultMessage: 'Model deployment failed',
|
||||
}),
|
||||
});
|
||||
};
|
||||
return { showSuccessToasts, showErrorToasts };
|
||||
}
|
|
@ -49,9 +49,9 @@
|
|||
"@kbn/utility-types",
|
||||
"@kbn/inference_integration_flyout",
|
||||
"@kbn/ml-plugin",
|
||||
"@kbn/ml-error-utils",
|
||||
"@kbn/react-kibana-context-render",
|
||||
"@kbn/react-kibana-mount"
|
||||
"@kbn/react-kibana-mount",
|
||||
"@kbn/ml-error-utils",
|
||||
],
|
||||
"exclude": ["target/**/*"]
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue