dev: implement ai model fallback

This commit is contained in:
KernelDeimos 2024-11-25 14:58:21 -05:00
parent 0d86233b45
commit 3f4efb9948
8 changed files with 250 additions and 37 deletions

View File

@ -1,7 +1,14 @@
const BaseService = require("../../services/BaseService");
const { Context } = require("../../util/context");
const MAX_FALLBACKS = 3 + 1; // includes first attempt
class AIChatService extends BaseService {
static MODULES = {
kv: globalThis.kv,
uuidv4: require('uuid').v4,
}
_construct () {
this.providers = [];
@ -10,19 +17,23 @@ class AIChatService extends BaseService {
this.detail_model_map = {};
}
_init () {
const svc_driver = this.services.get('driver')
for ( const provider of this.providers ) {
svc_driver.register_service_alias('ai-chat', provider.service_name);
}
this.kvkey = this.modules.uuidv4();
}
async ['__on_boot.consolidation'] () {
{
const svc_driver = this.services.get('driver')
for ( const provider of this.providers ) {
svc_driver.register_service_alias('ai-chat',
provider.service_name);
}
}
// TODO: get models and pricing for each model
for ( const provider of this.providers ) {
const delegate = this.services.get(provider.service_name)
.as('puter-chat-completion');
// Populate simple model list
{
const models = await (async () => {
@ -54,20 +65,30 @@ class AIChatService extends BaseService {
});
}
this.detail_model_list.push(...annotated_models);
for ( const model of annotated_models ) {
if ( this.detail_model_map[model.id] ) {
let array = this.detail_model_map[model.id];
// replace with array
if ( ! Array.isArray(array) ) {
array = [array];
this.detail_model_map[model.id] = array;
}
array.push(model);
continue;
const set_or_push = (key, model) => {
// Typical case: no conflict
if ( ! this.detail_model_map[key] ) {
this.detail_model_map[key] = model;
return;
}
this.detail_model_map[model.id] = model;
// Conflict: model name will map to an array
let array = this.detail_model_map[key];
if ( ! Array.isArray(array) ) {
array = [array];
this.detail_model_map[key] = array;
}
array.push(model);
};
for ( const model of annotated_models ) {
set_or_push(model.id, model);
if ( ! model.aliases ) continue;
for ( const alias of model.aliases ) {
set_or_push(alias, model);
}
}
}
}
@ -97,7 +118,7 @@ class AIChatService extends BaseService {
},
async complete (parameters) {
const client_driver_call = Context.get('client_driver_call');
let { test_mode, intended_service } = client_driver_call;
let { test_mode, intended_service, response_metadata } = client_driver_call;
this.log.noticeme('AIChatService.complete', { intended_service, parameters, test_mode });
const svc_event = this.services.get('event');
@ -123,13 +144,78 @@ class AIChatService extends BaseService {
}
const svc_driver = this.services.get('driver');
const ret = await svc_driver.call_new_({
actor: Context.get('actor'),
service_name: intended_service,
iface: 'puter-chat-completion',
method: 'complete',
args: parameters,
});
let ret, error, errors = [];
try {
ret = await svc_driver.call_new_({
actor: Context.get('actor'),
service_name: intended_service,
iface: 'puter-chat-completion',
method: 'complete',
args: parameters,
});
} catch (e) {
const tried = [];
let model = this.get_model_from_request(parameters);
// TODO: if conflict models exist, add service name
tried.push(model);
error = e;
errors.push(e);
this.log.error('error calling service', {
intended_service,
model,
error: e,
});
while ( !! error ) {
const fallback = this.get_fallback_model({
model, tried,
});
if ( ! fallback ) {
throw new Error('no fallback model available');
}
const {
fallback_service_name,
fallback_model_name,
} = fallback;
this.log.warn('model fallback', {
intended_service,
fallback_service_name,
fallback_model_name
});
try {
ret = await svc_driver.call_new_({
actor: Context.get('actor'),
service_name: fallback_service_name,
iface: 'puter-chat-completion',
method: 'complete',
args: {
...parameters,
model: fallback_model_name,
},
});
error = null;
response_metadata.fallback = {
service: fallback_service_name,
model: fallback_model_name,
tried: tried,
};
} catch (e) {
error = e;
errors.push(e);
tried.push(fallback_model_name);
this.log.error('error calling fallback', {
intended_service,
model,
error: e,
});
}
}
}
ret.result.via_ai_chat_service = true;
const username = Context.get('actor').type?.user?.username;
@ -190,6 +276,85 @@ class AIChatService extends BaseService {
const service = this.services.get(client_driver_call.intended_service);
return service.as('puter-chat-completion');
}
/**
* Find an appropriate fallback model by sorting the list of models
* by the euclidean distance of the input/output prices and selecting
* the first one that is not in the tried list.
*
* @param {*} param0
* @returns
*/
get_fallback_model ({ model, tried }) {
let target_model = this.detail_model_map[model];
if ( ! target_model ) {
this.log.error('could not find model', { model });
throw new Error('could not find model');
}
if ( Array.isArray(target_model) ) {
// TODO: better conflict resolution
this.log.noticeme('conflict exists', { model, target_model });
target_model = target_model[0];
}
// First check KV for the sorted list
let sorted_models = this.modules.kv.get(
`${this.kvkey}:fallbacks:${model}`);
if ( ! sorted_models ) {
// Calculate the sorted list
const models = this.detail_model_list;
sorted_models = models.sort((a, b) => {
return Math.sqrt(
Math.pow(a.cost.input - target_model.cost.input, 2) +
Math.pow(a.cost.output - target_model.cost.output, 2)
) - Math.sqrt(
Math.pow(b.cost.input - target_model.cost.input, 2) +
Math.pow(b.cost.output - target_model.cost.output, 2)
);
});
sorted_models = sorted_models.slice(0, MAX_FALLBACKS);
this.modules.kv.set(
`${this.kvkey}:fallbacks:${model}`, sorted_models);
}
for ( const model of sorted_models ) {
if ( tried.includes(model.id) ) continue;
return {
fallback_service_name: model.provider,
fallback_model_name: model.id,
};
}
// No fallbacks available
this.log.error('no fallbacks', {
sorted_models,
tried,
});
}
get_model_from_request (parameters) {
const client_driver_call = Context.get('client_driver_call');
let { intended_service } = client_driver_call;
let model = parameters.model;
if ( ! model ) {
const service = this.services.get(intended_service);
if ( ! service.get_default_model ) {
throw new Error('could not infer model from service');
}
model = service.get_default_model();
if ( ! model ) {
throw new Error('could not infer model from service');
}
}
return model;
}
}
module.exports = { AIChatService };

View File

@ -32,6 +32,10 @@ class ClaudeService extends BaseService {
alias: true,
});
}
get_default_model () {
return 'claude-3-5-sonnet-latest';
}
static IMPLEMENTS = {
['puter-chat-completion']: {
@ -106,7 +110,7 @@ class ClaudeService extends BaseService {
}, stream);
(async () => {
const completion = await this.anthropic.messages.stream({
model: model ?? 'claude-3-5-sonnet-latest',
model: model ?? this.get_default_model(),
max_tokens: 1000,
temperature: 0,
system: PUTER_PROMPT + JSON.stringify(system_prompts),
@ -129,7 +133,7 @@ class ClaudeService extends BaseService {
}
const msg = await this.anthropic.messages.create({
model: 'claude-3-5-sonnet-latest',
model: model ?? this.get_default_model(),
max_tokens: 1000,
temperature: 0,
system: PUTER_PROMPT + JSON.stringify(system_prompts),

View File

@ -20,6 +20,10 @@ class GroqAIService extends BaseService {
alias: true,
});
}
get_default_model () {
return 'llama-3.1-8b-instant';
}
static IMPLEMENTS = {
'puter-chat-completion': {
@ -37,6 +41,8 @@ class GroqAIService extends BaseService {
if ( ! message.role ) message.role = 'user';
}
model = model ?? this.get_default_model();
const completion = await this.client.chat.completions.create({
messages,
model,

View File

@ -130,6 +130,9 @@ class MistralAIService extends BaseService {
}
// return resp.data;
}
get_default_model () {
return 'mistral-large-latest';
}
static IMPLEMENTS = {
'puter-chat-completion': {
async models () {
@ -153,7 +156,7 @@ class MistralAIService extends BaseService {
chunked: true,
}, stream);
const completion = await this.client.chat.stream({
model: model ?? 'mistral-large-latest',
model: model ?? this.get_default_model(),
messages,
});
(async () => {
@ -179,7 +182,7 @@ class MistralAIService extends BaseService {
try {
const completion = await this.client.chat.complete({
model: model ?? 'mistral-large-latest',
model: model ?? this.get_default_model(),
messages,
});
// Expected case when mistralai/client-ts#23 is fixed

View File

@ -21,6 +21,10 @@ class OpenAICompletionService extends BaseService {
});
}
get_default_model () {
return 'gpt-4o-mini';
}
static IMPLEMENTS = {
['puter-chat-completion']: {
async list () {
@ -106,7 +110,7 @@ class OpenAICompletionService extends BaseService {
throw new Error('`messages` must be an array');
}
model = model ?? 'gpt-4o-mini';
model = model ?? this.get_default_model();
for ( let i = 0; i < messages.length; i++ ) {
let msg = messages[i];

View File

@ -19,11 +19,16 @@ class TogetherAIService extends BaseService {
this.kvkey = this.modules.uuidv4();
const svc_aiChat = this.services.get('ai-chat');
console.log('registering provider', this.service_name);
svc_aiChat.register_provider({
service_name: this.service_name,
alias: true,
});
}
get_default_model () {
return 'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo';
}
static IMPLEMENTS = {
['puter-chat-completion']: {
@ -36,10 +41,12 @@ class TogetherAIService extends BaseService {
return models.map(model => model.id);
},
async complete ({ messages, stream, model }) {
console.log('model?', model);
if ( model === 'model-fallback-test-1' ) {
throw new Error('Model Fallback Test 1');
}
const completion = await this.together.chat.completions.create({
model: model ??
'meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo',
model: model ?? this.get_default_model(),
messages: messages,
stream,
});
@ -93,6 +100,17 @@ class TogetherAIService extends BaseService {
},
});
}
models.push({
id: 'model-fallback-test-1',
name: 'Model Fallback Test 1',
context: 1000,
cost: {
currency: 'usd-cents',
tokens: 1_000_000,
input: 10,
output: 10,
},
});
this.modules.kv.set(
`${this.kvkey}:models`, models, { EX: 5*60 });
return models;

View File

@ -39,6 +39,10 @@ class XAIService extends BaseService {
});
}
get_default_model () {
return 'grok-beta';
}
static IMPLEMENTS = {
['puter-chat-completion']: {
async models () {
@ -98,7 +102,7 @@ class XAIService extends BaseService {
}, stream);
(async () => {
const completion = await this.anthropic.messages.stream({
model: model ?? 'grok-beta',
model: model ?? this.get_default_model(),
max_tokens: 1000,
temperature: 0,
system: this.get_system_prompt() +
@ -122,7 +126,7 @@ class XAIService extends BaseService {
}
const msg = await this.anthropic.messages.create({
model: model ?? 'grok-beta',
model: model ?? this.get_default_model(),
max_tokens: 1000,
temperature: 0,
system: this.get_system_prompt() +

View File

@ -116,6 +116,12 @@ class DriverService extends BaseService {
return await this._call(o);
} catch ( e ) {
this.log.error('Driver error response: ' + e.toString());
if ( ! (e instanceof APIError) ) {
this.errors.report('driver', {
source: e,
trace: true,
});
}
return this._driver_response_from_error(e);
}
}
@ -159,6 +165,7 @@ class DriverService extends BaseService {
const client_driver_call = {
intended_service: driver,
response_metadata: {},
test_mode,
};
driver = this.service_aliases[driver] ?? driver;
@ -180,13 +187,15 @@ class DriverService extends BaseService {
return await Context.sub({
client_driver_call,
}).arun(async () => {
return await this.call_new_({
const result = await this.call_new_({
actor,
service,
service_name: driver,
iface, method, args: processed_args,
skip_usage,
});
result.metadata = client_driver_call.response_metadata;
return result;
});
}