[8.x] [inference] add support for `temperature` parameter (#206479) (#206573)

# Backport

This will backport the following commits from `main` to `8.x`:
- [[inference] add support for `temperature` parameter
(#206479)](https://github.com/elastic/kibana/pull/206479)

<!--- Backport version: 9.4.3 -->

### Questions ?
Please refer to the [Backport tool
documentation](https://github.com/sqren/backport)

<!--BACKPORT [{"author":{"name":"Pierre
Gayvallet","email":"pierre.gayvallet@elastic.co"},"sourceCommit":{"committedDate":"2025-01-14T13:05:21Z","message":"[inference]
add support for `temperature` parameter (#206479)\n\n##
Summary\r\n\r\nFix
https://github.com/elastic/kibana/issues/206542\r\n\r\nAdd a
`temperature` parameter to the `chatComplete` inference API, and\r\nwire
it accordingly on all
adapters.","sha":"81b47fc094b3a1a404fad23de7c65970a121a519","branchLabelMapping":{"^v9.0.0$":"main","^v8.18.0$":"8.x","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","v9.0.0","backport:version","Team:AI
Infra","v8.18.0"],"title":"[inference] add support for `temperature`
parameter","number":206479,"url":"https://github.com/elastic/kibana/pull/206479","mergeCommit":{"message":"[inference]
add support for `temperature` parameter (#206479)\n\n##
Summary\r\n\r\nFix
https://github.com/elastic/kibana/issues/206542\r\n\r\nAdd a
`temperature` parameter to the `chatComplete` inference API, and\r\nwire
it accordingly on all
adapters.","sha":"81b47fc094b3a1a404fad23de7c65970a121a519"}},"sourceBranch":"main","suggestedTargetBranches":["8.x"],"targetPullRequestStates":[{"branch":"main","label":"v9.0.0","branchLabelMappingKey":"^v9.0.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/206479","number":206479,"mergeCommit":{"message":"[inference]
add support for `temperature` parameter (#206479)\n\n##
Summary\r\n\r\nFix
https://github.com/elastic/kibana/issues/206542\r\n\r\nAdd a
`temperature` parameter to the `chatComplete` inference API, and\r\nwire
it accordingly on all
adapters.","sha":"81b47fc094b3a1a404fad23de7c65970a121a519"}},{"branch":"8.x","label":"v8.18.0","branchLabelMappingKey":"^v8.18.0$","isSourceBranch":false,"state":"NOT_CREATED"}]}]
BACKPORT-->

Co-authored-by: Pierre Gayvallet <pierre.gayvallet@elastic.co>
This commit is contained in:
Kibana Machine 2025-01-15 01:37:43 +11:00 committed by GitHub
parent ba4cedbea4
commit 3ece5f837d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 113 additions and 8 deletions

View file

@ -89,6 +89,11 @@ export type ChatCompleteOptions<
* The list of messages for the current conversation
*/
messages: Message[];
/**
* LLM temperature. All models support the 0-1 range (some supports more).
* Defaults to 0;
*/
temperature?: number;
/**
* Function calling mode, defaults to "native".
*/

View file

@ -16,6 +16,7 @@ export type ChatCompleteRequestBody = {
connectorId: string;
stream?: boolean;
system?: string;
temperature?: number;
messages: Message[];
functionCalling?: FunctionCallingMode;
} & ToolOptions;

View file

@ -24,6 +24,7 @@ describe('createChatCompleteApi', () => {
connectorId: 'my-connector',
functionCalling: 'native',
system: 'system',
temperature: 0.5,
messages: [{ role: MessageRole.User, content: 'question' }],
};
await chatComplete(params as ChatCompleteOptions);
@ -44,6 +45,7 @@ describe('createChatCompleteApi', () => {
connectorId: 'my-connector',
functionCalling: 'native',
stream: true,
temperature: 0.4,
system: 'system',
messages: [{ role: MessageRole.User, content: 'question' }],
};

View file

@ -24,6 +24,7 @@ export function createChatCompleteApi({ http }: { http: HttpStart }) {
system,
toolChoice,
tools,
temperature,
functionCalling,
stream,
}: ChatCompleteOptions<ToolOptions, boolean>): ChatCompleteCompositeResponse<
@ -36,6 +37,7 @@ export function createChatCompleteApi({ http }: { http: HttpStart }) {
messages,
toolChoice,
tools,
temperature,
functionCalling,
};

View file

@ -430,5 +430,22 @@ describe('bedrockClaudeAdapter', () => {
}),
});
});
it('propagates the temperature parameter', () => {
bedrockClaudeAdapter.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
temperature: 0.9,
});
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
expect(executorMock.invoke).toHaveBeenCalledWith({
subAction: 'invokeStream',
subActionParams: expect.objectContaining({
temperature: 0.9,
}),
});
});
});
});

View file

@ -26,7 +26,15 @@ import { processCompletionChunks } from './process_completion_chunks';
import { addNoToolUsageDirective } from './prompts';
export const bedrockClaudeAdapter: InferenceConnectorAdapter = {
chatComplete: ({ executor, system, messages, toolChoice, tools, abortSignal }) => {
chatComplete: ({
executor,
system,
messages,
toolChoice,
tools,
temperature = 0,
abortSignal,
}) => {
const noToolUsage = toolChoice === ToolChoiceType.none;
const subActionParams = {
@ -34,7 +42,7 @@ export const bedrockClaudeAdapter: InferenceConnectorAdapter = {
messages: messagesToBedrock(messages),
tools: noToolUsage ? [] : toolsToBedrock(tools, messages),
toolChoice: toolChoiceToBedrock(toolChoice),
temperature: 0,
temperature,
stopSequences: ['\n\nHuman:'],
signal: abortSignal,
};

View file

@ -32,6 +32,7 @@ describe('geminiAdapter', () => {
tools: params.tools,
toolConfig: params.toolConfig,
systemInstruction: params.systemInstruction,
temperature: params.temperature,
};
}
@ -501,5 +502,22 @@ describe('geminiAdapter', () => {
}),
});
});
it('propagates the temperature parameter', () => {
geminiAdapter.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
temperature: 0.6,
});
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
expect(executorMock.invoke).toHaveBeenCalledWith({
subAction: 'invokeStream',
subActionParams: expect.objectContaining({
temperature: 0.6,
}),
});
});
});
});

