refactor: central controller for all LLM services

Adds AIChatService, an implementor of puter-chat-completion which can
delegate to other implementors (implementors that have registered with
AIChatService at initialization) based on details of the request.

Makes AIChatService the test implementation. AIChatService then
delegates to FakeChatService when in test mode.

Adds `models()` method to puter-chat-completion. This method, instead of
returning only the names of supported models, includes other details
such as the cost and maximum output size.

Implements `models()` on Claude and XAI.

Registers Claude and XAI with AIChatService.
This commit is contained in:
KernelDeimos 2024-11-22 15:52:37 -05:00
parent 791f7748c7
commit aa3dcea462
10 changed files with 276 additions and 25 deletions

View File

@ -0,0 +1,130 @@
const BaseService = require("../../services/BaseService");
const { Context } = require("../../util/context");
class AIChatService extends BaseService {
_construct () {
this.providers = [];
this.simple_model_list = [];
this.detail_model_list = [];
this.detail_model_map = {};
}
_init () {
const svc_driver = this.services.get('driver')
for ( const provider of this.providers ) {
svc_driver.register_service_alias('ai-chat', provider.service_name);
}
}
async ['__on_boot.consolidation'] () {
// TODO: get models and pricing for each model
for ( const provider of this.providers ) {
const delegate = this.services.get(provider.service_name)
.as('puter-chat-completion');
// Populate simple model list
{
const models = await delegate.list();
this.simple_model_list.push(...models);
}
// Populate detail model list and map
{
const models = await delegate.models();
const annotated_models = [];
for ( const model of models ) {
annotated_models.push({
...model,
provider: provider.service_name,
});
}
this.detail_model_list.push(...annotated_models);
for ( const model of annotated_models ) {
if ( this.detail_model_map[model.id] ) {
let array = this.detail_model_map[model.id];
// replace with array
if ( ! Array.isArray(array) ) {
array = [array];
this.detail_model_map[model.id] = array;
}
array.push(model);
continue;
}
this.detail_model_map[model.id] = model;
}
}
}
}
register_provider (spec) {
this.providers.push(spec);
}
static IMPLEMENTS = {
['driver-capabilities']: {
supports_test_mode (iface, method_name) {
return iface === 'puter-chat-completion' &&
method_name === 'complete';
}
},
['puter-chat-completion']: {
async models () {
const delegate = this.get_delegate();
if ( ! delegate ) return await this.models_();
return await delegate.models();
},
async list () {
const delegate = this.get_delegate();
if ( ! delegate ) return await this.list_();
return await delegate.list();
},
async complete (parameters) {
const client_driver_call = Context.get('client_driver_call');
const { test_mode } = client_driver_call;
let { intended_service } = client_driver_call;
if ( test_mode ) {
intended_service = 'fake-chat';
}
if ( intended_service === this.service_name ) {
throw new Error('Calling ai-chat directly is not yet supported');
}
const svc_driver = this.services.get('driver');
const ret = await svc_driver.call_new_({
actor: Context.get('actor'),
service_name: intended_service,
iface: 'puter-chat-completion',
method: 'complete',
args: parameters,
});
ret.result.via_ai_chat_service = true;
return ret.result;
}
}
}
async models_ () {
return this.detail_model_list;
}
async list_ () {
return this.simple_model_list;
}
get_delegate () {
const client_driver_call = Context.get('client_driver_call');
if ( client_driver_call.intended_service === this.service_name ) {
return undefined;
}
console.log('getting service', client_driver_call.intended_service);
const service = this.services.get(client_driver_call.intended_service);
return service.as('puter-chat-completion');
}
}
module.exports = { AIChatService };

View File

