mirror of
https://github.com/HeyPuter/puter.git
synced 2025-01-23 14:20:22 +08:00
dev: implement ai model fallback
This commit is contained in:
parent
0d86233b45
commit
3f4efb9948
@ -1,7 +1,14 @@
|
||||
const BaseService = require("../../services/BaseService");
|
||||
const { Context } = require("../../util/context");
|
||||
|
||||
const MAX_FALLBACKS = 3 + 1; // includes first attempt
|
||||
|
||||
class AIChatService extends BaseService {
|
||||
static MODULES = {
|
||||
kv: globalThis.kv,
|
||||
uuidv4: require('uuid').v4,
|
||||
}
|
||||
|
||||
_construct () {
|
||||
this.providers = [];
|
||||
|
||||
@ -10,19 +17,23 @@ class AIChatService extends BaseService {
|
||||
this.detail_model_map = {};
|
||||
}
|
||||
_init () {
|
||||
const svc_driver = this.services.get('driver')
|
||||
|
||||
for ( const provider of this.providers ) {
|
||||
svc_driver.register_service_alias('ai-chat', provider.service_name);
|
||||
}
|
||||
this.kvkey = this.modules.uuidv4();
|
||||
}
|
||||
|
||||
async ['__on_boot.consolidation'] () {
|
||||
{
|
||||
const svc_driver = this.services.get('driver')
|
||||
for ( const provider of this.providers ) {
|
||||
svc_driver.register_service_alias('ai-chat',
|
||||
provider.service_name);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: get models and pricing for each model
|
||||
for ( const provider of this.providers ) {
|
||||
const delegate = this.services.get(provider.service_name)
|
||||
.as('puter-chat-completion');
|
||||
|
||||
|
||||
// Populate simple model list
|
||||
{
|
||||
const models = await (async () => {
|
||||
@ -54,20 +65,30 @@ class AIChatService extends BaseService {
|
||||
});
|
||||
}
|
||||
this.detail_model_list.push(...annotated_models);
|
||||
for ( const model of annotated_models ) {
|
||||
if ( this.detail_model_map[model.id] ) {
|
||||
let array = this.detail_model_map[model.id];
|
||||
// replace with array
|
||||
if ( ! Array.isArray(array) ) {
|
||||
array = [array];
|
||||
this.detail_model_map[model.id] = array;
|
||||
}
|
||||
|
||||
array.push(model);
|
||||
continue;
|
||||
const set_or_push = (key, model) => {
|
||||
// Typical case: no conflict
|
||||
if ( ! this.detail_model_map[key] ) {
|
||||
this.detail_model_map[key] = model;
|
||||
return;
|
||||
}
|
||||
|
||||
this.detail_model_map[model.id] = model;
|
||||
// Conflict: model name will map to an array
|
||||
let array = this.detail_model_map[key];
|
||||
if ( ! Array.isArray(array) ) {
|
||||
array = [array];
|
||||
this.detail_model_map[key] = array;
|
||||
}
|
||||
|
||||
array.push(model);
|
||||
};
|
||||
for ( const model of annotated_models ) {
|
||||
set_or_push(model.id, model);
|
||||
|
||||
if ( ! model.aliases ) continue;
|
||||
|
||||
for ( const alias of model.aliases ) {
|
||||
set_or_push(alias, model);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -97,7 +118,7 @@ class AIChatService extends BaseService {
|
||||
},
|
||||
async complete (parameters) {
|
||||
const client_driver_call = Context.get('client_driver_call');
|
||||
let { test_mode, intended_service } = client_driver_call;
|
||||
let { test_mode, intended_service, response_metadata } = client_driver_call;
|
||||
|
||||
this.log.noticeme('AIChatService.complete', { intended_service, parameters, test_mode });
|
||||
const svc_event = this.services.get('event');
|
||||
@ -123,13 +144,78 @@ class AIChatService extends BaseService {
|
||||
}
|
||||
|
||||
const svc_driver = this.services.get('driver');
|
||||
const ret = await svc_driver.call_new_({
|
||||
actor: Context.get('actor'),
|
||||
service_name: intended_service,
|
||||
iface: 'puter-chat-completion',
|
||||
method: 'complete',
|
||||
args: parameters,
|
||||
});
|
||||
let ret, error, errors = [];
|
||||
try {
|
||||
ret = await svc_driver.call_new_({
|
||||
actor: Context.get('actor'),
|
||||
service_name: intended_service,
|
||||
iface: 'puter-chat-completion',
|
||||
method: 'complete',
|
||||
args: parameters,
|
||||
});
|
||||
} catch (e) {
|
||||
const tried = [];
|
||||
let model = this.get_model_from_request(parameters);
|
||||
|
||||
// TODO: if conflict models exist, add service name
|
||||
tried.push(model);
|
||||
|
||||
error = e;
|
||||
errors.push(e);
|
||||
this.log.error('error calling service', {
|
||||
intended_service,
|
||||
model,
|
||||
error: e,
|
||||
});
|
||||
while ( !! error ) {
|
||||
const fallback = this.get_fallback_model({
|
||||
model, tried,
|
||||
});
|
||||
|
||||
if ( ! fallback ) {
|
||||
throw new Error('no fallback model available');
|
||||
}
|
||||
|
||||
const {
|
||||
fallback_service_name,
|
||||
fallback_model_name,
|
||||
} = fallback;
|
||||
|
||||
this.log.warn('model fallback', {
|
||||
intended_service,
|
||||
fallback_service_name,
|
||||
fallback_model_name
|
||||
});
|
||||
|
||||
try {
|
||||
ret = await svc_driver.call_new_({
|
||||
actor: Context.get('actor'),
|
||||
service_name: fallback_service_name,
|
||||
iface: 'puter-chat-completion',
|
||||
method: 'complete',
|
||||
args: {
|
||||
...parameters,
|
||||
model: fallback_model_name,
|
||||
},
|
||||
});
|
||||
error = null;
|
||||
response_metadata.fallback = {
|
||||
service: fallback_service_name,
|
||||
model: fallback_model_name,
|
||||
tried: tried,
|
||||
};
|
||||
} catch (e) {
|
||||
error = e;
|
||||
errors.push(e);
|
||||
tried.push(fallback_model_name);
|
||||
this.log.error('error calling fallback', {
|
||||
intended_service,
|
||||
model,
|
||||
error: e,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
ret.result.via_ai_chat_service = true;
|
||||
|
||||
const username = Context.get('actor').type?.user?.username;
|
||||
@ -190,6 +276,85 @@ class AIChatService extends BaseService {
|
||||
const service = this.services.get(client_driver_call.intended_service);
|
||||
return service.as('puter-chat-completion');
|
||||
}
|
||||
|
||||
/**
|
||||
* Find an appropriate fallback model by sorting the list of models
|
||||
* by the euclidean distance of the input/output prices and selecting
|
||||
* the first one that is not in the tried list.
|
||||
*
|
||||
* @param {*} param0
|
||||
* @returns
|
||||
*/
|
||||
get_fallback_model ({ model, tried }) {
|
||||
let target_model = this.detail_model_map[model];
|
||||
if ( ! target_model ) {
|
||||
this.log.error('could not find model', { model });
|
||||
throw new Error('could not find model');
|
||||
}
|
||||
if ( Array.isArray(target_model) ) {
|
||||
// TODO: better conflict resolution
|
||||
this.log.noticeme('conflict exists', { model, target_model });
|
||||
target_model = target_model[0];
|
||||
}
|
||||
|
||||
// First check KV for the sorted list
|
||||
let sorted_models = this.modules.kv.get(
|
||||
`${this.kvkey}:fallbacks:${model}`);
|
||||
|
||||
if ( ! sorted_models ) {
|
||||
// Calculate the sorted list
|
||||
const models = this.detail_model_list;
|
||||
|
||||
sorted_models = models.sort((a, b) => {
|
||||
return Math.sqrt(
|
||||
Math.pow(a.cost.input - target_model.cost.input, 2) +
|
||||
Math.pow(a.cost.output - target_model.cost.output, 2)
|
||||
) - Math.sqrt(
|
||||
Math.pow(b.cost.input - target_model.cost.input, 2) +
|
||||
Math.pow(b.cost.output - target_model.cost.output, 2)
|
||||
);
|
||||
});
|
||||
|
||||
sorted_models = sorted_models.slice(0, MAX_FALLBACKS);
|
||||
|
||||
this.modules.kv.set(
|
||||
`${this.kvkey}:fallbacks:${model}`, sorted_models);
|
||||
}
|
||||
|
||||
for ( const model of sorted_models ) {
|
||||
if ( tried.includes(model.id) ) continue;
|
||||
|
||||
return {
|
||||
fallback_service_name: model.provider,
|
||||
fallback_model_name: model.id,
|
||||
};
|
||||
}
|
||||
|
||||
// No fallbacks available
|
||||
this.log.error('no fallbacks', {
|
||||
sorted_models,
|
||||
tried,
|
||||
});
|
||||
}
|
||||
|
||||
get_model_from_request (parameters) {
|
||||
const client_driver_call = Context.get('client_driver_call');
|
||||
let { intended_service } = client_driver_call;
|
||||
|
||||
let model = parameters.model;
|
||||
if ( ! model ) {
|
||||
const service = this.services.get(intended_service);
|
||||
if ( ! service.get_default_model ) {
|
||||
throw new Error('could not infer model from service');
|
||||
}
|
||||
model = service.get_default_model();
|
||||
if ( ! model ) {
|
||||
throw new Error('could not infer model from service');
|
||||
}
|
||||
}
|
||||
|
||||
return model;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = { AIChatService };
|
||||
|
@ -32,6 +32,10 @@ class ClaudeService extends BaseService {
|
||||
alias: true,
|
||||
});
|
||||
}
|
||||
|
||||
get_default_model () {
|
||||
return 'claude-3-5-sonnet-latest';
|
||||
}
|
||||
|
||||
static IMPLEMENTS = {
|
||||
['puter-chat-completion']: {
|
||||
@ -106,7 +110,7 @@ class ClaudeService extends BaseService {
|
||||
}, stream);
|
||||
(async () => {
|
||||
const completion = await this.anthropic.messages.stream({
|
||||
model: model ?? 'claude-3-5-sonnet-latest',
|
||||
model: model ?? this.get_default_model(),
|
||||
max_tokens: 1000,
|
||||
temperature: 0,
|
||||
system: PUTER_PROMPT + JSON.stringify(system_prompts),
|
||||
@ -129,7 +133,7 @@ class ClaudeService extends BaseService {
|
||||
}
|
||||
|
||||
const msg = await this.anthropic.messages.create({
|
||||
model: 'claude-3-5-sonnet-latest',
|
||||
model: model ?? this.get_default_model(),
|
||||
max_tokens: 1000,
|
||||
temperature: 0,
|
||||
system: PUTER_PROMPT + JSON.stringify(system_prompts),
|
||||
|
@ -20,6 +20,10 @@ class GroqAIService extends BaseService {
|
||||
alias: true,
|
||||
});
|
||||
}
|
||||
|
||||
get_default_model () {
|
||||
return 'llama-3.1-8b-instant';
|
||||
}
|
||||
|
||||
static IMPLEMENTS = {
|
||||
'puter-chat-completion': {
|
||||
@ -37,6 +41,8 @@ class GroqAIService extends BaseService {
|
||||
if ( ! message.role ) message.role = 'user';
|
||||
}
|
||||
|
||||
model = model ?? this.get_default_model();
|
||||
|
||||
const completion = await this.client.chat.completions.create({
|
||||
messages,
|
||||
model,
|
||||
|
@ -130,6 +130,9 @@ class MistralAIService extends BaseService {
|
||||
}
|
||||
// return resp.data;
|
||||
}
|
||||
get_default_model () {
|
||||
return 'mistral-large-latest';
|
||||
}
|
||||
static IMPLEMENTS = {
|
||||
'puter-chat-completion': {
|
||||
async models () {
|
||||
@ -153,7 +156,7 @@ class MistralAIService extends BaseService {
|
||||
chunked: true,
|
||||
}, stream);
|
||||
const completion = await this.client.chat.stream({
|
||||
model: model ?? 'mistral-large-latest',
|
||||
model: model ?? this.get_default_model(),
|
||||
messages,
|
||||
});
|
||||
(async () => {
|
||||
@ -179,7 +182,7 @@ class MistralAIService extends BaseService {
|
||||
|
||||
try {
|
||||
const completion = await this.client.chat.complete({
|
||||
model: model ?? 'mistral-large-latest',
|
||||
model: model ?? this.get_default_model(),
|
||||
messages,
|
||||
});
|
||||
// Expected case when mistralai/client-ts#23 is fixed
|
||||
|
@ -21,6 +21,10 @@ class OpenAICompletionService extends BaseService {
|
||||
});
|
||||
}
|
||||
|
||||
get_default_model () {
|
||||
return 'gpt-4o-mini';
|
||||
}
|
||||
|
||||
static IMPLEMENTS = {
|
||||
['puter-chat-completion']: {
|
||||
async list () {
|
||||
@ -106,7 +110,7 @@ class OpenAICompletionService extends BaseService {
|
||||
throw new Error('`messages` must be an array');
|
||||
}
|
||||
|
||||
model = model ?? 'gpt-4o-mini';
|
||||
model = model ?? this.get_default_model();
|
||||
|
||||
for ( let i = 0; i < messages.length; i++ ) {
|
||||
let msg = messages[i];
|
||||
|
@ -19,11 +19,16 @@ class TogetherAIService extends BaseService {
|
||||
this.kvkey = this.modules.uuidv4();
|
||||
|
||||
const svc_aiChat = this.services.get('ai-chat');
|
||||
console.log('registering provider', this.service_name);
|
||||
svc_aiChat.register_provider({
|
||||
service_name: this.service_name,
|
||||
alias: true,
|
||||
});
|
||||
}
|
||||
|
||||
get_default_model () {
|
||||
return 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo';
|
||||
}
|
||||
|
||||
static IMPLEMENTS = {
|
||||
['puter-chat-completion']: {
|
||||
@ -36,10 +41,12 @@ class TogetherAIService extends BaseService {
|
||||
return models.map(model => model.id);
|
||||
},
|
||||
async complete ({ messages, stream, model }) {
|
||||
console.log('model?', model);
|
||||
if ( model === 'model-fallback-test-1' ) {
|
||||
throw new Error('Model Fallback Test 1');
|
||||
}
|
||||
|
||||
const completion = await this.together.chat.completions.create({
|
||||
model: model ??
|
||||
'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo',
|
||||
model: model ?? this.get_default_model(),
|
||||
messages: messages,
|
||||
stream,
|
||||
});
|
||||
@ -93,6 +100,17 @@ class TogetherAIService extends BaseService {
|
||||
},
|
||||
});
|
||||
}
|
||||
models.push({
|
||||
id: 'model-fallback-test-1',
|
||||
name: 'Model Fallback Test 1',
|
||||
context: 1000,
|
||||
cost: {
|
||||
currency: 'usd-cents',
|
||||
tokens: 1_000_000,
|
||||
input: 10,
|
||||
output: 10,
|
||||
},
|
||||
});
|
||||
this.modules.kv.set(
|
||||
`${this.kvkey}:models`, models, { EX: 5*60 });
|
||||
return models;
|
||||
|
@ -39,6 +39,10 @@ class XAIService extends BaseService {
|
||||
});
|
||||
}
|
||||
|
||||
get_default_model () {
|
||||
return 'grok-beta';
|
||||
}
|
||||
|
||||
static IMPLEMENTS = {
|
||||
['puter-chat-completion']: {
|
||||
async models () {
|
||||
@ -98,7 +102,7 @@ class XAIService extends BaseService {
|
||||
}, stream);
|
||||
(async () => {
|
||||
const completion = await this.anthropic.messages.stream({
|
||||
model: model ?? 'grok-beta',
|
||||
model: model ?? this.get_default_model(),
|
||||
max_tokens: 1000,
|
||||
temperature: 0,
|
||||
system: this.get_system_prompt() +
|
||||
@ -122,7 +126,7 @@ class XAIService extends BaseService {
|
||||
}
|
||||
|
||||
const msg = await this.anthropic.messages.create({
|
||||
model: model ?? 'grok-beta',
|
||||
model: model ?? this.get_default_model(),
|
||||
max_tokens: 1000,
|
||||
temperature: 0,
|
||||
system: this.get_system_prompt() +
|
||||
|
@ -116,6 +116,12 @@ class DriverService extends BaseService {
|
||||
return await this._call(o);
|
||||
} catch ( e ) {
|
||||
this.log.error('Driver error response: ' + e.toString());
|
||||
if ( ! (e instanceof APIError) ) {
|
||||
this.errors.report('driver', {
|
||||
source: e,
|
||||
trace: true,
|
||||
});
|
||||
}
|
||||
return this._driver_response_from_error(e);
|
||||
}
|
||||
}
|
||||
@ -159,6 +165,7 @@ class DriverService extends BaseService {
|
||||
|
||||
const client_driver_call = {
|
||||
intended_service: driver,
|
||||
response_metadata: {},
|
||||
test_mode,
|
||||
};
|
||||
driver = this.service_aliases[driver] ?? driver;
|
||||
@ -180,13 +187,15 @@ class DriverService extends BaseService {
|
||||
return await Context.sub({
|
||||
client_driver_call,
|
||||
}).arun(async () => {
|
||||
return await this.call_new_({
|
||||
const result = await this.call_new_({
|
||||
actor,
|
||||
service,
|
||||
service_name: driver,
|
||||
iface, method, args: processed_args,
|
||||
skip_usage,
|
||||
});
|
||||
result.metadata = client_driver_call.response_metadata;
|
||||
return result;
|
||||
});
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user