mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 09:19:04 -04:00
[Obs AI Assistant] Use architecture-specific elser model (#205851)
Closes https://github.com/elastic/kibana/issues/205852 When installing the Obs knowledge base it will always install the model `.elser_model_2`. For Linux with an x86-64 CPU an optimised version of Elser exists (`elser_model_2_linux-x86_64`). We should use that when possible. After this change the inference endpoint will use `.elser_model_2_linux-x86_64` on supported hardware: 
This commit is contained in:
parent
8eb326d596
commit
ad3b9880c7
5 changed files with 63 additions and 44 deletions
|
@ -80,6 +80,7 @@ import {
|
|||
} from '../task_manager_definitions/register_migrate_knowledge_base_entries_task';
|
||||
import { ObservabilityAIAssistantPluginStartDependencies } from '../../types';
|
||||
import { ObservabilityAIAssistantConfig } from '../../config';
|
||||
import { getElserModelId } from '../knowledge_base_service/get_elser_model_id';
|
||||
|
||||
const MAX_FUNCTION_CALLS = 8;
|
||||
|
||||
|
@ -660,6 +661,10 @@ export class ObservabilityAIAssistantClient {
|
|||
setupKnowledgeBase = async (modelId: string | undefined) => {
|
||||
const { esClient, core, logger, knowledgeBaseService } = this.dependencies;
|
||||
|
||||
if (!modelId) {
|
||||
modelId = await getElserModelId({ core, logger });
|
||||
}
|
||||
|
||||
// setup the knowledge base
|
||||
const res = await knowledgeBaseService.setup(esClient, modelId);
|
||||
|
||||
|
|
|
@ -16,13 +16,13 @@ export const AI_ASSISTANT_KB_INFERENCE_ID = 'obs_ai_assistant_kb_inference';
|
|||
export async function createInferenceEndpoint({
|
||||
esClient,
|
||||
logger,
|
||||
modelId = '.elser_model_2',
|
||||
modelId,
|
||||
}: {
|
||||
esClient: {
|
||||
asCurrentUser: ElasticsearchClient;
|
||||
};
|
||||
logger: Logger;
|
||||
modelId: string | undefined;
|
||||
modelId: string;
|
||||
}) {
|
||||
try {
|
||||
logger.debug(`Creating inference endpoint "${AI_ASSISTANT_KB_INFERENCE_ID}"`);
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { Logger } from '@kbn/logging';
|
||||
import { CoreSetup } from '@kbn/core-lifecycle-server';
|
||||
import { firstValueFrom } from 'rxjs';
|
||||
import { ObservabilityAIAssistantPluginStartDependencies } from '../../types';
|
||||
|
||||
export async function getElserModelId({
|
||||
core,
|
||||
logger,
|
||||
}: {
|
||||
core: CoreSetup<ObservabilityAIAssistantPluginStartDependencies>;
|
||||
logger: Logger;
|
||||
}) {
|
||||
const defaultModelId = '.elser_model_2';
|
||||
const [_, pluginsStart] = await core.getStartServices();
|
||||
|
||||
// Wait for the license to be available so the ML plugin's guards pass once we ask for ELSER stats
|
||||
const license = await firstValueFrom(pluginsStart.licensing.license$);
|
||||
if (!license.hasAtLeast('enterprise')) {
|
||||
return defaultModelId;
|
||||
}
|
||||
|
||||
try {
|
||||
// Wait for the ML plugin's dependency on the internal saved objects client to be ready
|
||||
const { ml } = await core.plugins.onSetup<{
|
||||
ml: {
|
||||
trainedModelsProvider: (
|
||||
request: {},
|
||||
soClient: {}
|
||||
) => { getELSER: () => Promise<{ model_id: string }> };
|
||||
};
|
||||
}>('ml');
|
||||
|
||||
if (!ml.found) {
|
||||
throw new Error('Could not find ML plugin');
|
||||
}
|
||||
|
||||
const elserModelDefinition = await ml.contract
|
||||
.trainedModelsProvider({} as any, {} as any) // request, savedObjectsClient (but we fake it to use the internal user)
|
||||
.getELSER();
|
||||
|
||||
return elserModelDefinition.model_id;
|
||||
} catch (error) {
|
||||
logger.error(`Failed to resolve ELSER model definition: ${error}`);
|
||||
return defaultModelId;
|
||||
}
|
||||
}
|
|
@ -58,7 +58,7 @@ export class KnowledgeBaseService {
|
|||
asCurrentUser: ElasticsearchClient;
|
||||
asInternalUser: ElasticsearchClient;
|
||||
},
|
||||
modelId: string | undefined
|
||||
modelId: string
|
||||
) {
|
||||
await deleteInferenceEndpoint({ esClient }).catch((e) => {}); // ensure existing inference endpoint is deleted
|
||||
return createInferenceEndpoint({ esClient, logger: this.dependencies.logger, modelId });
|
||||
|
|
|
@ -10,10 +10,10 @@ import { IUiSettingsClient } from '@kbn/core-ui-settings-server';
|
|||
import { isEmpty, orderBy, compact } from 'lodash';
|
||||
import type { Logger } from '@kbn/logging';
|
||||
import { CoreSetup } from '@kbn/core-lifecycle-server';
|
||||
import { firstValueFrom } from 'rxjs';
|
||||
import { RecalledEntry } from '.';
|
||||
import { aiAssistantSearchConnectorIndexPattern } from '../../../common';
|
||||
import { ObservabilityAIAssistantPluginStartDependencies } from '../../types';
|
||||
import { getElserModelId } from './get_elser_model_id';
|
||||
|
||||
export async function recallFromSearchConnectors({
|
||||
queries,
|
||||
|
@ -128,7 +128,7 @@ async function recallFromLegacyConnectors({
|
|||
}): Promise<RecalledEntry[]> {
|
||||
const ML_INFERENCE_PREFIX = 'ml.inference.';
|
||||
|
||||
const modelIdPromise = getElserModelId(core, logger); // pre-fetch modelId in parallel with fieldCaps
|
||||
const modelIdPromise = getElserModelId({ core, logger }); // pre-fetch modelId in parallel with fieldCaps
|
||||
const fieldCaps = await esClient.asCurrentUser.fieldCaps({
|
||||
index: connectorIndices,
|
||||
fields: `${ML_INFERENCE_PREFIX}*`,
|
||||
|
@ -230,42 +230,3 @@ async function getConnectorIndices(
|
|||
|
||||
return connectorIndices;
|
||||
}
|
||||
|
||||
async function getElserModelId(
|
||||
core: CoreSetup<ObservabilityAIAssistantPluginStartDependencies>,
|
||||
logger: Logger
|
||||
) {
|
||||
const defaultModelId = '.elser_model_2';
|
||||
const [_, pluginsStart] = await core.getStartServices();
|
||||
|
||||
// Wait for the license to be available so the ML plugin's guards pass once we ask for ELSER stats
|
||||
const license = await firstValueFrom(pluginsStart.licensing.license$);
|
||||
if (!license.hasAtLeast('enterprise')) {
|
||||
return defaultModelId;
|
||||
}
|
||||
|
||||
try {
|
||||
// Wait for the ML plugin's dependency on the internal saved objects client to be ready
|
||||
const { ml } = await core.plugins.onSetup('ml');
|
||||
|
||||
if (!ml.found) {
|
||||
throw new Error('Could not find ML plugin');
|
||||
}
|
||||
|
||||
const elserModelDefinition = await (
|
||||
ml.contract as {
|
||||
trainedModelsProvider: (
|
||||
request: {},
|
||||
soClient: {}
|
||||
) => { getELSER: () => Promise<{ model_id: string }> };
|
||||
}
|
||||
)
|
||||
.trainedModelsProvider({} as any, {} as any) // request, savedObjectsClient (but we fake it to use the internal user)
|
||||
.getELSER();
|
||||
|
||||
return elserModelDefinition.model_id;
|
||||
} catch (error) {
|
||||
logger.error(`Failed to resolve ELSER model definition: ${error}`);
|
||||
return defaultModelId;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue