mirror of
https://github.com/TeamWiseFlow/wiseflow.git
synced 2025-02-02 18:28:46 +08:00
add openaiSDK
This commit is contained in:
parent
cb32415d28
commit
c5cebcce36
10
README.md
10
README.md
@ -16,11 +16,19 @@ Email:35252986@qq.com
|
|||||||
- 数据库管理
|
- 数据库管理
|
||||||
- 支持针对特定站点的自定义爬虫集成,并提供本地定时扫描任务……
|
- 支持针对特定站点的自定义爬虫集成,并提供本地定时扫描任务……
|
||||||
|
|
||||||
|
## change log
|
||||||
|
|
||||||
|
【2024.5.8】增加对openai SDK的支持,现在可以通过调用llms.openai_wrapper使用所有兼容openai SDK的大模型服务,具体见 [client/backend/llms/README.md](client/backend/llms/README.md)
|
||||||
|
|
||||||
**产品介绍视频:**
|
**产品介绍视频:**
|
||||||
|
|
||||||
[![wiseflow repo demo](https://res.cloudinary.com/marcomontalbano/image/upload/v1714005731/video_to_markdown/images/youtube--80KqYgE8utE-c05b58ac6eb4c4700831b2b3070cd403.jpg)](https://www.bilibili.com/video/BV17F4m1w7Ed/?share_source=copy_web&vd_source=5ad458dc9dae823257e82e48e0751e25 "wiseflow repo demo")
|
[![wiseflow repo demo](https://res.cloudinary.com/marcomontalbano/image/upload/v1714005731/video_to_markdown/images/youtube--80KqYgE8utE-c05b58ac6eb4c4700831b2b3070cd403.jpg)](https://www.bilibili.com/video/BV17F4m1w7Ed/?share_source=copy_web&vd_source=5ad458dc9dae823257e82e48e0751e25 "wiseflow repo demo")
|
||||||
|
|
||||||
打不开看这里:https://www.bilibili.com/video/BV17F4m1w7Ed/?share_source=copy_web&vd_source=5ad458dc9dae823257e82e48e0751e25
|
**打不开看这里**
|
||||||
|
|
||||||
|
Youtube:https://www.youtube.com/watch?v=80KqYgE8utE&t=8s
|
||||||
|
|
||||||
|
b站:https://www.bilibili.com/video/BV17F4m1w7Ed/?share_source=copy_web&vd_source=5ad458dc9dae823257e82e48e0751e25
|
||||||
|
|
||||||
## getting started
|
## getting started
|
||||||
|
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
## 使用阿里灵积提供的API接口服务
|
## 考虑到很多LLM提供商都兼容openai SDK,所以本项目目前提供dashscope和openai两种wrapper,不管使用哪个都请提前设定
|
||||||
|
|
||||||
export DASHSCOPE_API_KEY=
|
export LLM_API_KEY=
|
||||||
|
|
||||||
|
二者的用法也是完全一样的
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from llms.dashscope_wrapper import dashscope_llm
|
from llms.dashscope_wrapper import dashscope_llm
|
||||||
@ -9,14 +11,15 @@ result = dashscope_llm([{'role': 'system', 'content': '''}, {'role': 'user', 'co
|
|||||||
logger=logger)
|
logger=logger)
|
||||||
```
|
```
|
||||||
|
|
||||||
## 使用智谱提供的API接口服务(暂时只支持glm4)
|
*若要使用openai wrapper调用非openai服务,请提前设定
|
||||||
|
|
||||||
export ZHIPUAI_API_KEY=
|
export LLM_API_BASE=
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from llms.zhipu_wrapper import zhipuai_llm
|
from llms.openai_wrapper import openai_llm
|
||||||
|
|
||||||
result = zhipuai_llm([{'role': 'system', 'content': ''}, {'role': 'user', 'content': ''}], logger=logger)
|
result = openai_llm([{'role': 'system', 'content': '''}, {'role': 'user', 'content': '''}], 'deepseek-chat',
|
||||||
|
logger=logger)
|
||||||
```
|
```
|
||||||
|
|
||||||
## 对于本地部署模型的支持
|
## 对于本地部署模型的支持
|
||||||
|
@ -5,37 +5,32 @@ import time
|
|||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
import dashscope
|
import dashscope
|
||||||
import random
|
import random
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
def dashscope_llm(messages: list,
|
DASHSCOPE_KEY = os.getenv("LLM_API_KEY")
|
||||||
model: str,
|
if not DASHSCOPE_KEY:
|
||||||
seed: int = 1234,
|
raise ValueError("请指定LLM_API_KEY的环境变量")
|
||||||
max_tokens: int = 2000,
|
dashscope.api_key = DASHSCOPE_KEY
|
||||||
temperature: float = 1,
|
|
||||||
stop: list = None,
|
|
||||||
enable_search: bool = False,
|
def dashscope_llm(messages: list, model: str, logger=None, **kwargs) -> str:
|
||||||
logger=None) -> str:
|
|
||||||
|
|
||||||
if logger:
|
if logger:
|
||||||
logger.debug(f'messages:\n {messages}')
|
logger.debug(f'messages:\n {messages}')
|
||||||
logger.debug(f'params:\n model: {model}, max_tokens: {max_tokens}, temperature: {temperature}, stop: {stop},'
|
logger.debug(f'model: {model}')
|
||||||
f'enable_search: {enable_search}, seed: {seed}')
|
logger.debug(f'kwargs:\n {kwargs}')
|
||||||
|
|
||||||
for i in range(3):
|
|
||||||
response = dashscope.Generation.call(
|
response = dashscope.Generation.call(
|
||||||
model=model,
|
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=max_tokens,
|
model=model,
|
||||||
temperature=temperature,
|
|
||||||
stop=stop,
|
|
||||||
enable_search=enable_search,
|
|
||||||
seed=seed,
|
|
||||||
result_format='message', # set the result to be "message" format.
|
result_format='message', # set the result to be "message" format.
|
||||||
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for i in range(2):
|
||||||
if response.status_code == HTTPStatus.OK:
|
if response.status_code == HTTPStatus.OK:
|
||||||
break
|
break
|
||||||
|
|
||||||
if response.message == "Input data may contain inappropriate content.":
|
if response.message == "Input data may contain inappropriate content.":
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -45,7 +40,13 @@ def dashscope_llm(messages: list,
|
|||||||
print(f"request failed. code: {response.code}, message:{response.message}\nretrying...")
|
print(f"request failed. code: {response.code}, message:{response.message}\nretrying...")
|
||||||
|
|
||||||
time.sleep(1 + i*30)
|
time.sleep(1 + i*30)
|
||||||
seed = random.randint(1, 10000)
|
kwargs['seed'] = random.randint(1, 10000)
|
||||||
|
response = dashscope.Generation.call(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
result_format='message', # set the result to be "message" format.
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
if response.status_code != HTTPStatus.OK:
|
if response.status_code != HTTPStatus.OK:
|
||||||
if logger:
|
if logger:
|
||||||
|
@ -1,48 +1,60 @@
|
|||||||
|
'''
|
||||||
|
除了openai外,很多大模型提供商也都使用openai的SDK,对于这一类可以统一使用本wrapper
|
||||||
|
这里演示使用deepseek提供的DeepSeek-V2
|
||||||
|
'''
|
||||||
|
|
||||||
|
import random
|
||||||
import os
|
import os
|
||||||
|
from openai import OpenAI
|
||||||
from zhipuai import ZhipuAI
|
import time
|
||||||
|
|
||||||
zhipu_token = os.environ.get('ZHIPUAI_API_KEY', "")
|
|
||||||
if not zhipu_token:
|
|
||||||
raise ValueError('请设置环境变量ZHIPUAI_API_KEY')
|
|
||||||
|
|
||||||
client = ZhipuAI(api_key=zhipu_token) # 填写您自己的APIKey
|
|
||||||
|
|
||||||
|
|
||||||
def zhipuai_llm(messages: list,
|
token = os.environ.get('LLM_API_KEY', "")
|
||||||
model: str,
|
if not token:
|
||||||
seed: int = 1234,
|
raise ValueError('请设置环境变量LLM_API_KEY')
|
||||||
max_tokens: int = 2000,
|
|
||||||
temperature: float = 0.8,
|
base_url = os.environ.get('LLM_API_BASE', "")
|
||||||
stop: list = None,
|
|
||||||
enable_search: bool = False,
|
client = OpenAI(api_key=token, base_url=base_url)
|
||||||
logger=None) -> str:
|
|
||||||
|
|
||||||
|
def openai_llm(messages: list, model: str, logger=None, **kwargs) -> str:
|
||||||
|
|
||||||
if logger:
|
if logger:
|
||||||
logger.debug(f'messages:\n {messages}')
|
logger.debug(f'messages:\n {messages}')
|
||||||
logger.debug(f'params:\n model: {model}, max_tokens: {max_tokens}, temperature: {temperature}, stop: {stop},'
|
logger.debug(f'model: {model}')
|
||||||
f'enable_search: {enable_search}, seed: {seed}')
|
logger.debug(f'kwargs:\n {kwargs}')
|
||||||
|
|
||||||
for i in range(3):
|
response = client.chat.completions.create(messages=messages, model=model, **kwargs)
|
||||||
try:
|
|
||||||
response = client.chat.completions.create(
|
for i in range(2):
|
||||||
model="glm-4", # 填写需要调用的模型名称
|
|
||||||
seed=seed,
|
|
||||||
messages=messages,
|
|
||||||
temperature=temperature,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
)
|
|
||||||
if response and response.choices:
|
if response and response.choices:
|
||||||
break
|
break
|
||||||
except Exception as e:
|
|
||||||
if logger:
|
|
||||||
logger.error(f'error:\n {e}')
|
|
||||||
else:
|
|
||||||
print(e)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if logger:
|
if logger:
|
||||||
logger.debug(f'result:\n {response}')
|
logger.warning(f"request failed. code: {response}\nretrying...")
|
||||||
|
else:
|
||||||
|
print(f"request failed. code: {response}\nretrying...")
|
||||||
|
|
||||||
|
time.sleep(1 + i * 30)
|
||||||
|
kwargs['seed'] = random.randint(1, 10000)
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response or not response.choices:
|
||||||
|
if logger:
|
||||||
|
logger.warning(
|
||||||
|
f"request failed. code: {response}\nabort after multiple retries...")
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"request failed. code: {response}\naborted after multiple retries...")
|
||||||
|
return ''
|
||||||
|
|
||||||
|
if logger:
|
||||||
|
logger.debug(f'result:\n {response.choices[0]}')
|
||||||
logger.debug(f'usage:\n {response.usage}')
|
logger.debug(f'usage:\n {response.usage}')
|
||||||
|
|
||||||
return response.choices[0].message.content
|
return response.choices[0].message.content
|
||||||
@ -76,5 +88,5 @@ Hackers that breached Las Vegas casinos rely on violent threats, research shows'
|
|||||||
|
|
||||||
data = [{'role': 'user', 'content': user_content}]
|
data = [{'role': 'user', 'content': user_content}]
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
pprint(zhipuai_llm(data, 'glm-4'))
|
pprint(openai_llm(data, "deepseek-chat"))
|
||||||
print(f'time cost: {time.time() - start_time}')
|
print(f'time cost: {time.time() - start_time}')
|
@ -2,6 +2,7 @@ fastapi
|
|||||||
pydantic
|
pydantic
|
||||||
uvicorn
|
uvicorn
|
||||||
dashscope #optional(使用阿里灵积时安装)
|
dashscope #optional(使用阿里灵积时安装)
|
||||||
|
openai #optional(使用兼容openai sdk的llm服务时安装)
|
||||||
volcengine #optional(使用火山翻译时安装)
|
volcengine #optional(使用火山翻译时安装)
|
||||||
python-docx
|
python-docx
|
||||||
BCEmbedding==0.1.3
|
BCEmbedding==0.1.3
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
export DASHSCOPE_API_KEY=""
|
export LLM_API_KEY=""
|
||||||
|
export LLM_API_BASE="" ##使用本地模型服务或者使用openai_wrapper调用非openai服务时用
|
||||||
export VOLC_KEY="AK|SK"
|
export VOLC_KEY="AK|SK"
|
||||||
|
|
||||||
#**for embeddig model**
|
#**for embeddig model**
|
||||||
|
Loading…
Reference in New Issue
Block a user