diff --git a/src/backend/src/modules/puterai/AIChatService.js b/src/backend/src/modules/puterai/AIChatService.js new file mode 100644 index 00000000..f6347603 --- /dev/null +++ b/src/backend/src/modules/puterai/AIChatService.js @@ -0,0 +1,130 @@ +const BaseService = require("../../services/BaseService"); +const { Context } = require("../../util/context"); + +class AIChatService extends BaseService { + _construct () { + this.providers = []; + + this.simple_model_list = []; + this.detail_model_list = []; + this.detail_model_map = {}; + } + _init () { + const svc_driver = this.services.get('driver') + + for ( const provider of this.providers ) { + svc_driver.register_service_alias('ai-chat', provider.service_name); + } + } + + async ['__on_boot.consolidation'] () { + // TODO: get models and pricing for each model + for ( const provider of this.providers ) { + const delegate = this.services.get(provider.service_name) + .as('puter-chat-completion'); + + // Populate simple model list + { + const models = await delegate.list(); + this.simple_model_list.push(...models); + } + + // Populate detail model list and map + { + const models = await delegate.models(); + const annotated_models = []; + for ( const model of models ) { + annotated_models.push({ + ...model, + provider: provider.service_name, + }); + } + this.detail_model_list.push(...annotated_models); + for ( const model of annotated_models ) { + if ( this.detail_model_map[model.id] ) { + let array = this.detail_model_map[model.id]; + // replace with array + if ( ! Array.isArray(array) ) { + array = [array]; + this.detail_model_map[model.id] = array; + } + + array.push(model); + continue; + } + + this.detail_model_map[model.id] = model; + } + } + } + } + + register_provider (spec) { + this.providers.push(spec); + } + + static IMPLEMENTS = { + ['driver-capabilities']: { + supports_test_mode (iface, method_name) { + return iface === 'puter-chat-completion' && + method_name === 'complete'; + } + }, + ['puter-chat-completion']: { + async models () { + const delegate = this.get_delegate(); + if ( ! delegate ) return await this.models_(); + return await delegate.models(); + }, + async list () { + const delegate = this.get_delegate(); + if ( ! delegate ) return await this.list_(); + return await delegate.list(); + }, + async complete (parameters) { + const client_driver_call = Context.get('client_driver_call'); + const { test_mode } = client_driver_call; + let { intended_service } = client_driver_call; + + if ( test_mode ) { + intended_service = 'fake-chat'; + } + + if ( intended_service === this.service_name ) { + throw new Error('Calling ai-chat directly is not yet supported'); + } + + const svc_driver = this.services.get('driver'); + const ret = await svc_driver.call_new_({ + actor: Context.get('actor'), + service_name: intended_service, + iface: 'puter-chat-completion', + method: 'complete', + args: parameters, + }); + ret.result.via_ai_chat_service = true; + return ret.result; + } + } + } + + async models_ () { + return this.detail_model_list; + } + + async list_ () { + return this.simple_model_list; + } + + get_delegate () { + const client_driver_call = Context.get('client_driver_call'); + if ( client_driver_call.intended_service === this.service_name ) { + return undefined; + } + console.log('getting service', client_driver_call.intended_service); + const service = this.services.get(client_driver_call.intended_service); + return service.as('puter-chat-completion'); + } +} + +module.exports = { AIChatService }; diff --git a/src/backend/src/modules/puterai/AIInterfaceService.js b/src/backend/src/modules/puterai/AIInterfaceService.js index bd0d9083..17052cbd 100644 --- a/src/backend/src/modules/puterai/AIInterfaceService.js +++ b/src/backend/src/modules/puterai/AIInterfaceService.js @@ -28,6 +28,11 @@ class AIInterfaceService extends BaseService { col_interfaces.set('puter-chat-completion', { description: 'Chatbot.', methods: { + models: { + description: 'List supported models and their details.', + result: { type: 'json' }, + parameters: {}, + }, list: { description: 'List supported models', result: { type: 'json' }, diff --git a/src/backend/src/modules/puterai/AITestModeService.js b/src/backend/src/modules/puterai/AITestModeService.js index 64289aca..97dbfc38 100644 --- a/src/backend/src/modules/puterai/AITestModeService.js +++ b/src/backend/src/modules/puterai/AITestModeService.js @@ -3,7 +3,7 @@ const BaseService = require("../../services/BaseService"); class AITestModeService extends BaseService { async _init () { const svc_driver = this.services.get('driver'); - svc_driver.register_test_service('puter-chat-completion', 'openai-completion'); + svc_driver.register_test_service('puter-chat-completion', 'ai-chat'); } } diff --git a/src/backend/src/modules/puterai/ClaudeService.js b/src/backend/src/modules/puterai/ClaudeService.js index 33efd5a5..f302a1dd 100644 --- a/src/backend/src/modules/puterai/ClaudeService.js +++ b/src/backend/src/modules/puterai/ClaudeService.js @@ -22,17 +22,29 @@ class ClaudeService extends BaseService { this.anthropic = new Anthropic({ apiKey: this.config.apiKey }); + + const svc_aiChat = this.services.get('ai-chat'); + svc_aiChat.register_provider({ + service_name: this.service_name, + alias: true, + }); } static IMPLEMENTS = { ['puter-chat-completion']: { + async models () { + return await this.models_(); + }, async list () { - return [ - 'claude-3-5-sonnet-latest', - 'claude-3-5-sonnet-20241022', - 'claude-3-5-sonnet-20240620', - 'claude-3-haiku-20240307', - ]; + const models = await this.models_(); + const model_names = []; + for ( const model of models ) { + model_names.push(model.id); + if ( model.aliases ) { + model_names.push(...model.aliases); + } + } + return model_names; }, async complete ({ messages, stream, model }) { const adapted_messages = []; @@ -112,6 +124,45 @@ class ClaudeService extends BaseService { } } } + + async models_ () { + return [ + { + id: 'claude-3-5-sonnet-20241022', + aliases: ['claude-3-5-sonnet-latest'], + cost: { + currency: 'usd-cents', + tokens: 1_000_000, + input: 300, + output: 1500, + }, + qualitative_speed: 'fast', + max_output: 8192, + training_cutoff: '2024-04', + }, + { + id: 'claude-3-5-sonnet-20240620', + succeeded_by: 'claude-3-5-sonnet-20241022', + cost: { + currency: 'usd-cents', + tokens: 1_000_000, + input: 300, + output: 1500, + }, + }, + { + id: 'claude-3-haiku-20240307', + // aliases: ['claude-3-haiku-latest'], + cost: { + currency: 'usd-cents', + tokens: 1_000_000, + input: 25, + output: 125, + }, + qualitative_speed: 'fastest', + }, + ]; + } } module.exports = { diff --git a/src/backend/src/modules/puterai/FakeChatService.js b/src/backend/src/modules/puterai/FakeChatService.js index 915d3900..500f9d7d 100644 --- a/src/backend/src/modules/puterai/FakeChatService.js +++ b/src/backend/src/modules/puterai/FakeChatService.js @@ -7,7 +7,19 @@ class FakeChatService extends BaseService { return ['fake']; }, async complete ({ messages, stream, model }) { + const { LoremIpsum } = require('lorem-ipsum'); + const li = new LoremIpsum({ + sentencesPerParagraph: { + max: 8, + min: 4 + }, + wordsPerSentence: { + max: 20, + min: 12 + }, + }); return { + "index": 0, message: { "id": "00000000-0000-0000-0000-000000000000", "type": "message", @@ -16,7 +28,9 @@ class FakeChatService extends BaseService { "content": [ { "type": "text", - "text": "I am a fake AI, I don't know how to respond to anything." + "text": li.generateParagraphs( + Math.floor(Math.random() * 3) + 1 + ) } ], "stop_reason": "end_turn", @@ -25,7 +39,9 @@ class FakeChatService extends BaseService { "input_tokens": 0, "output_tokens": 1 } - } + }, + "logprobs": null, + "finish_reason": "stop" } } } diff --git a/src/backend/src/modules/puterai/OpenAICompletionService.js b/src/backend/src/modules/puterai/OpenAICompletionService.js index 5c4fc1f0..7c170fe3 100644 --- a/src/backend/src/modules/puterai/OpenAICompletionService.js +++ b/src/backend/src/modules/puterai/OpenAICompletionService.js @@ -22,12 +22,6 @@ class OpenAICompletionService extends BaseService { } static IMPLEMENTS = { - ['driver-capabilities']: { - supports_test_mode (iface, method_name) { - return iface === 'puter-chat-completion' && - method_name === 'complete'; - } - }, ['puter-chat-completion']: { async list () { return [ diff --git a/src/backend/src/modules/puterai/PuterAIModule.js b/src/backend/src/modules/puterai/PuterAIModule.js index 81f8279b..8844592d 100644 --- a/src/backend/src/modules/puterai/PuterAIModule.js +++ b/src/backend/src/modules/puterai/PuterAIModule.js @@ -57,6 +57,9 @@ class PuterAIModule extends AdvancedBase { // services.registerService('claude', ClaudeEnoughService); } + const { AIChatService } = require('./AIChatService'); + services.registerService('ai-chat', AIChatService); + const { FakeChatService } = require('./FakeChatService'); services.registerService('fake-chat', FakeChatService); diff --git a/src/backend/src/modules/puterai/XAIService.js b/src/backend/src/modules/puterai/XAIService.js index e5a8febb..44f58426 100644 --- a/src/backend/src/modules/puterai/XAIService.js +++ b/src/backend/src/modules/puterai/XAIService.js @@ -31,14 +31,29 @@ class XAIService extends BaseService { apiKey: this.global_config.services.xai.apiKey, baseURL: 'https://api.x.ai' }); + + const svc_aiChat = this.services.get('ai-chat'); + svc_aiChat.register_provider({ + service_name: this.service_name, + alias: true, + }); } static IMPLEMENTS = { ['puter-chat-completion']: { + async models () { + return await this.models_(); + }, async list () { - return [ - 'grok-beta', - ]; + const models = await this.models_(); + const model_names = []; + for ( const model of models ) { + model_names.push(model.id); + if ( model.aliases ) { + model_names.push(...model.aliases); + } + } + return model_names; }, async complete ({ messages, stream, model }) { model = this.adapt_model(model); @@ -121,6 +136,21 @@ class XAIService extends BaseService { } } } + + async models_ () { + return [ + { + id: 'grok-beta', + name: 'Grok Beta', + cost: { + currency: 'usd-cents', + tokens: 1_000_000, + input: 500, + output: 1500, + }, + } + ]; + } } module.exports = { diff --git a/src/backend/src/services/drivers/DriverService.js b/src/backend/src/services/drivers/DriverService.js index c9220b67..52a5ed7d 100644 --- a/src/backend/src/services/drivers/DriverService.js +++ b/src/backend/src/services/drivers/DriverService.js @@ -38,6 +38,7 @@ class DriverService extends BaseService { this.drivers = {}; this.interface_to_implementation = {}; this.interface_to_test_service = {}; + this.service_aliases = {}; } async ['__on_registry.collections'] () { @@ -82,6 +83,10 @@ class DriverService extends BaseService { register_test_service (interface_name, service_name) { this.interface_to_test_service[interface_name] = service_name; } + + register_service_alias (service_name, alias) { + this.service_aliases[alias] = service_name; + } get_interface (interface_name) { const o = {}; @@ -152,6 +157,12 @@ class DriverService extends BaseService { driver = this.interface_to_test_service[iface]; } + const client_driver_call = { + intended_service: driver, + test_mode, + }; + driver = this.service_aliases[driver] ?? driver; + const driver_service_exists = (() => { console.log('CHECKING FOR THIS', driver, iface); return this.services.has(driver) && @@ -165,13 +176,17 @@ class DriverService extends BaseService { if ( test_mode && caps && caps.supports_test_mode(iface, method) ) { skip_usage = true; } - - return await this.call_new_({ - actor, - service, - service_name: driver, - iface, method, args: processed_args, - skip_usage, + + return await Context.sub({ + client_driver_call, + }).arun(async () => { + return await this.call_new_({ + actor, + service, + service_name: driver, + iface, method, args: processed_args, + skip_usage, + }); }); } @@ -261,6 +276,10 @@ class DriverService extends BaseService { iface, method, args, skip_usage, }) { + if ( ! service ) { + service = this.services.get(service_name); + } + const svc_permission = this.services.get('permission'); const reading = await svc_permission.scan( actor, diff --git a/src/backend/src/util/context.js b/src/backend/src/util/context.js index cf68cb4c..552c64cb 100644 --- a/src/backend/src/util/context.js +++ b/src/backend/src/util/context.js @@ -63,6 +63,9 @@ class Context { static arun (cb) { return this.get().arun(cb); } + static sub (values, opt_name) { + return this.get().sub(values, opt_name); + } get (k) { return this.values_[k]; }