[Security solution] Handle Gemini finishReason: SAFETY (#192304)

This commit is contained in:
Steph Milovic 2024-09-11 12:18:44 -06:00 committed by GitHub
parent f22067fbc9
commit 53e88ec03a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 65 additions and 6 deletions

View file

@ -16,6 +16,8 @@ import {
POSSIBLE_ROLES,
Part,
TextPart,
FinishReason,
SafetyRating,
} from '@google/generative-ai';
import { ActionsClient } from '@kbn/actions-plugin/server';
import { PublicMethodsOf } from '@kbn/utility-types';
@ -46,6 +48,12 @@ export interface CustomChatModelInput extends BaseChatModelParams {
maxTokens?: number;
}
// not sure why these properties are not on the type, as they are on the data
interface SafetyReason extends SafetyRating {
blocked: boolean;
severity: string;
}
export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI {
#actionsClient: PublicMethodsOf<ActionsClient>;
#connectorId: string;
@ -100,6 +108,14 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI {
);
}
if (actionResult.data.candidates && actionResult.data.candidates.length > 0) {
// handle bad finish reason
const errorMessage = convertResponseBadFinishReasonToErrorMsg(actionResult.data);
if (errorMessage != null) {
throw new Error(errorMessage);
}
}
return {
response: {
...actionResult.data,
@ -239,6 +255,12 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI {
yield chunk;
await runManager?.handleLLMNewToken(chunk.text ?? '');
}
} else if (parsedStreamChunk) {
// handle bad finish reason
const errorMessage = convertResponseBadFinishReasonToErrorMsg(parsedStreamChunk);
if (errorMessage != null) {
throw new Error(errorMessage);
}
}
}
}
@ -460,3 +482,40 @@ function messageContentMedia(content: Record<string, unknown>): InlineDataPart {
}
throw new Error('Invalid media content');
}
const badFinishReasons = [FinishReason.RECITATION, FinishReason.SAFETY];
function hadBadFinishReason(candidate: { finishReason?: FinishReason }) {
return !!candidate.finishReason && badFinishReasons.includes(candidate.finishReason);
}
export function convertResponseBadFinishReasonToErrorMsg(
response: EnhancedGenerateContentResponse
): string | null {
if (response.candidates && response.candidates.length > 0) {
const candidate = response.candidates[0];
if (hadBadFinishReason(candidate)) {
if (
candidate.finishReason === FinishReason.SAFETY &&
candidate.safetyRatings &&
(candidate.safetyRatings?.length ?? 0) > 0
) {
const safetyReasons = getSafetyReasons(candidate.safetyRatings as SafetyReason[]);
return `ActionsClientGeminiChatModel: action result status is error. Candidate was blocked due to ${candidate.finishReason} - ${safetyReasons}`;
} else {
return `ActionsClientGeminiChatModel: action result status is error. Candidate was blocked due to ${candidate.finishReason}`;
}
}
}
return null;
}
const getSafetyReasons = (safetyRatings: SafetyReason[]) => {
const reasons = safetyRatings.filter((t: SafetyReason) => t.blocked);
return reasons.reduce(
(acc: string, t: SafetyReason, i: number) =>
`${acc.length ? `${acc} ` : ''}${t.category}: ${t.severity}${
i < reasons.length - 1 ? ',' : ''
}`,
''
);
};

View file

@ -605,15 +605,15 @@ export class ActionExecutor {
.then((tokenTracking) => {
if (tokenTracking != null) {
set(event, 'kibana.action.execution.gen_ai.usage', {
total_tokens: tokenTracking.total_tokens,
prompt_tokens: tokenTracking.prompt_tokens,
completion_tokens: tokenTracking.completion_tokens,
total_tokens: tokenTracking.total_tokens ?? 0,
prompt_tokens: tokenTracking.prompt_tokens ?? 0,
completion_tokens: tokenTracking.completion_tokens ?? 0,
});
analyticsService.reportEvent(GEN_AI_TOKEN_COUNT_EVENT.eventType, {
actionTypeId,
total_tokens: tokenTracking.total_tokens,
prompt_tokens: tokenTracking.prompt_tokens,
completion_tokens: tokenTracking.completion_tokens,
total_tokens: tokenTracking.total_tokens ?? 0,
prompt_tokens: tokenTracking.prompt_tokens ?? 0,
completion_tokens: tokenTracking.completion_tokens ?? 0,
...(actionTypeId === '.gen-ai' && config?.apiProvider != null
? { provider: config?.apiProvider }
: {}),