mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 09:48:58 -04:00
[Security AI] Bedrock prompt tuning and inference corrections (#209011)
This commit is contained in:
parent
48a69daccf
commit
0d415a6d3a
12 changed files with 86 additions and 42 deletions
|
@ -50,7 +50,7 @@ const AS_PLAIN_TEXT: EuiComboBoxSingleSelectionShape = { asPlainText: true };
|
|||
*/
|
||||
export const EvaluationSettings: React.FC = React.memo(() => {
|
||||
const { actionTypeRegistry, http, setTraceOptions, toasts, traceOptions } = useAssistantContext();
|
||||
const { data: connectors } = useLoadConnectors({ http });
|
||||
const { data: connectors } = useLoadConnectors({ http, inferenceEnabled: true });
|
||||
const { mutate: performEvaluation, isLoading: isPerformingEvaluation } = usePerformEvaluation({
|
||||
http,
|
||||
toasts,
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
*/
|
||||
|
||||
export { promptType } from './src/saved_object_mappings';
|
||||
export { getPrompt, getPromptsByGroupId } from './src/get_prompt';
|
||||
export { getPrompt, getPromptsByGroupId, resolveProviderAndModel } from './src/get_prompt';
|
||||
export {
|
||||
type PromptArray,
|
||||
type Prompt,
|
||||
|
|
|
@ -129,15 +129,15 @@ export const getPrompt = async ({
|
|||
return prompt;
|
||||
};
|
||||
|
||||
const resolveProviderAndModel = async ({
|
||||
export const resolveProviderAndModel = async ({
|
||||
providedProvider,
|
||||
providedModel,
|
||||
connectorId,
|
||||
actionsClient,
|
||||
providedConnector,
|
||||
}: {
|
||||
providedProvider: string | undefined;
|
||||
providedModel: string | undefined;
|
||||
providedProvider?: string;
|
||||
providedModel?: string;
|
||||
connectorId: string;
|
||||
actionsClient: PublicMethodsOf<ActionsClient>;
|
||||
providedConnector?: Connector;
|
||||
|
|
|
@ -127,6 +127,10 @@ export const getDefaultAssistantGraph = ({
|
|||
value: (x: boolean, y?: boolean) => y ?? x,
|
||||
default: () => contentReferencesEnabled,
|
||||
},
|
||||
provider: {
|
||||
value: (x: string, y?: string) => y ?? x,
|
||||
default: () => '',
|
||||
},
|
||||
};
|
||||
|
||||
// Default node parameters
|
||||
|
|
|
@ -56,6 +56,7 @@ describe('streamGraph', () => {
|
|||
input: 'input',
|
||||
responseLanguage: 'English',
|
||||
llmType: 'openai',
|
||||
provider: 'openai',
|
||||
connectorId: '123',
|
||||
},
|
||||
logger: mockLogger,
|
||||
|
@ -291,6 +292,7 @@ describe('streamGraph', () => {
|
|||
inputs: {
|
||||
...requestArgs.inputs,
|
||||
llmType: 'gemini',
|
||||
provider: 'gemini',
|
||||
},
|
||||
});
|
||||
|
||||
|
@ -306,6 +308,7 @@ describe('streamGraph', () => {
|
|||
inputs: {
|
||||
...requestArgs.inputs,
|
||||
llmType: 'bedrock',
|
||||
provider: 'bedrock',
|
||||
},
|
||||
});
|
||||
|
||||
|
|
|
@ -136,17 +136,21 @@ export const streamGraph = async ({
|
|||
|
||||
// Stream is from openai functions agent
|
||||
let finalMessage = '';
|
||||
const stream = assistantGraph.streamEvents(inputs, {
|
||||
callbacks: [
|
||||
apmTracer,
|
||||
...(traceOptions?.tracers ?? []),
|
||||
...(telemetryTracer ? [telemetryTracer] : []),
|
||||
],
|
||||
runName: DEFAULT_ASSISTANT_GRAPH_ID,
|
||||
streamMode: 'values',
|
||||
tags: traceOptions?.tags ?? [],
|
||||
version: 'v1',
|
||||
});
|
||||
const stream = assistantGraph.streamEvents(
|
||||
inputs,
|
||||
{
|
||||
callbacks: [
|
||||
apmTracer,
|
||||
...(traceOptions?.tracers ?? []),
|
||||
...(telemetryTracer ? [telemetryTracer] : []),
|
||||
],
|
||||
runName: DEFAULT_ASSISTANT_GRAPH_ID,
|
||||
streamMode: 'values',
|
||||
tags: traceOptions?.tags ?? [],
|
||||
version: 'v1',
|
||||
},
|
||||
inputs?.provider === 'bedrock' ? { includeNames: ['Summarizer'] } : undefined
|
||||
);
|
||||
|
||||
const pushStreamUpdate = async () => {
|
||||
for await (const { event, data, tags } of stream) {
|
||||
|
@ -155,8 +159,6 @@ export const streamGraph = async ({
|
|||
const chunk = data?.chunk;
|
||||
const msg = chunk.message;
|
||||
if (msg?.tool_call_chunks && msg?.tool_call_chunks.length > 0) {
|
||||
// I don't think we hit this anymore because of our check for AGENT_NODE_TAG
|
||||
// however, no harm to keep it in
|
||||
/* empty */
|
||||
} else if (!didEnd) {
|
||||
push({ payload: msg.content, type: 'content' });
|
||||
|
|
|
@ -20,12 +20,15 @@ import {
|
|||
} from 'langchain/agents';
|
||||
import { contentReferencesStoreFactoryMock } from '@kbn/elastic-assistant-common/impl/content_references/content_references_store/__mocks__/content_references_store.mock';
|
||||
import { savedObjectsClientMock } from '@kbn/core-saved-objects-api-server-mocks';
|
||||
import { resolveProviderAndModel } from '@kbn/security-ai-prompts';
|
||||
jest.mock('./graph');
|
||||
jest.mock('./helpers');
|
||||
jest.mock('langchain/agents');
|
||||
jest.mock('@kbn/langchain/server/tracers/apm');
|
||||
jest.mock('@kbn/langchain/server/tracers/telemetry');
|
||||
jest.mock('@kbn/security-ai-prompts');
|
||||
const getDefaultAssistantGraphMock = getDefaultAssistantGraph as jest.Mock;
|
||||
const resolveProviderAndModelMock = resolveProviderAndModel as jest.Mock;
|
||||
describe('callAssistantGraph', () => {
|
||||
const mockDataClients = {
|
||||
anonymizationFieldsDataClient: {
|
||||
|
@ -83,6 +86,9 @@ describe('callAssistantGraph', () => {
|
|||
jest.clearAllMocks();
|
||||
(mockDataClients?.kbDataClient?.isInferenceEndpointExists as jest.Mock).mockResolvedValue(true);
|
||||
getDefaultAssistantGraphMock.mockReturnValue({});
|
||||
resolveProviderAndModelMock.mockResolvedValue({
|
||||
provider: 'bedrock',
|
||||
});
|
||||
(invokeGraph as jest.Mock).mockResolvedValue({
|
||||
output: 'test-output',
|
||||
traceData: {},
|
||||
|
@ -224,5 +230,23 @@ describe('callAssistantGraph', () => {
|
|||
expect(createOpenAIToolsAgent).not.toHaveBeenCalled();
|
||||
expect(createToolCallingAgent).not.toHaveBeenCalled();
|
||||
});
|
||||
it('does not calls resolveProviderAndModel when llmType === openai', async () => {
|
||||
const params = { ...defaultParams, llmType: 'openai' };
|
||||
await callAssistantGraph(params);
|
||||
|
||||
expect(resolveProviderAndModelMock).not.toHaveBeenCalled();
|
||||
});
|
||||
it('calls resolveProviderAndModel when llmType === inference', async () => {
|
||||
const params = { ...defaultParams, llmType: 'inference' };
|
||||
await callAssistantGraph(params);
|
||||
|
||||
expect(resolveProviderAndModelMock).toHaveBeenCalled();
|
||||
});
|
||||
it('calls resolveProviderAndModel when llmType === undefined', async () => {
|
||||
const params = { ...defaultParams, llmType: undefined };
|
||||
await callAssistantGraph(params);
|
||||
|
||||
expect(resolveProviderAndModelMock).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -15,6 +15,7 @@ import {
|
|||
import { APMTracer } from '@kbn/langchain/server/tracers/apm';
|
||||
import { TelemetryTracer } from '@kbn/langchain/server/tracers/telemetry';
|
||||
import { pruneContentReferences, MessageMetadata } from '@kbn/elastic-assistant-common';
|
||||
import { resolveProviderAndModel } from '@kbn/security-ai-prompts';
|
||||
import { promptGroupId } from '../../../prompt/local_prompt_object';
|
||||
import { getModelOrOss } from '../../../prompt/helpers';
|
||||
import { getPrompt, promptDictionary } from '../../../prompt';
|
||||
|
@ -183,6 +184,13 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
|
|||
logger
|
||||
)
|
||||
: undefined;
|
||||
const { provider } =
|
||||
!llmType || llmType === 'inference'
|
||||
? await resolveProviderAndModel({
|
||||
connectorId,
|
||||
actionsClient,
|
||||
})
|
||||
: { provider: llmType };
|
||||
const assistantGraph = getDefaultAssistantGraph({
|
||||
agentRunnable,
|
||||
dataClients,
|
||||
|
@ -205,6 +213,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
|
|||
isStream,
|
||||
isOssModel,
|
||||
input: latestMessage[0]?.content as string,
|
||||
provider: provider ?? '',
|
||||
};
|
||||
|
||||
if (isStream) {
|
||||
|
|
|
@ -21,8 +21,7 @@ interface ModelInputParams extends NodeParamsBase {
|
|||
*/
|
||||
export function modelInput({ logger, state }: ModelInputParams): Partial<AgentState> {
|
||||
logger.debug(() => `${NodeType.MODEL_INPUT}: Node state:\n${JSON.stringify(state, null, 2)}`);
|
||||
|
||||
const hasRespondStep = state.isStream && (state.isOssModel || state.llmType === 'bedrock');
|
||||
const hasRespondStep = state.isStream && (state.isOssModel || state.provider === 'bedrock');
|
||||
|
||||
return {
|
||||
hasRespondStep,
|
||||
|
|
|
@ -25,6 +25,7 @@ export interface GraphInputs {
|
|||
isStream?: boolean;
|
||||
isOssModel?: boolean;
|
||||
input: string;
|
||||
provider: string;
|
||||
responseLanguage?: string;
|
||||
}
|
||||
|
||||
|
@ -37,6 +38,7 @@ export interface AgentState extends AgentStateBase {
|
|||
isStream: boolean;
|
||||
isOssModel: boolean;
|
||||
llmType: string;
|
||||
provider: string;
|
||||
responseLanguage: string;
|
||||
connectorId: string;
|
||||
conversation: ConversationResponse | undefined;
|
||||
|
|
|
@ -20,7 +20,7 @@ const BASE_GEMINI_PROMPT =
|
|||
const KB_CATCH =
|
||||
'If the knowledge base tool gives empty results, do your best to answer the question from the perspective of an expert security analyst.';
|
||||
export const GEMINI_SYSTEM_PROMPT = `${BASE_GEMINI_PROMPT} ${KB_CATCH} {include_citations_prompt_placeholder}`;
|
||||
export const BEDROCK_SYSTEM_PROMPT = `Use tools as often as possible, as they have access to the latest data and syntax. Always return value from NaturalLanguageESQLTool as is. Never return <thinking> tags in the response, but make sure to include <result> tags content in the response. Do not reflect on the quality of the returned search results in your response.`;
|
||||
export const BEDROCK_SYSTEM_PROMPT = `Use tools as often as possible, as they have access to the latest data and syntax. Never return <thinking> tags in the response, but make sure to include <result> tags content in the response. Do not reflect on the quality of the returned search results in your response. ALWAYS return the exact response from NaturalLanguageESQLTool verbatim in the final response, without adding further description.`;
|
||||
export const GEMINI_USER_PROMPT = `Now, always using the tools at your disposal, step by step, come up with a response to this request:\n\n`;
|
||||
|
||||
export const STRUCTURED_SYSTEM_PROMPT = `Respond to the human as helpfully and accurately as possible. ${KNOWLEDGE_HISTORY} You have access to the following tools:
|
||||
|
|
|
@ -26,7 +26,7 @@ import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/
|
|||
import { getDefaultArguments } from '@kbn/langchain/server';
|
||||
import { StructuredTool } from '@langchain/core/tools';
|
||||
import {
|
||||
createOpenAIFunctionsAgent,
|
||||
createOpenAIToolsAgent,
|
||||
createStructuredChatAgent,
|
||||
createToolCallingAgent,
|
||||
} from 'langchain/agents';
|
||||
|
@ -331,26 +331,27 @@ export const postEvaluateRoute = (
|
|||
savedObjectsClient,
|
||||
});
|
||||
|
||||
const agentRunnable = isOpenAI
|
||||
? await createOpenAIFunctionsAgent({
|
||||
llm,
|
||||
tools,
|
||||
prompt: formatPrompt(defaultSystemPrompt),
|
||||
streamRunnable: false,
|
||||
})
|
||||
: llmType && ['bedrock', 'gemini'].includes(llmType)
|
||||
? createToolCallingAgent({
|
||||
llm,
|
||||
tools,
|
||||
prompt: formatPrompt(defaultSystemPrompt),
|
||||
streamRunnable: false,
|
||||
})
|
||||
: await createStructuredChatAgent({
|
||||
llm,
|
||||
tools,
|
||||
prompt: formatPromptStructured(defaultSystemPrompt),
|
||||
streamRunnable: false,
|
||||
});
|
||||
const agentRunnable =
|
||||
isOpenAI || llmType === 'inference'
|
||||
? await createOpenAIToolsAgent({
|
||||
llm,
|
||||
tools,
|
||||
prompt: formatPrompt(defaultSystemPrompt),
|
||||
streamRunnable: false,
|
||||
})
|
||||
: llmType && ['bedrock', 'gemini'].includes(llmType)
|
||||
? createToolCallingAgent({
|
||||
llm,
|
||||
tools,
|
||||
prompt: formatPrompt(defaultSystemPrompt),
|
||||
streamRunnable: false,
|
||||
})
|
||||
: await createStructuredChatAgent({
|
||||
llm,
|
||||
tools,
|
||||
prompt: formatPromptStructured(defaultSystemPrompt),
|
||||
streamRunnable: false,
|
||||
});
|
||||
|
||||
return {
|
||||
connectorId: connector.id,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue