diff --git a/src/backend/src/modules/puterai/AIInterfaceService.js b/src/backend/src/modules/puterai/AIInterfaceService.js index 9deb15e2..25729103 100644 --- a/src/backend/src/modules/puterai/AIInterfaceService.js +++ b/src/backend/src/modules/puterai/AIInterfaceService.js @@ -33,8 +33,9 @@ class AIInterfaceService extends BaseService { parameters: { messages: { type: 'json' }, vision: { type: 'flag' }, + stream: { type: 'flag' }, }, - result: { type: 'json' } + result: { type: 'json' }, } } }); diff --git a/src/backend/src/modules/puterai/OpenAICompletionService.js b/src/backend/src/modules/puterai/OpenAICompletionService.js index 5a6cfd26..f9a9af3a 100644 --- a/src/backend/src/modules/puterai/OpenAICompletionService.js +++ b/src/backend/src/modules/puterai/OpenAICompletionService.js @@ -1,7 +1,10 @@ +const { PassThrough } = require('stream'); const APIError = require('../../api/APIError'); const BaseService = require('../../services/BaseService'); +const { TypedValue } = require('../../services/drivers/meta/Runtime'); const { Context } = require('../../util/context'); const SmolUtil = require('../../util/smolutil'); +const { nou } = require('../../util/langutil'); class OpenAICompletionService extends BaseService { static MODULES = { @@ -20,7 +23,7 @@ class OpenAICompletionService extends BaseService { static IMPLEMENTS = { ['puter-chat-completion']: { - async complete ({ messages, test_mode }) { + async complete ({ messages, test_mode, stream }) { if ( test_mode ) { const { LoremIpsum } = require('lorem-ipsum'); const li = new LoremIpsum({ @@ -50,6 +53,7 @@ class OpenAICompletionService extends BaseService { return await this.complete(messages, { model, moderation: true, + stream, }); } } @@ -76,7 +80,7 @@ class OpenAICompletionService extends BaseService { }; } - async complete (messages, { moderation, model }) { + async complete (messages, { stream, moderation, model }) { // Validate messages if ( ! Array.isArray(messages) ) { throw new Error('`messages` must be an array'); @@ -199,7 +203,35 @@ class OpenAICompletionService extends BaseService { messages: messages, model: model, max_tokens, + stream, }); + + if ( stream ) { + const entire = []; + const stream = new PassThrough(); + const retval = new TypedValue({ + $: 'stream', + content_type: 'application/x-ndjson', + chunked: true, + }, stream); + (async () => { + for await ( const chunk of completion ) { + entire.push(chunk); + if ( chunk.choices.length < 1 ) continue; + if ( chunk.choices[0].finish_reason ) { + stream.end(); + break; + } + if ( nou(chunk.choices[0].delta.content) ) continue; + const str = JSON.stringify({ + text: chunk.choices[0].delta.content + }); + stream.write(str + '\n'); + } + })(); + return retval; + } + this.log.info('how many choices?: ' + completion.choices.length); @@ -244,7 +276,7 @@ class OpenAICompletionService extends BaseService { throw new Error('message is not allowed'); } } - + return completion.choices[0]; } } diff --git a/src/backend/src/routers/drivers/call.js b/src/backend/src/routers/drivers/call.js index 4c66a855..3236b1b8 100644 --- a/src/backend/src/routers/drivers/call.js +++ b/src/backend/src/routers/drivers/call.js @@ -84,7 +84,7 @@ module.exports = eggspress('/drivers/call', { // consider the case where a driver method implements a // stream transformation, thus the stream from the request isn't // consumed until the response is being sent. - + _respond(res, result); // What we _can_ do is await the request promise while responding @@ -95,8 +95,12 @@ module.exports = eggspress('/drivers/call', { const _respond = (res, result) => { if ( result.result instanceof TypedValue ) { const tv = result.result; + debugger; if ( TypeSpec.adapt({ $: 'stream' }).equals(tv.type) ) { res.set('Content-Type', tv.type.raw.content_type); + if ( tv.type.raw.chunked ) { + res.set('Transfer-Encoding', 'chunked'); + } tv.value.pipe(res); return; } diff --git a/src/backend/src/services/drivers/CoercionService.js b/src/backend/src/services/drivers/CoercionService.js index c3784e17..5caceb09 100644 --- a/src/backend/src/services/drivers/CoercionService.js +++ b/src/backend/src/services/drivers/CoercionService.js @@ -88,7 +88,7 @@ class CoercionService extends BaseService { return coerced; } - return undefined; + return typed_value; } }