dev: normalize Claude input

- add merging of same-role messages for all normalized inputs
- generalize Claude-specific system message extraction
- remove redundant behavior from Claude
This commit is contained in:
KernelDeimos 2025-01-31 11:22:06 -05:00
parent 40aa52225f
commit d88dedb66c
3 changed files with 35 additions and 36 deletions

View File

@ -354,6 +354,7 @@ class AIChatService extends BaseService {
} }
if ( parameters.messages ) { if ( parameters.messages ) {
parameters.messages =
Messages.normalize_messages(parameters.messages); Messages.normalize_messages(parameters.messages);
} }

View File

@ -24,6 +24,7 @@ const { whatis } = require("../../util/langutil");
const { PassThrough } = require("stream"); const { PassThrough } = require("stream");
const { TypedValue } = require("../../services/drivers/meta/Runtime"); const { TypedValue } = require("../../services/drivers/meta/Runtime");
const FunctionCalling = require("./lib/FunctionCalling"); const FunctionCalling = require("./lib/FunctionCalling");
const Messages = require("./lib/Messages");
const { TeePromise } = require('@heyputer/putility').libs.promise; const { TeePromise } = require('@heyputer/putility').libs.promise;
const PUTER_PROMPT = ` const PUTER_PROMPT = `
@ -116,41 +117,10 @@ class ClaudeService extends BaseService {
* @returns {TypedValue|Object} Returns either a TypedValue with streaming response or a completion object * @returns {TypedValue|Object} Returns either a TypedValue with streaming response or a completion object
*/ */
async complete ({ messages, stream, model, tools }) { async complete ({ messages, stream, model, tools }) {
const adapted_messages = [];
tools = FunctionCalling.make_claude_tools(tools); tools = FunctionCalling.make_claude_tools(tools);
const system_prompts = []; let system_prompts;
let previous_was_user = false; [system_prompts, messages] = Messages.extract_and_remove_system_messages(messages);
for ( const message of messages ) {
if ( typeof message.content === 'string' ) {
message.content = {
type: 'text',
text: message.content,
};
}
if ( whatis(message.content) !== 'array' ) {
message.content = [message.content];
}
if ( ! message.role ) message.role = 'user';
if ( message.role === 'user' && previous_was_user ) {
const last_msg = adapted_messages[adapted_messages.length-1];
last_msg.content.push(
...(Array.isArray ? message.content : [message.content])
);
continue;
}
if ( message.role === 'system' ) {
system_prompts.push(...message.content);
continue;
}
adapted_messages.push(message);
if ( message.role === 'user' ) {
previous_was_user = true;
} else {
previous_was_user = false;
}
}
if ( stream ) { if ( stream ) {
let usage_promise = new TeePromise(); let usage_promise = new TeePromise();
@ -167,7 +137,7 @@ class ClaudeService extends BaseService {
max_tokens: (model === 'claude-3-5-sonnet-20241022' || model === 'claude-3-5-sonnet-20240620') ? 8192 : 4096, max_tokens: (model === 'claude-3-5-sonnet-20241022' || model === 'claude-3-5-sonnet-20240620') ? 8192 : 4096,
temperature: 0, temperature: 0,
system: PUTER_PROMPT + JSON.stringify(system_prompts), system: PUTER_PROMPT + JSON.stringify(system_prompts),
messages: adapted_messages, messages,
...(tools ? { tools } : {}), ...(tools ? { tools } : {}),
}); });
const counts = { input_tokens: 0, output_tokens: 0 }; const counts = { input_tokens: 0, output_tokens: 0 };
@ -278,7 +248,7 @@ class ClaudeService extends BaseService {
max_tokens: (model === 'claude-3-5-sonnet-20241022' || model === 'claude-3-5-sonnet-20240620') ? 8192 : 4096, max_tokens: (model === 'claude-3-5-sonnet-20241022' || model === 'claude-3-5-sonnet-20240620') ? 8192 : 4096,
temperature: 0, temperature: 0,
system: PUTER_PROMPT + JSON.stringify(system_prompts), system: PUTER_PROMPT + JSON.stringify(system_prompts),
messages: adapted_messages, messages,
...(tools ? { tools } : {}), ...(tools ? { tools } : {}),
}); });
return { return {

View File

@ -45,7 +45,35 @@ module.exports = class Messages {
for ( let i=0 ; i < messages.length ; i++ ) { for ( let i=0 ; i < messages.length ; i++ ) {
messages[i] = this.normalize_single_message(messages[i], params); messages[i] = this.normalize_single_message(messages[i], params);
} }
// If multiple messages are from the same role, merge them
let merged_messages = [];
let current_role = null;
for ( let i=0 ; i < messages.length ; i++ ) {
if ( current_role === messages[i].role ) {
merged_messages[merged_messages.length - 1].content.push(...messages[i].content);
} else {
merged_messages.push(messages[i]);
current_role = messages[i].role;
} }
}
return merged_messages;
}
static extract_and_remove_system_messages (messages) {
let system_messages = [];
let new_messages = [];
for ( let i=0 ; i < messages.length ; i++ ) {
if ( messages[i].role === 'system' ) {
system_messages.push(messages[i]);
} else {
new_messages.push(messages[i]);
}
}
return [system_messages, new_messages];
}
static extract_text (messages) { static extract_text (messages) {
return messages.map(m => { return messages.map(m => {
if ( whatis(m) === 'string' ) { if ( whatis(m) === 'string' ) {