[8.x] [Epic] AI Insights + Assistant - Add "Other" option to the existing OpenAI Connector dropdown list (#8936) (#194831) (#195688)

# Backport

This will backport the following commits from `main` to `8.x`:
- [[Epic] AI Insights + Assistant - Add "Other" option to the
existing OpenAI Connector dropdown list (#8936)
(#194831)](https://github.com/elastic/kibana/pull/194831)

<!--- Backport version: 9.4.3 -->

### Questions ?
Please refer to the [Backport tool
documentation](https://github.com/sqren/backport)

<!--BACKPORT [{"author":{"name":"Ievgen
Sorokopud","email":"ievgen.sorokopud@elastic.co"},"sourceCommit":{"committedDate":"2024-10-09T22:07:31Z","message":"[Epic]
AI Insights + Assistant - Add \"Other\" option to the existing OpenAI
Connector dropdown list (#8936)
(#194831)","sha":"83a701e837a7a84a86dcc8d359154f900f69676a","branchLabelMapping":{"^v9.0.0$":"main","^v8.16.0$":"8.x","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["v9.0.0","release_note:feature","Feature:Security
Assistant","Team:Security Generative
AI","v8.16.0","backport:version"],"title":"[Epic] AI Insights +
Assistant - Add \"Other\" option to the existing OpenAI Connector
dropdown list
(#8936)","number":194831,"url":"https://github.com/elastic/kibana/pull/194831","mergeCommit":{"message":"[Epic]
AI Insights + Assistant - Add \"Other\" option to the existing OpenAI
Connector dropdown list (#8936)
(#194831)","sha":"83a701e837a7a84a86dcc8d359154f900f69676a"}},"sourceBranch":"main","suggestedTargetBranches":["8.x"],"targetPullRequestStates":[{"branch":"main","label":"v9.0.0","branchLabelMappingKey":"^v9.0.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/194831","number":194831,"mergeCommit":{"message":"[Epic]
AI Insights + Assistant - Add \"Other\" option to the existing OpenAI
Connector dropdown list (#8936)
(#194831)","sha":"83a701e837a7a84a86dcc8d359154f900f69676a"}},{"branch":"8.x","label":"v8.16.0","branchLabelMappingKey":"^v8.16.0$","isSourceBranch":false,"state":"NOT_CREATED"}]}]
BACKPORT-->

Co-authored-by: Ievgen Sorokopud <ievgen.sorokopud@elastic.co>
This commit is contained in:
Kibana Machine 2024-10-10 10:56:19 +11:00 committed by GitHub
parent ce302a91fa
commit 0035e94287
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
37 changed files with 915 additions and 32 deletions

View file

@ -22906,6 +22906,7 @@ components:
enum:
- OpenAI
- Azure OpenAI
- Other
type: string
Security_AI_Assistant_API_Reader:
additionalProperties: true

View file

@ -22906,6 +22906,7 @@ components:
enum:
- OpenAI
- Azure OpenAI
- Other
type: string
Security_AI_Assistant_API_Reader:
additionalProperties: true

View file

@ -30731,6 +30731,7 @@ components:
enum:
- OpenAI
- Azure OpenAI
- Other
type: string
Security_AI_Assistant_API_Reader:
additionalProperties: true

View file

@ -30731,6 +30731,7 @@ components:
enum:
- OpenAI
- Azure OpenAI
- Other
type: string
Security_AI_Assistant_API_Reader:
additionalProperties: true

View file

@ -1194,6 +1194,7 @@ components:
enum:
- OpenAI
- Azure OpenAI
- Other
type: string
Reader:
additionalProperties: true

View file

@ -1194,6 +1194,7 @@ components:
enum:
- OpenAI
- Azure OpenAI
- Other
type: string
Reader:
additionalProperties: true

View file

@ -46,7 +46,7 @@ export const Reader = z.object({}).catchall(z.unknown());
* Provider
*/
export type Provider = z.infer<typeof Provider>;
export const Provider = z.enum(['OpenAI', 'Azure OpenAI']);
export const Provider = z.enum(['OpenAI', 'Azure OpenAI', 'Other']);
export type ProviderEnum = typeof Provider.enum;
export const ProviderEnum = Provider.enum;

View file

@ -34,6 +34,7 @@ components:
enum:
- OpenAI
- Azure OpenAI
- Other
MessageRole:
type: string

View file

@ -18,6 +18,7 @@ import { PRECONFIGURED_CONNECTOR } from './translations';
enum OpenAiProviderType {
OpenAi = 'OpenAI',
AzureAi = 'Azure OpenAI',
Other = 'Other',
}
interface GenAiConfig {

View file

@ -1025,15 +1025,17 @@ describe('actions telemetry', () => {
'.d3security': 2,
'.gen-ai__Azure OpenAI': 3,
'.gen-ai__OpenAI': 1,
'.gen-ai__Other': 1,
};
const { countByType, countGenAiProviderTypes } = getCounts(aggs);
expect(countByType).toEqual({
__d3security: 2,
'__gen-ai': 4,
'__gen-ai': 5,
});
expect(countGenAiProviderTypes).toEqual({
'Azure OpenAI': 3,
OpenAI: 1,
Other: 1,
});
});
});

View file

@ -51,6 +51,7 @@ export const byGenAiProviderTypeSchema: MakeSchemaFrom<ActionsUsage>['count_by_t
// Known providers:
['Azure OpenAI']: { type: 'long' },
['OpenAI']: { type: 'long' },
['Other']: { type: 'long' },
};
export const byServiceProviderTypeSchema: MakeSchemaFrom<ActionsUsage>['count_active_email_connectors_by_service_type'] =

View file

@ -65,5 +65,17 @@ describe('Utils', () => {
const isOpenModel = isOpenSourceModel(connector);
expect(isOpenModel).toEqual(true);
});
it('should return `true` when apiProvider of OpenAiProviderType.Other is specified', async () => {
const connector = {
actionTypeId: '.gen-ai',
config: {
apiUrl: OPENAI_CHAT_URL,
apiProvider: OpenAiProviderType.Other,
},
} as unknown as Connector;
const isOpenModel = isOpenSourceModel(connector);
expect(isOpenModel).toEqual(true);
});
});
});

View file

@ -203,19 +203,25 @@ export const isOpenSourceModel = (connector?: Connector): boolean => {
}
const llmType = getLlmType(connector.actionTypeId);
const connectorApiUrl = connector.config?.apiUrl
? (connector.config.apiUrl as string)
: undefined;
const isOpenAiType = llmType === 'openai';
if (!isOpenAiType) {
return false;
}
const connectorApiProvider = connector.config?.apiProvider
? (connector.config?.apiProvider as OpenAiProviderType)
: undefined;
if (connectorApiProvider === OpenAiProviderType.Other) {
return true;
}
const isOpenAiType = llmType === 'openai';
const isOpenAI =
isOpenAiType &&
(!connectorApiUrl ||
connectorApiUrl === OPENAI_CHAT_URL ||
connectorApiProvider === OpenAiProviderType.AzureAi);
const connectorApiUrl = connector.config?.apiUrl
? (connector.config.apiUrl as string)
: undefined;
return isOpenAiType && !isOpenAI;
return (
!!connectorApiUrl &&
connectorApiUrl !== OPENAI_CHAT_URL &&
connectorApiProvider !== OpenAiProviderType.AzureAi
);
};

View file

@ -57,6 +57,7 @@ export enum APIRoutes {
export enum LLMs {
openai = 'openai',
openai_azure = 'openai_azure',
openai_other = 'openai_other',
bedrock = 'bedrock',
gemini = 'gemini',
}

View file

@ -15,9 +15,10 @@ jest.mock('./use_load_connectors', () => ({
}));
const mockConnectors = [
{ id: 'connectorId1', title: 'OpenAI Connector', type: LLMs.openai },
{ id: 'connectorId2', title: 'OpenAI Azure Connector', type: LLMs.openai_azure },
{ id: 'connectorId2', title: 'Bedrock Connector', type: LLMs.bedrock },
{ id: 'connectorId1', name: 'OpenAI Connector', type: LLMs.openai },
{ id: 'connectorId2', name: 'OpenAI Azure Connector', type: LLMs.openai_azure },
{ id: 'connectorId2', name: 'Bedrock Connector', type: LLMs.bedrock },
{ id: 'connectorId3', name: 'OpenAI OSS Model Connector', type: LLMs.openai_other },
];
const mockUseLoadConnectors = (data: any) => {
(useLoadConnectors as jest.Mock).mockReturnValue({ data });
@ -36,7 +37,7 @@ describe('useLLMsModels Hook', () => {
expect(result.current).toEqual([
{
connectorId: 'connectorId1',
connectorName: undefined,
connectorName: 'OpenAI Connector',
connectorType: LLMs.openai,
disabled: false,
icon: expect.any(Function),
@ -48,7 +49,7 @@ describe('useLLMsModels Hook', () => {
},
{
connectorId: 'connectorId1',
connectorName: undefined,
connectorName: 'OpenAI Connector',
connectorType: LLMs.openai,
disabled: false,
icon: expect.any(Function),
@ -60,7 +61,7 @@ describe('useLLMsModels Hook', () => {
},
{
connectorId: 'connectorId1',
connectorName: undefined,
connectorName: 'OpenAI Connector',
connectorType: LLMs.openai,
disabled: false,
icon: expect.any(Function),
@ -72,19 +73,19 @@ describe('useLLMsModels Hook', () => {
},
{
connectorId: 'connectorId2',
connectorName: undefined,
connectorName: 'OpenAI Azure Connector',
connectorType: LLMs.openai_azure,
disabled: false,
icon: expect.any(Function),
id: 'connectorId2Azure OpenAI ',
name: 'Azure OpenAI ',
id: 'connectorId2OpenAI Azure Connector (Azure OpenAI)',
name: 'OpenAI Azure Connector (Azure OpenAI)',
showConnectorName: false,
value: undefined,
promptTokenLimit: undefined,
},
{
connectorId: 'connectorId2',
connectorName: undefined,
connectorName: 'Bedrock Connector',
connectorType: LLMs.bedrock,
disabled: false,
icon: expect.any(Function),
@ -96,7 +97,7 @@ describe('useLLMsModels Hook', () => {
},
{
connectorId: 'connectorId2',
connectorName: undefined,
connectorName: 'Bedrock Connector',
connectorType: LLMs.bedrock,
disabled: false,
icon: expect.any(Function),
@ -106,6 +107,18 @@ describe('useLLMsModels Hook', () => {
value: 'anthropic.claude-3-5-sonnet-20240620-v1:0',
promptTokenLimit: 200000,
},
{
connectorId: 'connectorId3',
connectorName: 'OpenAI OSS Model Connector',
connectorType: LLMs.openai_other,
disabled: false,
icon: expect.any(Function),
id: 'connectorId3OpenAI OSS Model Connector (OpenAI Compatible Service)',
name: 'OpenAI OSS Model Connector (OpenAI Compatible Service)',
showConnectorName: false,
value: undefined,
promptTokenLimit: undefined,
},
]);
});

View file

@ -34,11 +34,22 @@ const mapLlmToModels: Record<
},
[LLMs.openai_azure]: {
icon: OpenAILogo,
getModels: (connectorName, includeName) => [
getModels: (connectorName) => [
{
label: i18n.translate('xpack.searchPlayground.openAIAzureModel', {
defaultMessage: 'Azure OpenAI {name}',
values: { name: includeName ? `(${connectorName})` : '' },
defaultMessage: '{name} (Azure OpenAI)',
values: { name: connectorName },
}),
},
],
},
[LLMs.openai_other]: {
icon: OpenAILogo,
getModels: (connectorName) => [
{
label: i18n.translate('xpack.searchPlayground.otherOpenAIModel', {
defaultMessage: '{name} (OpenAI Compatible Service)',
values: { name: connectorName },
}),
},
],

View file

@ -71,6 +71,12 @@ describe('useLoadConnectors', () => {
actionTypeId: '.bedrock',
isMissingSecrets: false,
},
{
id: '5',
actionTypeId: '.gen-ai',
isMissingSecrets: false,
config: { apiProvider: OpenAiProviderType.Other },
},
];
mockedLoadConnectors.mockResolvedValue(connectors);
@ -106,6 +112,16 @@ describe('useLoadConnectors', () => {
title: 'Bedrock',
type: 'bedrock',
},
{
actionTypeId: '.gen-ai',
config: {
apiProvider: 'Other',
},
id: '5',
isMissingSecrets: false,
title: 'OpenAI Other',
type: 'openai_other',
},
]);
});
});

View file

@ -63,6 +63,20 @@ const connectorTypeToLLM: Array<{
type: LLMs.openai,
}),
},
{
actionId: OPENAI_CONNECTOR_ID,
actionProvider: OpenAiProviderType.Other,
match: (connector) =>
connector.actionTypeId === OPENAI_CONNECTOR_ID &&
(connector as OpenAIConnector)?.config?.apiProvider === OpenAiProviderType.Other,
transform: (connector) => ({
...connector,
title: i18n.translate('xpack.searchPlayground.openAIOtherConnectorTitle', {
defaultMessage: 'OpenAI Other',
}),
type: LLMs.openai_other,
}),
},
{
actionId: BEDROCK_CONNECTOR_ID,
match: (connector) => connector.actionTypeId === BEDROCK_CONNECTOR_ID,

View file

@ -152,4 +152,41 @@ describe('getChatParams', () => {
)
).rejects.toThrow('Invalid connector id');
});
it('returns the correct chat model and uses the default model when not specified in the params', async () => {
mockActionsClient.get.mockResolvedValue({
id: '2',
actionTypeId: OPENAI_CONNECTOR_ID,
config: { defaultModel: 'local' },
});
const result = await getChatParams(
{
connectorId: '2',
prompt: 'How does it work?',
citations: false,
},
{ actions, request, logger }
);
expect(Prompt).toHaveBeenCalledWith('How does it work?', {
citations: false,
context: true,
type: 'openai',
});
expect(QuestionRewritePrompt).toHaveBeenCalledWith({
type: 'openai',
});
expect(ActionsClientChatOpenAI).toHaveBeenCalledWith({
logger: expect.anything(),
model: 'local',
connectorId: '2',
actionsClient: expect.anything(),
signal: expect.anything(),
traceId: 'test-uuid',
temperature: 0.2,
maxRetries: 0,
});
expect(result.chatPrompt).toContain('How does it work?');
});
});

View file

@ -57,7 +57,7 @@ export const getChatParams = async (
actionsClient,
logger,
connectorId,
model,
model: model || connector?.config?.defaultModel,
traceId: uuidv4(),
signal: abortSignal,
temperature: getDefaultArguments().temperature,

View file

@ -18,6 +18,7 @@ import { isEmpty } from 'lodash/fp';
enum OpenAiProviderType {
OpenAi = 'OpenAI',
AzureAi = 'Azure OpenAI',
Other = 'Other',
}
interface GenAiConfig {

View file

@ -27,6 +27,7 @@ export enum SUB_ACTION {
export enum OpenAiProviderType {
OpenAi = 'OpenAI',
AzureAi = 'Azure OpenAI',
Other = 'Other',
}
export const DEFAULT_TIMEOUT_MS = 120000;

View file

@ -21,6 +21,12 @@ export const ConfigSchema = schema.oneOf([
defaultModel: schema.string({ defaultValue: DEFAULT_OPENAI_MODEL }),
headers: schema.maybe(schema.recordOf(schema.string(), schema.string())),
}),
schema.object({
apiProvider: schema.oneOf([schema.literal(OpenAiProviderType.Other)]),
apiUrl: schema.string(),
defaultModel: schema.string(),
headers: schema.maybe(schema.recordOf(schema.string(), schema.string())),
}),
]);
export const SecretsSchema = schema.object({ apiKey: schema.string() });

View file

@ -53,6 +53,7 @@ describe('useGetDashboard', () => {
it.each([
['Azure OpenAI', 'openai'],
['OpenAI', 'openai'],
['Other', 'openai'],
['Bedrock', 'bedrock'],
])(
'fetches the %p dashboard and sets the dashboard URL with %p',

View file

@ -50,6 +50,17 @@ const azureConnector = {
apiKey: 'thats-a-nice-looking-key',
},
};
const otherOpenAiConnector = {
...openAiConnector,
config: {
apiUrl: 'https://localhost/oss-llm',
apiProvider: OpenAiProviderType.Other,
defaultModel: 'local-model',
},
secrets: {
apiKey: 'thats-a-nice-looking-key',
},
};
const navigateToUrl = jest.fn();
@ -93,6 +104,24 @@ describe('ConnectorFields renders', () => {
expect(getAllByTestId('azure-ai-api-keys-doc')[0]).toBeInTheDocument();
});
test('other open ai connector fields are rendered', async () => {
const { getAllByTestId } = render(
<ConnectorFormTestProvider connector={otherOpenAiConnector}>
<ConnectorFields readOnly={false} isEdit={false} registerPreSubmitValidator={() => {}} />
</ConnectorFormTestProvider>
);
expect(getAllByTestId('config.apiUrl-input')[0]).toBeInTheDocument();
expect(getAllByTestId('config.apiUrl-input')[0]).toHaveValue(
otherOpenAiConnector.config.apiUrl
);
expect(getAllByTestId('config.apiProvider-select')[0]).toBeInTheDocument();
expect(getAllByTestId('config.apiProvider-select')[0]).toHaveValue(
otherOpenAiConnector.config.apiProvider
);
expect(getAllByTestId('other-ai-api-doc')[0]).toBeInTheDocument();
expect(getAllByTestId('other-ai-api-keys-doc')[0]).toBeInTheDocument();
});
describe('Dashboard link', () => {
it('Does not render if isEdit is false and dashboardUrl is defined', async () => {
const { queryByTestId } = render(

View file

@ -24,6 +24,8 @@ import * as i18n from './translations';
import {
azureAiConfig,
azureAiSecrets,
otherOpenAiConfig,
otherOpenAiSecrets,
openAiConfig,
openAiSecrets,
providerOptions,
@ -85,6 +87,14 @@ const ConnectorFields: React.FC<ActionConnectorFieldsProps> = ({ readOnly, isEdi
secretsFormSchema={azureAiSecrets}
/>
)}
{config != null && config.apiProvider === OpenAiProviderType.Other && (
<SimpleConnectorForm
isEdit={isEdit}
readOnly={readOnly}
configFormSchema={otherOpenAiConfig}
secretsFormSchema={otherOpenAiSecrets}
/>
)}
{isEdit && (
<DashboardLink
connectorId={id}

View file

@ -92,6 +92,41 @@ export const azureAiConfig: ConfigFieldSchema[] = [
},
];
export const otherOpenAiConfig: ConfigFieldSchema[] = [
{
id: 'apiUrl',
label: i18n.API_URL_LABEL,
isUrlField: true,
helpText: (
<FormattedMessage
defaultMessage="The Other (OpenAI Compatible Service) endpoint URL. For more information on the URL, refer to the {genAiAPIUrlDocs}."
id="xpack.stackConnectors.components.genAi.otherOpenAiDocumentation"
values={{
genAiAPIUrlDocs: (
<EuiLink
data-test-subj="other-ai-api-doc"
href="https://www.elastic.co/guide/en/security/current/connect-to-byo-llm.html"
target="_blank"
>
{`${i18n.OTHER_OPENAI} ${i18n.DOCUMENTATION}`}
</EuiLink>
),
}}
/>
),
},
{
id: 'defaultModel',
label: i18n.DEFAULT_MODEL_LABEL,
helpText: (
<FormattedMessage
defaultMessage="If a request does not include a model, it uses the default."
id="xpack.stackConnectors.components.genAi.otherOpenAiDocumentationModel"
/>
),
},
];
export const openAiSecrets: SecretsFieldSchema[] = [
{
id: 'apiKey',
@ -142,6 +177,31 @@ export const azureAiSecrets: SecretsFieldSchema[] = [
},
];
export const otherOpenAiSecrets: SecretsFieldSchema[] = [
{
id: 'apiKey',
label: i18n.API_KEY_LABEL,
isPasswordField: true,
helpText: (
<FormattedMessage
defaultMessage="The Other (OpenAI Compatible Service) API key for HTTP Basic authentication. For more details about generating Other model API keys, refer to the {genAiAPIKeyDocs}."
id="xpack.stackConnectors.components.genAi.otherOpenAiApiKeyDocumentation"
values={{
genAiAPIKeyDocs: (
<EuiLink
data-test-subj="other-ai-api-keys-doc"
href="https://www.elastic.co/guide/en/security/current/connect-to-byo-llm.html"
target="_blank"
>
{`${i18n.OTHER_OPENAI} ${i18n.DOCUMENTATION}`}
</EuiLink>
),
}}
/>
),
},
];
export const providerOptions = [
{
value: OpenAiProviderType.OpenAi,
@ -153,4 +213,9 @@ export const providerOptions = [
text: i18n.AZURE_AI,
label: i18n.AZURE_AI,
},
{
value: OpenAiProviderType.Other,
text: i18n.OTHER_OPENAI,
label: i18n.OTHER_OPENAI,
},
];

View file

@ -37,7 +37,7 @@ describe('Gen AI Params Fields renders', () => {
expect(getByTestId('bodyJsonEditor')).toHaveProperty('value', '{"message": "test"}');
expect(getByTestId('bodyAddVariableButton')).toBeInTheDocument();
});
test.each([OpenAiProviderType.OpenAi, OpenAiProviderType.AzureAi])(
test.each([OpenAiProviderType.OpenAi, OpenAiProviderType.AzureAi, OpenAiProviderType.Other])(
'useEffect handles the case when subAction and subActionParams are undefined and apiProvider is %p',
(apiProvider) => {
const actionParams = {
@ -79,6 +79,9 @@ describe('Gen AI Params Fields renders', () => {
if (apiProvider === OpenAiProviderType.AzureAi) {
expect(editAction).toHaveBeenCalledWith('subActionParams', { body: DEFAULT_BODY_AZURE }, 0);
}
if (apiProvider === OpenAiProviderType.Other) {
expect(editAction).toHaveBeenCalledWith('subActionParams', { body: DEFAULT_BODY }, 0);
}
}
);

View file

@ -47,6 +47,10 @@ export const AZURE_AI = i18n.translate('xpack.stackConnectors.components.genAi.a
defaultMessage: 'Azure OpenAI',
});
export const OTHER_OPENAI = i18n.translate('xpack.stackConnectors.components.genAi.otherAi', {
defaultMessage: 'Other (OpenAI Compatible Service)',
});
export const DOCUMENTATION = i18n.translate(
'xpack.stackConnectors.components.genAi.documentation',
{

View file

@ -53,7 +53,11 @@ export const configValidator = (configObject: Config, validatorServices: Validat
const { apiProvider } = configObject;
if (apiProvider !== OpenAiProviderType.OpenAi && apiProvider !== OpenAiProviderType.AzureAi) {
if (
apiProvider !== OpenAiProviderType.OpenAi &&
apiProvider !== OpenAiProviderType.AzureAi &&
apiProvider !== OpenAiProviderType.Other
) {
throw new Error(
`API Provider is not supported${
apiProvider && (apiProvider as OpenAiProviderType).length ? `: ${apiProvider}` : ``

View file

@ -0,0 +1,116 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { sanitizeRequest, getRequestWithStreamOption } from './other_openai_utils';
describe('Other (OpenAI Compatible Service) Utils', () => {
describe('sanitizeRequest', () => {
it('sets stream to false when stream is set to true in the body', () => {
const body = {
model: 'mistral',
stream: true,
messages: [
{
role: 'user',
content: 'This is a test',
},
],
};
const sanitizedBodyString = sanitizeRequest(JSON.stringify(body));
expect(sanitizedBodyString).toEqual(
`{\"model\":\"mistral\",\"stream\":false,\"messages\":[{\"role\":\"user\",\"content\":\"This is a test\"}]}`
);
});
it('sets stream to false when stream is not defined in the body', () => {
const body = {
model: 'mistral',
messages: [
{
role: 'user',
content: 'This is a test',
},
],
};
const sanitizedBodyString = sanitizeRequest(JSON.stringify(body));
expect(sanitizedBodyString).toEqual(
`{\"model\":\"mistral\",\"messages\":[{\"role\":\"user\",\"content\":\"This is a test\"}],\"stream\":false}`
);
});
it('sets stream to false when stream is set to false in the body', () => {
const body = {
model: 'mistral',
stream: false,
messages: [
{
role: 'user',
content: 'This is a test',
},
],
};
const sanitizedBodyString = sanitizeRequest(JSON.stringify(body));
expect(sanitizedBodyString).toEqual(
`{\"model\":\"mistral\",\"stream\":false,\"messages\":[{\"role\":\"user\",\"content\":\"This is a test\"}]}`
);
});
it('does nothing when body is malformed JSON', () => {
const bodyString = `{\"model\":\"mistral\",\"messages\":[{\"role\":\"user\",\"content\":\"This is a test\"}],,}`;
const sanitizedBodyString = sanitizeRequest(bodyString);
expect(sanitizedBodyString).toEqual(bodyString);
});
});
describe('getRequestWithStreamOption', () => {
it('sets stream parameter when stream is not defined in the body', () => {
const body = {
model: 'mistral',
messages: [
{
role: 'user',
content: 'This is a test',
},
],
};
const sanitizedBodyString = getRequestWithStreamOption(JSON.stringify(body), true);
expect(sanitizedBodyString).toEqual(
`{\"model\":\"mistral\",\"messages\":[{\"role\":\"user\",\"content\":\"This is a test\"}],\"stream\":true}`
);
});
it('overrides stream parameter if defined in body', () => {
const body = {
model: 'mistral',
stream: true,
messages: [
{
role: 'user',
content: 'This is a test',
},
],
};
const sanitizedBodyString = getRequestWithStreamOption(JSON.stringify(body), false);
expect(sanitizedBodyString).toEqual(
`{\"model\":\"mistral\",\"stream\":false,\"messages\":[{\"role\":\"user\",\"content\":\"This is a test\"}]}`
);
});
it('does nothing when body is malformed JSON', () => {
const bodyString = `{\"model\":\"mistral\",\"messages\":[{\"role\":\"user\",\"content\":\"This is a test\"}],,}`;
const sanitizedBodyString = getRequestWithStreamOption(bodyString, false);
expect(sanitizedBodyString).toEqual(bodyString);
});
});
});

View file

@ -0,0 +1,39 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
/**
* Sanitizes the Other (OpenAI Compatible Service) request body to set stream to false
* so users cannot specify a streaming response when the framework
* is not prepared to handle streaming
*
* The stream parameter is accepted in the ChatCompletion
* API and the Completion API only
*/
export const sanitizeRequest = (body: string): string => {
return getRequestWithStreamOption(body, false);
};
/**
* Intercepts the Other (OpenAI Compatible Service) request body to set the stream parameter
*
* The stream parameter is accepted in the ChatCompletion
* API and the Completion API only
*/
export const getRequestWithStreamOption = (body: string, stream: boolean): string => {
try {
const jsonBody = JSON.parse(body);
if (jsonBody) {
jsonBody.stream = stream;
}
return JSON.stringify(jsonBody);
} catch (err) {
// swallow the error
}
return body;
};

View file

@ -19,8 +19,14 @@ import {
sanitizeRequest as azureAiSanitizeRequest,
getRequestWithStreamOption as azureAiGetRequestWithStreamOption,
} from './azure_openai_utils';
import {
sanitizeRequest as otherOpenAiSanitizeRequest,
getRequestWithStreamOption as otherOpenAiGetRequestWithStreamOption,
} from './other_openai_utils';
jest.mock('./openai_utils');
jest.mock('./azure_openai_utils');
jest.mock('./other_openai_utils');
describe('Utils', () => {
const azureAiUrl =
@ -38,6 +44,7 @@ describe('Utils', () => {
describe('sanitizeRequest', () => {
const mockOpenAiSanitizeRequest = openAiSanitizeRequest as jest.Mock;
const mockAzureAiSanitizeRequest = azureAiSanitizeRequest as jest.Mock;
const mockOtherOpenAiSanitizeRequest = otherOpenAiSanitizeRequest as jest.Mock;
beforeEach(() => {
jest.clearAllMocks();
});
@ -50,24 +57,36 @@ describe('Utils', () => {
DEFAULT_OPENAI_MODEL
);
expect(mockAzureAiSanitizeRequest).not.toHaveBeenCalled();
expect(mockOtherOpenAiSanitizeRequest).not.toHaveBeenCalled();
});
it('calls other_openai_utils sanitizeRequest when provider is Other OpenAi', () => {
sanitizeRequest(OpenAiProviderType.Other, OPENAI_CHAT_URL, bodyString, DEFAULT_OPENAI_MODEL);
expect(mockOtherOpenAiSanitizeRequest).toHaveBeenCalledWith(bodyString);
expect(mockOpenAiSanitizeRequest).not.toHaveBeenCalled();
expect(mockAzureAiSanitizeRequest).not.toHaveBeenCalled();
});
it('calls azure_openai_utils sanitizeRequest when provider is AzureAi', () => {
sanitizeRequest(OpenAiProviderType.AzureAi, azureAiUrl, bodyString);
expect(mockAzureAiSanitizeRequest).toHaveBeenCalledWith(azureAiUrl, bodyString);
expect(mockOpenAiSanitizeRequest).not.toHaveBeenCalled();
expect(mockOtherOpenAiSanitizeRequest).not.toHaveBeenCalled();
});
it('does not call any helper fns when provider is unrecognized', () => {
sanitizeRequest('foo', OPENAI_CHAT_URL, bodyString);
expect(mockOpenAiSanitizeRequest).not.toHaveBeenCalled();
expect(mockAzureAiSanitizeRequest).not.toHaveBeenCalled();
expect(mockOtherOpenAiSanitizeRequest).not.toHaveBeenCalled();
});
});
describe('getRequestWithStreamOption', () => {
const mockOpenAiGetRequestWithStreamOption = openAiGetRequestWithStreamOption as jest.Mock;
const mockAzureAiGetRequestWithStreamOption = azureAiGetRequestWithStreamOption as jest.Mock;
const mockOtherOpenAiGetRequestWithStreamOption =
otherOpenAiGetRequestWithStreamOption as jest.Mock;
beforeEach(() => {
jest.clearAllMocks();
});
@ -88,6 +107,15 @@ describe('Utils', () => {
DEFAULT_OPENAI_MODEL
);
expect(mockAzureAiGetRequestWithStreamOption).not.toHaveBeenCalled();
expect(mockOtherOpenAiGetRequestWithStreamOption).not.toHaveBeenCalled();
});
it('calls other_openai_utils getRequestWithStreamOption when provider is Other OpenAi', () => {
getRequestWithStreamOption(OpenAiProviderType.Other, OPENAI_CHAT_URL, bodyString, true);
expect(mockOtherOpenAiGetRequestWithStreamOption).toHaveBeenCalledWith(bodyString, true);
expect(mockOpenAiGetRequestWithStreamOption).not.toHaveBeenCalled();
expect(mockAzureAiGetRequestWithStreamOption).not.toHaveBeenCalled();
});
it('calls azure_openai_utils getRequestWithStreamOption when provider is AzureAi', () => {
@ -99,6 +127,7 @@ describe('Utils', () => {
true
);
expect(mockOpenAiGetRequestWithStreamOption).not.toHaveBeenCalled();
expect(mockOtherOpenAiGetRequestWithStreamOption).not.toHaveBeenCalled();
});
it('does not call any helper fns when provider is unrecognized', () => {
@ -110,6 +139,7 @@ describe('Utils', () => {
);
expect(mockOpenAiGetRequestWithStreamOption).not.toHaveBeenCalled();
expect(mockAzureAiGetRequestWithStreamOption).not.toHaveBeenCalled();
expect(mockOtherOpenAiGetRequestWithStreamOption).not.toHaveBeenCalled();
});
});
@ -127,6 +157,19 @@ describe('Utils', () => {
});
});
it('returns correct axios options when provider is other openai and stream is false', () => {
expect(getAxiosOptions(OpenAiProviderType.Other, 'api-abc', false)).toEqual({
headers: { Authorization: `Bearer api-abc`, ['content-type']: 'application/json' },
});
});
it('returns correct axios options when provider is other openai and stream is true', () => {
expect(getAxiosOptions(OpenAiProviderType.Other, 'api-abc', true)).toEqual({
headers: { Authorization: `Bearer api-abc`, ['content-type']: 'application/json' },
responseType: 'stream',
});
});
it('returns correct axios options when provider is azure openai and stream is false', () => {
expect(getAxiosOptions(OpenAiProviderType.AzureAi, 'api-abc', false)).toEqual({
headers: { ['api-key']: `api-abc`, ['content-type']: 'application/json' },

View file

@ -16,6 +16,10 @@ import {
sanitizeRequest as azureAiSanitizeRequest,
getRequestWithStreamOption as azureAiGetRequestWithStreamOption,
} from './azure_openai_utils';
import {
sanitizeRequest as otherOpenAiSanitizeRequest,
getRequestWithStreamOption as otherOpenAiGetRequestWithStreamOption,
} from './other_openai_utils';
export const sanitizeRequest = (
provider: string,
@ -28,6 +32,8 @@ export const sanitizeRequest = (
return openAiSanitizeRequest(url, body, defaultModel!);
case OpenAiProviderType.AzureAi:
return azureAiSanitizeRequest(url, body);
case OpenAiProviderType.Other:
return otherOpenAiSanitizeRequest(body);
default:
return body;
}
@ -42,7 +48,7 @@ export function getRequestWithStreamOption(
): string;
export function getRequestWithStreamOption(
provider: OpenAiProviderType.AzureAi,
provider: OpenAiProviderType.AzureAi | OpenAiProviderType.Other,
url: string,
body: string,
stream: boolean
@ -68,6 +74,8 @@ export function getRequestWithStreamOption(
return openAiGetRequestWithStreamOption(url, body, stream, defaultModel!);
case OpenAiProviderType.AzureAi:
return azureAiGetRequestWithStreamOption(url, body, stream);
case OpenAiProviderType.Other:
return otherOpenAiGetRequestWithStreamOption(body, stream);
default:
return body;
}
@ -81,6 +89,7 @@ export const getAxiosOptions = (
const responseType = stream ? { responseType: 'stream' as ResponseType } : {};
switch (provider) {
case OpenAiProviderType.OpenAi:
case OpenAiProviderType.Other:
return {
headers: { Authorization: `Bearer ${apiKey}`, ['content-type']: 'application/json' },
...responseType,

View file

@ -20,6 +20,9 @@ import { RunActionResponseSchema, StreamingResponseSchema } from '../../../commo
import { initDashboard } from '../lib/gen_ai/create_gen_ai_dashboard';
import { PassThrough, Transform } from 'stream';
import { ConnectorUsageCollector } from '@kbn/actions-plugin/server/types';
const DEFAULT_OTHER_OPENAI_MODEL = 'local-model';
jest.mock('../lib/gen_ai/create_gen_ai_dashboard');
const mockTee = jest.fn();
@ -713,6 +716,431 @@ describe('OpenAIConnector', () => {
});
});
describe('Other OpenAI', () => {
const connector = new OpenAIConnector({
configurationUtilities: actionsConfigMock.create(),
connector: { id: '1', type: OPENAI_CONNECTOR_ID },
config: {
apiUrl: 'http://localhost:1234/v1/chat/completions',
apiProvider: OpenAiProviderType.Other,
defaultModel: DEFAULT_OTHER_OPENAI_MODEL,
headers: {
'X-My-Custom-Header': 'foo',
Authorization: 'override',
},
},
secrets: { apiKey: '123' },
logger,
services: actionsMock.createServices(),
});
const sampleOpenAiBody = {
model: DEFAULT_OTHER_OPENAI_MODEL,
messages: [
{
role: 'user',
content: 'Hello world',
},
],
};
beforeEach(() => {
// @ts-ignore
connector.request = mockRequest;
jest.clearAllMocks();
});
describe('runApi', () => {
it('the Other OpenAI API call is successful with correct parameters', async () => {
const response = await connector.runApi(
{ body: JSON.stringify(sampleOpenAiBody) },
connectorUsageCollector
);
expect(mockRequest).toBeCalledTimes(1);
expect(mockRequest).toHaveBeenCalledWith(
{
...mockDefaults,
url: 'http://localhost:1234/v1/chat/completions',
data: JSON.stringify({
...sampleOpenAiBody,
stream: false,
model: DEFAULT_OTHER_OPENAI_MODEL,
}),
headers: {
Authorization: 'Bearer 123',
'X-My-Custom-Header': 'foo',
'content-type': 'application/json',
},
},
connectorUsageCollector
);
expect(response).toEqual(mockResponse.data);
});
it('overrides stream parameter if set in the body', async () => {
const body = {
model: 'llama-3.1',
messages: [
{
role: 'user',
content: 'Hello world',
},
],
};
const response = await connector.runApi(
{
body: JSON.stringify({
...body,
stream: true,
}),
},
connectorUsageCollector
);
expect(mockRequest).toBeCalledTimes(1);
expect(mockRequest).toHaveBeenCalledWith(
{
...mockDefaults,
url: 'http://localhost:1234/v1/chat/completions',
data: JSON.stringify({
...body,
stream: false,
}),
headers: {
Authorization: 'Bearer 123',
'X-My-Custom-Header': 'foo',
'content-type': 'application/json',
},
},
connectorUsageCollector
);
expect(response).toEqual(mockResponse.data);
});
it('errors during API calls are properly handled', async () => {
// @ts-ignore
connector.request = mockError;
await expect(
connector.runApi({ body: JSON.stringify(sampleOpenAiBody) }, connectorUsageCollector)
).rejects.toThrow('API Error');
});
});
describe('streamApi', () => {
it('the Other OpenAI API call is successful with correct parameters when stream = false', async () => {
const response = await connector.streamApi(
{
body: JSON.stringify(sampleOpenAiBody),
stream: false,
},
connectorUsageCollector
);
expect(mockRequest).toBeCalledTimes(1);
expect(mockRequest).toHaveBeenCalledWith(
{
url: 'http://localhost:1234/v1/chat/completions',
method: 'post',
responseSchema: RunActionResponseSchema,
data: JSON.stringify({
...sampleOpenAiBody,
stream: false,
}),
headers: {
Authorization: 'Bearer 123',
'X-My-Custom-Header': 'foo',
'content-type': 'application/json',
},
},
connectorUsageCollector
);
expect(response).toEqual(mockResponse.data);
});
it('the Other OpenAI API call is successful with correct parameters when stream = true', async () => {
const response = await connector.streamApi(
{
body: JSON.stringify(sampleOpenAiBody),
stream: true,
},
connectorUsageCollector
);
expect(mockRequest).toBeCalledTimes(1);
expect(mockRequest).toHaveBeenCalledWith(
{
responseType: 'stream',
url: 'http://localhost:1234/v1/chat/completions',
method: 'post',
responseSchema: StreamingResponseSchema,
data: JSON.stringify({
...sampleOpenAiBody,
stream: true,
model: DEFAULT_OTHER_OPENAI_MODEL,
}),
headers: {
Authorization: 'Bearer 123',
'X-My-Custom-Header': 'foo',
'content-type': 'application/json',
},
},
connectorUsageCollector
);
expect(response).toEqual({
headers: { 'Content-Type': 'dont-compress-this' },
...mockResponse.data,
});
});
it('overrides stream parameter if set in the body with explicit stream parameter', async () => {
const body = {
model: 'llama-3.1',
messages: [
{
role: 'user',
content: 'Hello world',
},
],
};
const response = await connector.streamApi(
{
body: JSON.stringify({
...body,
stream: false,
}),
stream: true,
},
connectorUsageCollector
);
expect(mockRequest).toBeCalledTimes(1);
expect(mockRequest).toHaveBeenCalledWith(
{
responseType: 'stream',
url: 'http://localhost:1234/v1/chat/completions',
method: 'post',
responseSchema: StreamingResponseSchema,
data: JSON.stringify({
...body,
stream: true,
}),
headers: {
Authorization: 'Bearer 123',
'X-My-Custom-Header': 'foo',
'content-type': 'application/json',
},
},
connectorUsageCollector
);
expect(response).toEqual({
headers: { 'Content-Type': 'dont-compress-this' },
...mockResponse.data,
});
});
it('errors during API calls are properly handled', async () => {
// @ts-ignore
connector.request = mockError;
await expect(
connector.streamApi(
{ body: JSON.stringify(sampleOpenAiBody), stream: true },
connectorUsageCollector
)
).rejects.toThrow('API Error');
});
});
describe('invokeStream', () => {
const mockStream = (
dataToStream: string[] = [
'data: {"object":"chat.completion.chunk","choices":[{"delta":{"content":"My"}}]}\ndata: {"object":"chat.completion.chunk","choices":[{"delta":{"content":" new"}}]}',
]
) => {
const streamMock = createStreamMock();
dataToStream.forEach((chunk) => {
streamMock.write(chunk);
});
streamMock.complete();
mockRequest = jest.fn().mockResolvedValue({ ...mockResponse, data: streamMock.transform });
return mockRequest;
};
beforeEach(() => {
// @ts-ignore
connector.request = mockStream();
});
it('the API call is successful with correct request parameters', async () => {
await connector.invokeStream(sampleOpenAiBody, connectorUsageCollector);
expect(mockRequest).toBeCalledTimes(1);
expect(mockRequest).toHaveBeenCalledWith(
{
url: 'http://localhost:1234/v1/chat/completions',
method: 'post',
responseSchema: StreamingResponseSchema,
responseType: 'stream',
data: JSON.stringify({
...sampleOpenAiBody,
stream: true,
}),
headers: {
Authorization: 'Bearer 123',
'X-My-Custom-Header': 'foo',
'content-type': 'application/json',
},
},
connectorUsageCollector
);
});
it('signal is properly passed to streamApi', async () => {
const signal = jest.fn();
await connector.invokeStream({ ...sampleOpenAiBody, signal }, connectorUsageCollector);
expect(mockRequest).toHaveBeenCalledWith(
{
url: 'http://localhost:1234/v1/chat/completions',
method: 'post',
responseSchema: StreamingResponseSchema,
responseType: 'stream',
data: JSON.stringify({
...sampleOpenAiBody,
stream: true,
}),
headers: {
Authorization: 'Bearer 123',
'X-My-Custom-Header': 'foo',
'content-type': 'application/json',
},
signal,
},
connectorUsageCollector
);
});
it('timeout is properly passed to streamApi', async () => {
const timeout = 180000;
await connector.invokeStream({ ...sampleOpenAiBody, timeout }, connectorUsageCollector);
expect(mockRequest).toHaveBeenCalledWith(
{
url: 'http://localhost:1234/v1/chat/completions',
method: 'post',
responseSchema: StreamingResponseSchema,
responseType: 'stream',
data: JSON.stringify({
...sampleOpenAiBody,
stream: true,
}),
headers: {
Authorization: 'Bearer 123',
'X-My-Custom-Header': 'foo',
'content-type': 'application/json',
},
timeout,
},
connectorUsageCollector
);
});
it('errors during API calls are properly handled', async () => {
// @ts-ignore
connector.request = mockError;
await expect(
connector.invokeStream(sampleOpenAiBody, connectorUsageCollector)
).rejects.toThrow('API Error');
});
it('responds with a readable stream', async () => {
// @ts-ignore
connector.request = mockStream();
const response = await connector.invokeStream(sampleOpenAiBody, connectorUsageCollector);
expect(response instanceof PassThrough).toEqual(true);
});
});
describe('invokeAI', () => {
it('the API call is successful with correct parameters', async () => {
const response = await connector.invokeAI(sampleOpenAiBody, connectorUsageCollector);
expect(mockRequest).toBeCalledTimes(1);
expect(mockRequest).toHaveBeenCalledWith(
{
...mockDefaults,
url: 'http://localhost:1234/v1/chat/completions',
data: JSON.stringify({
...sampleOpenAiBody,
stream: false,
model: DEFAULT_OTHER_OPENAI_MODEL,
}),
headers: {
Authorization: 'Bearer 123',
'X-My-Custom-Header': 'foo',
'content-type': 'application/json',
},
},
connectorUsageCollector
);
expect(response.message).toEqual(mockResponseString);
expect(response.usage.total_tokens).toEqual(9);
});
it('signal is properly passed to runApi', async () => {
const signal = jest.fn();
await connector.invokeAI({ ...sampleOpenAiBody, signal }, connectorUsageCollector);
expect(mockRequest).toHaveBeenCalledWith(
{
...mockDefaults,
url: 'http://localhost:1234/v1/chat/completions',
data: JSON.stringify({
...sampleOpenAiBody,
stream: false,
model: DEFAULT_OTHER_OPENAI_MODEL,
}),
headers: {
Authorization: 'Bearer 123',
'X-My-Custom-Header': 'foo',
'content-type': 'application/json',
},
signal,
},
connectorUsageCollector
);
});
it('timeout is properly passed to runApi', async () => {
const timeout = 180000;
await connector.invokeAI({ ...sampleOpenAiBody, timeout }, connectorUsageCollector);
expect(mockRequest).toHaveBeenCalledWith(
{
...mockDefaults,
url: 'http://localhost:1234/v1/chat/completions',
data: JSON.stringify({
...sampleOpenAiBody,
stream: false,
model: DEFAULT_OTHER_OPENAI_MODEL,
}),
headers: {
Authorization: 'Bearer 123',
'X-My-Custom-Header': 'foo',
'content-type': 'application/json',
},
timeout,
},
connectorUsageCollector
);
});
it('errors during API calls are properly handled', async () => {
// @ts-ignore
connector.request = mockError;
await expect(connector.invokeAI(sampleOpenAiBody, connectorUsageCollector)).rejects.toThrow(
'API Error'
);
});
});
});
describe('AzureAI', () => {
const connector = new OpenAIConnector({
configurationUtilities: actionsConfigMock.create(),

View file

@ -73,6 +73,9 @@
},
"[OpenAI]": {
"type": "long"
},
"[Other]": {
"type": "long"
}
}
},

View file

@ -147,7 +147,7 @@ export default function genAiTest({ getService }: FtrProviderContext) {
statusCode: 400,
error: 'Bad Request',
message:
'error validating action type config: types that failed validation:\n- [0.apiProvider]: expected at least one defined value but got [undefined]\n- [1.apiProvider]: expected at least one defined value but got [undefined]',
'error validating action type config: types that failed validation:\n- [0.apiProvider]: expected at least one defined value but got [undefined]\n- [1.apiProvider]: expected at least one defined value but got [undefined]\n- [2.apiProvider]: expected at least one defined value but got [undefined]',
});
});
});
@ -168,7 +168,7 @@ export default function genAiTest({ getService }: FtrProviderContext) {
statusCode: 400,
error: 'Bad Request',
message:
'error validating action type config: types that failed validation:\n- [0.apiProvider]: expected value to equal [Azure OpenAI]\n- [1.apiUrl]: expected value of type [string] but got [undefined]',
'error validating action type config: types that failed validation:\n- [0.apiProvider]: expected value to equal [Azure OpenAI]\n- [1.apiUrl]: expected value of type [string] but got [undefined]\n- [2.apiProvider]: expected value to equal [Other]',
});
});
});