mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 01:13:23 -04:00
[Security solution] Handle Gemini finishReason: SAFETY
(#192304)
This commit is contained in:
parent
f22067fbc9
commit
53e88ec03a
2 changed files with 65 additions and 6 deletions
|
@ -16,6 +16,8 @@ import {
|
|||
POSSIBLE_ROLES,
|
||||
Part,
|
||||
TextPart,
|
||||
FinishReason,
|
||||
SafetyRating,
|
||||
} from '@google/generative-ai';
|
||||
import { ActionsClient } from '@kbn/actions-plugin/server';
|
||||
import { PublicMethodsOf } from '@kbn/utility-types';
|
||||
|
@ -46,6 +48,12 @@ export interface CustomChatModelInput extends BaseChatModelParams {
|
|||
maxTokens?: number;
|
||||
}
|
||||
|
||||
// not sure why these properties are not on the type, as they are on the data
|
||||
interface SafetyReason extends SafetyRating {
|
||||
blocked: boolean;
|
||||
severity: string;
|
||||
}
|
||||
|
||||
export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI {
|
||||
#actionsClient: PublicMethodsOf<ActionsClient>;
|
||||
#connectorId: string;
|
||||
|
@ -100,6 +108,14 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI {
|
|||
);
|
||||
}
|
||||
|
||||
if (actionResult.data.candidates && actionResult.data.candidates.length > 0) {
|
||||
// handle bad finish reason
|
||||
const errorMessage = convertResponseBadFinishReasonToErrorMsg(actionResult.data);
|
||||
if (errorMessage != null) {
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
response: {
|
||||
...actionResult.data,
|
||||
|
@ -239,6 +255,12 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI {
|
|||
yield chunk;
|
||||
await runManager?.handleLLMNewToken(chunk.text ?? '');
|
||||
}
|
||||
} else if (parsedStreamChunk) {
|
||||
// handle bad finish reason
|
||||
const errorMessage = convertResponseBadFinishReasonToErrorMsg(parsedStreamChunk);
|
||||
if (errorMessage != null) {
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -460,3 +482,40 @@ function messageContentMedia(content: Record<string, unknown>): InlineDataPart {
|
|||
}
|
||||
throw new Error('Invalid media content');
|
||||
}
|
||||
|
||||
const badFinishReasons = [FinishReason.RECITATION, FinishReason.SAFETY];
|
||||
function hadBadFinishReason(candidate: { finishReason?: FinishReason }) {
|
||||
return !!candidate.finishReason && badFinishReasons.includes(candidate.finishReason);
|
||||
}
|
||||
|
||||
export function convertResponseBadFinishReasonToErrorMsg(
|
||||
response: EnhancedGenerateContentResponse
|
||||
): string | null {
|
||||
if (response.candidates && response.candidates.length > 0) {
|
||||
const candidate = response.candidates[0];
|
||||
if (hadBadFinishReason(candidate)) {
|
||||
if (
|
||||
candidate.finishReason === FinishReason.SAFETY &&
|
||||
candidate.safetyRatings &&
|
||||
(candidate.safetyRatings?.length ?? 0) > 0
|
||||
) {
|
||||
const safetyReasons = getSafetyReasons(candidate.safetyRatings as SafetyReason[]);
|
||||
return `ActionsClientGeminiChatModel: action result status is error. Candidate was blocked due to ${candidate.finishReason} - ${safetyReasons}`;
|
||||
} else {
|
||||
return `ActionsClientGeminiChatModel: action result status is error. Candidate was blocked due to ${candidate.finishReason}`;
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
const getSafetyReasons = (safetyRatings: SafetyReason[]) => {
|
||||
const reasons = safetyRatings.filter((t: SafetyReason) => t.blocked);
|
||||
return reasons.reduce(
|
||||
(acc: string, t: SafetyReason, i: number) =>
|
||||
`${acc.length ? `${acc} ` : ''}${t.category}: ${t.severity}${
|
||||
i < reasons.length - 1 ? ',' : ''
|
||||
}`,
|
||||
''
|
||||
);
|
||||
};
|
||||
|
|
|
@ -605,15 +605,15 @@ export class ActionExecutor {
|
|||
.then((tokenTracking) => {
|
||||
if (tokenTracking != null) {
|
||||
set(event, 'kibana.action.execution.gen_ai.usage', {
|
||||
total_tokens: tokenTracking.total_tokens,
|
||||
prompt_tokens: tokenTracking.prompt_tokens,
|
||||
completion_tokens: tokenTracking.completion_tokens,
|
||||
total_tokens: tokenTracking.total_tokens ?? 0,
|
||||
prompt_tokens: tokenTracking.prompt_tokens ?? 0,
|
||||
completion_tokens: tokenTracking.completion_tokens ?? 0,
|
||||
});
|
||||
analyticsService.reportEvent(GEN_AI_TOKEN_COUNT_EVENT.eventType, {
|
||||
actionTypeId,
|
||||
total_tokens: tokenTracking.total_tokens,
|
||||
prompt_tokens: tokenTracking.prompt_tokens,
|
||||
completion_tokens: tokenTracking.completion_tokens,
|
||||
total_tokens: tokenTracking.total_tokens ?? 0,
|
||||
prompt_tokens: tokenTracking.prompt_tokens ?? 0,
|
||||
completion_tokens: tokenTracking.completion_tokens ?? 0,
|
||||
...(actionTypeId === '.gen-ai' && config?.apiProvider != null
|
||||
? { provider: config?.apiProvider }
|
||||
: {}),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue