[Security solution] Add feature flag for AI streaming (#172505)

This commit is contained in:
Steph Milovic 2023-12-04 17:31:18 -07:00 committed by GitHub
parent b876253e7b
commit 39caf945fa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 88 additions and 96 deletions

View file

@ -34,7 +34,13 @@ const apiConfig: Conversation['apiConfig'] = {
const messages: Message[] = [
{ content: 'This is a test', role: 'user', timestamp: new Date().toLocaleString() },
];
const fetchConnectorArgs: FetchConnectorExecuteAction = {
assistantLangChain: true,
http: mockHttp,
messages,
apiConfig,
assistantStreamingEnabled: true,
};
describe('API tests', () => {
beforeEach(() => {
jest.clearAllMocks();
@ -42,14 +48,7 @@ describe('API tests', () => {
describe('fetchConnectorExecuteAction', () => {
it('calls the internal assistant API when assistantLangChain is true', async () => {
const testProps: FetchConnectorExecuteAction = {
assistantLangChain: true,
http: mockHttp,
messages,
apiConfig,
};
await fetchConnectorExecuteAction(testProps);
await fetchConnectorExecuteAction(fetchConnectorArgs);
expect(mockHttp.fetch).toHaveBeenCalledWith(
'/internal/elastic_assistant/actions/connector/foo/_execute',
@ -62,12 +61,10 @@ describe('API tests', () => {
);
});
it('calls the actions connector api when assistantLangChain is false', async () => {
it('calls the actions connector api with streaming when assistantStreamingEnabled is true when assistantLangChain is false', async () => {
const testProps: FetchConnectorExecuteAction = {
...fetchConnectorArgs,
assistantLangChain: false,
http: mockHttp,
messages,
apiConfig,
};
await fetchConnectorExecuteAction(testProps);
@ -84,29 +81,62 @@ describe('API tests', () => {
);
});
it('calls the actions connector api with invoke when assistantStreamingEnabled is false when assistantLangChain is false', async () => {
const testProps: FetchConnectorExecuteAction = {
...fetchConnectorArgs,
assistantLangChain: false,
assistantStreamingEnabled: false,
};
await fetchConnectorExecuteAction(testProps);
expect(mockHttp.fetch).toHaveBeenCalledWith(
'/internal/elastic_assistant/actions/connector/foo/_execute',
{
body: '{"params":{"subActionParams":{"model":"gpt-4","messages":[{"role":"user","content":"This is a test"}],"n":1,"stop":null,"temperature":0.2},"subAction":"invokeAI"},"assistantLangChain":false}',
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
signal: undefined,
}
);
});
it('returns API_ERROR when the response status is error and langchain is on', async () => {
(mockHttp.fetch as jest.Mock).mockResolvedValue({ status: 'error' });
const result = await fetchConnectorExecuteAction(fetchConnectorArgs);
expect(result).toEqual({ response: API_ERROR, isStream: false, isError: true });
});
it('returns API_ERROR + error message on non streaming responses', async () => {
(mockHttp.fetch as jest.Mock).mockResolvedValue({
status: 'error',
service_message: 'an error message',
});
const testProps: FetchConnectorExecuteAction = {
assistantLangChain: true,
http: mockHttp,
messages,
apiConfig,
...fetchConnectorArgs,
assistantLangChain: false,
assistantStreamingEnabled: false,
};
const result = await fetchConnectorExecuteAction(testProps);
expect(result).toEqual({ response: API_ERROR, isStream: false, isError: true });
expect(result).toEqual({
response: `${API_ERROR}\n\nan error message`,
isStream: false,
isError: true,
});
});
it('returns API_ERROR when the response status is error, langchain is off, and response is not a reader', async () => {
(mockHttp.fetch as jest.Mock).mockResolvedValue({ status: 'error' });
const testProps: FetchConnectorExecuteAction = {
...fetchConnectorArgs,
assistantLangChain: false,
http: mockHttp,
messages,
apiConfig,
};
const result = await fetchConnectorExecuteAction(testProps);
@ -124,10 +154,8 @@ describe('API tests', () => {
response: { body: { getReader: jest.fn().mockImplementation(() => mockReader) } },
});
const testProps: FetchConnectorExecuteAction = {
...fetchConnectorArgs,
assistantLangChain: false,
http: mockHttp,
messages,
apiConfig,
};
const result = await fetchConnectorExecuteAction(testProps);
@ -141,14 +169,8 @@ describe('API tests', () => {
it('returns API_ERROR when there are no choices', async () => {
(mockHttp.fetch as jest.Mock).mockResolvedValue({ status: 'ok', data: '' });
const testProps: FetchConnectorExecuteAction = {
assistantLangChain: true,
http: mockHttp,
messages,
apiConfig,
};
const result = await fetchConnectorExecuteAction(testProps);
const result = await fetchConnectorExecuteAction(fetchConnectorArgs);
expect(result).toEqual({ response: API_ERROR, isStream: false, isError: true });
});
@ -161,14 +183,7 @@ describe('API tests', () => {
data: response,
});
const testProps: FetchConnectorExecuteAction = {
assistantLangChain: true, // <-- requires response parsing
http: mockHttp,
messages,
apiConfig,
};
const result = await fetchConnectorExecuteAction(testProps);
const result = await fetchConnectorExecuteAction(fetchConnectorArgs);
expect(result).toEqual({
response: 'value from action_input',
@ -185,14 +200,7 @@ describe('API tests', () => {
data: response,
});
const testProps: FetchConnectorExecuteAction = {
assistantLangChain: true, // <-- requires response parsing
http: mockHttp,
messages,
apiConfig,
};
const result = await fetchConnectorExecuteAction(testProps);
const result = await fetchConnectorExecuteAction(fetchConnectorArgs);
expect(result).toEqual({ response, isStream: false, isError: false });
});
@ -205,27 +213,19 @@ describe('API tests', () => {
data: response,
});
const testProps: FetchConnectorExecuteAction = {
assistantLangChain: true, // <-- requires response parsing
http: mockHttp,
messages,
apiConfig,
};
const result = await fetchConnectorExecuteAction(testProps);
const result = await fetchConnectorExecuteAction(fetchConnectorArgs);
expect(result).toEqual({ response, isStream: false, isError: false });
});
});
const knowledgeBaseArgs = {
resource: 'a-resource',
http: mockHttp,
};
describe('getKnowledgeBaseStatus', () => {
it('calls the knowledge base API when correct resource path', async () => {
const testProps = {
resource: 'a-resource',
http: mockHttp,
};
await getKnowledgeBaseStatus(testProps);
await getKnowledgeBaseStatus(knowledgeBaseArgs);
expect(mockHttp.fetch).toHaveBeenCalledWith(
'/internal/elastic_assistant/knowledge_base/a-resource',
@ -236,27 +236,20 @@ describe('API tests', () => {
);
});
it('returns error when error is an error', async () => {
const testProps = {
resource: 'a-resource',
http: mockHttp,
};
const error = 'simulated error';
(mockHttp.fetch as jest.Mock).mockImplementation(() => {
throw new Error(error);
});
await expect(getKnowledgeBaseStatus(testProps)).resolves.toThrowError('simulated error');
await expect(getKnowledgeBaseStatus(knowledgeBaseArgs)).resolves.toThrowError(
'simulated error'
);
});
});
describe('postKnowledgeBase', () => {
it('calls the knowledge base API when correct resource path', async () => {
const testProps = {
resource: 'a-resource',
http: mockHttp,
};
await postKnowledgeBase(testProps);
await postKnowledgeBase(knowledgeBaseArgs);
expect(mockHttp.fetch).toHaveBeenCalledWith(
'/internal/elastic_assistant/knowledge_base/a-resource',
@ -267,27 +260,18 @@ describe('API tests', () => {
);
});
it('returns error when error is an error', async () => {
const testProps = {
resource: 'a-resource',
http: mockHttp,
};
const error = 'simulated error';
(mockHttp.fetch as jest.Mock).mockImplementation(() => {
throw new Error(error);
});
await expect(postKnowledgeBase(testProps)).resolves.toThrowError('simulated error');
await expect(postKnowledgeBase(knowledgeBaseArgs)).resolves.toThrowError('simulated error');
});
});
describe('deleteKnowledgeBase', () => {
it('calls the knowledge base API when correct resource path', async () => {
const testProps = {
resource: 'a-resource',
http: mockHttp,
};
await deleteKnowledgeBase(testProps);
await deleteKnowledgeBase(knowledgeBaseArgs);
expect(mockHttp.fetch).toHaveBeenCalledWith(
'/internal/elastic_assistant/knowledge_base/a-resource',
@ -298,16 +282,12 @@ describe('API tests', () => {
);
});
it('returns error when error is an error', async () => {
const testProps = {
resource: 'a-resource',
http: mockHttp,
};
const error = 'simulated error';
(mockHttp.fetch as jest.Mock).mockImplementation(() => {
throw new Error(error);
});
await expect(deleteKnowledgeBase(testProps)).resolves.toThrowError('simulated error');
await expect(deleteKnowledgeBase(knowledgeBaseArgs)).resolves.toThrowError('simulated error');
});
});
@ -350,16 +330,12 @@ describe('API tests', () => {
});
});
it('returns error when error is an error', async () => {
const testProps = {
resource: 'a-resource',
http: mockHttp,
};
const error = 'simulated error';
(mockHttp.fetch as jest.Mock).mockImplementation(() => {
throw new Error(error);
});
await expect(postEvaluation(testProps)).resolves.toThrowError('simulated error');
await expect(postEvaluation(knowledgeBaseArgs)).resolves.toThrowError('simulated error');
});
});
});

View file

@ -20,6 +20,7 @@ export interface FetchConnectorExecuteAction {
http: HttpSetup;
messages: Message[];
signal?: AbortSignal | undefined;
assistantStreamingEnabled: boolean;
}
export interface FetchConnectorExecuteResponse {
@ -38,6 +39,7 @@ export const fetchConnectorExecuteAction = async ({
messages,
apiConfig,
signal,
assistantStreamingEnabled,
}: FetchConnectorExecuteAction): Promise<FetchConnectorExecuteResponse> => {
const outboundMessages = messages.map((msg) => ({
role: msg.role,
@ -62,7 +64,7 @@ export const fetchConnectorExecuteAction = async ({
// tracked here: https://github.com/elastic/security-team/issues/7363
// In part 3 I will make enhancements to langchain to introduce streaming
// Once implemented, invokeAI can be removed
const isStream = !assistantLangChain;
const isStream = assistantStreamingEnabled && !assistantLangChain;
const requestBody = isStream
? {
params: {

View file

@ -29,7 +29,7 @@ interface UseSendMessages {
}
export const useSendMessages = (): UseSendMessages => {
const { knowledgeBase } = useAssistantContext();
const { assistantStreamingEnabled, knowledgeBase } = useAssistantContext();
const [isLoading, setIsLoading] = useState(false);
const sendMessages = useCallback(
@ -41,12 +41,13 @@ export const useSendMessages = (): UseSendMessages => {
http,
messages,
apiConfig,
assistantStreamingEnabled,
});
} finally {
setIsLoading(false);
}
},
[knowledgeBase.assistantLangChain]
[assistantStreamingEnabled, knowledgeBase.assistantLangChain]
);
return { isLoading, sendMessages };

View file

@ -28,6 +28,7 @@ const ContextWrapper: React.FC = ({ children }) => (
<AssistantProvider
actionTypeRegistry={actionTypeRegistry}
assistantAvailability={mockAssistantAvailability}
assistantStreamingEnabled
augmentMessageCodeBlocks={jest.fn()}
baseAllow={[]}
baseAllowReplacement={[]}

View file

@ -51,6 +51,7 @@ type ShowAssistantOverlay = ({
export interface AssistantProviderProps {
actionTypeRegistry: ActionTypeRegistryContract;
assistantAvailability: AssistantAvailability;
assistantStreamingEnabled?: boolean;
assistantTelemetry?: AssistantTelemetry;
augmentMessageCodeBlocks: (currentConversation: Conversation) => CodeBlockDetails[][];
baseAllow: string[];
@ -95,6 +96,7 @@ export interface AssistantProviderProps {
export interface UseAssistantContext {
actionTypeRegistry: ActionTypeRegistryContract;
assistantAvailability: AssistantAvailability;
assistantStreamingEnabled: boolean;
assistantTelemetry?: AssistantTelemetry;
augmentMessageCodeBlocks: (currentConversation: Conversation) => CodeBlockDetails[][];
allQuickPrompts: QuickPrompt[];
@ -155,6 +157,7 @@ const AssistantContext = React.createContext<UseAssistantContext | undefined>(un
export const AssistantProvider: React.FC<AssistantProviderProps> = ({
actionTypeRegistry,
assistantAvailability,
assistantStreamingEnabled = false,
assistantTelemetry,
augmentMessageCodeBlocks,
baseAllow,
@ -284,6 +287,7 @@ export const AssistantProvider: React.FC<AssistantProviderProps> = ({
() => ({
actionTypeRegistry,
assistantAvailability,
assistantStreamingEnabled,
assistantTelemetry,
augmentMessageCodeBlocks,
allQuickPrompts: localStorageQuickPrompts ?? [],
@ -324,6 +328,7 @@ export const AssistantProvider: React.FC<AssistantProviderProps> = ({
[
actionTypeRegistry,
assistantAvailability,
assistantStreamingEnabled,
assistantTelemetry,
augmentMessageCodeBlocks,
baseAllow,

View file

@ -46,6 +46,11 @@ export const allowedExperimentalValues = Object.freeze({
*/
extendedRuleExecutionLoggingEnabled: false,
/**
* Enables streaming for Security AI Assistant - non-langchain only (knowledge base off)
*/
assistantStreamingEnabled: false,
/**
* Enables the SOC trends timerange and stats on D&R page
*/

View file

@ -36,6 +36,7 @@ export const AssistantProvider: React.FC = ({ children }) => {
} = useKibana().services;
const basePath = useBasePath();
const isModelEvaluationEnabled = useIsExperimentalFeatureEnabled('assistantModelEvaluation');
const assistantStreamingEnabled = useIsExperimentalFeatureEnabled('assistantStreamingEnabled');
const { conversations, setConversations } = useConversationStore();
const getInitialConversation = useCallback(() => {
@ -68,6 +69,7 @@ export const AssistantProvider: React.FC = ({ children }) => {
getInitialConversations={getInitialConversation}
getComments={getComments}
http={http}
assistantStreamingEnabled={assistantStreamingEnabled}
modelEvaluatorEnabled={isModelEvaluationEnabled}
nameSpace={nameSpace}
setConversations={setConversations}