mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 09:19:04 -04:00
[8.15] [Search] [Playground] Gemini search playground + Robustness for Question Rewriting (#187559) (#187779)
# Backport This will backport the following commits from `main` to `8.15`: - [[Search] [Playground] Gemini search playground + Robustness for Question Rewriting (#187559)](https://github.com/elastic/kibana/pull/187559) <!--- Backport version: 9.4.3 --> ### Questions ? Please refer to the [Backport tool documentation](https://github.com/sqren/backport) <!--BACKPORT [{"author":{"name":"Joe McElroy","email":"joseph.mcelroy@elastic.co"},"sourceCommit":{"committedDate":"2024-07-08T17:18:12Z","message":"[Search] [Playground] Gemini search playground + Robustness for Question Rewriting (#187559)\n\n## Summary\r\n\r\nWork largely based off the work @stephmilovic really nicely put together\r\nin this [draft PR](https://github.com/elastic/kibana/pull/186934)\r\n- Introduce Google Gemini Model support\r\n- Updated bedrock to use the ActionsSimpleChatModel \r\n- Updated the tests\r\n- Made the rewrite question chain more robust\r\n - the prompt is now uses the model specific tags\r\n- the system instruction has been updated to be less wordy, better for\r\nBM25 retrieval\r\n \r\n\r\n4558bc5d
-e0c1-4ff6-b68c-800441f7835e\r\n\r\n### Checklist\r\n\r\nDelete any items that are not applicable to this PR.\r\n\r\n- [ ] Any text added follows [EUI's writing\r\nguidelines](https://elastic.github.io/eui/#/guidelines/writing), uses\r\nsentence case text and includes [i18n\r\nsupport](https://github.com/elastic/kibana/blob/main/packages/kbn-i18n/README.md)\r\n- [ ]\r\n[Documentation](https://www.elastic.co/guide/en/kibana/master/development-documentation.html)\r\nwas added for features that require explanation or tutorials\r\n- [x] [Unit or functional\r\ntests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)\r\nwere updated or added to match the most common scenarios\r\n- [ ] [Flaky Test\r\nRunner](https://ci-stats.kibana.dev/trigger_flaky_test_runner/1) was\r\nused on any tests changed\r\n- [ ] Any UI touched in this PR is usable by keyboard only (learn more\r\nabout [keyboard accessibility](https://webaim.org/techniques/keyboard/))\r\n- [ ] Any UI touched in this PR does not create any new axe failures\r\n(run axe in browser:\r\n[FF](https://addons.mozilla.org/en-US/firefox/addon/axe-devtools/),\r\n[Chrome](https://chrome.google.com/webstore/detail/axe-web-accessibility-tes/lhdoppojpmngadmnindnejefpokejbdd?hl=en-US))\r\n- [ ] If a plugin configuration key changed, check if it needs to be\r\nallowlisted in the cloud and added to the [docker\r\nlist](https://github.com/elastic/kibana/blob/main/src/dev/build/tasks/os_packages/docker_generator/resources/base/bin/kibana-docker)\r\n- [ ] This renders correctly on smaller devices using a responsive\r\nlayout. (You can test this [in your\r\nbrowser](https://www.browserstack.com/guide/responsive-testing-on-local-server))\r\n- [ ] This was checked for [cross-browser\r\ncompatibility](https://www.elastic.co/support/matrix#matrix_browsers)\r\n\r\n---------\r\n\r\nCo-authored-by: Steph Milovic <stephanie.milovic@elastic.co>\r\nCo-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>","sha":"0be5528f21fd0442076d3f331c15cc3f34098d17","branchLabelMapping":{"^v8.16.0$":"main","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","Team:EnterpriseSearch","v8.15.0","v8.16.0"],"title":"[Search] [Playground] Gemini search playground + Robustness for Question Rewriting","number":187559,"url":"https://github.com/elastic/kibana/pull/187559","mergeCommit":{"message":"[Search] [Playground] Gemini search playground + Robustness for Question Rewriting (#187559)\n\n## Summary\r\n\r\nWork largely based off the work @stephmilovic really nicely put together\r\nin this [draft PR](https://github.com/elastic/kibana/pull/186934)\r\n- Introduce Google Gemini Model support\r\n- Updated bedrock to use the ActionsSimpleChatModel \r\n- Updated the tests\r\n- Made the rewrite question chain more robust\r\n - the prompt is now uses the model specific tags\r\n- the system instruction has been updated to be less wordy, better for\r\nBM25 retrieval\r\n \r\n\r\n4558bc5d
-e0c1-4ff6-b68c-800441f7835e\r\n\r\n### Checklist\r\n\r\nDelete any items that are not applicable to this PR.\r\n\r\n- [ ] Any text added follows [EUI's writing\r\nguidelines](https://elastic.github.io/eui/#/guidelines/writing), uses\r\nsentence case text and includes [i18n\r\nsupport](https://github.com/elastic/kibana/blob/main/packages/kbn-i18n/README.md)\r\n- [ ]\r\n[Documentation](https://www.elastic.co/guide/en/kibana/master/development-documentation.html)\r\nwas added for features that require explanation or tutorials\r\n- [x] [Unit or functional\r\ntests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)\r\nwere updated or added to match the most common scenarios\r\n- [ ] [Flaky Test\r\nRunner](https://ci-stats.kibana.dev/trigger_flaky_test_runner/1) was\r\nused on any tests changed\r\n- [ ] Any UI touched in this PR is usable by keyboard only (learn more\r\nabout [keyboard accessibility](https://webaim.org/techniques/keyboard/))\r\n- [ ] Any UI touched in this PR does not create any new axe failures\r\n(run axe in browser:\r\n[FF](https://addons.mozilla.org/en-US/firefox/addon/axe-devtools/),\r\n[Chrome](https://chrome.google.com/webstore/detail/axe-web-accessibility-tes/lhdoppojpmngadmnindnejefpokejbdd?hl=en-US))\r\n- [ ] If a plugin configuration key changed, check if it needs to be\r\nallowlisted in the cloud and added to the [docker\r\nlist](https://github.com/elastic/kibana/blob/main/src/dev/build/tasks/os_packages/docker_generator/resources/base/bin/kibana-docker)\r\n- [ ] This renders correctly on smaller devices using a responsive\r\nlayout. (You can test this [in your\r\nbrowser](https://www.browserstack.com/guide/responsive-testing-on-local-server))\r\n- [ ] This was checked for [cross-browser\r\ncompatibility](https://www.elastic.co/support/matrix#matrix_browsers)\r\n\r\n---------\r\n\r\nCo-authored-by: Steph Milovic <stephanie.milovic@elastic.co>\r\nCo-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>","sha":"0be5528f21fd0442076d3f331c15cc3f34098d17"}},"sourceBranch":"main","suggestedTargetBranches":["8.15"],"targetPullRequestStates":[{"branch":"8.15","label":"v8.15.0","branchLabelMappingKey":"^v(\\d+).(\\d+).\\d+$","isSourceBranch":false,"state":"NOT_CREATED"},{"branch":"main","label":"v8.16.0","branchLabelMappingKey":"^v8.16.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/187559","number":187559,"mergeCommit":{"message":"[Search] [Playground] Gemini search playground + Robustness for Question Rewriting (#187559)\n\n## Summary\r\n\r\nWork largely based off the work @stephmilovic really nicely put together\r\nin this [draft PR](https://github.com/elastic/kibana/pull/186934)\r\n- Introduce Google Gemini Model support\r\n- Updated bedrock to use the ActionsSimpleChatModel \r\n- Updated the tests\r\n- Made the rewrite question chain more robust\r\n - the prompt is now uses the model specific tags\r\n- the system instruction has been updated to be less wordy, better for\r\nBM25 retrieval\r\n \r\n\r\n4558bc5d
-e0c1-4ff6-b68c-800441f7835e\r\n\r\n### Checklist\r\n\r\nDelete any items that are not applicable to this PR.\r\n\r\n- [ ] Any text added follows [EUI's writing\r\nguidelines](https://elastic.github.io/eui/#/guidelines/writing), uses\r\nsentence case text and includes [i18n\r\nsupport](https://github.com/elastic/kibana/blob/main/packages/kbn-i18n/README.md)\r\n- [ ]\r\n[Documentation](https://www.elastic.co/guide/en/kibana/master/development-documentation.html)\r\nwas added for features that require explanation or tutorials\r\n- [x] [Unit or functional\r\ntests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)\r\nwere updated or added to match the most common scenarios\r\n- [ ] [Flaky Test\r\nRunner](https://ci-stats.kibana.dev/trigger_flaky_test_runner/1) was\r\nused on any tests changed\r\n- [ ] Any UI touched in this PR is usable by keyboard only (learn more\r\nabout [keyboard accessibility](https://webaim.org/techniques/keyboard/))\r\n- [ ] Any UI touched in this PR does not create any new axe failures\r\n(run axe in browser:\r\n[FF](https://addons.mozilla.org/en-US/firefox/addon/axe-devtools/),\r\n[Chrome](https://chrome.google.com/webstore/detail/axe-web-accessibility-tes/lhdoppojpmngadmnindnejefpokejbdd?hl=en-US))\r\n- [ ] If a plugin configuration key changed, check if it needs to be\r\nallowlisted in the cloud and added to the [docker\r\nlist](https://github.com/elastic/kibana/blob/main/src/dev/build/tasks/os_packages/docker_generator/resources/base/bin/kibana-docker)\r\n- [ ] This renders correctly on smaller devices using a responsive\r\nlayout. (You can test this [in your\r\nbrowser](https://www.browserstack.com/guide/responsive-testing-on-local-server))\r\n- [ ] This was checked for [cross-browser\r\ncompatibility](https://www.elastic.co/support/matrix#matrix_browsers)\r\n\r\n---------\r\n\r\nCo-authored-by: Steph Milovic <stephanie.milovic@elastic.co>\r\nCo-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>","sha":"0be5528f21fd0442076d3f331c15cc3f34098d17"}}]}] BACKPORT--> Co-authored-by: Joe McElroy <joseph.mcelroy@elastic.co>
This commit is contained in:
parent
12cbac97e5
commit
d11287dc97
14 changed files with 190 additions and 52 deletions
|
@ -27,15 +27,27 @@ export const MODELS: ModelProvider[] = [
|
|||
provider: LLMs.openai,
|
||||
},
|
||||
{
|
||||
name: 'Claude 3 Haiku',
|
||||
name: 'Anthropic Claude 3 Haiku',
|
||||
model: 'anthropic.claude-3-haiku-20240307-v1:0',
|
||||
promptTokenLimit: 200000,
|
||||
provider: LLMs.bedrock,
|
||||
},
|
||||
{
|
||||
name: 'Claude 3 Sonnet',
|
||||
model: 'anthropic.claude-3-haiku-20240307-v1:0',
|
||||
name: 'Anthropic Claude 3.5 Sonnet',
|
||||
model: 'anthropic.claude-3-5-sonnet-20240620-v1:0',
|
||||
promptTokenLimit: 200000,
|
||||
provider: LLMs.bedrock,
|
||||
},
|
||||
{
|
||||
name: 'Google Gemini 1.5 Pro',
|
||||
model: 'gemini-1.5-pro-001',
|
||||
promptTokenLimit: 2097152,
|
||||
provider: LLMs.gemini,
|
||||
},
|
||||
{
|
||||
name: 'Google Gemini 1.5 Flash',
|
||||
model: 'gemini-1.5-flash-001',
|
||||
promptTokenLimit: 2097152,
|
||||
provider: LLMs.gemini,
|
||||
},
|
||||
];
|
||||
|
|
|
@ -45,10 +45,23 @@ const AnthropicPrompt = (systemInstructions: string) => {
|
|||
`;
|
||||
};
|
||||
|
||||
const GeminiPrompt = (systemInstructions: string) => {
|
||||
return `
|
||||
Instructions:
|
||||
${systemInstructions}
|
||||
|
||||
Context:
|
||||
{context}
|
||||
|
||||
Question: {question}
|
||||
Answer:
|
||||
`;
|
||||
};
|
||||
|
||||
interface PromptTemplateOptions {
|
||||
citations?: boolean;
|
||||
context?: boolean;
|
||||
type?: 'openai' | 'mistral' | 'anthropic';
|
||||
type?: 'openai' | 'mistral' | 'anthropic' | 'gemini';
|
||||
}
|
||||
|
||||
export const Prompt = (instructions: string, options: PromptTemplateOptions): string => {
|
||||
|
@ -73,5 +86,20 @@ export const Prompt = (instructions: string, options: PromptTemplateOptions): st
|
|||
openai: OpenAIPrompt,
|
||||
mistral: MistralPrompt,
|
||||
anthropic: AnthropicPrompt,
|
||||
gemini: GeminiPrompt,
|
||||
}[options.type || 'openai'](systemInstructions);
|
||||
};
|
||||
|
||||
interface QuestionRewritePromptOptions {
|
||||
type: 'openai' | 'mistral' | 'anthropic' | 'gemini';
|
||||
}
|
||||
|
||||
export const QuestionRewritePrompt = (options: QuestionRewritePromptOptions): string => {
|
||||
const systemInstructions = `Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question. Rewrite the question in the question language. Keep the answer to a single sentence. Do not use quotes.`;
|
||||
return {
|
||||
openai: OpenAIPrompt,
|
||||
mistral: MistralPrompt,
|
||||
anthropic: AnthropicPrompt,
|
||||
gemini: GeminiPrompt,
|
||||
}[options.type || 'openai'](systemInstructions);
|
||||
};
|
||||
|
|
|
@ -40,6 +40,7 @@ export enum LLMs {
|
|||
openai = 'openai',
|
||||
openai_azure = 'openai_azure',
|
||||
bedrock = 'bedrock',
|
||||
gemini = 'gemini',
|
||||
}
|
||||
|
||||
export interface ChatRequestData {
|
||||
|
|
|
@ -88,8 +88,8 @@ describe('useLLMsModels Hook', () => {
|
|||
connectorType: LLMs.bedrock,
|
||||
disabled: false,
|
||||
icon: expect.any(Function),
|
||||
id: 'connectorId2Claude 3 Haiku',
|
||||
name: 'Claude 3 Haiku',
|
||||
id: 'connectorId2Anthropic Claude 3 Haiku',
|
||||
name: 'Anthropic Claude 3 Haiku',
|
||||
showConnectorName: false,
|
||||
value: 'anthropic.claude-3-haiku-20240307-v1:0',
|
||||
promptTokenLimit: 200000,
|
||||
|
@ -100,10 +100,10 @@ describe('useLLMsModels Hook', () => {
|
|||
connectorType: LLMs.bedrock,
|
||||
disabled: false,
|
||||
icon: expect.any(Function),
|
||||
id: 'connectorId2Claude 3 Sonnet',
|
||||
name: 'Claude 3 Sonnet',
|
||||
id: 'connectorId2Anthropic Claude 3.5 Sonnet',
|
||||
name: 'Anthropic Claude 3.5 Sonnet',
|
||||
showConnectorName: false,
|
||||
value: 'anthropic.claude-3-haiku-20240307-v1:0',
|
||||
value: 'anthropic.claude-3-5-sonnet-20240620-v1:0',
|
||||
promptTokenLimit: 200000,
|
||||
},
|
||||
]);
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
*/
|
||||
|
||||
import { i18n } from '@kbn/i18n';
|
||||
import { BedrockLogo, OpenAILogo } from '@kbn/stack-connectors-plugin/public/common';
|
||||
import { BedrockLogo, OpenAILogo, GeminiLogo } from '@kbn/stack-connectors-plugin/public/common';
|
||||
import { ComponentType, useMemo } from 'react';
|
||||
import { LLMs } from '../../common/types';
|
||||
import { LLMModel } from '../types';
|
||||
|
@ -52,6 +52,15 @@ const mapLlmToModels: Record<
|
|||
promptTokenLimit: model.promptTokenLimit,
|
||||
})),
|
||||
},
|
||||
[LLMs.gemini]: {
|
||||
icon: GeminiLogo,
|
||||
getModels: () =>
|
||||
MODELS.filter(({ provider }) => provider === LLMs.gemini).map((model) => ({
|
||||
label: model.name,
|
||||
value: model.model,
|
||||
promptTokenLimit: model.promptTokenLimit,
|
||||
})),
|
||||
},
|
||||
};
|
||||
|
||||
export const useLLMsModels = (): LLMModel[] => {
|
||||
|
|
|
@ -16,6 +16,7 @@ import {
|
|||
OPENAI_CONNECTOR_ID,
|
||||
OpenAiProviderType,
|
||||
BEDROCK_CONNECTOR_ID,
|
||||
GEMINI_CONNECTOR_ID,
|
||||
} from '@kbn/stack-connectors-plugin/public/common';
|
||||
import { UserConfiguredActionConnector } from '@kbn/triggers-actions-ui-plugin/public/types';
|
||||
import { useKibana } from './use_kibana';
|
||||
|
@ -73,6 +74,17 @@ const connectorTypeToLLM: Array<{
|
|||
type: LLMs.bedrock,
|
||||
}),
|
||||
},
|
||||
{
|
||||
actionId: GEMINI_CONNECTOR_ID,
|
||||
match: (connector) => connector.actionTypeId === GEMINI_CONNECTOR_ID,
|
||||
transform: (connector) => ({
|
||||
...connector,
|
||||
title: i18n.translate('xpack.searchPlayground.geminiConnectorTitle', {
|
||||
defaultMessage: 'Gemini',
|
||||
}),
|
||||
type: LLMs.gemini,
|
||||
}),
|
||||
},
|
||||
];
|
||||
|
||||
type PlaygroundConnector = ActionConnector & { title: string; type: LLMs };
|
||||
|
|
|
@ -97,6 +97,7 @@ describe('conversational chain', () => {
|
|||
inputTokensLimit: modelLimit,
|
||||
},
|
||||
prompt: 'you are a QA bot {question} {chat_history} {context}',
|
||||
questionRewritePrompt: 'rewrite question {question} using {chat_history}"',
|
||||
});
|
||||
|
||||
const stream = await conversationalChain.stream(aiClient, chat);
|
||||
|
@ -442,7 +443,7 @@ describe('conversational chain', () => {
|
|||
type: 'retrieved_docs',
|
||||
},
|
||||
],
|
||||
// Even with body_content of 1000, the token count should be below the model limit of 100
|
||||
// Even with body_content of 1000, the token count should be below or equal to model limit of 100
|
||||
expectedTokens: [
|
||||
{ type: 'context_token_count', count: 70 },
|
||||
{ type: 'prompt_token_count', count: 97 },
|
||||
|
|
|
@ -37,6 +37,7 @@ interface RAGOptions {
|
|||
interface ConversationalChainOptions {
|
||||
model: BaseLanguageModel;
|
||||
prompt: string;
|
||||
questionRewritePrompt: string;
|
||||
rag?: RAGOptions;
|
||||
}
|
||||
|
||||
|
@ -46,16 +47,6 @@ interface ContextInputs {
|
|||
question: string;
|
||||
}
|
||||
|
||||
const CONDENSE_QUESTION_TEMPLATE = `Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language. Be verbose in your answer.
|
||||
|
||||
Chat History:
|
||||
{chat_history}
|
||||
|
||||
Follow Up Input: {question}
|
||||
Standalone question:`;
|
||||
|
||||
const condenseQuestionPrompt = PromptTemplate.fromTemplate(CONDENSE_QUESTION_TEMPLATE);
|
||||
|
||||
const formatVercelMessages = (chatHistory: VercelChatMessage[]) => {
|
||||
const formattedDialogueTurns = chatHistory.map((message) => {
|
||||
if (message.role === 'user') {
|
||||
|
@ -160,11 +151,21 @@ class ConversationalChainFn {
|
|||
retrievalChain = retriever.pipe(buildContext);
|
||||
}
|
||||
|
||||
let standaloneQuestionChain: Runnable = RunnableLambda.from((input) => input.question);
|
||||
let standaloneQuestionChain: Runnable = RunnableLambda.from((input) => {
|
||||
return input.question;
|
||||
});
|
||||
|
||||
if (previousMessages.length > 0) {
|
||||
const questionRewritePromptTemplate = PromptTemplate.fromTemplate(
|
||||
this.options.questionRewritePrompt
|
||||
);
|
||||
standaloneQuestionChain = RunnableSequence.from([
|
||||
condenseQuestionPrompt,
|
||||
{
|
||||
context: () => '',
|
||||
chat_history: (input) => input.chat_history,
|
||||
question: (input) => input.question,
|
||||
},
|
||||
questionRewritePromptTemplate,
|
||||
this.options.model,
|
||||
new StringOutputParser(),
|
||||
]).withConfig({
|
||||
|
|
|
@ -6,12 +6,13 @@
|
|||
*/
|
||||
|
||||
import { getChatParams } from './get_chat_params';
|
||||
import { ActionsClientChatOpenAI, ActionsClientLlm } from '@kbn/langchain/server';
|
||||
import { ActionsClientChatOpenAI, ActionsClientSimpleChatModel } from '@kbn/langchain/server';
|
||||
import {
|
||||
OPENAI_CONNECTOR_ID,
|
||||
BEDROCK_CONNECTOR_ID,
|
||||
GEMINI_CONNECTOR_ID,
|
||||
} from '@kbn/stack-connectors-plugin/public/common';
|
||||
import { Prompt } from '../../common/prompt';
|
||||
import { Prompt, QuestionRewritePrompt } from '../../common/prompt';
|
||||
import { KibanaRequest, Logger } from '@kbn/core/server';
|
||||
import { PluginStartContract as ActionsPluginStartContract } from '@kbn/actions-plugin/server';
|
||||
|
||||
|
@ -20,12 +21,13 @@ jest.mock('@kbn/langchain/server', () => {
|
|||
return {
|
||||
...original,
|
||||
ActionsClientChatOpenAI: jest.fn(),
|
||||
ActionsClientLlm: jest.fn(),
|
||||
ActionsClientSimpleChatModel: jest.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
jest.mock('../../common/prompt', () => ({
|
||||
Prompt: jest.fn((instructions) => instructions),
|
||||
QuestionRewritePrompt: jest.fn((instructions) => instructions),
|
||||
}));
|
||||
|
||||
jest.mock('uuid', () => ({
|
||||
|
@ -64,10 +66,45 @@ describe('getChatParams', () => {
|
|||
context: true,
|
||||
type: 'openai',
|
||||
});
|
||||
expect(QuestionRewritePrompt).toHaveBeenCalledWith({
|
||||
type: 'openai',
|
||||
});
|
||||
expect(ActionsClientChatOpenAI).toHaveBeenCalledWith(expect.anything());
|
||||
expect(result.chatPrompt).toContain('Hello, world!');
|
||||
});
|
||||
|
||||
it('returns the correct chat model and prompt for Gemeni', async () => {
|
||||
mockActionsClient.get.mockResolvedValue({ id: '1', actionTypeId: GEMINI_CONNECTOR_ID });
|
||||
|
||||
const result = await getChatParams(
|
||||
{
|
||||
connectorId: '1',
|
||||
model: 'gemini-1.5-pro',
|
||||
prompt: 'Hello, world!',
|
||||
citations: true,
|
||||
},
|
||||
{ actions, request, logger }
|
||||
);
|
||||
expect(Prompt).toHaveBeenCalledWith('Hello, world!', {
|
||||
citations: true,
|
||||
context: true,
|
||||
type: 'gemini',
|
||||
});
|
||||
expect(QuestionRewritePrompt).toHaveBeenCalledWith({
|
||||
type: 'gemini',
|
||||
});
|
||||
expect(ActionsClientSimpleChatModel).toHaveBeenCalledWith({
|
||||
temperature: 0,
|
||||
llmType: 'gemini',
|
||||
logger: expect.anything(),
|
||||
model: 'gemini-1.5-pro',
|
||||
connectorId: '1',
|
||||
actionsClient: expect.anything(),
|
||||
streaming: true,
|
||||
});
|
||||
expect(result.chatPrompt).toContain('Hello, world!');
|
||||
});
|
||||
|
||||
it('returns the correct chat model and prompt for BEDROCK_CONNECTOR_ID', async () => {
|
||||
mockActionsClient.get.mockResolvedValue({ id: '2', actionTypeId: BEDROCK_CONNECTOR_ID });
|
||||
|
||||
|
@ -86,14 +123,17 @@ describe('getChatParams', () => {
|
|||
context: true,
|
||||
type: 'anthropic',
|
||||
});
|
||||
expect(ActionsClientLlm).toHaveBeenCalledWith({
|
||||
expect(QuestionRewritePrompt).toHaveBeenCalledWith({
|
||||
type: 'anthropic',
|
||||
});
|
||||
expect(ActionsClientSimpleChatModel).toHaveBeenCalledWith({
|
||||
temperature: 0,
|
||||
llmType: 'bedrock',
|
||||
traceId: 'test-uuid',
|
||||
logger: expect.anything(),
|
||||
model: 'custom-model',
|
||||
connectorId: '2',
|
||||
actionsClient: expect.anything(),
|
||||
streaming: true,
|
||||
});
|
||||
expect(result.chatPrompt).toContain('How does it work?');
|
||||
});
|
||||
|
|
|
@ -14,10 +14,11 @@ import { BaseLanguageModel } from '@langchain/core/language_models/base';
|
|||
import type { Connector } from '@kbn/actions-plugin/server/application/connector/types';
|
||||
import {
|
||||
ActionsClientChatOpenAI,
|
||||
ActionsClientLlm,
|
||||
ActionsClientSimpleChatModel,
|
||||
getDefaultArguments,
|
||||
} from '@kbn/langchain/server';
|
||||
import { Prompt } from '../../common/prompt';
|
||||
import { GEMINI_CONNECTOR_ID } from '@kbn/stack-connectors-plugin/common/gemini/constants';
|
||||
import { Prompt, QuestionRewritePrompt } from '../../common/prompt';
|
||||
|
||||
export const getChatParams = async (
|
||||
{
|
||||
|
@ -35,13 +36,20 @@ export const getChatParams = async (
|
|||
logger: Logger;
|
||||
request: KibanaRequest;
|
||||
}
|
||||
): Promise<{ chatModel: BaseLanguageModel; chatPrompt: string; connector: Connector }> => {
|
||||
): Promise<{
|
||||
chatModel: BaseLanguageModel;
|
||||
chatPrompt: string;
|
||||
questionRewritePrompt: string;
|
||||
connector: Connector;
|
||||
}> => {
|
||||
const abortController = new AbortController();
|
||||
const abortSignal = abortController.signal;
|
||||
const actionsClient = await actions.getActionsClientWithRequest(request);
|
||||
const connector = await actionsClient.get({ id: connectorId });
|
||||
let chatModel;
|
||||
let chatPrompt;
|
||||
let questionRewritePrompt;
|
||||
let llmType;
|
||||
|
||||
switch (connector.actionTypeId) {
|
||||
case OPENAI_CONNECTOR_ID:
|
||||
|
@ -62,31 +70,57 @@ export const getChatParams = async (
|
|||
context: true,
|
||||
type: 'openai',
|
||||
});
|
||||
questionRewritePrompt = QuestionRewritePrompt({
|
||||
type: 'openai',
|
||||
});
|
||||
break;
|
||||
case BEDROCK_CONNECTOR_ID:
|
||||
const llmType = 'bedrock';
|
||||
chatModel = new ActionsClientLlm({
|
||||
llmType = 'bedrock';
|
||||
chatModel = new ActionsClientSimpleChatModel({
|
||||
actionsClient,
|
||||
logger,
|
||||
connectorId,
|
||||
model,
|
||||
traceId: uuidv4(),
|
||||
llmType,
|
||||
temperature: getDefaultArguments(llmType).temperature,
|
||||
streaming: true,
|
||||
});
|
||||
chatPrompt = Prompt(prompt, {
|
||||
citations,
|
||||
context: true,
|
||||
type: 'anthropic',
|
||||
});
|
||||
questionRewritePrompt = QuestionRewritePrompt({
|
||||
type: 'anthropic',
|
||||
});
|
||||
break;
|
||||
case GEMINI_CONNECTOR_ID:
|
||||
llmType = 'gemini';
|
||||
chatModel = new ActionsClientSimpleChatModel({
|
||||
actionsClient,
|
||||
logger,
|
||||
connectorId,
|
||||
model,
|
||||
llmType,
|
||||
temperature: getDefaultArguments(llmType).temperature,
|
||||
streaming: true,
|
||||
});
|
||||
chatPrompt = Prompt(prompt, {
|
||||
citations,
|
||||
context: true,
|
||||
type: 'gemini',
|
||||
});
|
||||
questionRewritePrompt = QuestionRewritePrompt({
|
||||
type: 'gemini',
|
||||
});
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
if (!chatModel || !chatPrompt) {
|
||||
if (!chatModel || !chatPrompt || !questionRewritePrompt) {
|
||||
throw new Error('Invalid connector id');
|
||||
}
|
||||
|
||||
return { chatModel, chatPrompt, connector };
|
||||
return { chatModel, chatPrompt, questionRewritePrompt, connector };
|
||||
};
|
||||
|
|
|
@ -27,7 +27,7 @@ import { MODELS } from '../common/models';
|
|||
export function createRetriever(esQuery: string) {
|
||||
return (question: string) => {
|
||||
try {
|
||||
const replacedQuery = esQuery.replace(/{query}/g, question.replace(/"/g, '\\"'));
|
||||
const replacedQuery = esQuery.replace(/\"{query}\"/g, JSON.stringify(question));
|
||||
const query = JSON.parse(replacedQuery);
|
||||
return query;
|
||||
} catch (e) {
|
||||
|
@ -96,7 +96,7 @@ export function defineRoutes({
|
|||
es_client: client.asCurrentUser,
|
||||
} as AssistClientOptionsWithClient);
|
||||
const { messages, data } = await request.body;
|
||||
const { chatModel, chatPrompt, connector } = await getChatParams(
|
||||
const { chatModel, chatPrompt, questionRewritePrompt, connector } = await getChatParams(
|
||||
{
|
||||
connectorId: data.connector_id,
|
||||
model: data.summarization_model,
|
||||
|
@ -133,6 +133,7 @@ export function defineRoutes({
|
|||
inputTokensLimit: modelPromptLimit,
|
||||
},
|
||||
prompt: chatPrompt,
|
||||
questionRewritePrompt,
|
||||
});
|
||||
|
||||
let stream: ReadableStream<Uint8Array>;
|
||||
|
|
|
@ -7,9 +7,13 @@
|
|||
|
||||
import OpenAILogo from '../connector_types/openai/logo';
|
||||
import BedrockLogo from '../connector_types/bedrock/logo';
|
||||
import GeminiLogo from '../connector_types/bedrock/logo';
|
||||
|
||||
export { GEMINI_CONNECTOR_ID } from '../../common/gemini/constants';
|
||||
|
||||
export { OPENAI_CONNECTOR_ID, OpenAiProviderType } from '../../common/openai/constants';
|
||||
export { OpenAILogo };
|
||||
export { GeminiLogo };
|
||||
|
||||
import SentinelOneLogo from '../connector_types/sentinelone/logo';
|
||||
|
||||
|
|
|
@ -11,9 +11,8 @@ import { PassThrough } from 'stream';
|
|||
import { IncomingMessage } from 'http';
|
||||
import { SubActionRequestParams } from '@kbn/actions-plugin/server/sub_action_framework/types';
|
||||
import { getGoogleOAuthJwtAccessToken } from '@kbn/actions-plugin/server/lib/get_gcp_oauth_access_token';
|
||||
import { Logger } from '@kbn/core/server';
|
||||
import { ConnectorTokenClientContract } from '@kbn/actions-plugin/server/types';
|
||||
import { ActionsConfigurationUtilities } from '@kbn/actions-plugin/server/actions_config';
|
||||
|
||||
import {
|
||||
RunActionParamsSchema,
|
||||
RunApiResponseSchema,
|
||||
|
@ -39,16 +38,6 @@ import {
|
|||
DEFAULT_TOKEN_LIMIT,
|
||||
} from '../../../common/gemini/constants';
|
||||
import { DashboardActionParamsSchema } from '../../../common/gemini/schema';
|
||||
|
||||
export interface GetAxiosInstanceOpts {
|
||||
connectorId: string;
|
||||
logger: Logger;
|
||||
credentials: string;
|
||||
snServiceUrl: string;
|
||||
connectorTokenClient: ConnectorTokenClientContract;
|
||||
configurationUtilities: ActionsConfigurationUtilities;
|
||||
}
|
||||
|
||||
/** Interfaces to define Gemini model response type */
|
||||
|
||||
interface MessagePart {
|
||||
|
|
|
@ -10,7 +10,10 @@ import {
|
|||
SubActionConnectorType,
|
||||
ValidatorType,
|
||||
} from '@kbn/actions-plugin/server/sub_action_framework/types';
|
||||
import { GenerativeAIForSecurityConnectorFeatureId } from '@kbn/actions-plugin/common';
|
||||
import {
|
||||
GenerativeAIForSearchPlaygroundConnectorFeatureId,
|
||||
GenerativeAIForSecurityConnectorFeatureId,
|
||||
} from '@kbn/actions-plugin/common';
|
||||
import { urlAllowListValidator } from '@kbn/actions-plugin/server';
|
||||
import { ValidatorServices } from '@kbn/actions-plugin/server/types';
|
||||
import { assertURL } from '@kbn/actions-plugin/server/sub_action_framework/helpers/validators';
|
||||
|
@ -29,7 +32,10 @@ export const getConnectorType = (): SubActionConnectorType<Config, Secrets> => (
|
|||
secrets: SecretsSchema,
|
||||
},
|
||||
validators: [{ type: ValidatorType.CONFIG, validator: configValidator }],
|
||||
supportedFeatureIds: [GenerativeAIForSecurityConnectorFeatureId],
|
||||
supportedFeatureIds: [
|
||||
GenerativeAIForSecurityConnectorFeatureId,
|
||||
GenerativeAIForSearchPlaygroundConnectorFeatureId,
|
||||
],
|
||||
minimumLicenseRequired: 'enterprise' as const,
|
||||
renderParameterTemplates,
|
||||
});
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue