[8.15] [Semantic text] Make semantic text work with non-root level fields (#187154) (#188080)

# Backport

This will backport the following commits from `main` to `8.15`:
- [[Semantic text] Make semantic text work with non-root level fields
(#187154)](https://github.com/elastic/kibana/pull/187154)

<!--- Backport version: 9.4.3 -->

### Questions ?
Please refer to the [Backport tool
documentation](https://github.com/sqren/backport)

<!--BACKPORT [{"author":{"name":"Sander
Philipse","email":"94373878+sphilipse@users.noreply.github.com"},"sourceCommit":{"committedDate":"2024-07-11T10:35:56Z","message":"[Semantic
text] Make semantic text work with non-root level fields (#187154)\n\n##
Summary\r\n\r\nThis makes semantic text work with non-root level
reference fields. It\r\nalso correctly adds copy_to to existing copy_to
fields instead of\r\nreplacing them, and streamlines a lot of the
code.\r\n\r\nTo test these changes:\r\n\r\n- Create an index\r\n- Go to
the index mappings page
at\r\n`app/management/data/index_management/indices/index_details?{yourIndexName}=blah&tab=mappings`\r\n-
Add an object field with a text field inside\r\n- Add a semantic text
field referencing that text field\r\n- If you're on a Macbook, create a
new inference endpoint with the model\r\n`.elser_model_2` instead of
using the default inference endpoint.\r\n- Add a second semantic text
field referencing that text field\r\n- Save your mappings\r\n- Use JSON
view to verify that the newly created text field contains a\r\n`copy_to`
field referencing both newly created semantic text fields\r\n- Verify
that the newly created semantic text fields are also in the\r\nJSON
view\r\n\r\n\r\n\r\n### Checklist\r\n\r\nDelete any items that are not
applicable to this PR.\r\n\r\n- [x] Any text added follows [EUI's
writing\r\nguidelines](https://elastic.github.io/eui/#/guidelines/writing),
uses\r\nsentence case text and includes
[i18n\r\nsupport](https://github.com/elastic/kibana/blob/main/packages/kbn-i18n/README.md)\r\n-
[x] [Unit or
functional\r\ntests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)\r\nwere
updated or added to match the most common scenarios\r\n- [x] Any UI
touched in this PR is usable by keyboard only (learn more\r\nabout
[keyboard accessibility](https://webaim.org/techniques/keyboard/))\r\n-
[x] Any UI touched in this PR does not create any new axe
failures\r\n(run axe in
browser:\r\n[FF](https://addons.mozilla.org/en-US/firefox/addon/axe-devtools/),\r\n[Chrome](https://chrome.google.com/webstore/detail/axe-web-accessibility-tes/lhdoppojpmngadmnindnejefpokejbdd?hl=en-US))\r\n-
[x] This renders correctly on smaller devices using a
responsive\r\nlayout. (You can test this [in
your\r\nbrowser](https://www.browserstack.com/guide/responsive-testing-on-local-server))\r\n-
[x] This was checked for
[cross-browser\r\ncompatibility](https://www.elastic.co/support/matrix#matrix_browsers)","sha":"460b52077ffa26673b1a40fff87a7ee182f0c9db","branchLabelMapping":{"^v8.16.0$":"main","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","Team:Search","v8.15.0","v8.16.0"],"title":"[Semantic
text] Make semantic text work with non-root level
fields","number":187154,"url":"https://github.com/elastic/kibana/pull/187154","mergeCommit":{"message":"[Semantic
text] Make semantic text work with non-root level fields (#187154)\n\n##
Summary\r\n\r\nThis makes semantic text work with non-root level
reference fields. It\r\nalso correctly adds copy_to to existing copy_to
fields instead of\r\nreplacing them, and streamlines a lot of the
code.\r\n\r\nTo test these changes:\r\n\r\n- Create an index\r\n- Go to
the index mappings page
at\r\n`app/management/data/index_management/indices/index_details?{yourIndexName}=blah&tab=mappings`\r\n-
Add an object field with a text field inside\r\n- Add a semantic text
field referencing that text field\r\n- If you're on a Macbook, create a
new inference endpoint with the model\r\n`.elser_model_2` instead of
using the default inference endpoint.\r\n- Add a second semantic text
field referencing that text field\r\n- Save your mappings\r\n- Use JSON
view to verify that the newly created text field contains a\r\n`copy_to`
field referencing both newly created semantic text fields\r\n- Verify
that the newly created semantic text fields are also in the\r\nJSON
view\r\n\r\n\r\n\r\n### Checklist\r\n\r\nDelete any items that are not
applicable to this PR.\r\n\r\n- [x] Any text added follows [EUI's
writing\r\nguidelines](https://elastic.github.io/eui/#/guidelines/writing),
uses\r\nsentence case text and includes
[i18n\r\nsupport](https://github.com/elastic/kibana/blob/main/packages/kbn-i18n/README.md)\r\n-
[x] [Unit or
functional\r\ntests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)\r\nwere
updated or added to match the most common scenarios\r\n- [x] Any UI
touched in this PR is usable by keyboard only (learn more\r\nabout
[keyboard accessibility](https://webaim.org/techniques/keyboard/))\r\n-
[x] Any UI touched in this PR does not create any new axe
failures\r\n(run axe in
browser:\r\n[FF](https://addons.mozilla.org/en-US/firefox/addon/axe-devtools/),\r\n[Chrome](https://chrome.google.com/webstore/detail/axe-web-accessibility-tes/lhdoppojpmngadmnindnejefpokejbdd?hl=en-US))\r\n-
[x] This renders correctly on smaller devices using a
responsive\r\nlayout. (You can test this [in
your\r\nbrowser](https://www.browserstack.com/guide/responsive-testing-on-local-server))\r\n-
[x] This was checked for
[cross-browser\r\ncompatibility](https://www.elastic.co/support/matrix#matrix_browsers)","sha":"460b52077ffa26673b1a40fff87a7ee182f0c9db"}},"sourceBranch":"main","suggestedTargetBranches":["8.15"],"targetPullRequestStates":[{"branch":"8.15","label":"v8.15.0","branchLabelMappingKey":"^v(\\d+).(\\d+).\\d+$","isSourceBranch":false,"state":"NOT_CREATED"},{"branch":"main","label":"v8.16.0","branchLabelMappingKey":"^v8.16.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/187154","number":187154,"mergeCommit":{"message":"[Semantic
text] Make semantic text work with non-root level fields (#187154)\n\n##
Summary\r\n\r\nThis makes semantic text work with non-root level
reference fields. It\r\nalso correctly adds copy_to to existing copy_to
fields instead of\r\nreplacing them, and streamlines a lot of the
code.\r\n\r\nTo test these changes:\r\n\r\n- Create an index\r\n- Go to
the index mappings page
at\r\n`app/management/data/index_management/indices/index_details?{yourIndexName}=blah&tab=mappings`\r\n-
Add an object field with a text field inside\r\n- Add a semantic text
field referencing that text field\r\n- If you're on a Macbook, create a
new inference endpoint with the model\r\n`.elser_model_2` instead of
using the default inference endpoint.\r\n- Add a second semantic text
field referencing that text field\r\n- Save your mappings\r\n- Use JSON
view to verify that the newly created text field contains a\r\n`copy_to`
field referencing both newly created semantic text fields\r\n- Verify
that the newly created semantic text fields are also in the\r\nJSON
view\r\n\r\n\r\n\r\n### Checklist\r\n\r\nDelete any items that are not
applicable to this PR.\r\n\r\n- [x] Any text added follows [EUI's
writing\r\nguidelines](https://elastic.github.io/eui/#/guidelines/writing),
uses\r\nsentence case text and includes
[i18n\r\nsupport](https://github.com/elastic/kibana/blob/main/packages/kbn-i18n/README.md)\r\n-
[x] [Unit or
functional\r\ntests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)\r\nwere
updated or added to match the most common scenarios\r\n- [x] Any UI
touched in this PR is usable by keyboard only (learn more\r\nabout
[keyboard accessibility](https://webaim.org/techniques/keyboard/))\r\n-
[x] Any UI touched in this PR does not create any new axe
failures\r\n(run axe in
browser:\r\n[FF](https://addons.mozilla.org/en-US/firefox/addon/axe-devtools/),\r\n[Chrome](https://chrome.google.com/webstore/detail/axe-web-accessibility-tes/lhdoppojpmngadmnindnejefpokejbdd?hl=en-US))\r\n-
[x] This renders correctly on smaller devices using a
responsive\r\nlayout. (You can test this [in
your\r\nbrowser](https://www.browserstack.com/guide/responsive-testing-on-local-server))\r\n-
[x] This was checked for
[cross-browser\r\ncompatibility](https://www.elastic.co/support/matrix#matrix_browsers)","sha":"460b52077ffa26673b1a40fff87a7ee182f0c9db"}}]}]
BACKPORT-->

Co-authored-by: Sander Philipse <94373878+sphilipse@users.noreply.github.com>
This commit is contained in:
Kibana Machine 2024-07-11 14:18:07 +02:00 committed by GitHub
parent 279152a274
commit 109901226d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
37 changed files with 1535 additions and 1044 deletions

View file

@ -70,10 +70,7 @@ export const ElasticsearchModels: React.FC<ElasticsearchModelsProps> = ({
}
}, [numberOfAllocations, numberOfThreads, serviceType]);
const elasticSearchModelTypesDescriptions: Record<
ElasticsearchModelDefaultOptions | string,
ElasticsearchModelDescriptions
> = {
const elasticSearchModelTypesDescriptions: Record<string, ElasticsearchModelDescriptions> = {
[ElasticsearchModelDefaultOptions.elser]: {
description: i18n.translate(
'xpack.ml.addInferenceEndpoint.elasticsearchModels.elser.description',

View file

@ -27,4 +27,13 @@ export {
ELASTIC_MODEL_TYPE,
MODEL_STATE,
type ModelState,
ELSER_LINUX_OPTIMIZED_MODEL_ID,
ELSER_MODEL_ID,
E5_LINUX_OPTIMIZED_MODEL_ID,
E5_MODEL_ID,
LANG_IDENT_MODEL_ID,
LATEST_ELSER_VERSION,
LATEST_ELSER_MODEL_ID,
LATEST_E5_MODEL_ID,
ElserModels,
} from './src/constants/trained_models';

View file

@ -7,6 +7,18 @@
import { i18n } from '@kbn/i18n';
export const ELSER_MODEL_ID = '.elser_model_2';
export const ELSER_LINUX_OPTIMIZED_MODEL_ID = '.elser_model_2_linux-x86_64';
export const E5_MODEL_ID = '.multilingual-e5-small';
export const E5_LINUX_OPTIMIZED_MODEL_ID = '.multilingual-e5-small_linux-x86_64';
export const LANG_IDENT_MODEL_ID = 'lang_ident_model_1';
export const ELSER_ID_V1 = '.elser_model_1' as const;
export const LATEST_ELSER_VERSION: ElserVersion = 2;
export const LATEST_ELSER_MODEL_ID = ELSER_LINUX_OPTIMIZED_MODEL_ID;
export const LATEST_E5_MODEL_ID = E5_LINUX_OPTIMIZED_MODEL_ID;
export const ElserModels = [ELSER_MODEL_ID, ELSER_LINUX_OPTIMIZED_MODEL_ID, ELSER_ID_V1];
export const DEPLOYMENT_STATE = {
STARTED: 'started',
STARTING: 'starting',
@ -46,10 +58,8 @@ export const BUILT_IN_MODEL_TAG = 'prepackaged';
export const ELASTIC_MODEL_TAG = 'elastic';
export const ELSER_ID_V1 = '.elser_model_1' as const;
export const ELASTIC_MODEL_DEFINITIONS: Record<string, ModelDefinition> = Object.freeze({
'.elser_model_1': {
[ELSER_ID_V1]: {
modelName: 'elser',
hidden: true,
version: 1,
@ -63,7 +73,7 @@ export const ELASTIC_MODEL_DEFINITIONS: Record<string, ModelDefinition> = Object
}),
type: ['elastic', 'pytorch', 'text_expansion'],
},
'.elser_model_2': {
[ELSER_MODEL_ID]: {
modelName: 'elser',
version: 2,
default: true,
@ -77,7 +87,7 @@ export const ELASTIC_MODEL_DEFINITIONS: Record<string, ModelDefinition> = Object
}),
type: ['elastic', 'pytorch', 'text_expansion'],
},
'.elser_model_2_linux-x86_64': {
[ELSER_LINUX_OPTIMIZED_MODEL_ID]: {
modelName: 'elser',
version: 2,
os: 'Linux',
@ -92,7 +102,7 @@ export const ELASTIC_MODEL_DEFINITIONS: Record<string, ModelDefinition> = Object
}),
type: ['elastic', 'pytorch', 'text_expansion'],
},
'.multilingual-e5-small': {
[E5_MODEL_ID]: {
modelName: 'e5',
version: 1,
default: true,
@ -108,7 +118,7 @@ export const ELASTIC_MODEL_DEFINITIONS: Record<string, ModelDefinition> = Object
licenseUrl: 'https://huggingface.co/elastic/multilingual-e5-small',
type: ['pytorch', 'text_embedding'],
},
'.multilingual-e5-small_linux-x86_64': {
[E5_LINUX_OPTIMIZED_MODEL_ID]: {
modelName: 'e5',
version: 1,
os: 'Linux',
@ -178,23 +188,17 @@ export interface GetModelDownloadConfigOptions {
version?: ElserVersion;
}
export interface LocalInferenceServiceSettings {
service: 'elser' | 'elasticsearch';
service_settings: {
num_allocations: number;
num_threads: number;
model_id: string;
};
}
export type InferenceServiceSettings =
| {
service: 'elser';
service_settings: {
num_allocations: number;
num_threads: number;
model_id: string;
};
}
| {
service: 'elasticsearch';
service_settings: {
num_allocations: number;
num_threads: number;
model_id: string;
};
}
| LocalInferenceServiceSettings
| {
service: 'openai';
service_settings: {

View file

@ -9,15 +9,16 @@ import { mockLogger } from '../../__mocks__';
import { MlTrainedModels } from '@kbn/ml-plugin/server';
import { MlModelDeploymentState } from '../../../common/types/ml';
import { fetchMlModels } from './fetch_ml_models';
import {
E5_LINUX_OPTIMIZED_MODEL_ID,
E5_MODEL_ID,
ELSER_LINUX_OPTIMIZED_MODEL_ID,
ELSER_MODEL_ID,
} from './utils';
} from '@kbn/ml-trained-models-utils';
import { MlModelDeploymentState } from '../../../common/types/ml';
import { fetchMlModels } from './fetch_ml_models';
describe('fetchMlModels', () => {
const mockTrainedModelsProvider = {

View file

@ -11,6 +11,14 @@ import { Logger } from '@kbn/core/server';
import { i18n } from '@kbn/i18n';
import { MlTrainedModels } from '@kbn/ml-plugin/server';
import {
E5_LINUX_OPTIMIZED_MODEL_ID,
E5_MODEL_ID,
ELSER_LINUX_OPTIMIZED_MODEL_ID,
ELSER_MODEL_ID,
LANG_IDENT_MODEL_ID,
} from '@kbn/ml-trained-models-utils';
import { getMlModelTypesForModelConfig } from '../../../common/ml_inference_pipeline';
import { MlModelDeploymentState, MlModel } from '../../../common/types/ml';
@ -18,15 +26,10 @@ import { MlModelDeploymentState, MlModel } from '../../../common/types/ml';
import {
BASE_MODEL,
ELSER_LINUX_OPTIMIZED_MODEL_PLACEHOLDER,
ELSER_MODEL_ID,
ELSER_MODEL_PLACEHOLDER,
E5_LINUX_OPTIMIZED_MODEL_PLACEHOLDER,
E5_MODEL_ID,
E5_MODEL_PLACEHOLDER,
LANG_IDENT_MODEL_ID,
MODEL_TITLES_BY_TYPE,
E5_LINUX_OPTIMIZED_MODEL_ID,
ELSER_LINUX_OPTIMIZED_MODEL_ID,
} from './utils';
let compatibleElserModelId = ELSER_MODEL_ID;

View file

@ -8,13 +8,14 @@
import { i18n } from '@kbn/i18n';
import { SUPPORTED_PYTORCH_TASKS } from '@kbn/ml-trained-models-utils';
import { MlModelDeploymentState, MlModel } from '../../../common/types/ml';
import {
E5_LINUX_OPTIMIZED_MODEL_ID,
E5_MODEL_ID,
ELSER_LINUX_OPTIMIZED_MODEL_ID,
ELSER_MODEL_ID,
} from '@kbn/ml-trained-models-utils';
export const ELSER_MODEL_ID = '.elser_model_2';
export const ELSER_LINUX_OPTIMIZED_MODEL_ID = '.elser_model_2_linux-x86_64';
export const E5_MODEL_ID = '.multilingual-e5-small';
export const E5_LINUX_OPTIMIZED_MODEL_ID = '.multilingual-e5-small_linux-x86_64';
export const LANG_IDENT_MODEL_ID = 'lang_ident_model_1';
import { MlModelDeploymentState, MlModel } from '../../../common/types/ml';
export const MODEL_TITLES_BY_TYPE: Record<string, string | undefined> = {
fill_mask: i18n.translate('xpack.enterpriseSearch.content.ml_inference.fill_mask', {

View file

@ -61,7 +61,7 @@ export interface IndexDetailsPageTestBed extends TestBed {
selectInferenceIdButtonExists: () => void;
openSelectInferencePopover: () => void;
expectDefaultInferenceModelToExists: () => void;
expectCustomInferenceModelToExists: (customInference: string) => Promise<void>;
expectCustomInferenceModelToExists: (customInference: string) => void;
};
settings: {
getCodeBlockContent: () => string;
@ -317,23 +317,23 @@ export const setup = async ({
expect(exists('fieldTypesOptions-semantic_text')).toBe(false);
});
},
isReferenceFieldVisible: async () => {
expect(exists('referenceField.select')).toBe(true);
isReferenceFieldVisible: () => {
expect(exists('referenceFieldSelect')).toBe(true);
},
selectInferenceIdButtonExists: async () => {
selectInferenceIdButtonExists: () => {
expect(exists('selectInferenceId')).toBe(true);
expect(exists('inferenceIdButton')).toBe(true);
find('inferenceIdButton').simulate('click');
},
openSelectInferencePopover: async () => {
openSelectInferencePopover: () => {
expect(exists('addInferenceEndpointButton')).toBe(true);
expect(exists('manageInferenceEndpointButton')).toBe(true);
},
expectDefaultInferenceModelToExists: async () => {
expect(exists('default-inference_elser_model_2')).toBe(true);
expect(exists('default-inference_e5')).toBe(true);
expectDefaultInferenceModelToExists: () => {
expect(exists('custom-inference_elser_model_2')).toBe(true);
expect(exists('custom-inference_e5')).toBe(true);
},
expectCustomInferenceModelToExists: async (customInference: string) => {
expectCustomInferenceModelToExists: (customInference: string) => {
expect(exists(customInference)).toBe(true);
},
};

View file

@ -69,6 +69,7 @@ describe('<IndexDetailsPage />', () => {
httpRequestsMockHelpers.setLoadIndexStatsResponse(testIndexName, testIndexStats);
httpRequestsMockHelpers.setLoadIndexMappingResponse(testIndexName, testIndexMappings);
httpRequestsMockHelpers.setLoadIndexSettingsResponse(testIndexName, testIndexSettings);
httpRequestsMockHelpers.setInferenceModels([]);
await act(async () => {
testBed = await setup({
@ -692,6 +693,7 @@ describe('<IndexDetailsPage />', () => {
ml: {
mlApi: {
trainedModels: {
getModelsDownloadStatus: jest.fn().mockResolvedValue({}),
getTrainedModels: jest.fn().mockResolvedValue([
{
model_id: '.elser_model_2',

View file

@ -5,13 +5,20 @@
* 2.0.
*/
import {
Form,
useForm,
} from '../../../public/application/components/mappings_editor/shared_imports';
import { registerTestBed } from '@kbn/test-jest-helpers';
import { act } from 'react-dom/test-utils';
import { SelectInferenceId } from '../../../public/application/components/mappings_editor/components/document_fields/field_parameters/select_inference_id';
import {
SelectInferenceId,
SelectInferenceIdProps,
} from '../../../public/application/components/mappings_editor/components/document_fields/field_parameters/select_inference_id';
import React from 'react';
const onChangeMock = jest.fn();
const setValueMock = jest.fn();
const setNewInferenceEndpointMock = jest.fn();
const createInferenceEndpointMock = jest.fn();
const mockDispatch = jest.fn();
jest.mock('../../../public/application/app_context', () => ({
useAppContext: jest.fn().mockReturnValue({
@ -41,23 +48,40 @@ jest.mock(
}),
})
);
jest.mock('../../../public/application/components/mappings_editor/mappings_state_context', () => ({
useMappingsState: () => ({ inferenceToModelIdMap: {} }),
useDispatch: () => mockDispatch,
}));
function getTestForm(Component: React.FC<SelectInferenceIdProps>) {
return (defaultProps: SelectInferenceIdProps) => {
const { form } = useForm();
form.setFieldValue('inference_id', 'elser_model_2');
return (
<Form form={form}>
<Component {...(defaultProps as any)} />
</Form>
);
};
}
describe('SelectInferenceId', () => {
let exists: any;
let find: any;
beforeAll(async () => {
const setup = registerTestBed(SelectInferenceId, {
defaultProps: {
onChange: onChangeMock,
'data-test-subj': 'data-inference-endpoint-list',
setValue: setValueMock,
setNewInferenceEndpoint: setNewInferenceEndpointMock,
},
const defaultProps: SelectInferenceIdProps = {
'data-test-subj': 'data-inference-endpoint-list',
createInferenceEndpoint: createInferenceEndpointMock,
};
const setup = registerTestBed(getTestForm(SelectInferenceId), {
defaultProps,
memoryRouter: { wrapComponent: false },
});
await act(async () => {
const testBed = setup();
const testBed = await setup();
exists = testBed.exists;
find = testBed.find;
});

View file

@ -6,26 +6,159 @@
*/
import { registerTestBed } from '@kbn/test-jest-helpers';
import { TrainedModelsDeploymentModal } from '../../../public/application/sections/home/index_list/details_page/trained_models_deployment_modal';
import {
TrainedModelsDeploymentModal,
TrainedModelsDeploymentModalProps,
} from '../../../public/application/sections/home/index_list/details_page/trained_models_deployment_modal';
import { act } from 'react-dom/test-utils';
import * as mappingsContext from '../../../public/application/components/mappings_editor/mappings_state_context';
import { NormalizedField } from '../../../public/application/components/mappings_editor/types';
const refreshModal = jest.fn();
const setIsModalVisible = jest.fn();
const tryAgainForErrorModal = jest.fn();
const setIsVisibleForErrorModal = jest.fn();
jest.mock('../../../public/hooks/use_ml_model_status_toasts', () => ({
useMLModelNotificationToasts: jest.fn().mockReturnValue({
showErrorToasts: jest.fn(),
}),
}));
jest.mock('../../../public/application/app_context', () => ({
useAppContext: jest.fn().mockReturnValue({
url: undefined,
plugins: {
ml: {
mlApi: {
trainedModels: {
getModelsDownloadStatus: jest.fn().mockResolvedValue({}),
getTrainedModels: jest.fn().mockResolvedValue([
{
model_id: '.elser_model_2',
model_type: 'pytorch',
model_package: {
packaged_model_id: 'elser_model_2',
model_repository: 'https://ml-models.elastic.co',
minimum_version: '11.0.0',
size: 438123914,
sha256: '',
metadata: {},
tags: [],
vocabulary_file: 'elser_model_2.vocab.json',
},
description: 'Elastic Learned Sparse EncodeR v2',
tags: ['elastic'],
},
]),
getTrainedModelStats: jest.fn().mockResolvedValue({
count: 1,
trained_model_stats: [
{
model_id: '.elser_model_2',
deployment_stats: {
deployment_id: 'elser_model_2',
model_id: '.elser_model_2',
threads_per_allocation: 1,
number_of_allocations: 1,
queue_capacity: 1024,
state: 'started',
},
},
],
}),
},
},
},
},
}),
}));
jest.mock('../../../public/application/components/mappings_editor/mappings_state_context');
const mappingsContextMocked = jest.mocked(mappingsContext);
const defaultState = {
inferenceToModelIdMap: {
e5: {
isDeployed: false,
isDeployable: true,
trainedModelId: '.multilingual-e5-small',
},
elser_model_2: {
isDeployed: false,
isDeployable: true,
trainedModelId: '.elser_model_2',
},
openai: {
isDeployed: false,
isDeployable: false,
trainedModelId: '',
},
my_elser_endpoint: {
isDeployed: false,
isDeployable: true,
trainedModelId: '.elser_model_2',
},
},
fields: {
aliases: {},
byId: {},
rootLevelFields: [],
maxNestedDepth: 0,
},
mappingViewFields: { byId: {} },
} as any;
const setErrorsInTrainedModelDeployment = jest.fn().mockReturnValue(undefined);
const fetchData = jest.fn().mockReturnValue(undefined);
describe('When semantic_text is enabled', () => {
describe('When there is no error in the model deployment', () => {
const setup = registerTestBed(TrainedModelsDeploymentModal, {
defaultProps: {
setIsModalVisible,
refreshModal,
pendingDeployments: ['.elser-test-3'],
errorsInTrainedModelDeployment: [],
},
const setup = (defaultProps: Partial<TrainedModelsDeploymentModalProps>) =>
registerTestBed(TrainedModelsDeploymentModal, {
defaultProps,
memoryRouter: { wrapComponent: false },
})();
beforeEach(() => {
jest.clearAllMocks();
});
describe('When there are no pending deployments and no errors in the model deployment', () => {
mappingsContextMocked.useMappingsState.mockReturnValue(defaultState);
const { exists } = setup({
errorsInTrainedModelDeployment: {},
fetchData,
setErrorsInTrainedModelDeployment: () => undefined,
});
it('should not display the modal', () => {
expect(exists('trainedModelsDeploymentModal')).toBe(false);
});
});
describe('When there are pending deployments in the model deployment', () => {
mappingsContextMocked.useMappingsState.mockReturnValue({
...defaultState,
fields: {
...defaultState.fields,
byId: {
new_field: {
id: 'new_field',
isMultiField: false,
path: ['new_field'],
source: {
name: 'new_field',
type: 'semantic_text',
reference_field: 'title',
inference_id: 'elser_model_2',
},
} as NormalizedField,
},
rootLevelFields: ['new_field'],
},
} as any);
const { exists, find } = setup({
errorsInTrainedModelDeployment: {},
fetchData,
setErrorsInTrainedModelDeployment,
});
const { exists, find } = setup();
it('should display the modal', () => {
expect(exists('trainedModelsDeploymentModal')).toBe(true);
@ -37,55 +170,61 @@ describe('When semantic_text is enabled', () => {
);
});
it('should call refresh method if refresh button is pressed', async () => {
it('should call fetch data if refresh button is pressed', async () => {
await act(async () => {
find('confirmModalConfirmButton').simulate('click');
});
expect(refreshModal.mock.calls).toHaveLength(1);
});
it('should call setIsModalVisible method if cancel button is pressed', async () => {
await act(async () => {
find('confirmModalCancelButton').simulate('click');
});
expect(setIsModalVisible).toHaveBeenLastCalledWith(false);
expect(fetchData.mock.calls).toHaveLength(1);
});
});
describe('When there is error in the model deployment', () => {
const setup = registerTestBed(TrainedModelsDeploymentModal, {
defaultProps: {
setIsModalVisible: setIsVisibleForErrorModal,
refreshModal: tryAgainForErrorModal,
pendingDeployments: ['.elser-test-3'],
errorsInTrainedModelDeployment: ['.elser-test-3'],
mappingsContextMocked.useMappingsState.mockReturnValue({
...defaultState,
fields: {
...defaultState.fields,
byId: {
new_field: {
id: 'new_field',
isMultiField: false,
path: ['new_field'],
source: {
name: 'new_field',
type: 'semantic_text',
reference_field: 'title',
inference_id: 'elser_model_2',
},
} as NormalizedField,
},
rootLevelFields: ['new_field'],
},
memoryRouter: { wrapComponent: false },
});
const { exists, find } = setup();
it('should display the modal', () => {
expect(exists('trainedModelsErroredDeploymentModal')).toBe(true);
} as any);
const { find } = setup({
fetchData,
errorsInTrainedModelDeployment: { '.elser_model_2': 'Error' },
setErrorsInTrainedModelDeployment,
});
it('should contain content related to semantic_text', () => {
expect(find('trainedModelsErrorDeploymentModalText').text()).toContain(
'There was an error when trying to deploy'
);
it('should display text related to errored deployments', () => {
expect(find('trainedModelsDeploymentModalText').text()).toContain('There was an error');
});
it('should display only the errored deployment', () => {
expect(find('trainedModelsDeploymentModal').text()).toContain('.elser_model_2');
expect(find('trainedModelsDeploymentModal').text()).not.toContain('valid-model');
});
it("should call refresh method if 'Try again' button is pressed", async () => {
await act(async () => {
find('confirmModalConfirmButton').simulate('click');
});
expect(tryAgainForErrorModal.mock.calls).toHaveLength(1);
expect(fetchData.mock.calls).toHaveLength(1);
});
it('should call setIsVisibleForErrorModal method if cancel button is pressed', async () => {
await act(async () => {
find('confirmModalCancelButton').simulate('click');
});
expect(setIsVisibleForErrorModal).toHaveBeenLastCalledWith(false);
});
});
});

View file

@ -380,8 +380,11 @@ export const setup = (
props: any = { onUpdate() {} },
appDependencies?: any
): MappingsEditorTestBed => {
const defaultAppDependencies = {
plugins: {},
};
const setupTestBed = registerTestBed<TestSubjects>(
WithAppDependencies(MappingsEditor, appDependencies),
WithAppDependencies(MappingsEditor, appDependencies ?? defaultAppDependencies),
{
memoryRouter: {
wrapComponent: false,

View file

@ -28,6 +28,7 @@ describe('Mappings editor: core', () => {
let onChangeHandler: jest.Mock = jest.fn();
let getMappingsEditorData = getMappingsEditorDataFactory(onChangeHandler);
let testBed: MappingsEditorTestBed;
const appDependencies = { plugins: { ml: { mlApi: {} } } };
beforeAll(() => {
jest.useFakeTimers({ legacyFakeTimers: true });
@ -55,7 +56,7 @@ describe('Mappings editor: core', () => {
};
await act(async () => {
testBed = setup({ value: defaultMappings, onChange: onChangeHandler });
testBed = setup({ value: defaultMappings, onChange: onChangeHandler }, appDependencies);
});
const { component } = testBed;
@ -95,7 +96,7 @@ describe('Mappings editor: core', () => {
};
await act(async () => {
testBed = setup({ onChange: onChangeHandler, value });
testBed = setup({ onChange: onChangeHandler, value }, appDependencies);
});
const { component, exists } = testBed;
@ -115,7 +116,7 @@ describe('Mappings editor: core', () => {
},
};
await act(async () => {
testBed = setup({ onChange: onChangeHandler, value });
testBed = setup({ onChange: onChangeHandler, value }, appDependencies);
});
const { component, exists } = testBed;
@ -137,6 +138,7 @@ describe('Mappings editor: core', () => {
config: {
enableMappingsSourceFieldSection: true,
},
...appDependencies,
};
beforeEach(async () => {
@ -295,6 +297,7 @@ describe('Mappings editor: core', () => {
config: {
enableMappingsSourceFieldSection: true,
},
...appDependencies,
};
beforeEach(async () => {
@ -472,7 +475,7 @@ describe('Mappings editor: core', () => {
},
};
await act(async () => {
testBed = setup({ onChange: onChangeHandler, value });
testBed = setup({ onChange: onChangeHandler, value }, appDependencies);
});
const { component, exists } = testBed;
@ -494,7 +497,7 @@ describe('Mappings editor: core', () => {
},
};
await act(async () => {
testBed = setup({ onChange: onChangeHandler, value });
testBed = setup({ onChange: onChangeHandler, value }, appDependencies);
});
const { component } = testBed;

View file

@ -7,6 +7,7 @@
import React, { useCallback, useMemo } from 'react';
import { i18n } from '@kbn/i18n';
import { TextField, UseField, FieldConfig } from '../../../shared_imports';
import { validateUniqueName } from '../../../lib';
import { PARAMETERS_DEFINITION } from '../../../constants';
@ -14,7 +15,11 @@ import { useMappingsState } from '../../../mappings_state_context';
const { validations, ...rest } = PARAMETERS_DEFINITION.name.fieldConfig as FieldConfig;
export const NameParameter = () => {
interface NameParameterProps {
isSemanticText?: boolean;
}
export const NameParameter: React.FC<NameParameterProps> = ({ isSemanticText }) => {
const {
fields: { rootLevelFields, byId },
documentFields: { fieldToAddFieldTo, fieldToEdit },
@ -32,6 +37,11 @@ export const NameParameter = () => {
const nameConfig: FieldConfig = useMemo(
() => ({
...rest,
label: isSemanticText
? i18n.translate('xpack.idxMgmt.mappingsEditor.semanticTextNameFieldLabel', {
defaultMessage: 'New field name',
})
: rest.label,
validations: [
...validations!,
{
@ -39,7 +49,7 @@ export const NameParameter = () => {
},
],
}),
[uniqueNameValidator]
[isSemanticText, uniqueNameValidator]
);
return (

View file

@ -5,65 +5,45 @@
* 2.0.
*/
import React, { useEffect } from 'react';
import React from 'react';
import { useLoadIndexMappings } from '../../../../../services';
import { getFieldConfig } from '../../../lib';
import { Form, SuperSelectField, UseField, useForm } from '../../../shared_imports';
import { useMappingsState } from '../../../mappings_state_context';
import { SuperSelectField, UseField } from '../../../shared_imports';
import { SuperSelectOption } from '../../../types';
interface Props {
onChange(value: string): void;
'data-test-subj'?: string;
indexName?: string;
}
export const ReferenceFieldSelects = () => {
const { fields, mappingViewFields } = useMappingsState();
export const ReferenceFieldSelects = ({
onChange,
'data-test-subj': dataTestSubj,
indexName,
}: Props) => {
const { form } = useForm();
const { subscribe } = form;
const allFields = {
byId: {
...mappingViewFields.byId,
...fields.byId,
},
rootLevelFields: [],
aliases: {},
maxNestedDepth: 0,
};
const { data } = useLoadIndexMappings(indexName ?? '');
const referenceFieldOptions: SuperSelectOption[] = [];
if (data && data.mappings && data.mappings.properties) {
Object.keys(data.mappings.properties).forEach((key) => {
const field = data.mappings.properties[key];
if (field.type === 'text') {
referenceFieldOptions.push({
value: key,
inputDisplay: key,
'data-test-subj': `select-reference-field-${key}`,
});
}
});
}
const referenceFieldOptions: SuperSelectOption[] = Object.values(allFields.byId)
.filter((field) => field.source.type === 'text')
.map((field) => ({
value: field.path.join('.'),
inputDisplay: field.path.join('.'),
'data-test-subj': `select-reference-field-${field.path.join('.')}}`,
}));
const fieldConfigReferenceField = getFieldConfig('reference_field');
useEffect(() => {
const subscription = subscribe((updateData) => {
const formData = updateData.data.internal;
const value = formData.main;
onChange(value);
});
return subscription.unsubscribe;
}, [subscribe, onChange]);
return (
<Form form={form} data-test-subj="referenceField">
<UseField path="main" config={fieldConfigReferenceField}>
{(field) => (
<SuperSelectField
field={field}
euiFieldProps={{
options: referenceFieldOptions,
}}
data-test-subj={dataTestSubj}
/>
)}
</UseField>
</Form>
<UseField path="reference_field" config={fieldConfigReferenceField}>
{(field) => (
<SuperSelectField
field={field}
euiFieldProps={{
options: referenceFieldOptions,
}}
data-test-subj="referenceFieldSelect"
/>
)}
</UseField>
);
};

View file

@ -24,63 +24,74 @@ import {
import { i18n } from '@kbn/i18n';
import React, { useEffect, useState, useCallback, useMemo } from 'react';
import {
InferenceAPIConfigResponse,
SUPPORTED_PYTORCH_TASKS,
TRAINED_MODEL_TYPE,
} from '@kbn/ml-trained-models-utils';
import { SUPPORTED_PYTORCH_TASKS, TRAINED_MODEL_TYPE } from '@kbn/ml-trained-models-utils';
import { InferenceTaskType } from '@elastic/elasticsearch/lib/api/types';
import {
ElasticsearchModelDefaultOptions,
ModelConfig,
Service,
} from '@kbn/inference_integration_flyout/types';
import { ModelConfig } from '@kbn/inference_integration_flyout/types';
import { InferenceFlyoutWrapper } from '@kbn/inference_integration_flyout/components/inference_flyout_wrapper';
import { TrainedModelConfigResponse } from '@kbn/ml-plugin/common/types/trained_models';
import { getFieldConfig } from '../../../lib';
import { useAppContext } from '../../../../../app_context';
import { Form, UseField, useForm } from '../../../shared_imports';
import { useLoadInferenceEndpoints } from '../../../../../services/api';
import { getTrainedModelStats } from '../../../../../../hooks/use_details_page_mappings_model_management';
import { InferenceToModelIdMap } from '../fields';
import { useMLModelNotificationToasts } from '../../../../../../hooks/use_ml_model_status_toasts';
import {
CustomInferenceEndpointConfig,
DefaultInferenceModels,
DeploymentState,
} from '../../../types';
import { CustomInferenceEndpointConfig } from '../../../types';
import { UseField } from '../../../shared_imports';
const inferenceServiceTypeElasticsearchModelMap: Record<string, ElasticsearchModelDefaultOptions> =
{
elser: ElasticsearchModelDefaultOptions.elser,
elasticsearch: ElasticsearchModelDefaultOptions.e5,
};
const uncheckSelectedModelOption = (options: EuiSelectableOption[]) => {
const checkedOption = options.find(({ checked }) => checked === 'on');
if (checkedOption) {
checkedOption.checked = undefined;
}
};
interface Props {
onChange(value: string): void;
export interface SelectInferenceIdProps {
createInferenceEndpoint: (
trainedModelId: string,
inferenceId: string,
modelConfig: CustomInferenceEndpointConfig
) => Promise<void>;
'data-test-subj'?: string;
setValue: (value: string) => void;
setNewInferenceEndpoint: (
newInferenceEndpoint: InferenceToModelIdMap,
customInferenceEndpointConfig: CustomInferenceEndpointConfig
) => void;
}
export const SelectInferenceId = ({
onChange,
type SelectInferenceIdContentProps = SelectInferenceIdProps & {
setValue: (value: string) => void;
value: string;
};
const defaultEndpoints = [
{
model_id: 'elser_model_2',
},
{
model_id: 'e5',
},
];
export const SelectInferenceId: React.FC<SelectInferenceIdProps> = ({
createInferenceEndpoint,
'data-test-subj': dataTestSubj,
}: SelectInferenceIdProps) => {
const config = getFieldConfig('inference_id');
return (
<UseField path="inference_id" fieldConfig={config}>
{(field) => {
return (
<SelectInferenceIdContent
createInferenceEndpoint={createInferenceEndpoint}
data-test-subj={dataTestSubj}
value={field.value as string}
setValue={field.setValue}
/>
);
}}
</UseField>
);
};
const SelectInferenceIdContent: React.FC<SelectInferenceIdContentProps> = ({
createInferenceEndpoint,
'data-test-subj': dataTestSubj,
setValue,
setNewInferenceEndpoint,
}: Props) => {
value,
}) => {
const {
core: { application },
docLinks,
plugins: { ml },
} = useAppContext();
const config = getFieldConfig('inference_id');
const getMlTrainedModelPageUrl = useCallback(async () => {
return await ml?.locator?.getUrl({
@ -88,9 +99,6 @@ export const SelectInferenceId = ({
});
}, [ml]);
const { form } = useForm({ defaultValue: { main: DefaultInferenceModels.elser_model_2 } });
const { subscribe } = form;
const [isInferenceFlyoutVisible, setIsInferenceFlyoutVisible] = useState<boolean>(false);
const [availableTrainedModels, setAvailableTrainedModels] = useState<
TrainedModelConfigResponse[]
@ -118,101 +126,57 @@ export const SelectInferenceId = ({
return availableTrainedModelsList;
}, [availableTrainedModels]);
const [isSaveInferenceLoading, setIsSaveInferenceLoading] = useState<boolean>(false);
const fieldConfigModelId = getFieldConfig('inference_id');
const defaultInferenceIds: EuiSelectableOption[] = useMemo(() => {
return [
{
const { isLoading, data: endpoints, resendRequest } = useLoadInferenceEndpoints();
const options: EuiSelectableOption[] = useMemo(() => {
const missingDefaultEndpoints = defaultEndpoints.filter(
(endpoint) => !(endpoints || []).find((e) => e.model_id === endpoint.model_id)
);
const newOptions: EuiSelectableOption[] = [
...(endpoints || []),
...missingDefaultEndpoints,
].map((endpoint) => ({
label: endpoint.model_id,
'data-test-subj': `custom-inference_${endpoint.model_id}`,
checked: value === endpoint.model_id ? 'on' : undefined,
}));
if (value && !newOptions.find((option) => option.label === value)) {
// Sometimes we create a new endpoint but the backend is slow in updating so we need to optimistically update
const newOption: EuiSelectableOption = {
label: value,
checked: 'on',
label: 'elser_model_2',
'data-test-subj': 'default-inference_elser_model_2',
},
{
label: 'e5',
'data-test-subj': 'default-inference_e5',
},
];
}, []);
const { isLoading, data: models } = useLoadInferenceEndpoints();
const [options, setOptions] = useState<EuiSelectableOption[]>([...defaultInferenceIds]);
const inferenceIdOptionsFromModels = useMemo(() => {
const inferenceIdOptions =
models?.map((model: InferenceAPIConfigResponse) => ({
label: model.model_id,
'data-test-subj': `custom-inference_${model.model_id}`,
})) || [];
return inferenceIdOptions;
}, [models]);
useEffect(() => {
const mergedOptions = {
...inferenceIdOptionsFromModels.reduce(
(acc, option) => ({ ...acc, [option.label]: option }),
{}
),
...defaultInferenceIds.reduce((acc, option) => ({ ...acc, [option.label]: option }), {}),
};
setOptions(Object.values(mergedOptions));
}, [inferenceIdOptionsFromModels, defaultInferenceIds]);
'data-test-subj': `custom-inference_${value}`,
};
return [...newOptions, newOption];
}
return newOptions;
}, [endpoints, value]);
const { showErrorToasts } = useMLModelNotificationToasts();
const onSaveInferenceCallback = useCallback(
async (inferenceId: string, taskType: InferenceTaskType, modelConfig: ModelConfig) => {
setIsInferenceFlyoutVisible(!isInferenceFlyoutVisible);
try {
const isDeployable =
modelConfig.service === Service.elser || modelConfig.service === Service.elasticsearch;
const newOption: EuiSelectableOption[] = [
{
label: inferenceId,
checked: 'on',
'data-test-subj': `custom-inference_${inferenceId}`,
},
];
// uncheck selected endpoint id
uncheckSelectedModelOption(options);
setOptions([...options, ...newOption]);
const trainedModelStats = await ml?.mlApi?.trainedModels.getTrainedModelStats();
const defaultEndpointId =
inferenceServiceTypeElasticsearchModelMap[modelConfig.service] || '';
const newModelId: InferenceToModelIdMap = {};
newModelId[inferenceId] = {
trainedModelId: defaultEndpointId,
isDeployable,
isDeployed:
getTrainedModelStats(trainedModelStats)[defaultEndpointId] === DeploymentState.DEPLOYED,
};
const customInferenceEndpointConfig: CustomInferenceEndpointConfig = {
const trainedModelId = modelConfig.service_settings.model_id || '';
const customModelConfig = {
taskType,
modelConfig,
};
setNewInferenceEndpoint(newModelId, customInferenceEndpointConfig);
setIsSaveInferenceLoading(true);
await createInferenceEndpoint(trainedModelId, inferenceId, customModelConfig);
resendRequest();
setValue(inferenceId);
setIsInferenceFlyoutVisible(!isInferenceFlyoutVisible);
setIsSaveInferenceLoading(false);
} catch (error) {
showErrorToasts(error);
setIsSaveInferenceLoading(false);
}
},
[isInferenceFlyoutVisible, ml, setNewInferenceEndpoint, options, showErrorToasts]
[createInferenceEndpoint, setValue, isInferenceFlyoutVisible, showErrorToasts, resendRequest]
);
useEffect(() => {
const subscription = subscribe((updateData) => {
const formData = updateData.data.internal;
const value = formData.main;
onChange(value);
});
return subscription.unsubscribe;
}, [subscribe, onChange]);
const selectedOptionLabel = options.find((option) => option.checked)?.label;
useEffect(() => {
setValue(selectedOptionLabel ?? DefaultInferenceModels.elser_model_2);
}, [selectedOptionLabel, setValue]);
const [isInferencePopoverVisible, setIsInferencePopoverVisible] = useState<boolean>(false);
const [inferenceEndpointError, setInferenceEndpointError] = useState<string | undefined>(
undefined
@ -221,7 +185,15 @@ export const SelectInferenceId = ({
async (inferenceId: string) => {
const modelsExist = options.some((i) => i.label === inferenceId);
if (modelsExist) {
setInferenceEndpointError('Inference Endpoint id already exists');
setInferenceEndpointError(
i18n.translate(
'xpack.idxMgmt.mappingsEditor.parameters.inferenceId.popover.defaultLabel',
{
defaultMessage: 'Inference endpoint {inferenceId} already exists',
values: { inferenceId },
}
)
);
} else {
setInferenceEndpointError(undefined);
}
@ -229,139 +201,133 @@ export const SelectInferenceId = ({
[options]
);
const inferencePopover = () => {
return (
<EuiPopover
button={
<>
<UseField path="main" config={fieldConfigModelId}>
{(field) => (
<>
<EuiText size="xs">
<p>
<strong>{field.label}</strong>
</p>
</EuiText>
<EuiSpacer size="xs" />
<EuiButton
iconType="arrowDown"
iconSide="right"
color="text"
data-test-subj="inferenceIdButton"
onClick={() => {
setIsInferencePopoverVisible(!isInferencePopoverVisible);
}}
>
{selectedOptionLabel ||
i18n.translate(
'xpack.idxMgmt.mappingsEditor.parameters.inferenceId.popover.defaultLabel',
{
defaultMessage: 'No model selected',
}
)}
</EuiButton>
</>
)}
</UseField>
</>
}
isOpen={isInferencePopoverVisible}
panelPaddingSize="m"
closePopover={() => setIsInferencePopoverVisible(!isInferencePopoverVisible)}
>
<EuiContextMenuPanel>
<EuiContextMenuItem
key="addInferenceEndpoint"
icon="plusInCircle"
size="s"
data-test-subj="addInferenceEndpointButton"
const selectedOptionLabel = options.find((option) => option.checked)?.label;
const inferencePopover = () => (
<EuiPopover
button={
<>
<EuiText size="xs">
<p>
<strong>{config.label}</strong>
</p>
</EuiText>
<EuiSpacer size="xs" />
<EuiButton
iconType="arrowDown"
iconSide="right"
color="text"
data-test-subj="inferenceIdButton"
onClick={() => {
setIsInferenceFlyoutVisible(!isInferenceFlyoutVisible);
setInferenceEndpointError(undefined);
setIsInferencePopoverVisible(!isInferencePopoverVisible);
}}
>
{i18n.translate(
'xpack.idxMgmt.mappingsEditor.parameters.inferenceId.popover.addInferenceEndpointButton',
{
defaultMessage: 'Add inference Endpoint',
}
)}
</EuiContextMenuItem>
<EuiHorizontalRule margin="none" />
<EuiContextMenuItem
key="manageInferenceEndpointButton"
icon="gear"
size="s"
data-test-subj="manageInferenceEndpointButton"
onClick={async () => {
const mlTrainedPageUrl = await getMlTrainedModelPageUrl();
if (typeof mlTrainedPageUrl === 'string') {
application.navigateToUrl(mlTrainedPageUrl);
}
}}
>
{i18n.translate(
'xpack.idxMgmt.mappingsEditor.parameters.inferenceId.popover.manageInferenceEndpointButton',
{
defaultMessage: 'Manage Inference Endpoint',
}
)}
</EuiContextMenuItem>
</EuiContextMenuPanel>
<EuiHorizontalRule margin="none" />
<EuiPanel color="transparent" paddingSize="s">
<EuiTitle size="xxxs">
<h3>
{i18n.translate(
'xpack.idxMgmt.mappingsEditor.parameters.inferenceId.popover.selectable.Label',
{selectedOptionLabel ||
i18n.translate(
'xpack.idxMgmt.mappingsEditor.parameters.inferenceId.popover.alreadyExistsLabel',
{
defaultMessage: 'Existing endpoints',
defaultMessage: 'No inference endpoint selected',
}
)}
</h3>
</EuiTitle>
<EuiSpacer size="xs" />
</EuiButton>
</>
}
isOpen={isInferencePopoverVisible}
panelPaddingSize="m"
closePopover={() => setIsInferencePopoverVisible(!isInferencePopoverVisible)}
>
<EuiContextMenuPanel>
<EuiContextMenuItem
key="addInferenceEndpoint"
icon="plusInCircle"
size="s"
data-test-subj="addInferenceEndpointButton"
onClick={() => {
setIsInferenceFlyoutVisible(!isInferenceFlyoutVisible);
setInferenceEndpointError(undefined);
setIsInferencePopoverVisible(!isInferencePopoverVisible);
}}
>
{i18n.translate(
'xpack.idxMgmt.mappingsEditor.parameters.inferenceId.popover.addInferenceEndpointButton',
{
defaultMessage: 'Add Inference Endpoint',
}
)}
</EuiContextMenuItem>
<EuiHorizontalRule margin="none" />
<EuiContextMenuItem
key="manageInferenceEndpointButton"
icon="gear"
size="s"
data-test-subj="manageInferenceEndpointButton"
onClick={async () => {
const mlTrainedPageUrl = await getMlTrainedModelPageUrl();
if (typeof mlTrainedPageUrl === 'string') {
application.navigateToUrl(mlTrainedPageUrl);
}
}}
>
{i18n.translate(
'xpack.idxMgmt.mappingsEditor.parameters.inferenceId.popover.manageInferenceEndpointButton',
{
defaultMessage: 'Manage Inference Endpoints',
}
)}
</EuiContextMenuItem>
</EuiContextMenuPanel>
<EuiHorizontalRule margin="none" />
<EuiPanel color="transparent" paddingSize="s">
<EuiTitle size="xxxs">
<h3>
{i18n.translate(
'xpack.idxMgmt.mappingsEditor.parameters.inferenceId.popover.selectable.Label',
{
defaultMessage: 'Existing endpoints',
}
)}
</h3>
</EuiTitle>
<EuiSpacer size="xs" />
<EuiSelectable
aria-label={i18n.translate(
'xpack.idxMgmt.mappingsEditor.parameters.inferenceId.popover.selectable.ariaLabel',
<EuiSelectable
aria-label={i18n.translate(
'xpack.idxMgmt.mappingsEditor.parameters.inferenceId.popover.selectable.ariaLabel',
{
defaultMessage: 'Existing endpoints',
}
)}
data-test-subj={dataTestSubj}
searchable
isLoading={isLoading}
singleSelection="always"
searchProps={{
compressed: true,
placeholder: i18n.translate(
'xpack.idxMgmt.mappingsEditor.parameters.inferenceId.popover.selectable.placeholder',
{
defaultMessage: 'Search',
}
)}
data-test-subj={dataTestSubj}
searchable
isLoading={isLoading}
singleSelection="always"
searchProps={{
compressed: true,
placeholder: i18n.translate(
'xpack.idxMgmt.mappingsEditor.parameters.inferenceId.popover.selectable.placeholder',
{
defaultMessage: 'Search',
}
),
}}
options={options}
onChange={(newOptions) => {
setOptions(newOptions);
setIsInferencePopoverVisible(!isInferencePopoverVisible);
}}
>
{(list, search) => (
<>
{search}
{list}
</>
)}
</EuiSelectable>
</EuiPanel>
</EuiPopover>
);
};
),
}}
options={options}
onChange={(newOptions) => {
setValue(newOptions.find((option) => option.checked)?.label || '');
}}
>
{(list, search) => (
<>
{search}
{list}
</>
)}
</EuiSelectable>
</EuiPanel>
</EuiPopover>
);
return (
<Form form={form}>
<>
<EuiSpacer />
<EuiFlexGroup data-test-subj="selectInferenceId">
<EuiFlexItem grow={false}>
{inferencePopover()}
@ -378,6 +344,7 @@ export const SelectInferenceId = ({
supportedNlpModels={docLinks.links.enterpriseSearch.supportedNlpModels}
nlpImportModel={docLinks.links.ml.nlpImportModel}
setInferenceEndpointError={setInferenceEndpointError}
isCreateInferenceApiLoading={isSaveInferenceLoading}
/>
)}
</EuiFlexItem>
@ -395,6 +362,6 @@ export const SelectInferenceId = ({
/>
</EuiFlexItem>
</EuiFlexGroup>
</Form>
</>
);
};

View file

@ -14,20 +14,16 @@ import {
EuiSpacer,
} from '@elastic/eui';
import { i18n } from '@kbn/i18n';
import { ElasticsearchModelDefaultOptions } from '@kbn/inference_integration_flyout/types';
import { TrainedModelStat } from '@kbn/ml-plugin/common/types/trained_models';
import { MlPluginStart } from '@kbn/ml-plugin/public';
import classNames from 'classnames';
import React, { useCallback, useEffect, useState } from 'react';
import React, { useEffect } from 'react';
import { EUI_SIZE, TYPE_DEFINITION } from '../../../../constants';
import { fieldSerializer } from '../../../../lib';
import { useDispatch, useMappingsState } from '../../../../mappings_state_context';
import { Form, FormDataProvider, UseField, useForm, useFormData } from '../../../../shared_imports';
import {
CustomInferenceEndpointConfig,
Field,
MainType,
NormalizedFields,
} from '../../../../types';
import { isSemanticTextField } from '../../../../lib/utils';
import { useDispatch } from '../../../../mappings_state_context';
import { Form, FormDataProvider, useForm, useFormData } from '../../../../shared_imports';
import { Field, MainType, NormalizedFields } from '../../../../types';
import { NameParameter, SubTypeParameter, TypeParameter } from '../../field_parameters';
import { ReferenceFieldSelects } from '../../field_parameters/reference_field_selects';
import { SelectInferenceId } from '../../field_parameters/select_inference_id';
@ -38,9 +34,11 @@ import { useSemanticText } from './semantic_text/use_semantic_text';
const formWrapper = (props: any) => <form {...props} />;
export interface InferenceToModelIdMap {
[key: string]: {
trainedModelId: ElasticsearchModelDefaultOptions | string;
trainedModelId: string;
isDeployed: boolean;
isDeployable: boolean;
isDownloading: boolean;
modelStats?: TrainedModelStat; // third-party models don't have model stats
};
}
@ -48,7 +46,9 @@ export interface SemanticTextInfo {
isSemanticTextEnabled?: boolean;
indexName?: string;
ml?: MlPluginStart;
setErrorsInTrainedModelDeployment: React.Dispatch<React.SetStateAction<string[]>>;
setErrorsInTrainedModelDeployment: React.Dispatch<
React.SetStateAction<Record<string, string | undefined>>
>;
}
interface Props {
allFields: NormalizedFields['byId'];
@ -73,13 +73,13 @@ export const CreateField = React.memo(function CreateFieldComponent({
isAddingFields,
semanticTextInfo,
}: Props) {
const { isSemanticTextEnabled, indexName, ml, setErrorsInTrainedModelDeployment } =
semanticTextInfo ?? {};
const { isSemanticTextEnabled, ml, setErrorsInTrainedModelDeployment } = semanticTextInfo ?? {};
const dispatch = useDispatch();
const { form } = useForm<Field>({
serializer: fieldSerializer,
options: { stripEmptyFields: false },
id: 'create-field',
});
useFormData({ form });
@ -93,9 +93,6 @@ export const CreateField = React.memo(function CreateFieldComponent({
return subscription.unsubscribe;
}, [dispatch, subscribe]);
const [customInferenceEndpointConfig, setCustomInferenceEndpointConfig] = useState<
CustomInferenceEndpointConfig | undefined
>(undefined);
const cancel = () => {
if (isAddingFields && onCancelAddingNewFields) {
onCancelAddingNewFields();
@ -104,19 +101,14 @@ export const CreateField = React.memo(function CreateFieldComponent({
}
};
const {
referenceFieldComboValue,
nameValue,
inferenceIdComboValue,
setInferenceValue,
semanticFieldType,
handleSemanticText,
} = useSemanticText({
const { createInferenceEndpoint, handleSemanticText } = useSemanticText({
form,
setErrorsInTrainedModelDeployment,
ml,
});
const isSemanticText = form.getFormData().type === 'semantic_text';
const submitForm = async (
e?: React.FormEvent,
exitAfter: boolean = false,
@ -128,11 +120,9 @@ export const CreateField = React.memo(function CreateFieldComponent({
const { isValid, data } = await form.submit();
if (isValid) {
form.reset();
if (data.type === 'semantic_text' && !clickOutside) {
handleSemanticText(data, customInferenceEndpointConfig);
if (isValid && !clickOutside) {
if (isSemanticTextField(data)) {
handleSemanticText(data);
} else {
dispatch({ type: 'field.add', value: data });
}
@ -140,6 +130,7 @@ export const CreateField = React.memo(function CreateFieldComponent({
if (exitAfter) {
cancel();
}
form.reset();
}
};
@ -187,23 +178,19 @@ export const CreateField = React.memo(function CreateFieldComponent({
</FormDataProvider>
{/* Field reference_field for semantic_text field type */}
<ReferenceFieldCombo indexName={indexName} />
{isSemanticText && (
<EuiFlexItem grow={false}>
<ReferenceFieldSelects />
</EuiFlexItem>
)}
{/* Field name */}
<EuiFlexItem>
<NameParameter />
<NameParameter isSemanticText={isSemanticText} />
</EuiFlexItem>
</EuiFlexGroup>
);
const isAddFieldButtonDisabled = (): boolean => {
if (semanticFieldType) {
return !referenceFieldComboValue || !nameValue || !inferenceIdComboValue;
}
return false;
};
const renderFormActions = () => (
<EuiFlexGroup gutterSize="s" justifyContent="flexEnd">
{(isCancelable !== false || isAddingFields) && (
@ -222,7 +209,7 @@ export const CreateField = React.memo(function CreateFieldComponent({
onClick={submitForm}
type="submit"
data-test-subj="addButton"
isDisabled={isAddFieldButtonDisabled()}
isDisabled={form.getErrors().length > 0}
>
{isMultiField
? i18n.translate('xpack.idxMgmt.mappingsEditor.createField.addMultiFieldButtonLabel', {
@ -289,11 +276,10 @@ export const CreateField = React.memo(function CreateFieldComponent({
);
}}
</FormDataProvider>
{/* Field inference_id for semantic_text field type */}
<InferenceIdCombo
setValue={setInferenceValue}
setCustomInferenceEndpointConfig={setCustomInferenceEndpointConfig}
/>
{isSemanticText && (
<SelectInferenceId createInferenceEndpoint={createInferenceEndpoint} />
)}
{renderFormActions()}
</div>
</div>
@ -302,69 +288,3 @@ export const CreateField = React.memo(function CreateFieldComponent({
</>
);
});
function ReferenceFieldCombo({ indexName }: { indexName?: string }) {
const [{ type }] = useFormData({ watch: 'type' });
if (type === undefined || type[0]?.value !== 'semantic_text') {
return null;
}
return (
<EuiFlexItem grow={false}>
<UseField path="referenceField">
{(field) => <ReferenceFieldSelects onChange={field.setValue} indexName={indexName} />}
</UseField>
</EuiFlexItem>
);
}
interface InferenceProps {
setValue: (value: string) => void;
setCustomInferenceEndpointConfig: (config: CustomInferenceEndpointConfig) => void;
}
function InferenceIdCombo({ setValue, setCustomInferenceEndpointConfig }: InferenceProps) {
const { inferenceToModelIdMap } = useMappingsState();
const dispatch = useDispatch();
const [{ type }] = useFormData({ watch: 'type' });
// update new inferenceEndpoint
const setNewInferenceEndpoint = useCallback(
(
newInferenceEndpoint: InferenceToModelIdMap,
customInferenceEndpointConfig: CustomInferenceEndpointConfig
) => {
dispatch({
type: 'inferenceToModelIdMap.update',
value: {
inferenceToModelIdMap: {
...inferenceToModelIdMap,
...newInferenceEndpoint,
},
},
});
setCustomInferenceEndpointConfig(customInferenceEndpointConfig);
},
[dispatch, inferenceToModelIdMap, setCustomInferenceEndpointConfig]
);
if (type === undefined || type[0]?.value !== 'semantic_text') {
return null;
}
return (
<>
<EuiSpacer />
<UseField path="inferenceId">
{(field) => (
<SelectInferenceId
onChange={field.setValue}
setValue={setValue}
setNewInferenceEndpoint={setNewInferenceEndpoint}
/>
)}
</UseField>
</>
);
}

View file

@ -6,10 +6,37 @@
*/
import { renderHook } from '@testing-library/react-hooks';
import { CustomInferenceEndpointConfig, Field } from '../../../../../types';
import { CustomInferenceEndpointConfig, SemanticTextField } from '../../../../../types';
import { useSemanticText } from './use_semantic_text';
import { act } from 'react-dom/test-utils';
jest.mock('../../../../../../../../hooks/use_details_page_mappings_model_management', () => ({
useDetailsPageMappingsModelManagement: () => ({
fetchInferenceToModelIdMap: () => ({
e5: {
isDeployed: false,
isDeployable: true,
trainedModelId: '.multilingual-e5-small',
},
elser_model_2: {
isDeployed: false,
isDeployable: true,
trainedModelId: '.elser_model_2',
},
openai: {
isDeployed: false,
isDeployable: false,
trainedModelId: '',
},
my_elser_endpoint: {
isDeployed: false,
isDeployable: true,
trainedModelId: '.elser_model_2',
},
}),
}),
}));
const mlMock: any = {
mlApi: {
inferenceModels: {
@ -18,26 +45,30 @@ const mlMock: any = {
},
};
const mockField: Record<string, Field> = {
const mockField: Record<string, SemanticTextField> = {
elser_model_2: {
name: 'name',
type: 'semantic_text',
inferenceId: 'elser_model_2',
inference_id: 'elser_model_2',
reference_field: 'title',
},
e5: {
name: 'name',
type: 'semantic_text',
inferenceId: 'e5',
inference_id: 'e5',
reference_field: 'title',
},
openai: {
name: 'name',
type: 'semantic_text',
inferenceId: 'openai',
inference_id: 'openai',
reference_field: 'title',
},
my_elser_endpoint: {
name: 'name',
type: 'semantic_text',
inferenceId: 'my_elser_endpoint',
inference_id: 'my_elser_endpoint',
reference_field: 'title',
},
};
@ -90,6 +121,10 @@ jest.mock('../../../../../mappings_state_context', () => ({
trainedModelId: '.elser_model_2',
},
},
fields: {
byId: {},
},
mappingViewFields: { byId: {} },
}),
useDispatch: () => mockDispatch,
}));
@ -128,28 +163,31 @@ describe('useSemanticText', () => {
jest.clearAllMocks();
mockForm = {
form: {
getFields: jest.fn().mockReturnValue({
referenceField: { value: 'title' },
name: { value: 'sem' },
type: { value: [{ value: 'semantic_text' }] },
inferenceId: { value: 'e5' },
getFormData: jest.fn().mockReturnValue({
referenceField: 'title',
name: 'sem',
type: 'semantic_text',
inferenceId: 'e5',
}),
setFieldValue: jest.fn(),
},
thirdPartyModel: {
getFields: jest.fn().mockReturnValue({
referenceField: { value: 'title' },
name: { value: 'semantic_text_openai_endpoint' },
type: { value: [{ value: 'semantic_text' }] },
inferenceId: { value: 'openai' },
getFormData: jest.fn().mockReturnValue({
referenceField: 'title',
name: 'semantic_text_openai_endpoint',
type: 'semantic_text',
inferenceId: 'openai',
}),
setFieldValue: jest.fn(),
},
elasticModelEndpointCreatedfromFlyout: {
getFields: jest.fn().mockReturnValue({
referenceField: { value: 'title' },
name: { value: 'semantic_text_elserServiceType_endpoint' },
type: { value: [{ value: 'semantic_text' }] },
inferenceId: { value: 'my_elser_endpoint' },
getFormData: jest.fn().mockReturnValue({
referenceField: 'title',
name: 'semantic_text_elserServiceType_endpoint',
type: 'semantic_text',
inferenceId: 'my_elser_endpoint',
}),
setFieldValue: jest.fn(),
},
};
});
@ -162,11 +200,10 @@ describe('useSemanticText', () => {
})
);
await act(async () => {
result.current.setInferenceValue('openai');
result.current.handleSemanticText(mockField.openai, mockConfig.openai);
});
expect(mockDispatch).toHaveBeenCalledWith({
type: 'field.addSemanticText',
type: 'field.add',
value: mockField.openai,
});
expect(mlMock.mlApi.inferenceModels.createInferenceEndpoint).toHaveBeenCalledWith(
@ -184,12 +221,11 @@ describe('useSemanticText', () => {
})
);
await act(async () => {
result.current.setInferenceValue('my_elser_endpoint');
result.current.handleSemanticText(mockField.my_elser_endpoint, mockConfig.elser);
});
expect(mockDispatch).toHaveBeenCalledWith({
type: 'field.addSemanticText',
type: 'field.add',
value: mockField.my_elser_endpoint,
});
expect(mlMock.mlApi.inferenceModels.createInferenceEndpoint).toHaveBeenCalledWith(
@ -198,20 +234,6 @@ describe('useSemanticText', () => {
mockConfig.elser.modelConfig
);
});
it('should populate the values from the form', () => {
const { result } = renderHook(() =>
useSemanticText({
form: mockForm.form,
setErrorsInTrainedModelDeployment: jest.fn(),
ml: mlMock,
})
);
expect(result.current.referenceFieldComboValue).toBe('title');
expect(result.current.nameValue).toBe('sem');
expect(result.current.inferenceIdComboValue).toBe('e5');
expect(result.current.semanticFieldType).toBe('semantic_text');
});
it('should handle semantic text correctly', async () => {
const { result } = renderHook(() =>
@ -227,7 +249,7 @@ describe('useSemanticText', () => {
});
expect(mockDispatch).toHaveBeenCalledWith({
type: 'field.addSemanticText',
type: 'field.add',
value: mockField.elser_model_2,
});
expect(mlMock.mlApi.inferenceModels.createInferenceEndpoint).toHaveBeenCalledWith(
@ -253,12 +275,11 @@ describe('useSemanticText', () => {
);
await act(async () => {
result.current.setInferenceValue('e5');
result.current.handleSemanticText(mockField.e5);
});
expect(mockDispatch).toHaveBeenCalledWith({
type: 'field.addSemanticText',
type: 'field.add',
value: mockField.e5,
});

View file

@ -5,23 +5,26 @@
* 2.0.
*/
import { i18n } from '@kbn/i18n';
import { useCallback } from 'react';
import { MlPluginStart } from '@kbn/ml-plugin/public';
import React, { useEffect, useState } from 'react';
import { ElasticsearchModelDefaultOptions } from '@kbn/inference_integration_flyout/types';
import React, { useEffect } from 'react';
import { InferenceTaskType } from '@elastic/elasticsearch/lib/api/types';
import { InferenceAPIConfigResponse } from '@kbn/ml-trained-models-utils';
import { ElserModels } from '@kbn/ml-trained-models-utils';
import { i18n } from '@kbn/i18n';
import { useDetailsPageMappingsModelManagement } from '../../../../../../../../hooks/use_details_page_mappings_model_management';
import { useDispatch, useMappingsState } from '../../../../../mappings_state_context';
import { FormHook } from '../../../../../shared_imports';
import { CustomInferenceEndpointConfig, DefaultInferenceModels, Field } from '../../../../../types';
import { CustomInferenceEndpointConfig, Field, SemanticTextField } from '../../../../../types';
import { useMLModelNotificationToasts } from '../../../../../../../../hooks/use_ml_model_status_toasts';
import { getInferenceEndpoints } from '../../../../../../../services/api';
import { getFieldByPathName } from '../../../../../lib/utils';
interface UseSemanticTextProps {
form: FormHook<Field, Field>;
ml?: MlPluginStart;
setErrorsInTrainedModelDeployment: React.Dispatch<React.SetStateAction<string[]>> | undefined;
setErrorsInTrainedModelDeployment?: React.Dispatch<
React.SetStateAction<Record<string, string | undefined>>
>;
}
interface DefaultInferenceEndpointConfig {
taskType: InferenceTaskType;
@ -30,70 +33,52 @@ interface DefaultInferenceEndpointConfig {
export function useSemanticText(props: UseSemanticTextProps) {
const { form, setErrorsInTrainedModelDeployment, ml } = props;
const { inferenceToModelIdMap } = useMappingsState();
const { fields, mappingViewFields } = useMappingsState();
const { fetchInferenceToModelIdMap } = useDetailsPageMappingsModelManagement();
const dispatch = useDispatch();
const [referenceFieldComboValue, setReferenceFieldComboValue] = useState<string>();
const [nameValue, setNameValue] = useState<string>();
const [inferenceIdComboValue, setInferenceIdComboValue] = useState<string>();
const [semanticFieldType, setSemanticTextFieldType] = useState<string>();
const [inferenceValue, setInferenceValue] = useState<string>(
DefaultInferenceModels.elser_model_2
);
const { showSuccessToasts, showErrorToasts } = useMLModelNotificationToasts();
const { showSuccessToasts, showErrorToasts, showSuccessfullyDeployedToast } =
useMLModelNotificationToasts();
const useFieldEffect = (
semanticTextform: FormHook,
fieldName: string,
setState: React.Dispatch<React.SetStateAction<string | undefined>>
) => {
const fieldValue = semanticTextform.getFields()?.[fieldName]?.value;
useEffect(() => {
if (typeof fieldValue === 'string') {
setState(fieldValue);
const fieldTypeValue = form.getFormData()?.type;
useEffect(() => {
if (fieldTypeValue === 'semantic_text') {
const allFields = {
byId: {
...fields.byId,
...mappingViewFields.byId,
},
rootLevelFields: [],
aliases: {},
maxNestedDepth: 0,
};
const defaultName = getFieldByPathName(allFields, 'semantic_text') ? '' : 'semantic_text';
const referenceField =
Object.values(allFields.byId)
.find((field) => field.source.type === 'text')
?.path.join('.') || '';
if (!form.getFormData().name) {
form.setFieldValue('name', defaultName);
}
if (!form.getFormData().reference_field) {
form.setFieldValue('reference_field', referenceField);
}
if (!form.getFormData().inference_id) {
form.setFieldValue('inference_id', 'elser_model_2');
}
}, [semanticTextform, fieldValue, setState]);
};
useFieldEffect(form, 'referenceField', setReferenceFieldComboValue);
useFieldEffect(form, 'name', setNameValue);
const fieldTypeValue = form.getFields()?.type?.value;
useEffect(() => {
if (!Array.isArray(fieldTypeValue) || fieldTypeValue.length === 0) {
return;
}
setSemanticTextFieldType(
fieldTypeValue[0]?.value === 'semantic_text' ? fieldTypeValue[0].value : undefined
);
}, [form, fieldTypeValue]);
const inferenceId = form.getFields()?.inferenceId?.value;
useEffect(() => {
if (typeof inferenceId === 'string') {
setInferenceIdComboValue(inferenceId);
}
}, [form, inferenceId, inferenceToModelIdMap]);
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [fieldTypeValue]);
const createInferenceEndpoint = useCallback(
async (
trainedModelId: ElasticsearchModelDefaultOptions | string,
data: Field,
trainedModelId: string,
inferenceId: string,
customInferenceEndpointConfig?: CustomInferenceEndpointConfig
) => {
if (data.inferenceId === undefined) {
throw new Error(
i18n.translate('xpack.idxMgmt.mappingsEditor.createField.undefinedInferenceIdError', {
defaultMessage: 'Inference ID is undefined',
})
);
}
const isElser = ElserModels.includes(trainedModelId);
const defaultInferenceEndpointConfig: DefaultInferenceEndpointConfig = {
service:
trainedModelId === ElasticsearchModelDefaultOptions.elser ? 'elser' : 'elasticsearch',
taskType:
trainedModelId === ElasticsearchModelDefaultOptions.elser
? 'sparse_embedding'
: 'text_embedding',
service: isElser ? 'elser' : 'elasticsearch',
taskType: isElser ? 'sparse_embedding' : 'text_embedding',
};
const modelConfig = customInferenceEndpointConfig
@ -108,65 +93,69 @@ export function useSemanticText(props: UseSemanticTextProps) {
};
const taskType: InferenceTaskType =
customInferenceEndpointConfig?.taskType ?? defaultInferenceEndpointConfig.taskType;
try {
await ml?.mlApi?.inferenceModels?.createInferenceEndpoint(
data.inferenceId,
taskType,
modelConfig
);
} catch (error) {
throw error;
}
await ml?.mlApi?.inferenceModels?.createInferenceEndpoint(inferenceId, taskType, modelConfig);
},
[ml?.mlApi?.inferenceModels]
);
const handleSemanticText = async (
data: Field,
data: SemanticTextField,
customInferenceEndpointConfig?: CustomInferenceEndpointConfig
) => {
data.inferenceId = inferenceValue;
if (data.inferenceId === undefined) {
return;
}
const inferenceData = inferenceToModelIdMap?.[data.inferenceId];
const modelIdMap = await fetchInferenceToModelIdMap();
const inferenceId = data.inference_id;
const inferenceData = modelIdMap?.[inferenceId];
if (!inferenceData) {
return;
throw new Error(
i18n.translate('xpack.idxMgmt.mappingsEditor.semanticText.inferenceError', {
defaultMessage: 'No inference model found for inference ID {inferenceId}',
})
);
}
const { trainedModelId } = inferenceData;
dispatch({ type: 'field.addSemanticText', value: data });
dispatch({ type: 'field.add', value: data });
const inferenceEndpoints = await getInferenceEndpoints();
const hasInferenceEndpoint = inferenceEndpoints.data?.some(
(inference) => inference.model_id === inferenceId
);
// if inference endpoint exists already, do not create new inference endpoint
if (hasInferenceEndpoint) {
return;
}
try {
// if inference endpoint exists already, do not create inference endpoint
const inferenceModels = await getInferenceEndpoints();
const inferenceModel: InferenceAPIConfigResponse[] = inferenceModels.data.some(
(e: InferenceAPIConfigResponse) => e.model_id === inferenceValue
);
if (inferenceModel) {
return;
}
// Only show toast if it's an internal Elastic model that hasn't been deployed yet
if (trainedModelId && inferenceData.isDeployable && !inferenceData.isDeployed) {
showSuccessToasts(trainedModelId);
}
await createInferenceEndpoint(trainedModelId, data, customInferenceEndpointConfig);
await createInferenceEndpoint(
trainedModelId,
data.inference_id,
customInferenceEndpointConfig
);
if (trainedModelId) {
// clear error because we've succeeded here
setErrorsInTrainedModelDeployment?.((prevItems) => ({
...prevItems,
[trainedModelId]: undefined,
}));
}
showSuccessfullyDeployedToast(trainedModelId);
} catch (error) {
// trainedModelId is empty string when it's a third party model
if (trainedModelId) {
setErrorsInTrainedModelDeployment?.((prevItems) => [...prevItems, trainedModelId]);
setErrorsInTrainedModelDeployment?.((prevItems) => ({
...prevItems,
[trainedModelId]: error,
}));
}
showErrorToasts(error);
}
};
return {
referenceFieldComboValue,
nameValue,
inferenceIdComboValue,
semanticFieldType,
createInferenceEndpoint,
handleSemanticText,
setInferenceValue,
};
}

View file

@ -1041,7 +1041,6 @@ export const PARAMETERS_DEFINITION: { [key in ParameterName]: ParameterDefinitio
},
schema: t.number,
},
dims: {
fieldConfig: {
defaultValue: '',
@ -1070,20 +1069,46 @@ export const PARAMETERS_DEFINITION: { [key in ParameterName]: ParameterDefinitio
},
reference_field: {
fieldConfig: {
defaultValue: '',
label: i18n.translate('xpack.idxMgmt.mappingsEditor.parameters.referenceFieldLabel', {
defaultMessage: 'Reference field',
}),
helpText: i18n.translate('xpack.idxMgmt.mappingsEditor.parameters.referenceFieldHelpText', {
defaultMessage: 'Reference field for model inference.',
}),
validations: [
{
validator: emptyField(
i18n.translate(
'xpack.idxMgmt.mappingsEditor.parameters.validations.referenceFieldIsRequiredErrorMessage',
{
defaultMessage: 'Reference field is required.',
}
)
),
},
],
},
schema: t.string,
},
inference_id: {
fieldConfig: {
defaultValue: 'elser_model_2',
label: i18n.translate('xpack.idxMgmt.mappingsEditor.parameters.inferenceIdLabel', {
defaultMessage: 'Select an inference endpoint:',
}),
validations: [
{
validator: emptyField(
i18n.translate(
'xpack.idxMgmt.mappingsEditor.parameters.validations.inferenceIdIsRequiredErrorMessage',
{
defaultMessage: 'Inference ID is required.',
}
)
),
},
],
},
schema: t.string,
},

View file

@ -18,6 +18,7 @@ import {
getFieldsFromState,
getAllFieldTypesFromState,
getFieldsMatchingFilterFromState,
getStateWithCopyToFields,
} from './utils';
const fieldsWithnestedFields: NormalizedFields = {
@ -420,6 +421,7 @@ describe('utils', () => {
selectedDataTypes: ['Boolean'],
},
inferenceToModelIdMap: {},
mappingViewFields: { byId: {}, rootLevelFields: [], aliases: {}, maxNestedDepth: 0 },
};
test('returns list of matching fields with search term', () => {
expect(getFieldsMatchingFilterFromState(sampleState, ['Boolean'])).toEqual({
@ -442,5 +444,203 @@ describe('utils', () => {
},
});
});
describe('getStateWithCopyToFields', () => {
test('returns state if there is no semantic text field', () => {
const state = {
fields: {
byId: {
'88ebcfdb-19b7-4458-9ea2-9488df54453d': {
id: '88ebcfdb-19b7-4458-9ea2-9488df54453d',
isMultiField: false,
source: {
name: 'title',
type: 'text',
},
},
},
},
} as any;
expect(getStateWithCopyToFields(state)).toEqual(state);
});
test('returns state if semantic text field has no reference field', () => {
const state = {
fields: {
byId: {
'88ebcfdb-19b7-4458-9ea2-9488df54453d': {
id: '88ebcfdb-19b7-4458-9ea2-9488df54453d',
isMultiField: false,
source: {
name: 'title',
type: 'semantic_text',
inference_id: 'id',
},
},
},
},
} as any;
expect(getStateWithCopyToFields(state)).toEqual(state);
});
test('adds text field with copy to to state if semantic text field has reference field', () => {
const state = {
fields: {
byId: {
'88ebcfdb-19b7-4458-9ea2-9488df54453d': {
id: '88ebcfdb-19b7-4458-9ea2-9488df54453d',
isMultiField: false,
path: ['title'],
source: {
name: 'title',
type: 'semantic_text',
inference_id: 'id',
reference_field: 'new',
},
},
'new-field': {
id: 'new-field',
isMultiField: false,
path: ['new'],
source: {
name: 'new',
type: 'text',
},
},
},
rootLevelFields: ['88ebcfdb-19b7-4458-9ea2-9488df54453d'],
},
} as any;
const expectedState = {
fields: {
byId: {
'88ebcfdb-19b7-4458-9ea2-9488df54453d': {
id: '88ebcfdb-19b7-4458-9ea2-9488df54453d',
isMultiField: false,
path: ['title'],
source: {
name: 'title',
type: 'semantic_text',
inference_id: 'id',
},
},
'new-field': {
id: 'new-field',
isMultiField: false,
path: ['new'],
source: {
name: 'new',
type: 'text',
copy_to: ['title'],
},
},
},
rootLevelFields: ['88ebcfdb-19b7-4458-9ea2-9488df54453d', 'new-field'],
},
} as any;
expect(getStateWithCopyToFields(state)).toEqual(expectedState);
});
test('adds nested text field with copy to to state if semantic text field has reference field', () => {
const state = {
fields: {
byId: {
'88ebcfdb-19b7-4458-9ea2-9488df54453d': {
id: '88ebcfdb-19b7-4458-9ea2-9488df54453d',
isMultiField: false,
path: ['title'],
source: {
name: 'title',
type: 'semantic_text',
inference_id: 'id',
reference_field: 'existing.new',
},
},
},
rootLevelFields: ['88ebcfdb-19b7-4458-9ea2-9488df54453d'],
},
mappingViewFields: {
byId: {
existing: {
id: 'existing',
isMultiField: false,
path: ['existing'],
source: {
name: 'existing',
type: 'object',
},
},
'new-field': {
id: 'new-field',
parentId: 'existing',
isMultiField: false,
path: ['existing', 'new'],
source: {
name: 'new',
type: 'text',
},
},
},
},
} as any;
const expectedState = {
fields: {
byId: {
'88ebcfdb-19b7-4458-9ea2-9488df54453d': {
id: '88ebcfdb-19b7-4458-9ea2-9488df54453d',
isMultiField: false,
path: ['title'],
source: {
name: 'title',
type: 'semantic_text',
inference_id: 'id',
},
},
existing: {
id: 'existing',
isMultiField: false,
path: ['existing'],
source: {
name: 'existing',
type: 'object',
},
},
'new-field': {
id: 'new-field',
isMultiField: false,
parentId: 'existing',
path: ['existing', 'new'],
source: {
name: 'new',
type: 'text',
copy_to: ['title'],
},
},
},
rootLevelFields: ['88ebcfdb-19b7-4458-9ea2-9488df54453d', 'existing'],
},
mappingViewFields: {
byId: {
existing: {
id: 'existing',
isMultiField: false,
path: ['existing'],
source: {
name: 'existing',
type: 'object',
},
},
'new-field': {
id: 'new-field',
parentId: 'existing',
isMultiField: false,
path: ['existing', 'new'],
source: {
name: 'new',
type: 'text',
},
},
},
},
} as any;
expect(getStateWithCopyToFields(state)).toEqual(expectedState);
});
});
});
});

View file

@ -7,6 +7,7 @@
import { v4 as uuidv4 } from 'uuid';
import { cloneDeep } from 'lodash';
import {
ChildFieldName,
ComboBoxOption,
@ -23,6 +24,7 @@ import {
ParameterName,
RuntimeFields,
SubType,
SemanticTextField,
} from '../types';
import {
@ -685,3 +687,88 @@ export const getAllFieldTypesFromState = (allFields: Fields): DataType[] => {
const fields: DataType[] = [];
return getallFieldsIncludingNestedFields(allFields, fields).filter(filterUnique);
};
export function isSemanticTextField(field: Partial<Field>): field is SemanticTextField {
return Boolean(field.inference_id && field.type === 'semantic_text');
}
/**
* Returns deep copy of state with `copy_to` added to text fields that are referenced by new semantic text fields
* @param state
* @returns state
*/
export function getStateWithCopyToFields(state: State): State {
// Make sure we don't accidentally modify existing state
let updatedState = cloneDeep(state);
for (const field of Object.values(updatedState.fields.byId)) {
if (field.source.type === 'semantic_text' && field.source.reference_field) {
// Check fields already added to the list of to-update fields first
// API will not accept reference_field so removing it now
const { reference_field: referenceField, ...source } = field.source;
if (typeof referenceField !== 'string') {
// should never happen
throw new Error('Reference field is not a string');
}
field.source = source;
const existingTextField =
getFieldByPathName(updatedState.fields, referenceField) ||
getFieldByPathName(updatedState.mappingViewFields || { byId: {} }, referenceField);
if (existingTextField) {
// Add copy_to to existing text field's copy_to array
const updatedTextField: NormalizedField = {
...existingTextField,
source: {
...existingTextField.source,
copy_to: existingTextField.source.copy_to
? [
...(Array.isArray(existingTextField.source.copy_to)
? existingTextField.source.copy_to
: [existingTextField.source.copy_to]),
field.path.join('.'),
]
: [field.path.join('.')],
},
};
updatedState = {
...updatedState,
fields: {
...updatedState.fields,
byId: {
...updatedState.fields.byId,
[existingTextField.id]: updatedTextField,
},
},
};
if (existingTextField.parentId) {
let currentField = existingTextField;
let hasParent = true;
while (hasParent) {
if (!currentField.parentId) {
// reached the top of the tree, push current field to root level fields
updatedState.fields.rootLevelFields.push(currentField.id);
hasParent = false;
} else if (updatedState.fields.byId[currentField.parentId]) {
// parent is already in state, don't need to do anything
hasParent = false;
} else {
// parent is not in state yet
updatedState.fields.byId[currentField.parentId] =
updatedState.mappingViewFields.byId[currentField.parentId];
currentField = updatedState.fields.byId[currentField.parentId];
}
}
} else {
updatedState.fields.rootLevelFields.push(existingTextField.id);
}
} else {
throw new Error(`Semantic text field ${field.path.join('.')} has invalid reference field`);
}
}
}
return updatedState;
}
export const getFieldByPathName = (fields: NormalizedFields, name: string) => {
return Object.values(fields.byId).find((field) => field.path.join('.') === name);
};

View file

@ -60,6 +60,7 @@ export const StateProvider: React.FC<{ children?: React.ReactNode }> = ({ childr
selectedDataTypes: [],
},
inferenceToModelIdMap: {},
mappingViewFields: { byId: {}, rootLevelFields: [], aliases: {}, maxNestedDepth: 0 },
};
const [state, dispatch] = useReducer(reducer, initialState);

View file

@ -214,6 +214,9 @@ export const reducer = (state: State, action: Action): State => {
},
};
}
case 'editor.replaceViewMappings': {
return { ...state, mappingViewFields: action.value.fields };
}
case 'configuration.update': {
const nextState = {
...state,
@ -323,28 +326,6 @@ export const reducer = (state: State, action: Action): State => {
case 'field.add': {
return addFieldToState(action.value, state);
}
case 'field.addSemanticText': {
const addTexFieldWithCopyToActionValue: Field = {
name: action.value.referenceField as string,
type: 'text',
copy_to: [action.value.name],
};
// Add text field to state with copy_to of semantic_text field
let updatedState = addFieldToState(addTexFieldWithCopyToActionValue, state);
const addSemanticTextFieldActionValue: Field = {
name: action.value.name,
inference_id: action.value.inferenceId,
type: 'semantic_text',
};
// Add semantic_text field to state and reset fieldToAddFieldTo
updatedState = addFieldToState(addSemanticTextFieldActionValue, updatedState);
updatedState.documentFields.fieldToAddFieldTo = undefined;
return updatedState;
}
case 'field.remove': {
const field = state.fields.byId[action.value];
const { id, hasChildFields, hasMultiFields } = field;

View file

@ -186,9 +186,6 @@ interface FieldBasic {
subType?: SubType;
properties?: { [key: string]: Omit<Field, 'name'> };
fields?: { [key: string]: Omit<Field, 'name'> };
referenceField?: string;
inferenceId?: string;
inference_id?: string;
// other* exist together as a holder of types that the mappings editor does not yet know about but
// enables the user to create mappings with them.
@ -201,6 +198,8 @@ type FieldParams = {
export type Field = FieldBasic & Partial<FieldParams>;
export type SemanticTextField = Field & { inference_id: string; reference_field: string };
export interface FieldMeta {
childFieldsName: ChildFieldName | undefined;
canHaveChildFields: boolean;

View file

@ -108,10 +108,12 @@ export interface State {
};
templates: TemplatesFormState;
inferenceToModelIdMap?: InferenceToModelIdMap;
mappingViewFields: NormalizedFields; // state of the incoming index mappings, separate from the editor state above
}
export type Action =
| { type: 'editor.replaceMappings'; value: { [key: string]: any } }
| { type: 'editor.replaceViewMappings'; value: { fields: NormalizedFields } }
| {
type: 'inferenceToModelIdMap.update';
value: { inferenceToModelIdMap?: InferenceToModelIdMap };
@ -122,7 +124,6 @@ export type Action =
| { type: 'templates.save'; value: MappingsTemplates }
| { type: 'fieldForm.update'; value: OnFormUpdateArg<any> }
| { type: 'field.add'; value: Field }
| { type: 'field.addSemanticText'; value: Field }
| { type: 'field.remove'; value: string }
| { type: 'field.edit'; value: Field }
| { type: 'field.toggleExpand'; value: { fieldId: string; isExpanded?: boolean } }

View file

@ -8,6 +8,7 @@
import { useEffect, useMemo } from 'react';
import { EuiSelectableOption } from '@elastic/eui';
import { cloneDeep } from 'lodash';
import {
DocumentFieldsStatus,
Field,
@ -182,6 +183,12 @@ export const useMappingsStateListener = ({ onChange, value, status }: Args) => {
},
},
});
dispatch({
type: 'editor.replaceViewMappings',
value: {
fields: cloneDeep(parsedFieldsDefaultValue),
},
});
}, [
value,
parsedFieldsDefaultValue,

View file

@ -29,6 +29,10 @@ import { FormattedMessage } from '@kbn/i18n-react';
import React, { FunctionComponent, useCallback, useEffect, useMemo, useState } from 'react';
import { ILicense } from '@kbn/licensing-plugin/public';
import { useUnsavedChangesPrompt } from '@kbn/unsaved-changes-prompt';
import {
getStateWithCopyToFields,
isSemanticTextField,
} from '../../../../components/mappings_editor/lib/utils';
import { Index } from '../../../../../../common';
import { useDetailsPageMappingsModelManagement } from '../../../../../hooks/use_details_page_mappings_model_management';
import { useAppContext } from '../../../../app_context';
@ -73,7 +77,6 @@ export const DetailsPageMappingsContent: FunctionComponent<{
http,
},
plugins: { ml, licensing },
url,
config,
overlays,
history,
@ -89,9 +92,9 @@ export const DetailsPageMappingsContent: FunctionComponent<{
}, [licensing]);
const { enableSemanticText: isSemanticTextEnabled } = config;
const [errorsInTrainedModelDeployment, setErrorsInTrainedModelDeployment] = useState<string[]>(
[]
);
const [errorsInTrainedModelDeployment, setErrorsInTrainedModelDeployment] = useState<
Record<string, string | undefined>
>({});
const hasMLPermissions = capabilities?.ml?.canGetTrainedModels ? true : false;
@ -151,11 +154,10 @@ export const DetailsPageMappingsContent: FunctionComponent<{
[jsonData]
);
const [hasSavedFields, setHasSavedFields] = useState<boolean>(false);
useMappingsStateListener({ value: parsedDefaultValue, status: 'disabled' });
const { fetchInferenceToModelIdMap, pendingDeployments } = useDetailsPageMappingsModelManagement(
state.fields,
state.inferenceToModelIdMap
);
const { fetchInferenceToModelIdMap } = useDetailsPageMappingsModelManagement();
const onCancelAddingNewFields = useCallback(() => {
setAddingFields(!isAddingFields);
@ -198,8 +200,6 @@ export const DetailsPageMappingsContent: FunctionComponent<{
});
}, [dispatch, isAddingFields, state]);
const [isModalVisible, setIsModalVisible] = useState(false);
useEffect(() => {
if (!isSemanticTextEnabled || !hasMLPermissions) {
return;
@ -213,7 +213,7 @@ export const DetailsPageMappingsContent: FunctionComponent<{
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
const refreshModal = useCallback(async () => {
const fetchInferenceData = useCallback(async () => {
try {
if (!isSemanticTextEnabled) {
return;
@ -230,35 +230,45 @@ export const DetailsPageMappingsContent: FunctionComponent<{
}, [fetchInferenceToModelIdMap, isSemanticTextEnabled, hasMLPermissions]);
const updateMappings = useCallback(async () => {
const hasSemanticText = hasSemanticTextField(state.fields);
try {
if (isSemanticTextEnabled && hasMLPermissions) {
if (isSemanticTextEnabled && hasMLPermissions && hasSemanticText) {
await fetchInferenceToModelIdMap();
if (pendingDeployments.length > 0) {
setIsModalVisible(true);
return;
}
}
const denormalizedFields = deNormalize(state.fields);
const fields = hasSemanticText ? getStateWithCopyToFields(state).fields : state.fields;
const { error } = await updateIndexMappings(indexName, denormalizedFields);
const denormalizedFields = deNormalize(fields);
if (!error) {
notificationService.showSuccessToast(
i18n.translate('xpack.idxMgmt.indexDetails.mappings.successfullyUpdatedIndexMappings', {
defaultMessage: 'Updated index mapping',
})
const inferenceIdsInPendingList = Object.values(deNormalize(fields))
.filter(isSemanticTextField)
.map((field) => field.inference_id)
.filter(
(inferenceId: string) =>
state.inferenceToModelIdMap?.[inferenceId] &&
!state.inferenceToModelIdMap?.[inferenceId].isDeployed
);
refetchMapping();
} else {
setSaveMappingError(error.message);
setHasSavedFields(true);
if (inferenceIdsInPendingList.length === 0) {
const { error } = await updateIndexMappings(indexName, denormalizedFields);
if (!error) {
notificationService.showSuccessToast(
i18n.translate('xpack.idxMgmt.indexDetails.mappings.successfullyUpdatedIndexMappings', {
defaultMessage: 'Updated index mapping',
})
);
refetchMapping();
setHasSavedFields(false);
} else {
setSaveMappingError(error.message);
}
}
} catch (exception) {
setSaveMappingError(exception.message);
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [state.fields, pendingDeployments]);
}, [state.fields]);
const onSearchChange = useCallback(
(value: string) => {
@ -494,7 +504,7 @@ export const DetailsPageMappingsContent: FunctionComponent<{
>
<FormattedMessage
id="xpack.idxMgmt.indexDetails.mappings.saveMappings"
defaultMessage="Save mappings"
defaultMessage="Save mapping"
/>
</EuiButton>
)}
@ -601,15 +611,17 @@ export const DetailsPageMappingsContent: FunctionComponent<{
</EuiFlexItem>
</EuiFlexGroup>
</EuiFlexGroup>
{isModalVisible && isSemanticTextEnabled && (
{isSemanticTextEnabled && isAddingFields && hasSavedFields && (
<TrainedModelsDeploymentModal
pendingDeployments={pendingDeployments}
fetchData={fetchInferenceData}
errorsInTrainedModelDeployment={errorsInTrainedModelDeployment}
setIsModalVisible={setIsModalVisible}
refreshModal={refreshModal}
url={url}
setErrorsInTrainedModelDeployment={setErrorsInTrainedModelDeployment}
/>
)}
</>
);
};
function hasSemanticTextField(fields: NormalizedFields): boolean {
return Object.values(fields.byId).some((field) => field.source.type === 'semantic_text');
}

View file

@ -8,179 +8,175 @@ import { EuiConfirmModal, useGeneratedHtmlId, EuiHealth } from '@elastic/eui';
import React from 'react';
import { EuiLink } from '@elastic/eui';
import { useEffect, useState } from 'react';
import type { SharePluginStart } from '@kbn/share-plugin/public';
import { useEffect, useMemo, useState } from 'react';
import { i18n } from '@kbn/i18n';
import { isSemanticTextField } from '../../../../components/mappings_editor/lib/utils';
import { deNormalize } from '../../../../components/mappings_editor/lib';
import { useMLModelNotificationToasts } from '../../../../../hooks/use_ml_model_status_toasts';
import { useMappingsState } from '../../../../components/mappings_editor/mappings_state_context';
import { useAppContext } from '../../../../app_context';
interface SemanticTextProps {
setIsModalVisible: (isVisible: boolean) => void;
refreshModal: () => void;
pendingDeployments: Array<string | undefined>;
errorsInTrainedModelDeployment: string[];
url?: SharePluginStart['url'];
export interface TrainedModelsDeploymentModalProps {
fetchData: () => void;
errorsInTrainedModelDeployment: Record<string, string | undefined>;
setErrorsInTrainedModelDeployment: React.Dispatch<
React.SetStateAction<Record<string, string | undefined>>
>;
}
const ML_APP_LOCATOR = 'ML_APP_LOCATOR';
const TRAINED_MODELS_MANAGE = 'trained_models';
export function TrainedModelsDeploymentModal({
setIsModalVisible,
refreshModal,
pendingDeployments = [],
errorsInTrainedModelDeployment = [],
url,
}: SemanticTextProps) {
errorsInTrainedModelDeployment = {},
fetchData,
setErrorsInTrainedModelDeployment,
}: TrainedModelsDeploymentModalProps) {
const { fields, inferenceToModelIdMap } = useMappingsState();
const {
plugins: { ml },
url,
} = useAppContext();
const modalTitleId = useGeneratedHtmlId();
const [isModalVisible, setIsModalVisible] = useState<boolean>(false);
const closeModal = () => setIsModalVisible(false);
const [mlManagementPageUrl, setMlManagementPageUrl] = useState<string>('');
const { showErrorToasts } = useMLModelNotificationToasts();
useEffect(() => {
setIsModalVisible(pendingDeployments.length > 0);
}, [pendingDeployments, setIsModalVisible]);
useEffect(() => {
let isCancelled = false;
const mlLocator = url?.locators.get(ML_APP_LOCATOR);
const generateUrl = async () => {
if (mlLocator) {
const mlURL = await mlLocator.getUrl({
page: TRAINED_MODELS_MANAGE,
});
if (!isCancelled) {
setMlManagementPageUrl(mlURL);
}
setMlManagementPageUrl(mlURL);
}
};
generateUrl();
return () => {
isCancelled = true;
};
}, [url]);
const ErroredDeployments = pendingDeployments.filter(
(deployment) => deployment !== undefined && errorsInTrainedModelDeployment.includes(deployment)
);
const inferenceIdsInPendingList = useMemo(() => {
return Object.values(deNormalize(fields))
.filter(isSemanticTextField)
.map((field) => field.inference_id);
}, [fields]);
const PendingModelsDeploymentModal = () => {
const pendingDeploymentsList = pendingDeployments.map((deployment, index) => (
<li key={index}>
<EuiHealth textSize="xs" color="warning">
{deployment}
</EuiHealth>
</li>
));
const [pendingDeployments, setPendingDeployments] = useState<string[]>([]);
return (
<EuiConfirmModal
aria-labelledby={modalTitleId}
style={{ width: 600 }}
title={i18n.translate(
'xpack.idxMgmt.indexDetails.trainedModelsDeploymentModal.titleLabel',
{
defaultMessage: 'Models still deploying',
}
)}
titleProps={{ id: modalTitleId }}
onCancel={closeModal}
onConfirm={refreshModal}
cancelButtonText={i18n.translate(
'xpack.idxMgmt.indexDetails.trainedModelsDeploymentModal.closeButtonLabel',
{
defaultMessage: 'Close',
}
)}
confirmButtonText={i18n.translate(
'xpack.idxMgmt.indexDetails.trainedModelsDeploymentModal.refreshButtonLabel',
{
defaultMessage: 'Refresh',
}
)}
defaultFocusedButton="confirm"
data-test-subj="trainedModelsDeploymentModal"
>
<p data-test-subj="trainedModelsDeploymentModalText">
{i18n.translate(
'xpack.idxMgmt.indexDetails.trainedModelsDeploymentModal.textAboutDeploymentsNotCompleted',
{
defaultMessage:
'Some fields are referencing models that have not yet completed deployment. Deployment may take a few minutes to complete.',
}
)}
</p>
<ul style={{ listStyleType: 'none' }}>{pendingDeploymentsList}</ul>
<EuiLink href={mlManagementPageUrl} target="_blank">
{i18n.translate(
'xpack.idxMgmt.indexDetails.trainedModelsDeploymentModal.textTrainedModelManagementLink',
{
defaultMessage: 'Go to Trained Model Management',
}
)}
</EuiLink>
</EuiConfirmModal>
);
const startModelAllocation = async (trainedModelId: string) => {
try {
await ml?.mlApi?.trainedModels.startModelAllocation(trainedModelId);
} catch (error) {
setErrorsInTrainedModelDeployment((previousState) => ({
...previousState,
[trainedModelId]: error.message,
}));
showErrorToasts(error);
setIsModalVisible(true);
}
};
const ErroredModelsDeploymentModal = () => {
const pendingDeploymentsList = pendingDeployments.map((deployment, index) => (
<li key={index}>
<EuiHealth textSize="xs" color="danger">
{deployment}
</EuiHealth>
</li>
));
return (
<EuiConfirmModal
aria-labelledby={modalTitleId}
style={{ width: 600 }}
title={i18n.translate(
'xpack.idxMgmt.indexDetails.trainedModelsDeploymentModal.deploymentErrorTitle',
{
defaultMessage: 'Models could not be deployed',
}
)}
titleProps={{ id: modalTitleId }}
onCancel={closeModal}
onConfirm={refreshModal}
cancelButtonText={i18n.translate(
'xpack.idxMgmt.indexDetails.trainedModelsDeploymentModal.deploymentErrorCancelButtonLabel',
{
defaultMessage: 'Cancel',
}
)}
confirmButtonText={i18n.translate(
'xpack.idxMgmt.indexDetails.trainedModelsDeploymentModal.deploymentErrorTryAgainButtonLabel',
{
defaultMessage: 'Try again',
}
)}
defaultFocusedButton="confirm"
data-test-subj="trainedModelsErroredDeploymentModal"
>
<p data-test-subj="trainedModelsErrorDeploymentModalText">
{i18n.translate(
'xpack.idxMgmt.indexDetails.trainedModelsDeploymentModal.deploymentErrorText',
{
defaultMessage: 'There was an error when trying to deploy the following models.',
}
)}
</p>
<ul style={{ listStyleType: 'none' }}>{pendingDeploymentsList}</ul>
<EuiLink href={mlManagementPageUrl} target="_blank">
{i18n.translate(
'xpack.idxMgmt.indexDetails.trainedModelsDeploymentModal.deploymentErrorTrainedModelManagementLink',
{
defaultMessage: 'Go to Trained Model Management',
}
)}
</EuiLink>
</EuiConfirmModal>
useEffect(() => {
const models = inferenceIdsInPendingList.map(
(inferenceId) => inferenceToModelIdMap?.[inferenceId]
);
};
for (const model of models) {
if (model && !model.isDownloading && !model.isDeployed) {
// Sometimes the model gets stuck in a ready to deploy state, so we need to trigger deployment manually
startModelAllocation(model.trainedModelId);
}
}
const pendingModels = models
.map((model) => {
return model?.trainedModelId && !model?.isDeployed ? model?.trainedModelId : '';
})
.filter((trainedModelId) => !!trainedModelId);
const uniqueDeployments = pendingModels.filter(
(deployment, index) => pendingModels.indexOf(deployment) === index
);
setPendingDeployments(uniqueDeployments);
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [inferenceIdsInPendingList, inferenceToModelIdMap]);
return ErroredDeployments.length > 0 ? (
<ErroredModelsDeploymentModal />
) : (
<PendingModelsDeploymentModal />
const erroredDeployments = pendingDeployments.filter(
(deployment) => errorsInTrainedModelDeployment[deployment]
);
useEffect(() => {
if (erroredDeployments.length > 0 || pendingDeployments.length > 0) {
setIsModalVisible(true);
}
}, [erroredDeployments.length, pendingDeployments.length]);
return isModalVisible ? (
<EuiConfirmModal
aria-labelledby={modalTitleId}
style={{ width: 600 }}
title={
erroredDeployments.length > 0
? i18n.translate(
'xpack.idxMgmt.indexDetails.trainedModelsDeploymentModal.deploymentErrorTitle',
{
defaultMessage: 'Models could not be deployed',
}
)
: i18n.translate('xpack.idxMgmt.indexDetails.trainedModelsDeploymentModal.titleLabel', {
defaultMessage: 'Models still deploying',
})
}
titleProps={{ id: modalTitleId }}
onCancel={closeModal}
onConfirm={fetchData}
cancelButtonText={i18n.translate(
'xpack.idxMgmt.indexDetails.trainedModelsDeploymentModal.closeButtonLabel',
{
defaultMessage: 'Close',
}
)}
confirmButtonText={i18n.translate(
'xpack.idxMgmt.indexDetails.trainedModelsDeploymentModal.refreshButtonLabel',
{
defaultMessage: 'Refresh',
}
)}
defaultFocusedButton="confirm"
data-test-subj="trainedModelsDeploymentModal"
>
<p data-test-subj="trainedModelsDeploymentModalText">
{erroredDeployments.length > 0
? i18n.translate(
'xpack.idxMgmt.indexDetails.trainedModelsDeploymentModal.deploymentErrorText',
{
defaultMessage: 'There was an error when trying to deploy the following models.',
}
)
: i18n.translate(
'xpack.idxMgmt.indexDetails.trainedModelsDeploymentModal.textAboutDeploymentsNotCompleted',
{
defaultMessage:
'Some fields are referencing models that have not yet completed deployment. Deployment may take a few minutes to complete.',
}
)}
</p>
<ul style={{ listStyleType: 'none' }}>
{(erroredDeployments.length > 0 ? erroredDeployments : pendingDeployments).map(
(deployment) => (
<li key={deployment}>
<EuiHealth textSize="xs" color="danger">
{deployment}
</EuiHealth>
</li>
)
)}
</ul>
<EuiLink href={mlManagementPageUrl} target="_blank">
{i18n.translate(
'xpack.idxMgmt.indexDetails.trainedModelsDeploymentModal.textTrainedModelManagementLink',
{
defaultMessage: 'Go to Trained Model Management',
}
)}
</EuiLink>
</EuiConfirmModal>
) : null;
}

View file

@ -434,6 +434,7 @@ export function createIndex(indexName: string) {
}),
});
}
export function updateIndexMappings(indexName: string, newFields: Fields) {
return sendRequest({
path: `${API_BASE_PATH}/mapping/${encodeURIComponent(indexName)}`,
@ -443,7 +444,7 @@ export function updateIndexMappings(indexName: string, newFields: Fields) {
}
export function getInferenceEndpoints() {
return sendRequest({
return sendRequest<InferenceAPIConfigResponse[]>({
path: `${API_BASE_PATH}/inference/all`,
method: 'get',
});

View file

@ -6,7 +6,6 @@
*/
import { renderHook } from '@testing-library/react-hooks';
import { InferenceToModelIdMap } from '../application/components/mappings_editor/components/document_fields/fields';
import { NormalizedFields } from '../application/components/mappings_editor/types';
import { useDetailsPageMappingsModelManagement } from './use_details_page_mappings_model_management';
@ -16,13 +15,24 @@ jest.mock('../application/app_context', () => ({
ml: {
mlApi: {
trainedModels: {
getModelsDownloadStatus: jest.fn().mockResolvedValue({
'.elser_model_2_linux-x86_64': {},
}),
getTrainedModelStats: jest.fn().mockResolvedValue({
trained_model_stats: [
{
model_id: '.elser_model_2',
model_id: '.elser_model_2-x86_64',
deployment_stats: {
deployment_id: 'elser_model_2',
model_id: '.elser_model_2',
model_id: '.elser_model_2-x86_64',
state: 'not started',
},
},
{
model_id: '.multilingual-e5-small',
deployment_stats: {
deployment_id: 'e5',
model_id: '.multilingual-e5-small',
state: 'started',
},
},
@ -55,68 +65,73 @@ jest.mock('../application/services/api', () => ({
jest.mock('../application/components/mappings_editor/mappings_state_context', () => ({
useDispatch: () => mockDispatch,
useMappingsState: () => ({
fields: {
byId: {
'88ebcfdb-19b7-4458-9ea2-9488df54453d': {
id: '88ebcfdb-19b7-4458-9ea2-9488df54453d',
isMultiField: false,
source: {
name: 'title',
type: 'text',
copy_to: ['semantic'],
},
path: ['title'],
nestedDepth: 0,
childFieldsName: 'fields',
canHaveChildFields: false,
hasChildFields: false,
canHaveMultiFields: true,
hasMultiFields: false,
isExpanded: false,
},
'c5d86c82-ea07-4457-b469-3ffd4b96db81': {
id: 'c5d86c82-ea07-4457-b469-3ffd4b96db81',
isMultiField: false,
source: {
name: 'semantic',
inference_id: 'elser_model_2',
type: 'semantic_text',
},
path: ['semantic'],
nestedDepth: 0,
childFieldsName: 'fields',
canHaveChildFields: false,
hasChildFields: false,
canHaveMultiFields: true,
hasMultiFields: false,
isExpanded: false,
},
},
aliases: {},
rootLevelFields: [
'88ebcfdb-19b7-4458-9ea2-9488df54453d',
'c5d86c82-ea07-4457-b469-3ffd4b96db81',
],
maxNestedDepth: 2,
} as NormalizedFields,
inferenceToModelIdMap: {
elser_model_2: {
trainedModelId: '.elser_model_2',
isDeployed: false,
isDeployable: true,
isDownloading: false,
},
e5: {
trainedModelId: '.multilingual-e5-small',
isDeployed: true,
isDeployable: true,
isDownloading: false,
},
},
}),
}));
const mockDispatch = jest.fn();
const fields = {
byId: {
'88ebcfdb-19b7-4458-9ea2-9488df54453d': {
id: '88ebcfdb-19b7-4458-9ea2-9488df54453d',
isMultiField: false,
source: {
name: 'title',
type: 'text',
copy_to: ['semantic'],
},
path: ['title'],
nestedDepth: 0,
childFieldsName: 'fields',
canHaveChildFields: false,
hasChildFields: false,
canHaveMultiFields: true,
hasMultiFields: false,
isExpanded: false,
},
'c5d86c82-ea07-4457-b469-3ffd4b96db81': {
id: 'c5d86c82-ea07-4457-b469-3ffd4b96db81',
isMultiField: false,
source: {
name: 'semantic',
inference_id: 'elser_model_2',
type: 'semantic_text',
},
path: ['semantic'],
nestedDepth: 0,
childFieldsName: 'fields',
canHaveChildFields: false,
hasChildFields: false,
canHaveMultiFields: true,
hasMultiFields: false,
isExpanded: false,
},
},
aliases: {},
rootLevelFields: ['88ebcfdb-19b7-4458-9ea2-9488df54453d', 'c5d86c82-ea07-4457-b469-3ffd4b96db81'],
maxNestedDepth: 2,
} as NormalizedFields;
const inferenceToModelIdMap = {
elser_model_2: {
trainedModelId: '.elser_model_2',
isDeployed: true,
isDeployable: true,
},
e5: {
trainedModelId: '.multilingual-e5-small',
isDeployed: true,
isDeployable: true,
},
} as InferenceToModelIdMap;
const mockDispatch = jest.fn();
describe('useDetailsPageMappingsModelManagement', () => {
it('should call the dispatch with correct parameters', async () => {
const { result } = renderHook(() =>
useDetailsPageMappingsModelManagement(fields, inferenceToModelIdMap)
);
const { result } = renderHook(() => useDetailsPageMappingsModelManagement());
await result.current.fetchInferenceToModelIdMap();
@ -125,14 +140,22 @@ describe('useDetailsPageMappingsModelManagement', () => {
value: {
inferenceToModelIdMap: {
e5: {
isDeployed: false,
isDeployable: true,
trainedModelId: '.multilingual-e5-small',
},
elser_model_2: {
isDeployed: true,
isDeployable: true,
trainedModelId: '.elser_model_2',
trainedModelId: '.multilingual-e5-small',
isDownloading: false,
modelStats: {
deployment_id: 'e5',
model_id: '.multilingual-e5-small',
state: 'started',
},
},
elser_model_2: {
isDeployed: false,
isDeployable: true,
trainedModelId: '.elser_model_2_linux-x86_64',
isDownloading: true,
modelStats: undefined,
},
},
},

View file

@ -5,134 +5,130 @@
* 2.0.
*/
import { ElasticsearchModelDefaultOptions, Service } from '@kbn/inference_integration_flyout/types';
import { InferenceStatsResponse } from '@kbn/ml-plugin/public/application/services/ml_api_service/trained_models';
import type { InferenceAPIConfigResponse } from '@kbn/ml-trained-models-utils';
import { useCallback, useMemo } from 'react';
import { useAppContext } from '../application/app_context';
import { InferenceToModelIdMap } from '../application/components/mappings_editor/components/document_fields/fields';
import { deNormalize } from '../application/components/mappings_editor/lib';
import { useDispatch } from '../application/components/mappings_editor/mappings_state_context';
import { Service } from '@kbn/inference_integration_flyout/types';
import { ModelDownloadState, TrainedModelStat } from '@kbn/ml-plugin/common/types/trained_models';
import { InferenceAPIConfigResponse } from '@kbn/ml-trained-models-utils';
import {
DefaultInferenceModels,
DeploymentState,
NormalizedFields,
} from '../application/components/mappings_editor/types';
LATEST_ELSER_VERSION,
InferenceServiceSettings,
LocalInferenceServiceSettings,
LATEST_ELSER_MODEL_ID,
LATEST_E5_MODEL_ID,
ElserVersion,
} from '@kbn/ml-trained-models-utils/src/constants/trained_models';
import { useCallback } from 'react';
import { AppDependencies, useAppContext } from '../application/app_context';
import { InferenceToModelIdMap } from '../application/components/mappings_editor/components/document_fields/fields';
import { useDispatch } from '../application/components/mappings_editor/mappings_state_context';
import { DefaultInferenceModels } from '../application/components/mappings_editor/types';
import { getInferenceEndpoints } from '../application/services/api';
interface InferenceModel {
data: InferenceAPIConfigResponse[];
function isLocalModel(model: InferenceServiceSettings): model is LocalInferenceServiceSettings {
return Boolean((model as LocalInferenceServiceSettings).service_settings.model_id);
}
type DeploymentStatusType = Record<string, DeploymentState>;
const getCustomInferenceIdMap = (
deploymentStatsByModelId: DeploymentStatusType,
models?: InferenceModel
) => {
return models?.data.reduce<InferenceToModelIdMap>((inferenceMap, model) => {
const inferenceId = model.model_id;
const trainedModelId =
'model_id' in model.service_settings &&
(model.service_settings.model_id === ElasticsearchModelDefaultOptions.elser ||
model.service_settings.model_id === ElasticsearchModelDefaultOptions.e5)
? model.service_settings.model_id
: '';
inferenceMap[inferenceId] = {
trainedModelId,
isDeployable: model.service === Service.elser || model.service === Service.elasticsearch,
isDeployed: deploymentStatsByModelId[trainedModelId] === 'deployed',
};
models: InferenceAPIConfigResponse[],
modelStatsById: Record<string, TrainedModelStat['deployment_stats'] | undefined>,
downloadStates: Record<string, ModelDownloadState | undefined>,
elser: string,
e5: string
): InferenceToModelIdMap => {
const inferenceIdMap = models.reduce<InferenceToModelIdMap>((inferenceMap, model) => {
const inferenceEntry = isLocalModel(model)
? {
trainedModelId: model.service_settings.model_id, // third-party models don't have trained model ids
isDeployable: model.service === Service.elser || model.service === Service.elasticsearch,
isDeployed: modelStatsById[model.service_settings.model_id]?.state === 'started',
isDownloading: Boolean(downloadStates[model.service_settings.model_id]),
modelStats: modelStatsById[model.service_settings.model_id],
}
: {
trainedModelId: '',
isDeployable: false,
isDeployed: false,
isDownloading: false,
modelStats: undefined,
};
inferenceMap[model.model_id] = inferenceEntry;
return inferenceMap;
}, {});
};
export const getTrainedModelStats = (modelStats?: InferenceStatsResponse): DeploymentStatusType => {
return (
modelStats?.trained_model_stats.reduce<DeploymentStatusType>((acc, modelStat) => {
if (modelStat.model_id) {
acc[modelStat.model_id] =
modelStat?.deployment_stats?.state === 'started'
? DeploymentState.DEPLOYED
: DeploymentState.NOT_DEPLOYED;
}
return acc;
}, {}) || {}
);
};
const getDefaultInferenceIds = (deploymentStatsByModelId: DeploymentStatusType) => {
return {
const defaultInferenceIds = {
[DefaultInferenceModels.elser_model_2]: {
trainedModelId: ElasticsearchModelDefaultOptions.elser,
trainedModelId: elser,
isDeployable: true,
isDeployed:
deploymentStatsByModelId[ElasticsearchModelDefaultOptions.elser] ===
DeploymentState.DEPLOYED,
isDeployed: modelStatsById[elser]?.state === 'started',
isDownloading: Boolean(downloadStates[elser]),
modelStats: modelStatsById[elser],
},
[DefaultInferenceModels.e5]: {
trainedModelId: ElasticsearchModelDefaultOptions.e5,
trainedModelId: e5,
isDeployable: true,
isDeployed:
deploymentStatsByModelId[ElasticsearchModelDefaultOptions.e5] === DeploymentState.DEPLOYED,
isDeployed: modelStatsById[e5]?.state === 'started',
isDownloading: Boolean(downloadStates[e5]),
modelStats: modelStatsById[e5],
},
};
return { ...defaultInferenceIds, ...inferenceIdMap };
};
export const useDetailsPageMappingsModelManagement = (
fields: NormalizedFields,
inferenceToModelIdMap?: InferenceToModelIdMap
) => {
async function getCuratedModelConfig(
ml: AppDependencies['plugins']['ml'] | undefined,
model: string,
version?: ElserVersion
) {
if (ml?.mlApi) {
try {
const result = await ml.mlApi.trainedModels.getCuratedModelConfig(
model,
version ? { version } : undefined
);
return result.model_id;
} catch (e) {
// pass through and return default models below
}
}
return model === 'elser' ? LATEST_ELSER_MODEL_ID : LATEST_E5_MODEL_ID;
}
export const useDetailsPageMappingsModelManagement = () => {
const {
plugins: { ml },
} = useAppContext();
const dispatch = useDispatch();
const fetchInferenceModelsAndTrainedModelStats = useCallback(async () => {
const fetchInferenceToModelIdMap = useCallback<() => Promise<InferenceToModelIdMap>>(async () => {
const inferenceModels = await getInferenceEndpoints();
const trainedModelStats = await ml?.mlApi?.trainedModels.getTrainedModelStats();
return { inferenceModels, trainedModelStats };
}, [ml]);
const fetchInferenceToModelIdMap = useCallback(async () => {
const { inferenceModels, trainedModelStats } = await fetchInferenceModelsAndTrainedModelStats();
const deploymentStatsByModelId = getTrainedModelStats(trainedModelStats);
const defaultInferenceIds = getDefaultInferenceIds(deploymentStatsByModelId);
const modelIdMap = getCustomInferenceIdMap(deploymentStatsByModelId, inferenceModels);
const downloadStates = await ml?.mlApi?.trainedModels.getModelsDownloadStatus();
const elser = await getCuratedModelConfig(ml, 'elser', LATEST_ELSER_VERSION);
const e5 = await getCuratedModelConfig(ml, 'e5');
const modelStatsById =
trainedModelStats?.trained_model_stats.reduce<
Record<string, TrainedModelStat['deployment_stats'] | undefined>
>((acc, { model_id: modelId, deployment_stats: stats }) => {
if (modelId && stats) {
acc[modelId] = stats;
}
return acc;
}, {}) || {};
const modelIdMap = getCustomInferenceIdMap(
inferenceModels.data || [],
modelStatsById,
downloadStates || {},
elser,
e5
);
dispatch({
type: 'inferenceToModelIdMap.update',
value: { inferenceToModelIdMap: { ...defaultInferenceIds, ...modelIdMap } },
value: { inferenceToModelIdMap: modelIdMap },
});
}, [dispatch, fetchInferenceModelsAndTrainedModelStats]);
const inferenceIdsInPendingList = useMemo(() => {
return Object.values(deNormalize(fields))
.filter((field) => field.type === 'semantic_text' && field.inference_id)
.map((field) => field.inference_id);
}, [fields]);
const pendingDeployments = useMemo(() => {
return inferenceIdsInPendingList
.map((inferenceId) => {
if (inferenceId === undefined) {
return undefined;
}
const trainedModelId = inferenceToModelIdMap?.[inferenceId]?.trainedModelId ?? '';
return trainedModelId && !inferenceToModelIdMap?.[inferenceId]?.isDeployed
? trainedModelId
: undefined;
})
.filter((trainedModelId) => !!trainedModelId);
}, [inferenceIdsInPendingList, inferenceToModelIdMap]);
return modelIdMap;
}, [dispatch, ml]);
return {
pendingDeployments,
fetchInferenceToModelIdMap,
fetchInferenceModelsAndTrainedModelStats,
};
};

View file

@ -27,6 +27,22 @@ export function useMLModelNotificationToasts() {
}),
});
};
const showSuccessfullyDeployedToast = (modelName: string) => {
return toasts.addSuccess({
title: i18n.translate(
'xpack.idxMgmt.mappingsEditor.createField.modelDeploymentStartedNotification',
{
defaultMessage: 'Model deployment started',
}
),
text: i18n.translate('xpack.idxMgmt.mappingsEditor.createField.modelDeployedNotification', {
defaultMessage: 'Model {modelName} has been deployed on your machine learning node.',
values: {
modelName,
},
}),
});
};
const showErrorToasts = (error: ErrorType) => {
const errorObj = extractErrorProperties(error);
return toasts.addError(new MLRequestFailure(errorObj, error), {
@ -35,5 +51,5 @@ export function useMLModelNotificationToasts() {
}),
});
};
return { showSuccessToasts, showErrorToasts };
return { showSuccessToasts, showErrorToasts, showSuccessfullyDeployedToast };
}

View file

@ -18,17 +18,18 @@ export function inferenceModelsApiProvider(httpService: HttpService) {
* @param taskType - Inference Task type. Either sparse_embedding or text_embedding
* @param modelConfig - Model configuration based on service type
*/
createInferenceEndpoint(
async createInferenceEndpoint(
inferenceId: string,
taskType: InferenceTaskType,
modelConfig: ModelConfig
) {
return httpService.http<estypes.InferencePutModelResponse>({
const result = await httpService.http<estypes.InferencePutModelResponse>({
path: `${ML_INTERNAL_BASE_PATH}/_inference/${taskType}/${inferenceId}`,
method: 'PUT',
body: JSON.stringify(modelConfig),
version: '1',
});
return result;
},
};
}

View file

@ -177,6 +177,18 @@ export function trainedModelsApiProvider(httpService: HttpService) {
});
},
/**
* Gets model config based on the cluster OS and CPU architecture.
*/
getCuratedModelConfig(modelName: string, options?: GetModelDownloadConfigOptions) {
return httpService.http<ModelDefinitionResponse>({
path: `${ML_INTERNAL_BASE_PATH}/trained_models/curated_model_config/${modelName}`,
method: 'GET',
...(options ? { query: options as HttpFetchQuery } : {}),
version: '1',
});
},
getTrainedModelsNodesOverview() {
return httpService.http<NodesOverviewResponse>({
path: `${ML_INTERNAL_BASE_PATH}/model_management/nodes_overview`,

View file

@ -12,6 +12,7 @@ import { createInferenceSchema } from './schemas/inference_schema';
import { modelsProvider } from '../models/model_management';
import { wrapError } from '../client/error_wrapper';
import { ML_INTERNAL_BASE_PATH } from '../../common/constants/app';
import { syncSavedObjectsFactory } from '../saved_objects';
export function inferenceModelRoutes(
{ router, routeGuard }: RouteInitialization,
@ -42,20 +43,24 @@ export function inferenceModelRoutes(
},
},
},
routeGuard.fullLicenseAPIGuard(async ({ client, mlClient, request, response }) => {
try {
const { inferenceId, taskType } = request.params;
const body = await modelsProvider(client, mlClient, cloud).createInferenceEndpoint(
inferenceId,
taskType as InferenceTaskType,
request.body as InferenceModelConfig
);
return response.ok({
body,
});
} catch (e) {
return response.customError(wrapError(e));
routeGuard.fullLicenseAPIGuard(
async ({ client, mlClient, request, response, mlSavedObjectService }) => {
try {
const { inferenceId, taskType } = request.params;
const body = await modelsProvider(client, mlClient, cloud).createInferenceEndpoint(
inferenceId,
taskType as InferenceTaskType,
request.body as InferenceModelConfig
);
const { syncSavedObjects } = syncSavedObjectsFactory(client, mlSavedObjectService);
await syncSavedObjects(false);
return response.ok({
body,
});
} catch (e) {
return response.customError(wrapError(e));
}
}
})
)
);
}

View file

@ -96,3 +96,9 @@ export const createIngestPipelineSchema = schema.object({
export const modelDownloadsQuery = schema.object({
version: schema.maybe(schema.oneOf([schema.literal('1'), schema.literal('2')])),
});
export const curatedModelsParamsSchema = schema.object({
modelName: schema.string(),
});
export const curatedModelsQuerySchema = schema.object({ version: schema.maybe(schema.number()) });

View file

@ -10,7 +10,11 @@ import { groupBy } from 'lodash';
import { schema } from '@kbn/config-schema';
import type { ErrorType } from '@kbn/ml-error-utils';
import type { CloudSetup } from '@kbn/cloud-plugin/server';
import type { ElserVersion, InferenceAPIConfigResponse } from '@kbn/ml-trained-models-utils';
import type {
ElasticCuratedModelName,
ElserVersion,
InferenceAPIConfigResponse,
} from '@kbn/ml-trained-models-utils';
import { isDefined } from '@kbn/ml-is-defined';
import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server';
import { type MlFeatures, ML_INTERNAL_BASE_PATH } from '../../common/constants/app';
@ -30,6 +34,8 @@ import {
updateDeploymentParamsSchema,
createIngestPipelineSchema,
modelDownloadsQuery,
curatedModelsParamsSchema,
curatedModelsQuerySchema,
} from './schemas/inference_schema';
import type { PipelineDefinition } from '../../common/types/trained_models';
import { type TrainedModelConfigResponse } from '../../common/types/trained_models';
@ -920,6 +926,49 @@ export function trainedModelsRoutes(
try {
const body = await modelsProvider(client, mlClient, cloud).getModelsDownloadStatus();
return response.ok({
body,
});
} catch (e) {
return response.customError(wrapError(e));
}
}
)
);
/**
* @apiGroup TrainedModels
*
* @api {get} /internal/ml/trained_models/curated_model_config Gets curated model config
* @apiName ModelsCuratedConfigs
* @apiDescription Gets curated model config for the specified model based on cluster architecture
*/
router.versioned
.get({
path: `${ML_INTERNAL_BASE_PATH}/trained_models/curated_model_config/{modelName}`,
access: 'internal',
options: {
tags: ['access:ml:canGetTrainedModels'],
},
})
.addVersion(
{
version: '1',
validate: {
request: {
params: curatedModelsParamsSchema,
query: curatedModelsQuerySchema,
},
},
},
routeGuard.fullLicenseAPIGuard(
async ({ client, mlClient, request, response, mlSavedObjectService }) => {
try {
const body = await modelsProvider(client, mlClient, cloud).getCuratedModelConfig(
request.params.modelName as ElasticCuratedModelName,
{ version: request.query.version as ElserVersion }
);
return response.ok({
body,
});