feat: add image generation driver to puterai module

This commit is contained in:
KernelDeimos 2024-08-01 17:18:30 -04:00 committed by Eric Dubé
parent 4e3bd1831e
commit fb26fdbc56
4 changed files with 153 additions and 0 deletions

View File

@ -38,6 +38,38 @@ class AIInterfaceService extends BaseService {
}
}
});
col_interfaces.set('puter-image-generation', {
description: 'AI Image Generation.',
methods: {
generate: {
description: 'Generate an image from a prompt.',
parameters: {
prompt: { type: 'string' },
},
result_choices: [
{
names: ['image'],
type: {
$: 'stream',
content_type: 'image',
}
},
{
names: ['url'],
type: {
$: 'string:url:web',
content_type: 'image',
}
},
],
result: {
description: 'URL of the generated image.',
type: 'string'
}
}
}
});
}
}

View File

@ -0,0 +1,100 @@
const BaseService = require("../../services/BaseService");
const { TypedValue } = require("../../services/drivers/meta/Runtime");
const { Context } = require("../../util/context");
class OpenAIImageGenerationService extends BaseService {
static MODULES = {
openai: require('openai'),
}
async _init () {
const sk_key =
this.config?.openai?.secret_key ??
this.global_config.openai?.secret_key;
this.openai = new this.modules.openai.OpenAI({
apiKey: sk_key
});
}
static IMPLEMENTS = {
['puter-image-generation']: {
async generate ({ prompt, test_mode }) {
const url = await this.generate(prompt, {
ratio: this.constructor.RATIO_SQUARE,
});
if ( test_mode ) {
return new TypedValue({
$: 'string:url:web',
content_type: 'image',
}, 'https://puter-sample-data.puter.site/image_example.png');
}
const image = new TypedValue({
$: 'string:url:web',
content_type: 'image'
}, url);
return image;
}
}
};
static RATIO_SQUARE = { w: 1024, h: 1024 };
static RATIO_PORTRAIT = { w: 1024, h: 1792 };
static RATIO_LANDSCAPE = { w: 1792, h: 1024 };
async generate (prompt, {
ratio,
model,
}) {
if ( typeof prompt !== 'string' ) {
throw new Error('`prompt` must be a string');
}
if ( ! ratio || ! this._validate_ratio(ratio) ) {
throw new Error('`ratio` must be a valid ratio');
}
model = model ?? 'dall-e-3';
const user_private_uid = Context.get('actor')?.private_uid ?? 'UNKNOWN';
if ( user_private_uid === 'UNKNOWN' ) {
this.errors.report('chat-completion-service:unknown-user', {
message: 'failed to get a user ID for an OpenAI request',
alarm: true,
trace: true,
});
}
const result =
await this.openai.images.generate({
user: user_private_uid,
prompt,
size: `${ratio.w}x${ratio.h}`,
});
const spending_meta = {
model,
size: `${ratio.w}x${ratio.h}`,
};
const svc_spending = Context.get('services').get('spending');
svc_spending.record_spending('openai', 'image-generation', spending_meta);
const url = result.data?.[0]?.url;
return url;
}
_validate_ratio (ratio) {
return false
|| ratio === this.constructor.RATIO_SQUARE
|| ratio === this.constructor.RATIO_PORTRAIT
|| ratio === this.constructor.RATIO_LANDSCAPE
;
}
}
module.exports = {
OpenAIImageGenerationService,
};

View File

@ -12,6 +12,9 @@ class PuterAIModule extends AdvancedBase {
const { OpenAICompletionService } = require('./OpenAICompletionService');
services.registerService('openai-completion', OpenAICompletionService);
const { OpenAIImageGenerationService } = require('./OpenAIImageGenerationService');
services.registerService('openai-image-generation', OpenAIImageGenerationService);
}
}

View File

@ -41,4 +41,22 @@ await (await fetch("http://api.puter.localhost:4100/drivers/call", {
}),
"method": "POST",
})).json();
```
```javascript
URL.createObjectURL(await (await fetch("http://api.puter.localhost:4100/drivers/call", {
"headers": {
"Content-Type": "application/json",
"Authorization": `Bearer ${puter.authToken}`,
},
"body": JSON.stringify({
interface: 'puter-image-generation',
driver: 'openai-image-generation',
method: 'generate',
args: {
prompt: 'photorealistic teapot made of swiss cheese',
}
}),
"method": "POST",
})).blob());
```