View file

@ -22,7 +22,15 @@ import { processVertexStream } from './process_vertex_stream';
import type { GenerateContentResponseChunk, GeminiMessage, GeminiToolConfig } from './types';
export const geminiAdapter: InferenceConnectorAdapter = {
chatComplete: ({ executor, system, messages, toolChoice, tools, abortSignal }) => {
chatComplete: ({
executor,
system,
messages,
toolChoice,
tools,
temperature = 0,
abortSignal,
}) => {
return from(
executor.invoke({
subAction: 'invokeStream',
@ -31,7 +39,7 @@ export const geminiAdapter: InferenceConnectorAdapter = {
systemInstruction: system,
tools: toolsToGemini(tools),
toolConfig: toolChoiceToConfig(toolChoice),
temperature: 0,
temperature,
signal: abortSignal,
stopSequences: ['\n\nHuman:'],
},

View file

@ -144,5 +144,24 @@ describe('inferenceAdapter', () => {
}),
});
});
it('propagates the temperature parameter', () => {
inferenceAdapter.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
temperature: 0.4,
});
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
expect(executorMock.invoke).toHaveBeenCalledWith({
subAction: 'unified_completion_stream',
subActionParams: expect.objectContaining({
body: expect.objectContaining({
temperature: 0.4,
}),
}),
});
});
});
});

View file

