mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 09:19:04 -04:00
# Backport This will backport the following commits from `main` to `8.x`: - [[inference] add support for `temperature` parameter (#206479)](https://github.com/elastic/kibana/pull/206479) <!--- Backport version: 9.4.3 --> ### Questions ? Please refer to the [Backport tool documentation](https://github.com/sqren/backport) <!--BACKPORT [{"author":{"name":"Pierre Gayvallet","email":"pierre.gayvallet@elastic.co"},"sourceCommit":{"committedDate":"2025-01-14T13:05:21Z","message":"[inference] add support for `temperature` parameter (#206479)\n\n## Summary\r\n\r\nFix https://github.com/elastic/kibana/issues/206542\r\n\r\nAdd a `temperature` parameter to the `chatComplete` inference API, and\r\nwire it accordingly on all adapters.","sha":"81b47fc094b3a1a404fad23de7c65970a121a519","branchLabelMapping":{"^v9.0.0$":"main","^v8.18.0$":"8.x","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","v9.0.0","backport:version","Team:AI Infra","v8.18.0"],"title":"[inference] add support for `temperature` parameter","number":206479,"url":"https://github.com/elastic/kibana/pull/206479","mergeCommit":{"message":"[inference] add support for `temperature` parameter (#206479)\n\n## Summary\r\n\r\nFix https://github.com/elastic/kibana/issues/206542\r\n\r\nAdd a `temperature` parameter to the `chatComplete` inference API, and\r\nwire it accordingly on all adapters.","sha":"81b47fc094b3a1a404fad23de7c65970a121a519"}},"sourceBranch":"main","suggestedTargetBranches":["8.x"],"targetPullRequestStates":[{"branch":"main","label":"v9.0.0","branchLabelMappingKey":"^v9.0.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/206479","number":206479,"mergeCommit":{"message":"[inference] add support for `temperature` parameter (#206479)\n\n## Summary\r\n\r\nFix https://github.com/elastic/kibana/issues/206542\r\n\r\nAdd a `temperature` parameter to the `chatComplete` inference API, and\r\nwire it accordingly on all adapters.","sha":"81b47fc094b3a1a404fad23de7c65970a121a519"}},{"branch":"8.x","label":"v8.18.0","branchLabelMappingKey":"^v8.18.0$","isSourceBranch":false,"state":"NOT_CREATED"}]}] BACKPORT--> Co-authored-by: Pierre Gayvallet <pierre.gayvallet@elastic.co>
This commit is contained in:
parent
ba4cedbea4
commit
3ece5f837d
17 changed files with 113 additions and 8 deletions
|
@ -89,6 +89,11 @@ export type ChatCompleteOptions<
|
|||
* The list of messages for the current conversation
|
||||
*/
|
||||
messages: Message[];
|
||||
/**
|
||||
* LLM temperature. All models support the 0-1 range (some supports more).
|
||||
* Defaults to 0;
|
||||
*/
|
||||
temperature?: number;
|
||||
/**
|
||||
* Function calling mode, defaults to "native".
|
||||
*/
|
||||
|
|
|
@ -16,6 +16,7 @@ export type ChatCompleteRequestBody = {
|
|||
connectorId: string;
|
||||
stream?: boolean;
|
||||
system?: string;
|
||||
temperature?: number;
|
||||
messages: Message[];
|
||||
functionCalling?: FunctionCallingMode;
|
||||
} & ToolOptions;
|
||||
|
|
|
@ -24,6 +24,7 @@ describe('createChatCompleteApi', () => {
|
|||
connectorId: 'my-connector',
|
||||
functionCalling: 'native',
|
||||
system: 'system',
|
||||
temperature: 0.5,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
};
|
||||
await chatComplete(params as ChatCompleteOptions);
|
||||
|
@ -44,6 +45,7 @@ describe('createChatCompleteApi', () => {
|
|||
connectorId: 'my-connector',
|
||||
functionCalling: 'native',
|
||||
stream: true,
|
||||
temperature: 0.4,
|
||||
system: 'system',
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
};
|
||||
|
|
|
@ -24,6 +24,7 @@ export function createChatCompleteApi({ http }: { http: HttpStart }) {
|
|||
system,
|
||||
toolChoice,
|
||||
tools,
|
||||
temperature,
|
||||
functionCalling,
|
||||
stream,
|
||||
}: ChatCompleteOptions<ToolOptions, boolean>): ChatCompleteCompositeResponse<
|
||||
|
@ -36,6 +37,7 @@ export function createChatCompleteApi({ http }: { http: HttpStart }) {
|
|||
messages,
|
||||
toolChoice,
|
||||
tools,
|
||||
temperature,
|
||||
functionCalling,
|
||||
};
|
||||
|
||||
|
|
|
@ -430,5 +430,22 @@ describe('bedrockClaudeAdapter', () => {
|
|||
}),
|
||||
});
|
||||
});
|
||||
|
||||
it('propagates the temperature parameter', () => {
|
||||
bedrockClaudeAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
temperature: 0.9,
|
||||
});
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
expect(executorMock.invoke).toHaveBeenCalledWith({
|
||||
subAction: 'invokeStream',
|
||||
subActionParams: expect.objectContaining({
|
||||
temperature: 0.9,
|
||||
}),
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -26,7 +26,15 @@ import { processCompletionChunks } from './process_completion_chunks';
|
|||
import { addNoToolUsageDirective } from './prompts';
|
||||
|
||||
export const bedrockClaudeAdapter: InferenceConnectorAdapter = {
|
||||
chatComplete: ({ executor, system, messages, toolChoice, tools, abortSignal }) => {
|
||||
chatComplete: ({
|
||||
executor,
|
||||
system,
|
||||
messages,
|
||||
toolChoice,
|
||||
tools,
|
||||
temperature = 0,
|
||||
abortSignal,
|
||||
}) => {
|
||||
const noToolUsage = toolChoice === ToolChoiceType.none;
|
||||
|
||||
const subActionParams = {
|
||||
|
@ -34,7 +42,7 @@ export const bedrockClaudeAdapter: InferenceConnectorAdapter = {
|
|||
messages: messagesToBedrock(messages),
|
||||
tools: noToolUsage ? [] : toolsToBedrock(tools, messages),
|
||||
toolChoice: toolChoiceToBedrock(toolChoice),
|
||||
temperature: 0,
|
||||
temperature,
|
||||
stopSequences: ['\n\nHuman:'],
|
||||
signal: abortSignal,
|
||||
};
|
||||
|
|
|
@ -32,6 +32,7 @@ describe('geminiAdapter', () => {
|
|||
tools: params.tools,
|
||||
toolConfig: params.toolConfig,
|
||||
systemInstruction: params.systemInstruction,
|
||||
temperature: params.temperature,
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -501,5 +502,22 @@ describe('geminiAdapter', () => {
|
|||
}),
|
||||
});
|
||||
});
|
||||
|
||||
it('propagates the temperature parameter', () => {
|
||||
geminiAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
temperature: 0.6,
|
||||
});
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
expect(executorMock.invoke).toHaveBeenCalledWith({
|
||||
subAction: 'invokeStream',
|
||||
subActionParams: expect.objectContaining({
|
||||
temperature: 0.6,
|
||||
}),
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -22,7 +22,15 @@ import { processVertexStream } from './process_vertex_stream';
|
|||
import type { GenerateContentResponseChunk, GeminiMessage, GeminiToolConfig } from './types';
|
||||
|
||||
export const geminiAdapter: InferenceConnectorAdapter = {
|
||||
chatComplete: ({ executor, system, messages, toolChoice, tools, abortSignal }) => {
|
||||
chatComplete: ({
|
||||
executor,
|
||||
system,
|
||||
messages,
|
||||
toolChoice,
|
||||
tools,
|
||||
temperature = 0,
|
||||
abortSignal,
|
||||
}) => {
|
||||
return from(
|
||||
executor.invoke({
|
||||
subAction: 'invokeStream',
|
||||
|
@ -31,7 +39,7 @@ export const geminiAdapter: InferenceConnectorAdapter = {
|
|||
systemInstruction: system,
|
||||
tools: toolsToGemini(tools),
|
||||
toolConfig: toolChoiceToConfig(toolChoice),
|
||||
temperature: 0,
|
||||
temperature,
|
||||
signal: abortSignal,
|
||||
stopSequences: ['\n\nHuman:'],
|
||||
},
|
||||
|
|
|
@ -144,5 +144,24 @@ describe('inferenceAdapter', () => {
|
|||
}),
|
||||
});
|
||||
});
|
||||
|
||||
it('propagates the temperature parameter', () => {
|
||||
inferenceAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
temperature: 0.4,
|
||||
});
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
expect(executorMock.invoke).toHaveBeenCalledWith({
|
||||
subAction: 'unified_completion_stream',
|
||||
subActionParams: expect.objectContaining({
|
||||
body: expect.objectContaining({
|
||||
temperature: 0.4,
|
||||
}),
|
||||
}),
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -30,6 +30,7 @@ export const inferenceAdapter: InferenceConnectorAdapter = {
|
|||
toolChoice,
|
||||
tools,
|
||||
functionCalling,
|
||||
temperature = 0,
|
||||
logger,
|
||||
abortSignal,
|
||||
}) => {
|
||||
|
@ -44,10 +45,12 @@ export const inferenceAdapter: InferenceConnectorAdapter = {
|
|||
tools,
|
||||
});
|
||||
request = {
|
||||
temperature,
|
||||
messages: messagesToOpenAI({ system: wrapped.system, messages: wrapped.messages }),
|
||||
};
|
||||
} else {
|
||||
request = {
|
||||
temperature,
|
||||
messages: messagesToOpenAI({ system, messages }),
|
||||
tool_choice: toolChoiceToOpenAI(toolChoice),
|
||||
tools: toolsToOpenAI(tools),
|
||||
|
|
|
@ -354,6 +354,18 @@ describe('openAIAdapter', () => {
|
|||
}),
|
||||
});
|
||||
});
|
||||
|
||||
it('propagates the temperature', () => {
|
||||
openAIAdapter.chatComplete({
|
||||
logger,
|
||||
executor: executorMock,
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
temperature: 0.7,
|
||||
});
|
||||
|
||||
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
|
||||
expect(getRequest().body.temperature).toBe(0.7);
|
||||
});
|
||||
});
|
||||
|
||||
describe('when handling the response', () => {
|
||||
|
|
|
@ -25,11 +25,11 @@ export const openAIAdapter: InferenceConnectorAdapter = {
|
|||
messages,
|
||||
toolChoice,
|
||||
tools,
|
||||
temperature = 0,
|
||||
functionCalling,
|
||||
logger,
|
||||
abortSignal,
|
||||
}) => {
|
||||
const stream = true;
|
||||
const simulatedFunctionCalling = functionCalling === 'simulated';
|
||||
|
||||
let request: Omit<OpenAI.ChatCompletionCreateParams, 'model'> & { model?: string };
|
||||
|
@ -41,12 +41,14 @@ export const openAIAdapter: InferenceConnectorAdapter = {
|
|||
tools,
|
||||
});
|
||||
request = {
|
||||
stream,
|
||||
stream: true,
|
||||
temperature,
|
||||
messages: messagesToOpenAI({ system: wrapped.system, messages: wrapped.messages }),
|
||||
};
|
||||
} else {
|
||||
request = {
|
||||
stream,
|
||||
stream: true,
|
||||
temperature,
|
||||
messages: messagesToOpenAI({ system, messages }),
|
||||
tool_choice: toolChoiceToOpenAI(toolChoice),
|
||||
tools: toolsToOpenAI(tools),
|
||||
|
@ -59,7 +61,7 @@ export const openAIAdapter: InferenceConnectorAdapter = {
|
|||
subActionParams: {
|
||||
body: JSON.stringify(request),
|
||||
signal: abortSignal,
|
||||
stream,
|
||||
stream: true,
|
||||
},
|
||||
})
|
||||
).pipe(
|
||||
|
|
|
@ -84,6 +84,7 @@ describe('createChatCompleteApi', () => {
|
|||
await chatComplete({
|
||||
connectorId: 'connectorId',
|
||||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
temperature: 0.7,
|
||||
});
|
||||
|
||||
expect(inferenceAdapter.chatComplete).toHaveBeenCalledTimes(1);
|
||||
|
@ -91,6 +92,7 @@ describe('createChatCompleteApi', () => {
|
|||
messages: [{ role: MessageRole.User, content: 'question' }],
|
||||
executor: inferenceExecutor,
|
||||
logger,
|
||||
temperature: 0.7,
|
||||
});
|
||||
});
|
||||
|
||||
|
|
|
@ -38,6 +38,7 @@ export function createChatCompleteApi({ request, actions, logger }: CreateChatCo
|
|||
messages,
|
||||
toolChoice,
|
||||
tools,
|
||||
temperature,
|
||||
system,
|
||||
functionCalling,
|
||||
stream,
|
||||
|
@ -80,6 +81,7 @@ export function createChatCompleteApi({ request, actions, logger }: CreateChatCo
|
|||
messages,
|
||||
toolChoice,
|
||||
tools,
|
||||
temperature,
|
||||
logger,
|
||||
functionCalling,
|
||||
abortSignal,
|
||||
|
|
|
@ -29,6 +29,7 @@ export interface InferenceConnectorAdapter {
|
|||
messages: Message[];
|
||||
system?: string;
|
||||
functionCalling?: FunctionCallingMode;
|
||||
temperature?: number;
|
||||
abortSignal?: AbortSignal;
|
||||
logger: Logger;
|
||||
} & ToolOptions
|
||||
|
|
|
@ -87,6 +87,7 @@ const chatCompleteBodySchema: Type<ChatCompleteRequestBody> = schema.object({
|
|||
functionCalling: schema.maybe(
|
||||
schema.oneOf([schema.literal('native'), schema.literal('simulated')])
|
||||
),
|
||||
temperature: schema.maybe(schema.number()),
|
||||
});
|
||||
|
||||
export function registerChatCompleteRoute({
|
||||
|
|
|
@ -25,6 +25,7 @@ export const chatCompleteSuite = (
|
|||
.set('kbn-xsrf', 'kibana')
|
||||
.send({
|
||||
connectorId,
|
||||
temperature: 0.1,
|
||||
system: 'Please answer the user question',
|
||||
messages: [{ role: 'user', content: '2+2 ?' }],
|
||||
})
|
||||
|
@ -154,6 +155,7 @@ export const chatCompleteSuite = (
|
|||
.set('kbn-xsrf', 'kibana')
|
||||
.send({
|
||||
connectorId,
|
||||
temperature: 0.1,
|
||||
system: 'Please answer the user question',
|
||||
messages: [{ role: 'user', content: '2+2 ?' }],
|
||||
})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue