[ML] Update ELSER version for Elastic Assistant (#167522)

## Summary

Update Elastic Assistant to utilize a new version of ELSER.

### Checklist

- [ ]
[Documentation](https://www.elastic.co/guide/en/kibana/master/development-documentation.html)
was added for features that require explanation or tutorials
- [x] [Unit or functional
tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)
were updated or added to match the most common scenarios
This commit is contained in:
Dima Arnautov 2023-10-03 23:27:27 +02:00 committed by GitHub
parent 8b6ba3d15f
commit e8c0942672
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 114 additions and 68 deletions

View file

@ -9,7 +9,8 @@
"browser": false,
"requiredPlugins": [
"actions",
"data"
"data",
"ml"
]
}
}

View file

@ -23,6 +23,7 @@ export const callAgentExecutor = async ({
llmType,
logger,
request,
elserId,
}: AgentExecutorParams): AgentExecutorResponse => {
const llm = new ActionsClientLlm({ actions, connectorId, request, llmType, logger });
@ -38,7 +39,7 @@ export const callAgentExecutor = async ({
});
// ELSER backed ElasticsearchStore for Knowledge Base
const esStore = new ElasticsearchStore(esClient, KNOWLEDGE_BASE_INDEX_PATTERN, logger);
const esStore = new ElasticsearchStore(esClient, KNOWLEDGE_BASE_INDEX_PATTERN, logger, elserId);
const chain = RetrievalQAChain.fromLLM(llm, esStore.asRetriever());
const tools: Tool[] = [

View file

@ -20,6 +20,7 @@ export interface AgentExecutorParams {
llmType?: string;
logger: Logger;
request: KibanaRequest<unknown, unknown, RequestBody>;
elserId?: string;
}
export type AgentExecutorResponse = Promise<ResponseBody>;

View file

@ -12,7 +12,10 @@ import {
Plugin,
Logger,
IContextProvider,
KibanaRequest,
SavedObjectsClientContract,
} from '@kbn/core/server';
import { once } from 'lodash';
import {
ElasticAssistantPluginSetup,
@ -20,6 +23,7 @@ import {
ElasticAssistantPluginStart,
ElasticAssistantPluginStartDependencies,
ElasticAssistantRequestHandlerContext,
GetElser,
} from './types';
import {
deleteKnowledgeBaseRoute,
@ -73,12 +77,19 @@ export class ElasticAssistantPlugin
)
);
const getElserId: GetElser = once(
async (request: KibanaRequest, savedObjectsClient: SavedObjectsClientContract) => {
return (await plugins.ml.trainedModelsProvider(request, savedObjectsClient).getELSER())
.name;
}
);
// Knowledge Base
deleteKnowledgeBaseRoute(router);
getKnowledgeBaseStatusRoute(router);
postKnowledgeBaseRoute(router);
getKnowledgeBaseStatusRoute(router, getElserId);
postKnowledgeBaseRoute(router, getElserId);
// Actions Connector Execute (LLM Wrapper)
postActionsConnectorExecuteRoute(router);
postActionsConnectorExecuteRoute(router, getElserId);
// Evaluate
postEvaluateRoute(router);
return {

View file

@ -5,8 +5,6 @@
* 2.0.
*/
// Note: using default ELSER model ID so when setup by user in UI, all defaults can be accepted and everything works
export const ELSER_MODEL_ID = '.elser_model_1';
export const MODEL_EVALUATION_RESULTS_INDEX_PATTERN =
'.kibana-elastic-ai-assistant-evaluation-results';
export const KNOWLEDGE_BASE_INDEX_PATTERN = '.kibana-elastic-ai-assistant-kb';

View file

@ -18,11 +18,13 @@ describe('Get Knowledge Base Status Route', () => {
clients.core.elasticsearch.client = elasticsearchServiceMock.createScopedClusterClient();
const mockGetElser = jest.fn().mockResolvedValue('.elser_model_2');
beforeEach(() => {
server = serverMock.create();
({ context } = requestContextMock.createTools());
getKnowledgeBaseStatusRoute(server.router);
getKnowledgeBaseStatusRoute(server.router, mockGetElser);
});
describe('Status codes', () => {

View file

@ -7,20 +7,14 @@
import { IRouter } from '@kbn/core/server';
import { transformError } from '@kbn/securitysolution-es-utils';
import type { GetKnowledgeBaseStatusResponse } from '@kbn/elastic-assistant';
import { buildResponse } from '../../lib/build_response';
import { buildRouteValidation } from '../../schemas/common';
import { ElasticAssistantRequestHandlerContext } from '../../types';
import { ElasticAssistantRequestHandlerContext, GetElser } from '../../types';
import { KNOWLEDGE_BASE } from '../../../common/constants';
import { GetKnowledgeBaseStatusPathParams } from '../../schemas/knowledge_base/get_knowledge_base_status';
import { ElasticsearchStore } from '../../lib/langchain/elasticsearch_store/elasticsearch_store';
import {
ELSER_MODEL_ID,
ESQL_DOCS_LOADED_QUERY,
ESQL_RESOURCE,
KNOWLEDGE_BASE_INDEX_PATTERN,
} from './constants';
import { ESQL_DOCS_LOADED_QUERY, ESQL_RESOURCE, KNOWLEDGE_BASE_INDEX_PATTERN } from './constants';
/**
* Get the status of the Knowledge Base index, pipeline, and resources (collection of documents)
@ -28,7 +22,8 @@ import {
* @param router IRouter for registering routes
*/
export const getKnowledgeBaseStatusRoute = (
router: IRouter<ElasticAssistantRequestHandlerContext>
router: IRouter<ElasticAssistantRequestHandlerContext>,
getElser: GetElser
) => {
router.get(
{
@ -56,7 +51,9 @@ export const getKnowledgeBaseStatusRoute = (
const indexExists = await esStore.indexExists();
const pipelineExists = await esStore.pipelineExists();
const modelExists = await esStore.isModelInstalled(ELSER_MODEL_ID);
const elserId = await getElser(request, (await context.core).savedObjects.getClient());
const modelExists = await esStore.isModelInstalled(elserId);
const body: GetKnowledgeBaseStatusResponse = {
elser_exists: modelExists,

View file

@ -18,11 +18,13 @@ describe('Post Knowledge Base Route', () => {
clients.core.elasticsearch.client = elasticsearchServiceMock.createScopedClusterClient();
const mockGetElser = jest.fn().mockResolvedValue('.elser_model_2');
beforeEach(() => {
server = serverMock.create();
({ context } = requestContextMock.createTools());
postKnowledgeBaseRoute(server.router);
postKnowledgeBaseRoute(server.router, mockGetElser);
});
describe('Status codes', () => {

View file

@ -7,10 +7,9 @@
import { IRouter } from '@kbn/core/server';
import { transformError } from '@kbn/securitysolution-es-utils';
import { buildResponse } from '../../lib/build_response';
import { buildRouteValidation } from '../../schemas/common';
import { ElasticAssistantRequestHandlerContext } from '../../types';
import { ElasticAssistantRequestHandlerContext, GetElser } from '../../types';
import { KNOWLEDGE_BASE } from '../../../common/constants';
import { ElasticsearchStore } from '../../lib/langchain/elasticsearch_store/elasticsearch_store';
import { ESQL_DOCS_LOADED_QUERY, ESQL_RESOURCE, KNOWLEDGE_BASE_INDEX_PATTERN } from './constants';
@ -21,7 +20,10 @@ import { loadESQL } from '../../lib/langchain/content_loaders/esql_loader';
* Load Knowledge Base index, pipeline, and resources (collection of documents)
* @param router
*/
export const postKnowledgeBaseRoute = (router: IRouter<ElasticAssistantRequestHandlerContext>) => {
export const postKnowledgeBaseRoute = (
router: IRouter<ElasticAssistantRequestHandlerContext>,
getElser: GetElser
) => {
router.post(
{
path: KNOWLEDGE_BASE,
@ -44,7 +46,13 @@ export const postKnowledgeBaseRoute = (router: IRouter<ElasticAssistantRequestHa
// Get a scoped esClient for creating the Knowledge Base index, pipeline, and documents
const esClient = (await context.core).elasticsearch.client.asCurrentUser;
const esStore = new ElasticsearchStore(esClient, KNOWLEDGE_BASE_INDEX_PATTERN, logger);
const elserId = await getElser(request, (await context.core).savedObjects.getClient());
const esStore = new ElasticsearchStore(
esClient,
KNOWLEDGE_BASE_INDEX_PATTERN,
logger,
elserId
);
// Pre-check on index/pipeline
let indexExists = await esStore.indexExists();

View file

@ -14,6 +14,7 @@ import { postActionsConnectorExecuteRoute } from './post_actions_connector_execu
import { ElasticAssistantRequestHandlerContext } from '../types';
import { elasticsearchServiceMock } from '@kbn/core-elasticsearch-server-mocks';
import { loggingSystemMock } from '@kbn/core-logging-server-mocks';
import { coreMock } from '@kbn/core/server/mocks';
jest.mock('../lib/build_response', () => ({
buildResponse: jest.fn().mockImplementation((x) => x),
@ -54,6 +55,7 @@ const mockContext = {
elasticsearch: {
client: elasticsearchServiceMock.createScopedClusterClient(),
},
savedObjects: coreMock.createRequestHandlerContext().savedObjects,
},
};
@ -89,6 +91,8 @@ const mockResponse = {
};
describe('postActionsConnectorExecuteRoute', () => {
const mockGetElser = jest.fn().mockResolvedValue('.elser_model_2');
beforeEach(() => {
jest.clearAllMocks();
});
@ -109,7 +113,8 @@ describe('postActionsConnectorExecuteRoute', () => {
};
await postActionsConnectorExecuteRoute(
mockRouter as unknown as IRouter<ElasticAssistantRequestHandlerContext>
mockRouter as unknown as IRouter<ElasticAssistantRequestHandlerContext>,
mockGetElser
);
});
@ -131,7 +136,8 @@ describe('postActionsConnectorExecuteRoute', () => {
};
await postActionsConnectorExecuteRoute(
mockRouter as unknown as IRouter<ElasticAssistantRequestHandlerContext>
mockRouter as unknown as IRouter<ElasticAssistantRequestHandlerContext>,
mockGetElser
);
});
});

View file

@ -7,7 +7,6 @@
import { IRouter, Logger } from '@kbn/core/server';
import { transformError } from '@kbn/securitysolution-es-utils';
import { POST_ACTIONS_CONNECTOR_EXECUTE } from '../../common/constants';
import { getLangChainMessages } from '../lib/langchain/helpers';
import { buildResponse } from '../lib/build_response';
@ -16,11 +15,12 @@ import {
PostActionsConnectorExecuteBody,
PostActionsConnectorExecutePathParams,
} from '../schemas/post_actions_connector_execute';
import { ElasticAssistantRequestHandlerContext } from '../types';
import { ElasticAssistantRequestHandlerContext, GetElser } from '../types';
import { callAgentExecutor } from '../lib/langchain/execute_custom_llm_chain';
export const postActionsConnectorExecuteRoute = (
router: IRouter<ElasticAssistantRequestHandlerContext>
router: IRouter<ElasticAssistantRequestHandlerContext>,
getElser: GetElser
) => {
router.post(
{
@ -48,6 +48,8 @@ export const postActionsConnectorExecuteRoute = (
request.body.params.subActionParams.messages
);
const elserId = await getElser(request, (await context.core).savedObjects.getClient());
const langChainResponseBody = await callAgentExecutor({
actions,
connectorId,
@ -55,6 +57,7 @@ export const postActionsConnectorExecuteRoute = (
langChainMessages,
logger,
request,
elserId,
});
return response.ok({

View file

@ -9,7 +9,13 @@ import type {
PluginSetupContract as ActionsPluginSetup,
PluginStartContract as ActionsPluginStart,
} from '@kbn/actions-plugin/server';
import { CustomRequestHandlerContext, Logger } from '@kbn/core/server';
import type {
CustomRequestHandlerContext,
KibanaRequest,
Logger,
SavedObjectsClientContract,
} from '@kbn/core/server';
import { type MlPluginSetup } from '@kbn/ml-plugin/server';
/** The plugin setup interface */
export interface ElasticAssistantPluginSetup {
@ -23,6 +29,7 @@ export interface ElasticAssistantPluginStart {
export interface ElasticAssistantPluginSetupDependencies {
actions: ActionsPluginSetup;
ml: MlPluginSetup;
}
export interface ElasticAssistantPluginStartDependencies {
actions: ActionsPluginStart;
@ -39,3 +46,8 @@ export interface ElasticAssistantApiRequestHandlerContext {
export type ElasticAssistantRequestHandlerContext = CustomRequestHandlerContext<{
elasticAssistant: ElasticAssistantApiRequestHandlerContext;
}>;
export type GetElser = (
request: KibanaRequest,
savedObjectsClient: SavedObjectsClientContract
) => Promise<string> | never;

View file

@ -31,6 +31,7 @@
"@kbn/logging",
"@kbn/std",
"@kbn/stack-connectors-plugin",
"@kbn/ml-plugin",
],
"exclude": [
"target/**/*",

View file

@ -7,6 +7,8 @@
import { loggingSystemMock } from '@kbn/core/server/mocks';
import { mlPluginServerMock } from '@kbn/ml-plugin/server/mocks';
import { ConfigType } from '..';
export const mockLogger = loggingSystemMock.createLogger().get();
@ -18,9 +20,7 @@ export const mockRequestHandler = {
},
};
export const mockMl = {
trainedModelsProvider: jest.fn(),
};
export const mockMl = mlPluginServerMock.createSetupContract();
export const mockConfig = {
host: 'http://localhost:3002',
@ -38,5 +38,5 @@ export const mockDependencies = {
config: mockConfig,
log: mockLogger,
enterpriseSearchRequestHandler: mockRequestHandler as any,
ml: mockMl as any,
ml: mockMl,
};

View file

@ -7,9 +7,14 @@
import { MockRouter, mockDependencies } from '../../__mocks__';
import { RequestHandlerContext } from '@kbn/core/server';
import type {
KibanaRequest,
RequestHandlerContext,
SavedObjectsClientContract,
} from '@kbn/core/server';
import type { MlPluginSetup, MlTrainedModels } from '@kbn/ml-plugin/server';
import { mlPluginServerMock } from '@kbn/ml-plugin/server/mocks';
import { ErrorCode } from '../../../common/types/error_codes';
@ -181,20 +186,13 @@ describe('Enterprise Search Managed Indices', () => {
path: '/internal/enterprise_search/indices/{indexName}/ml_inference/pipeline_processors',
});
mockTrainedModelsProvider = {
getTrainedModels: jest.fn(),
getTrainedModelsStats: jest.fn(),
startTrainedModelDeployment: jest.fn(),
stopTrainedModelDeployment: jest.fn(),
inferTrainedModel: jest.fn(),
deleteTrainedModel: jest.fn(),
updateTrainedModelDeployment: jest.fn(),
putTrainedModel: jest.fn(),
} as MlTrainedModels;
mockMl = mlPluginServerMock.createSetupContract();
mockTrainedModelsProvider = mockMl.trainedModelsProvider(
{} as KibanaRequest,
{} as SavedObjectsClientContract
);
mockMl = {
trainedModelsProvider: () => Promise.resolve(mockTrainedModelsProvider),
} as unknown as jest.Mocked<MlPluginSetup>;
mlPluginServerMock.createSetupContract();
registerIndexRoutes({
...mockDependencies,
@ -1007,20 +1005,11 @@ describe('Enterprise Search Managed Indices', () => {
path: '/internal/enterprise_search/pipelines/ml_inference',
});
mockTrainedModelsProvider = {
getTrainedModels: jest.fn(),
getTrainedModelsStats: jest.fn(),
startTrainedModelDeployment: jest.fn(),
stopTrainedModelDeployment: jest.fn(),
inferTrainedModel: jest.fn(),
deleteTrainedModel: jest.fn(),
updateTrainedModelDeployment: jest.fn(),
putTrainedModel: jest.fn(),
} as MlTrainedModels;
mockMl = {
trainedModelsProvider: () => Promise.resolve(mockTrainedModelsProvider),
} as unknown as jest.Mocked<MlPluginSetup>;
mockMl = mlPluginServerMock.createSetupContract();
mockTrainedModelsProvider = mockMl.trainedModelsProvider(
{} as KibanaRequest,
{} as SavedObjectsClientContract
);
registerIndexRoutes({
...mockDependencies,

View file

@ -394,7 +394,7 @@ export function registerIndexRoutes({
savedObjects: { client: savedObjectsClient },
} = await context.core;
const trainedModelsProvider = ml
? await ml.trainedModelsProvider(request, savedObjectsClient)
? ml.trainedModelsProvider(request, savedObjectsClient)
: undefined;
const mlInferencePipelineProcessorConfigs = await fetchMlInferencePipelineProcessors(

View file

@ -4,6 +4,7 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { createTrainedModelsProviderMock } from './shared_services/providers/__mocks__/trained_models';
import { createJobServiceProviderMock } from './shared_services/providers/__mocks__/jobs_service';
import { createAnomalyDetectorsProviderMock } from './shared_services/providers/__mocks__/anomaly_detectors';
import { createMockMlSystemProvider } from './shared_services/providers/__mocks__/system';
@ -20,6 +21,7 @@ const createSetupContract = () =>
modulesProvider: createModulesProviderMock(),
resultsServiceProvider: createResultsServiceProviderMock(),
alertingServiceProvider: createAlertingServiceProviderMock(),
trainedModelsProvider: createTrainedModelsProviderMock(),
} as unknown as jest.Mocked<MlPluginSetup>);
const createStartContract = () => jest.fn();

View file

@ -5,8 +5,19 @@
* 2.0.
*/
import { TrainedModels } from '../../../shared';
const trainedModelsServiceMock = {
getTrainedModels: jest.fn().mockResolvedValue([]),
getTrainedModelsStats: jest.fn().mockResolvedValue([]),
startTrainedModelDeployment: jest.fn(),
stopTrainedModelDeployment: jest.fn(),
inferTrainedModel: jest.fn(),
deleteTrainedModel: jest.fn(),
updateTrainedModelDeployment: jest.fn(),
putTrainedModel: jest.fn(),
getELSER: jest.fn().mockResolvedValue({ name: '' }),
} as jest.Mocked<TrainedModels>;
export const createTrainedModelsProviderMock = () =>
jest.fn(() => ({
getTrainedModels: jest.fn(),
getTrainedModelStats: jest.fn(),
}));
jest.fn().mockReturnValue(trainedModelsServiceMock);

View file

@ -8,7 +8,7 @@
import type * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import type { CloudSetup } from '@kbn/cloud-plugin/server';
import type { KibanaRequest, SavedObjectsClientContract } from '@kbn/core/server';
import type { GetElserOptions } from '@kbn/ml-trained-models-utils';
import type { GetElserOptions, ModelDefinitionResponse } from '@kbn/ml-trained-models-utils';
import type {
MlInferTrainedModelRequest,
MlStopTrainedModelDeploymentRequest,
@ -47,6 +47,7 @@ export interface TrainedModelsProvider {
putTrainedModel(
params: estypes.MlPutTrainedModelRequest
): Promise<estypes.MlPutTrainedModelResponse>;
getELSER(params?: GetElserOptions): Promise<ModelDefinitionResponse>;
};
}
@ -122,7 +123,7 @@ export function getTrainedModelsProvider(
return mlClient.putTrainedModel(params);
});
},
async getELSER(params: GetElserOptions) {
async getELSER(params?: GetElserOptions) {
return await guards
.isFullLicense()
.hasMlCapabilities(['canGetTrainedModels'])