@ -30,6 +30,7 @@ export const inferenceAdapter: InferenceConnectorAdapter = {
toolChoice,
tools,
functionCalling,
temperature = 0,
logger,
abortSignal,
}) => {
@ -44,10 +45,12 @@ export const inferenceAdapter: InferenceConnectorAdapter = {
tools,
});
request = {
temperature,
messages: messagesToOpenAI({ system: wrapped.system, messages: wrapped.messages }),
};
} else {
request = {
temperature,
messages: messagesToOpenAI({ system, messages }),
tool_choice: toolChoiceToOpenAI(toolChoice),
tools: toolsToOpenAI(tools),

View file

@ -354,6 +354,18 @@ describe('openAIAdapter', () => {
}),
});
});
it('propagates the temperature', () => {
openAIAdapter.chatComplete({
logger,
executor: executorMock,
messages: [{ role: MessageRole.User, content: 'question' }],
temperature: 0.7,
});
expect(executorMock.invoke).toHaveBeenCalledTimes(1);
expect(getRequest().body.temperature).toBe(0.7);
});
});
describe('when handling the response', () => {

View file

@ -25,11 +25,11 @@ export const openAIAdapter: InferenceConnectorAdapter = {
messages,
toolChoice,
tools,
temperature = 0,
functionCalling,
logger,
abortSignal,
}) => {
const stream = true;
const simulatedFunctionCalling = functionCalling === 'simulated';
let request: Omit<OpenAI.ChatCompletionCreateParams, 'model'> & { model?: string };
@ -41,12 +41,14 @@ export const openAIAdapter: InferenceConnectorAdapter = {
tools,
});
request = {
stream,
stream: true,
temperature,
messages: messagesToOpenAI({ system: wrapped.system, messages: wrapped.messages }),
};
} else {
request = {
stream,
stream: true,
temperature,
messages: messagesToOpenAI({ system, messages }),
tool_choice: toolChoiceToOpenAI(toolChoice),
tools: toolsToOpenAI(tools),
@ -59,7 +61,7 @@ export const openAIAdapter: InferenceConnectorAdapter = {
subActionParams: {
body: JSON.stringify(request),
signal: abortSignal,
stream,
stream: true,
},
})
).pipe(

View file

@ -84,6 +84,7 @@ describe('createChatCompleteApi', () => {
await chatComplete({
connectorId: 'connectorId',
messages: [{ role: MessageRole.User, content: 'question' }],
temperature: 0.7,
});
expect(inferenceAdapter.chatComplete).toHaveBeenCalledTimes(1);
@ -91,6 +92,7 @@ describe('createChatCompleteApi', () => {
messages: [{ role: MessageRole.User, content: 'question' }],
executor: inferenceExecutor,
logger,
temperature: 0.7,
});
});

View file

@ -38,6 +38,7 @@ export function createChatCompleteApi({ request, actions, logger }: CreateChatCo
messages,
toolChoice,
tools,
temperature,
system,
functionCalling,
stream,
@ -80,6 +81,7 @@ export function createChatCompleteApi({ request, actions, logger }: CreateChatCo
messages,
toolChoice,
tools,
temperature,
logger,
functionCalling,
abortSignal,

View file

@ -29,6 +29,7 @@ export interface InferenceConnectorAdapter {
messages: Message[];
system?: string;
functionCalling?: FunctionCallingMode;
temperature?: number;
abortSignal?: AbortSignal;
logger: Logger;
} & ToolOptions

View file

@ -87,6 +87,7 @@ const chatCompleteBodySchema: Type<ChatCompleteRequestBody> = schema.object({
functionCalling: schema.maybe(
schema.oneOf([schema.literal('native'), schema.literal('simulated')])
),
temperature: schema.maybe(schema.number()),
});
export function registerChatCompleteRoute({

View file

@ -25,6 +25,7 @@ export const chatCompleteSuite = (
.set('kbn-xsrf', 'kibana')
.send({
connectorId,
temperature: 0.1,
system: 'Please answer the user question',
messages: [{ role: 'user', content: '2+2 ?' }],
})
@ -154,6 +155,7 @@ export const chatCompleteSuite = (
.set('kbn-xsrf', 'kibana')
.send({
connectorId,
temperature: 0.1,
system: 'Please answer the user question',
messages: [{ role: 'user', content: '2+2 ?' }],
})