mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 01:13:23 -04:00
[Security solution] Add feature flag for AI streaming (#172505)
This commit is contained in:
parent
b876253e7b
commit
39caf945fa
7 changed files with 88 additions and 96 deletions
|
@ -34,7 +34,13 @@ const apiConfig: Conversation['apiConfig'] = {
|
|||
const messages: Message[] = [
|
||||
{ content: 'This is a test', role: 'user', timestamp: new Date().toLocaleString() },
|
||||
];
|
||||
|
||||
const fetchConnectorArgs: FetchConnectorExecuteAction = {
|
||||
assistantLangChain: true,
|
||||
http: mockHttp,
|
||||
messages,
|
||||
apiConfig,
|
||||
assistantStreamingEnabled: true,
|
||||
};
|
||||
describe('API tests', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
@ -42,14 +48,7 @@ describe('API tests', () => {
|
|||
|
||||
describe('fetchConnectorExecuteAction', () => {
|
||||
it('calls the internal assistant API when assistantLangChain is true', async () => {
|
||||
const testProps: FetchConnectorExecuteAction = {
|
||||
assistantLangChain: true,
|
||||
http: mockHttp,
|
||||
messages,
|
||||
apiConfig,
|
||||
};
|
||||
|
||||
await fetchConnectorExecuteAction(testProps);
|
||||
await fetchConnectorExecuteAction(fetchConnectorArgs);
|
||||
|
||||
expect(mockHttp.fetch).toHaveBeenCalledWith(
|
||||
'/internal/elastic_assistant/actions/connector/foo/_execute',
|
||||
|
@ -62,12 +61,10 @@ describe('API tests', () => {
|
|||
);
|
||||
});
|
||||
|
||||
it('calls the actions connector api when assistantLangChain is false', async () => {
|
||||
it('calls the actions connector api with streaming when assistantStreamingEnabled is true when assistantLangChain is false', async () => {
|
||||
const testProps: FetchConnectorExecuteAction = {
|
||||
...fetchConnectorArgs,
|
||||
assistantLangChain: false,
|
||||
http: mockHttp,
|
||||
messages,
|
||||
apiConfig,
|
||||
};
|
||||
|
||||
await fetchConnectorExecuteAction(testProps);
|
||||
|
@ -84,29 +81,62 @@ describe('API tests', () => {
|
|||
);
|
||||
});
|
||||
|
||||
it('calls the actions connector api with invoke when assistantStreamingEnabled is false when assistantLangChain is false', async () => {
|
||||
const testProps: FetchConnectorExecuteAction = {
|
||||
...fetchConnectorArgs,
|
||||
assistantLangChain: false,
|
||||
assistantStreamingEnabled: false,
|
||||
};
|
||||
|
||||
await fetchConnectorExecuteAction(testProps);
|
||||
|
||||
expect(mockHttp.fetch).toHaveBeenCalledWith(
|
||||
'/internal/elastic_assistant/actions/connector/foo/_execute',
|
||||
{
|
||||
body: '{"params":{"subActionParams":{"model":"gpt-4","messages":[{"role":"user","content":"This is a test"}],"n":1,"stop":null,"temperature":0.2},"subAction":"invokeAI"},"assistantLangChain":false}',
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
signal: undefined,
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
it('returns API_ERROR when the response status is error and langchain is on', async () => {
|
||||
(mockHttp.fetch as jest.Mock).mockResolvedValue({ status: 'error' });
|
||||
|
||||
const result = await fetchConnectorExecuteAction(fetchConnectorArgs);
|
||||
|
||||
expect(result).toEqual({ response: API_ERROR, isStream: false, isError: true });
|
||||
});
|
||||
|
||||
it('returns API_ERROR + error message on non streaming responses', async () => {
|
||||
(mockHttp.fetch as jest.Mock).mockResolvedValue({
|
||||
status: 'error',
|
||||
service_message: 'an error message',
|
||||
});
|
||||
const testProps: FetchConnectorExecuteAction = {
|
||||
assistantLangChain: true,
|
||||
http: mockHttp,
|
||||
messages,
|
||||
apiConfig,
|
||||
...fetchConnectorArgs,
|
||||
assistantLangChain: false,
|
||||
assistantStreamingEnabled: false,
|
||||
};
|
||||
|
||||
const result = await fetchConnectorExecuteAction(testProps);
|
||||
|
||||
expect(result).toEqual({ response: API_ERROR, isStream: false, isError: true });
|
||||
expect(result).toEqual({
|
||||
response: `${API_ERROR}\n\nan error message`,
|
||||
isStream: false,
|
||||
isError: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('returns API_ERROR when the response status is error, langchain is off, and response is not a reader', async () => {
|
||||
(mockHttp.fetch as jest.Mock).mockResolvedValue({ status: 'error' });
|
||||
|
||||
const testProps: FetchConnectorExecuteAction = {
|
||||
...fetchConnectorArgs,
|
||||
assistantLangChain: false,
|
||||
http: mockHttp,
|
||||
messages,
|
||||
apiConfig,
|
||||
};
|
||||
|
||||
const result = await fetchConnectorExecuteAction(testProps);
|
||||
|
@ -124,10 +154,8 @@ describe('API tests', () => {
|
|||
response: { body: { getReader: jest.fn().mockImplementation(() => mockReader) } },
|
||||
});
|
||||
const testProps: FetchConnectorExecuteAction = {
|
||||
...fetchConnectorArgs,
|
||||
assistantLangChain: false,
|
||||
http: mockHttp,
|
||||
messages,
|
||||
apiConfig,
|
||||
};
|
||||
|
||||
const result = await fetchConnectorExecuteAction(testProps);
|
||||
|
@ -141,14 +169,8 @@ describe('API tests', () => {
|
|||
|
||||
it('returns API_ERROR when there are no choices', async () => {
|
||||
(mockHttp.fetch as jest.Mock).mockResolvedValue({ status: 'ok', data: '' });
|
||||
const testProps: FetchConnectorExecuteAction = {
|
||||
assistantLangChain: true,
|
||||
http: mockHttp,
|
||||
messages,
|
||||
apiConfig,
|
||||
};
|
||||
|
||||
const result = await fetchConnectorExecuteAction(testProps);
|
||||
const result = await fetchConnectorExecuteAction(fetchConnectorArgs);
|
||||
|
||||
expect(result).toEqual({ response: API_ERROR, isStream: false, isError: true });
|
||||
});
|
||||
|
@ -161,14 +183,7 @@ describe('API tests', () => {
|
|||
data: response,
|
||||
});
|
||||
|
||||
const testProps: FetchConnectorExecuteAction = {
|
||||
assistantLangChain: true, // <-- requires response parsing
|
||||
http: mockHttp,
|
||||
messages,
|
||||
apiConfig,
|
||||
};
|
||||
|
||||
const result = await fetchConnectorExecuteAction(testProps);
|
||||
const result = await fetchConnectorExecuteAction(fetchConnectorArgs);
|
||||
|
||||
expect(result).toEqual({
|
||||
response: 'value from action_input',
|
||||
|
@ -185,14 +200,7 @@ describe('API tests', () => {
|
|||
data: response,
|
||||
});
|
||||
|
||||
const testProps: FetchConnectorExecuteAction = {
|
||||
assistantLangChain: true, // <-- requires response parsing
|
||||
http: mockHttp,
|
||||
messages,
|
||||
apiConfig,
|
||||
};
|
||||
|
||||
const result = await fetchConnectorExecuteAction(testProps);
|
||||
const result = await fetchConnectorExecuteAction(fetchConnectorArgs);
|
||||
|
||||
expect(result).toEqual({ response, isStream: false, isError: false });
|
||||
});
|
||||
|
@ -205,27 +213,19 @@ describe('API tests', () => {
|
|||
data: response,
|
||||
});
|
||||
|
||||
const testProps: FetchConnectorExecuteAction = {
|
||||
assistantLangChain: true, // <-- requires response parsing
|
||||
http: mockHttp,
|
||||
messages,
|
||||
apiConfig,
|
||||
};
|
||||
|
||||
const result = await fetchConnectorExecuteAction(testProps);
|
||||
const result = await fetchConnectorExecuteAction(fetchConnectorArgs);
|
||||
|
||||
expect(result).toEqual({ response, isStream: false, isError: false });
|
||||
});
|
||||
});
|
||||
|
||||
const knowledgeBaseArgs = {
|
||||
resource: 'a-resource',
|
||||
http: mockHttp,
|
||||
};
|
||||
describe('getKnowledgeBaseStatus', () => {
|
||||
it('calls the knowledge base API when correct resource path', async () => {
|
||||
const testProps = {
|
||||
resource: 'a-resource',
|
||||
http: mockHttp,
|
||||
};
|
||||
|
||||
await getKnowledgeBaseStatus(testProps);
|
||||
await getKnowledgeBaseStatus(knowledgeBaseArgs);
|
||||
|
||||
expect(mockHttp.fetch).toHaveBeenCalledWith(
|
||||
'/internal/elastic_assistant/knowledge_base/a-resource',
|
||||
|
@ -236,27 +236,20 @@ describe('API tests', () => {
|
|||
);
|
||||
});
|
||||
it('returns error when error is an error', async () => {
|
||||
const testProps = {
|
||||
resource: 'a-resource',
|
||||
http: mockHttp,
|
||||
};
|
||||
const error = 'simulated error';
|
||||
(mockHttp.fetch as jest.Mock).mockImplementation(() => {
|
||||
throw new Error(error);
|
||||
});
|
||||
|
||||
await expect(getKnowledgeBaseStatus(testProps)).resolves.toThrowError('simulated error');
|
||||
await expect(getKnowledgeBaseStatus(knowledgeBaseArgs)).resolves.toThrowError(
|
||||
'simulated error'
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('postKnowledgeBase', () => {
|
||||
it('calls the knowledge base API when correct resource path', async () => {
|
||||
const testProps = {
|
||||
resource: 'a-resource',
|
||||
http: mockHttp,
|
||||
};
|
||||
|
||||
await postKnowledgeBase(testProps);
|
||||
await postKnowledgeBase(knowledgeBaseArgs);
|
||||
|
||||
expect(mockHttp.fetch).toHaveBeenCalledWith(
|
||||
'/internal/elastic_assistant/knowledge_base/a-resource',
|
||||
|
@ -267,27 +260,18 @@ describe('API tests', () => {
|
|||
);
|
||||
});
|
||||
it('returns error when error is an error', async () => {
|
||||
const testProps = {
|
||||
resource: 'a-resource',
|
||||
http: mockHttp,
|
||||
};
|
||||
const error = 'simulated error';
|
||||
(mockHttp.fetch as jest.Mock).mockImplementation(() => {
|
||||
throw new Error(error);
|
||||
});
|
||||
|
||||
await expect(postKnowledgeBase(testProps)).resolves.toThrowError('simulated error');
|
||||
await expect(postKnowledgeBase(knowledgeBaseArgs)).resolves.toThrowError('simulated error');
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteKnowledgeBase', () => {
|
||||
it('calls the knowledge base API when correct resource path', async () => {
|
||||
const testProps = {
|
||||
resource: 'a-resource',
|
||||
http: mockHttp,
|
||||
};
|
||||
|
||||
await deleteKnowledgeBase(testProps);
|
||||
await deleteKnowledgeBase(knowledgeBaseArgs);
|
||||
|
||||
expect(mockHttp.fetch).toHaveBeenCalledWith(
|
||||
'/internal/elastic_assistant/knowledge_base/a-resource',
|
||||
|
@ -298,16 +282,12 @@ describe('API tests', () => {
|
|||
);
|
||||
});
|
||||
it('returns error when error is an error', async () => {
|
||||
const testProps = {
|
||||
resource: 'a-resource',
|
||||
http: mockHttp,
|
||||
};
|
||||
const error = 'simulated error';
|
||||
(mockHttp.fetch as jest.Mock).mockImplementation(() => {
|
||||
throw new Error(error);
|
||||
});
|
||||
|
||||
await expect(deleteKnowledgeBase(testProps)).resolves.toThrowError('simulated error');
|
||||
await expect(deleteKnowledgeBase(knowledgeBaseArgs)).resolves.toThrowError('simulated error');
|
||||
});
|
||||
});
|
||||
|
||||
|
@ -350,16 +330,12 @@ describe('API tests', () => {
|
|||
});
|
||||
});
|
||||
it('returns error when error is an error', async () => {
|
||||
const testProps = {
|
||||
resource: 'a-resource',
|
||||
http: mockHttp,
|
||||
};
|
||||
const error = 'simulated error';
|
||||
(mockHttp.fetch as jest.Mock).mockImplementation(() => {
|
||||
throw new Error(error);
|
||||
});
|
||||
|
||||
await expect(postEvaluation(testProps)).resolves.toThrowError('simulated error');
|
||||
await expect(postEvaluation(knowledgeBaseArgs)).resolves.toThrowError('simulated error');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -20,6 +20,7 @@ export interface FetchConnectorExecuteAction {
|
|||
http: HttpSetup;
|
||||
messages: Message[];
|
||||
signal?: AbortSignal | undefined;
|
||||
assistantStreamingEnabled: boolean;
|
||||
}
|
||||
|
||||
export interface FetchConnectorExecuteResponse {
|
||||
|
@ -38,6 +39,7 @@ export const fetchConnectorExecuteAction = async ({
|
|||
messages,
|
||||
apiConfig,
|
||||
signal,
|
||||
assistantStreamingEnabled,
|
||||
}: FetchConnectorExecuteAction): Promise<FetchConnectorExecuteResponse> => {
|
||||
const outboundMessages = messages.map((msg) => ({
|
||||
role: msg.role,
|
||||
|
@ -62,7 +64,7 @@ export const fetchConnectorExecuteAction = async ({
|
|||
// tracked here: https://github.com/elastic/security-team/issues/7363
|
||||
// In part 3 I will make enhancements to langchain to introduce streaming
|
||||
// Once implemented, invokeAI can be removed
|
||||
const isStream = !assistantLangChain;
|
||||
const isStream = assistantStreamingEnabled && !assistantLangChain;
|
||||
const requestBody = isStream
|
||||
? {
|
||||
params: {
|
||||
|
|
|
@ -29,7 +29,7 @@ interface UseSendMessages {
|
|||
}
|
||||
|
||||
export const useSendMessages = (): UseSendMessages => {
|
||||
const { knowledgeBase } = useAssistantContext();
|
||||
const { assistantStreamingEnabled, knowledgeBase } = useAssistantContext();
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
|
||||
const sendMessages = useCallback(
|
||||
|
@ -41,12 +41,13 @@ export const useSendMessages = (): UseSendMessages => {
|
|||
http,
|
||||
messages,
|
||||
apiConfig,
|
||||
assistantStreamingEnabled,
|
||||
});
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
},
|
||||
[knowledgeBase.assistantLangChain]
|
||||
[assistantStreamingEnabled, knowledgeBase.assistantLangChain]
|
||||
);
|
||||
|
||||
return { isLoading, sendMessages };
|
||||
|
|
|
@ -28,6 +28,7 @@ const ContextWrapper: React.FC = ({ children }) => (
|
|||
<AssistantProvider
|
||||
actionTypeRegistry={actionTypeRegistry}
|
||||
assistantAvailability={mockAssistantAvailability}
|
||||
assistantStreamingEnabled
|
||||
augmentMessageCodeBlocks={jest.fn()}
|
||||
baseAllow={[]}
|
||||
baseAllowReplacement={[]}
|
||||
|
|
|
@ -51,6 +51,7 @@ type ShowAssistantOverlay = ({
|
|||
export interface AssistantProviderProps {
|
||||
actionTypeRegistry: ActionTypeRegistryContract;
|
||||
assistantAvailability: AssistantAvailability;
|
||||
assistantStreamingEnabled?: boolean;
|
||||
assistantTelemetry?: AssistantTelemetry;
|
||||
augmentMessageCodeBlocks: (currentConversation: Conversation) => CodeBlockDetails[][];
|
||||
baseAllow: string[];
|
||||
|
@ -95,6 +96,7 @@ export interface AssistantProviderProps {
|
|||
export interface UseAssistantContext {
|
||||
actionTypeRegistry: ActionTypeRegistryContract;
|
||||
assistantAvailability: AssistantAvailability;
|
||||
assistantStreamingEnabled: boolean;
|
||||
assistantTelemetry?: AssistantTelemetry;
|
||||
augmentMessageCodeBlocks: (currentConversation: Conversation) => CodeBlockDetails[][];
|
||||
allQuickPrompts: QuickPrompt[];
|
||||
|
@ -155,6 +157,7 @@ const AssistantContext = React.createContext<UseAssistantContext | undefined>(un
|
|||
export const AssistantProvider: React.FC<AssistantProviderProps> = ({
|
||||
actionTypeRegistry,
|
||||
assistantAvailability,
|
||||
assistantStreamingEnabled = false,
|
||||
assistantTelemetry,
|
||||
augmentMessageCodeBlocks,
|
||||
baseAllow,
|
||||
|
@ -284,6 +287,7 @@ export const AssistantProvider: React.FC<AssistantProviderProps> = ({
|
|||
() => ({
|
||||
actionTypeRegistry,
|
||||
assistantAvailability,
|
||||
assistantStreamingEnabled,
|
||||
assistantTelemetry,
|
||||
augmentMessageCodeBlocks,
|
||||
allQuickPrompts: localStorageQuickPrompts ?? [],
|
||||
|
@ -324,6 +328,7 @@ export const AssistantProvider: React.FC<AssistantProviderProps> = ({
|
|||
[
|
||||
actionTypeRegistry,
|
||||
assistantAvailability,
|
||||
assistantStreamingEnabled,
|
||||
assistantTelemetry,
|
||||
augmentMessageCodeBlocks,
|
||||
baseAllow,
|
||||
|
|
|
@ -46,6 +46,11 @@ export const allowedExperimentalValues = Object.freeze({
|
|||
*/
|
||||
extendedRuleExecutionLoggingEnabled: false,
|
||||
|
||||
/**
|
||||
* Enables streaming for Security AI Assistant - non-langchain only (knowledge base off)
|
||||
*/
|
||||
assistantStreamingEnabled: false,
|
||||
|
||||
/**
|
||||
* Enables the SOC trends timerange and stats on D&R page
|
||||
*/
|
||||
|
|
|
@ -36,6 +36,7 @@ export const AssistantProvider: React.FC = ({ children }) => {
|
|||
} = useKibana().services;
|
||||
const basePath = useBasePath();
|
||||
const isModelEvaluationEnabled = useIsExperimentalFeatureEnabled('assistantModelEvaluation');
|
||||
const assistantStreamingEnabled = useIsExperimentalFeatureEnabled('assistantStreamingEnabled');
|
||||
|
||||
const { conversations, setConversations } = useConversationStore();
|
||||
const getInitialConversation = useCallback(() => {
|
||||
|
@ -68,6 +69,7 @@ export const AssistantProvider: React.FC = ({ children }) => {
|
|||
getInitialConversations={getInitialConversation}
|
||||
getComments={getComments}
|
||||
http={http}
|
||||
assistantStreamingEnabled={assistantStreamingEnabled}
|
||||
modelEvaluatorEnabled={isModelEvaluationEnabled}
|
||||
nameSpace={nameSpace}
|
||||
setConversations={setConversations}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue