[Security AI] Bedrock prompt tuning and inference corrections (#209011)

This commit is contained in:
Steph Milovic 2025-01-31 16:14:34 -07:00 committed by GitHub
parent 48a69daccf
commit 0d415a6d3a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 86 additions and 42 deletions

View file

@ -50,7 +50,7 @@ const AS_PLAIN_TEXT: EuiComboBoxSingleSelectionShape = { asPlainText: true };
*/
export const EvaluationSettings: React.FC = React.memo(() => {
const { actionTypeRegistry, http, setTraceOptions, toasts, traceOptions } = useAssistantContext();
const { data: connectors } = useLoadConnectors({ http });
const { data: connectors } = useLoadConnectors({ http, inferenceEnabled: true });
const { mutate: performEvaluation, isLoading: isPerformingEvaluation } = usePerformEvaluation({
http,
toasts,

View file

@ -6,7 +6,7 @@
*/
export { promptType } from './src/saved_object_mappings';
export { getPrompt, getPromptsByGroupId } from './src/get_prompt';
export { getPrompt, getPromptsByGroupId, resolveProviderAndModel } from './src/get_prompt';
export {
type PromptArray,
type Prompt,

View file

@ -129,15 +129,15 @@ export const getPrompt = async ({
return prompt;
};
const resolveProviderAndModel = async ({
export const resolveProviderAndModel = async ({
providedProvider,
providedModel,
connectorId,
actionsClient,
providedConnector,
}: {
providedProvider: string | undefined;
providedModel: string | undefined;
providedProvider?: string;
providedModel?: string;
connectorId: string;
actionsClient: PublicMethodsOf<ActionsClient>;
providedConnector?: Connector;

View file

@ -127,6 +127,10 @@ export const getDefaultAssistantGraph = ({
value: (x: boolean, y?: boolean) => y ?? x,
default: () => contentReferencesEnabled,
},
provider: {
value: (x: string, y?: string) => y ?? x,
default: () => '',
},
};
// Default node parameters

View file

@ -56,6 +56,7 @@ describe('streamGraph', () => {
input: 'input',
responseLanguage: 'English',
llmType: 'openai',
provider: 'openai',
connectorId: '123',
},
logger: mockLogger,
@ -291,6 +292,7 @@ describe('streamGraph', () => {
inputs: {
...requestArgs.inputs,
llmType: 'gemini',
provider: 'gemini',
},
});
@ -306,6 +308,7 @@ describe('streamGraph', () => {
inputs: {
...requestArgs.inputs,
llmType: 'bedrock',
provider: 'bedrock',
},
});

View file

@ -136,17 +136,21 @@ export const streamGraph = async ({
// Stream is from openai functions agent
let finalMessage = '';
const stream = assistantGraph.streamEvents(inputs, {
callbacks: [
apmTracer,
...(traceOptions?.tracers ?? []),
...(telemetryTracer ? [telemetryTracer] : []),
],
runName: DEFAULT_ASSISTANT_GRAPH_ID,
streamMode: 'values',
tags: traceOptions?.tags ?? [],
version: 'v1',
});
const stream = assistantGraph.streamEvents(
inputs,
{
callbacks: [
apmTracer,
...(traceOptions?.tracers ?? []),
...(telemetryTracer ? [telemetryTracer] : []),
],
runName: DEFAULT_ASSISTANT_GRAPH_ID,
streamMode: 'values',
tags: traceOptions?.tags ?? [],
version: 'v1',
},
inputs?.provider === 'bedrock' ? { includeNames: ['Summarizer'] } : undefined
);
const pushStreamUpdate = async () => {
for await (const { event, data, tags } of stream) {
@ -155,8 +159,6 @@ export const streamGraph = async ({
const chunk = data?.chunk;
const msg = chunk.message;
if (msg?.tool_call_chunks && msg?.tool_call_chunks.length > 0) {
// I don't think we hit this anymore because of our check for AGENT_NODE_TAG
// however, no harm to keep it in
/* empty */
} else if (!didEnd) {
push({ payload: msg.content, type: 'content' });

View file

@ -20,12 +20,15 @@ import {
} from 'langchain/agents';
import { contentReferencesStoreFactoryMock } from '@kbn/elastic-assistant-common/impl/content_references/content_references_store/__mocks__/content_references_store.mock';
import { savedObjectsClientMock } from '@kbn/core-saved-objects-api-server-mocks';
import { resolveProviderAndModel } from '@kbn/security-ai-prompts';
jest.mock('./graph');
jest.mock('./helpers');
jest.mock('langchain/agents');
jest.mock('@kbn/langchain/server/tracers/apm');
jest.mock('@kbn/langchain/server/tracers/telemetry');
jest.mock('@kbn/security-ai-prompts');
const getDefaultAssistantGraphMock = getDefaultAssistantGraph as jest.Mock;
const resolveProviderAndModelMock = resolveProviderAndModel as jest.Mock;
describe('callAssistantGraph', () => {
const mockDataClients = {
anonymizationFieldsDataClient: {
@ -83,6 +86,9 @@ describe('callAssistantGraph', () => {
jest.clearAllMocks();
(mockDataClients?.kbDataClient?.isInferenceEndpointExists as jest.Mock).mockResolvedValue(true);
getDefaultAssistantGraphMock.mockReturnValue({});
resolveProviderAndModelMock.mockResolvedValue({
provider: 'bedrock',
});
(invokeGraph as jest.Mock).mockResolvedValue({
output: 'test-output',
traceData: {},
@ -224,5 +230,23 @@ describe('callAssistantGraph', () => {
expect(createOpenAIToolsAgent).not.toHaveBeenCalled();
expect(createToolCallingAgent).not.toHaveBeenCalled();
});
it('does not calls resolveProviderAndModel when llmType === openai', async () => {
const params = { ...defaultParams, llmType: 'openai' };
await callAssistantGraph(params);
expect(resolveProviderAndModelMock).not.toHaveBeenCalled();
});
it('calls resolveProviderAndModel when llmType === inference', async () => {
const params = { ...defaultParams, llmType: 'inference' };
await callAssistantGraph(params);
expect(resolveProviderAndModelMock).toHaveBeenCalled();
});
it('calls resolveProviderAndModel when llmType === undefined', async () => {
const params = { ...defaultParams, llmType: undefined };
await callAssistantGraph(params);
expect(resolveProviderAndModelMock).toHaveBeenCalled();
});
});
});

View file

@ -15,6 +15,7 @@ import {
import { APMTracer } from '@kbn/langchain/server/tracers/apm';
import { TelemetryTracer } from '@kbn/langchain/server/tracers/telemetry';
import { pruneContentReferences, MessageMetadata } from '@kbn/elastic-assistant-common';
import { resolveProviderAndModel } from '@kbn/security-ai-prompts';
import { promptGroupId } from '../../../prompt/local_prompt_object';
import { getModelOrOss } from '../../../prompt/helpers';
import { getPrompt, promptDictionary } from '../../../prompt';
@ -183,6 +184,13 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
logger
)
: undefined;
const { provider } =
!llmType || llmType === 'inference'
? await resolveProviderAndModel({
connectorId,
actionsClient,
})
: { provider: llmType };
const assistantGraph = getDefaultAssistantGraph({
agentRunnable,
dataClients,
@ -205,6 +213,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
isStream,
isOssModel,
input: latestMessage[0]?.content as string,
provider: provider ?? '',
};
if (isStream) {

View file

@ -21,8 +21,7 @@ interface ModelInputParams extends NodeParamsBase {
*/
export function modelInput({ logger, state }: ModelInputParams): Partial<AgentState> {
logger.debug(() => `${NodeType.MODEL_INPUT}: Node state:\n${JSON.stringify(state, null, 2)}`);
const hasRespondStep = state.isStream && (state.isOssModel || state.llmType === 'bedrock');
const hasRespondStep = state.isStream && (state.isOssModel || state.provider === 'bedrock');
return {
hasRespondStep,

View file

@ -25,6 +25,7 @@ export interface GraphInputs {
isStream?: boolean;
isOssModel?: boolean;
input: string;
provider: string;
responseLanguage?: string;
}
@ -37,6 +38,7 @@ export interface AgentState extends AgentStateBase {
isStream: boolean;
isOssModel: boolean;
llmType: string;
provider: string;
responseLanguage: string;
connectorId: string;
conversation: ConversationResponse | undefined;

View file

@ -20,7 +20,7 @@ const BASE_GEMINI_PROMPT =
const KB_CATCH =
'If the knowledge base tool gives empty results, do your best to answer the question from the perspective of an expert security analyst.';
export const GEMINI_SYSTEM_PROMPT = `${BASE_GEMINI_PROMPT} ${KB_CATCH} {include_citations_prompt_placeholder}`;
export const BEDROCK_SYSTEM_PROMPT = `Use tools as often as possible, as they have access to the latest data and syntax. Always return value from NaturalLanguageESQLTool as is. Never return <thinking> tags in the response, but make sure to include <result> tags content in the response. Do not reflect on the quality of the returned search results in your response.`;
export const BEDROCK_SYSTEM_PROMPT = `Use tools as often as possible, as they have access to the latest data and syntax. Never return <thinking> tags in the response, but make sure to include <result> tags content in the response. Do not reflect on the quality of the returned search results in your response. ALWAYS return the exact response from NaturalLanguageESQLTool verbatim in the final response, without adding further description.`;
export const GEMINI_USER_PROMPT = `Now, always using the tools at your disposal, step by step, come up with a response to this request:\n\n`;
export const STRUCTURED_SYSTEM_PROMPT = `Respond to the human as helpfully and accurately as possible. ${KNOWLEDGE_HISTORY} You have access to the following tools:

View file

@ -26,7 +26,7 @@ import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/
import { getDefaultArguments } from '@kbn/langchain/server';
import { StructuredTool } from '@langchain/core/tools';
import {
createOpenAIFunctionsAgent,
createOpenAIToolsAgent,
createStructuredChatAgent,
createToolCallingAgent,
} from 'langchain/agents';
@ -331,26 +331,27 @@ export const postEvaluateRoute = (
savedObjectsClient,
});
const agentRunnable = isOpenAI
? await createOpenAIFunctionsAgent({
llm,
tools,
prompt: formatPrompt(defaultSystemPrompt),
streamRunnable: false,
})
: llmType && ['bedrock', 'gemini'].includes(llmType)
? createToolCallingAgent({
llm,
tools,
prompt: formatPrompt(defaultSystemPrompt),
streamRunnable: false,
})
: await createStructuredChatAgent({
llm,
tools,
prompt: formatPromptStructured(defaultSystemPrompt),
streamRunnable: false,
});
const agentRunnable =
isOpenAI || llmType === 'inference'
? await createOpenAIToolsAgent({
llm,
tools,
prompt: formatPrompt(defaultSystemPrompt),
streamRunnable: false,
})
: llmType && ['bedrock', 'gemini'].includes(llmType)
? createToolCallingAgent({
llm,
tools,
prompt: formatPrompt(defaultSystemPrompt),
streamRunnable: false,
})
: await createStructuredChatAgent({
llm,
tools,
prompt: formatPromptStructured(defaultSystemPrompt),
streamRunnable: false,
});
return {
connectorId: connector.id,