@ -28,6 +28,11 @@ class AIInterfaceService extends BaseService {
col_interfaces.set('puter-chat-completion', { col_interfaces.set('puter-chat-completion', {
description: 'Chatbot.', description: 'Chatbot.',
methods: { methods: {
models: {
description: 'List supported models and their details.',
result: { type: 'json' },
parameters: {},
},
list: { list: {
description: 'List supported models', description: 'List supported models',
result: { type: 'json' }, result: { type: 'json' },

View File

@ -3,7 +3,7 @@ const BaseService = require("../../services/BaseService");
class AITestModeService extends BaseService { class AITestModeService extends BaseService {
async _init () { async _init () {
const svc_driver = this.services.get('driver'); const svc_driver = this.services.get('driver');
svc_driver.register_test_service('puter-chat-completion', 'openai-completion'); svc_driver.register_test_service('puter-chat-completion', 'ai-chat');
} }
} }

View File

@ -22,17 +22,29 @@ class ClaudeService extends BaseService {
this.anthropic = new Anthropic({ this.anthropic = new Anthropic({
apiKey: this.config.apiKey apiKey: this.config.apiKey
}); });
const svc_aiChat = this.services.get('ai-chat');
svc_aiChat.register_provider({
service_name: this.service_name,
alias: true,
});
} }
static IMPLEMENTS = { static IMPLEMENTS = {
['puter-chat-completion']: { ['puter-chat-completion']: {
async models () {
return await this.models_();
},
async list () { async list () {
return [ const models = await this.models_();
'claude-3-5-sonnet-latest', const model_names = [];
'claude-3-5-sonnet-20241022', for ( const model of models ) {
'claude-3-5-sonnet-20240620', model_names.push(model.id);
'claude-3-haiku-20240307', if ( model.aliases ) {
]; model_names.push(...model.aliases);
}
}
return model_names;
}, },
async complete ({ messages, stream, model }) { async complete ({ messages, stream, model }) {
const adapted_messages = []; const adapted_messages = [];
@ -112,6 +124,45 @@ class ClaudeService extends BaseService {
} }
} }
} }
async models_ () {
return [
{
id: 'claude-3-5-sonnet-20241022',
aliases: ['claude-3-5-sonnet-latest'],
cost: {
currency: 'usd-cents',
tokens: 1_000_000,
input: 300,
output: 1500,
},
qualitative_speed: 'fast',
max_output: 8192,
training_cutoff: '2024-04',
},
{
id: 'claude-3-5-sonnet-20240620',
succeeded_by: 'claude-3-5-sonnet-20241022',
cost: {
currency: 'usd-cents',
tokens: 1_000_000,
input: 300,
output: 1500,
},
},
{
id: 'claude-3-haiku-20240307',
// aliases: ['claude-3-haiku-latest'],
cost: {
currency: 'usd-cents',
tokens: 1_000_000,
input: 25,
output: 125,
},
qualitative_speed: 'fastest',
},
];
}
} }
module.exports = { module.exports = {

View File

@ -7,7 +7,19 @@ class FakeChatService extends BaseService {
return ['fake']; return ['fake'];
}, },
async complete ({ messages, stream, model }) { async complete ({ messages, stream, model }) {
const { LoremIpsum } = require('lorem-ipsum');
const li = new LoremIpsum({
sentencesPerParagraph: {
max: 8,
min: 4
},
wordsPerSentence: {
max: 20,
min: 12
},
});
return { return {
"index": 0,
message: { message: {
"id": "00000000-0000-0000-0000-000000000000", "id": "00000000-0000-0000-0000-000000000000",
"type": "message", "type": "message",
@ -16,7 +28,9 @@ class FakeChatService extends BaseService {
"content": [ "content": [
{ {
"type": "text", "type": "text",
"text": "I am a fake AI, I don't know how to respond to anything." "text": li.generateParagraphs(
Math.floor(Math.random() * 3) + 1
)
} }
], ],
"stop_reason": "end_turn", "stop_reason": "end_turn",
@ -25,7 +39,9 @@ class FakeChatService extends BaseService {
"input_tokens": 0, "input_tokens": 0,
"output_tokens": 1 "output_tokens": 1
} }
} },
"logprobs": null,
"finish_reason": "stop"
} }
} }
} }

View File

@ -22,12 +22,6 @@ class OpenAICompletionService extends BaseService {
} }
static IMPLEMENTS = { static IMPLEMENTS = {
['driver-capabilities']: {
supports_test_mode (iface, method_name) {
return iface === 'puter-chat-completion' &&
method_name === 'complete';
}
},
['puter-chat-completion']: { ['puter-chat-completion']: {
async list () { async list () {
return [ return [

View File

@ -57,6 +57,9 @@ class PuterAIModule extends AdvancedBase {
// services.registerService('claude', ClaudeEnoughService); // services.registerService('claude', ClaudeEnoughService);
} }
const { AIChatService } = require('./AIChatService');
services.registerService('ai-chat', AIChatService);
const { FakeChatService } = require('./FakeChatService'); const { FakeChatService } = require('./FakeChatService');
services.registerService('fake-chat', FakeChatService); services.registerService('fake-chat', FakeChatService);

View File

@ -31,14 +31,29 @@ class XAIService extends BaseService {
apiKey: this.global_config.services.xai.apiKey, apiKey: this.global_config.services.xai.apiKey,
baseURL: 'https://api.x.ai' baseURL: 'https://api.x.ai'
}); });
const svc_aiChat = this.services.get('ai-chat');
svc_aiChat.register_provider({
service_name: this.service_name,
alias: true,
});
} }
static IMPLEMENTS = { static IMPLEMENTS = {
['puter-chat-completion']: { ['puter-chat-completion']: {
async models () {
return await this.models_();
},
async list () { async list () {
return [ const models = await this.models_();
'grok-beta', const model_names = [];
]; for ( const model of models ) {
model_names.push(model.id);
if ( model.aliases ) {
model_names.push(...model.aliases);
}
}
return model_names;
}, },
async complete ({ messages, stream, model }) { async complete ({ messages, stream, model }) {
model = this.adapt_model(model); model = this.adapt_model(model);
@ -121,6 +136,21 @@ class XAIService extends BaseService {
} }
} }
} }
async models_ () {
return [
{
id: 'grok-beta',
name: 'Grok Beta',
cost: {
currency: 'usd-cents',
tokens: 1_000_000,
input: 500,
output: 1500,
},
}
];
}
} }
module.exports = { module.exports = {

View File

@ -38,6 +38,7 @@ class DriverService extends BaseService {
this.drivers = {}; this.drivers = {};
this.interface_to_implementation = {}; this.interface_to_implementation = {};
this.interface_to_test_service = {}; this.interface_to_test_service = {};
this.service_aliases = {};
} }
async ['__on_registry.collections'] () { async ['__on_registry.collections'] () {
@ -82,6 +83,10 @@ class DriverService extends BaseService {
register_test_service (interface_name, service_name) { register_test_service (interface_name, service_name) {
this.interface_to_test_service[interface_name] = service_name; this.interface_to_test_service[interface_name] = service_name;
} }
register_service_alias (service_name, alias) {
this.service_aliases[alias] = service_name;
}
get_interface (interface_name) { get_interface (interface_name) {
const o = {}; const o = {};
@ -152,6 +157,12 @@ class DriverService extends BaseService {
driver = this.interface_to_test_service[iface]; driver = this.interface_to_test_service[iface];
} }
const client_driver_call = {
intended_service: driver,
test_mode,
};
driver = this.service_aliases[driver] ?? driver;
const driver_service_exists = (() => { const driver_service_exists = (() => {
console.log('CHECKING FOR THIS', driver, iface); console.log('CHECKING FOR THIS', driver, iface);
return this.services.has(driver) && return this.services.has(driver) &&
@ -165,13 +176,17 @@ class DriverService extends BaseService {
if ( test_mode && caps && caps.supports_test_mode(iface, method) ) { if ( test_mode && caps && caps.supports_test_mode(iface, method) ) {
skip_usage = true; skip_usage = true;
} }
return await this.call_new_({ return await Context.sub({
actor, client_driver_call,
service, }).arun(async () => {
service_name: driver, return await this.call_new_({
iface, method, args: processed_args, actor,
skip_usage, service,
service_name: driver,
iface, method, args: processed_args,
skip_usage,
});
}); });
} }
@ -261,6 +276,10 @@ class DriverService extends BaseService {
iface, method, args, iface, method, args,
skip_usage, skip_usage,
}) { }) {
if ( ! service ) {
service = this.services.get(service_name);
}
const svc_permission = this.services.get('permission'); const svc_permission = this.services.get('permission');
const reading = await svc_permission.scan( const reading = await svc_permission.scan(
actor, actor,

View File

@ -63,6 +63,9 @@ class Context {
static arun (cb) { static arun (cb) {
return this.get().arun(cb); return this.get().arun(cb);
} }
static sub (values, opt_name) {
return this.get().sub(values, opt_name);
}
get (k) { get (k) {
return this.values_[k]; return this.values_[k];
} }