diff --git a/README.md b/README.md index 564a4b9..1b021c9 100644 --- a/README.md +++ b/README.md @@ -16,11 +16,19 @@ Email:35252986@qq.com - 数据库管理 - 支持针对特定站点的自定义爬虫集成,并提供本地定时扫描任务…… +## change log + +【2024.5.8】增加对openai SDK的支持,现在可以通过调用llms.openai_wrapper使用所有兼容openai SDK的大模型服务,具体见 [client/backend/llms/README.md](client/backend/llms/README.md) + **产品介绍视频:** [![wiseflow repo demo](https://res.cloudinary.com/marcomontalbano/image/upload/v1714005731/video_to_markdown/images/youtube--80KqYgE8utE-c05b58ac6eb4c4700831b2b3070cd403.jpg)](https://www.bilibili.com/video/BV17F4m1w7Ed/?share_source=copy_web&vd_source=5ad458dc9dae823257e82e48e0751e25 "wiseflow repo demo") -打不开看这里:https://www.bilibili.com/video/BV17F4m1w7Ed/?share_source=copy_web&vd_source=5ad458dc9dae823257e82e48e0751e25 +**打不开看这里** + +Youtube:https://www.youtube.com/watch?v=80KqYgE8utE&t=8s + +b站:https://www.bilibili.com/video/BV17F4m1w7Ed/?share_source=copy_web&vd_source=5ad458dc9dae823257e82e48e0751e25 ## getting started diff --git a/client/backend/llms/README.md b/client/backend/llms/README.md index e915538..21eb7f8 100644 --- a/client/backend/llms/README.md +++ b/client/backend/llms/README.md @@ -1,6 +1,8 @@ -## 使用阿里灵积提供的API接口服务 +## 考虑到很多LLM提供商都兼容openai SDK,所以本项目目前提供dashscope和openai两种wrapper,不管使用哪个都请提前设定 -export DASHSCOPE_API_KEY= +export LLM_API_KEY= + +二者的用法也是完全一样的 ```python from llms.dashscope_wrapper import dashscope_llm @@ -9,14 +11,15 @@ result = dashscope_llm([{'role': 'system', 'content': '''}, {'role': 'user', 'co logger=logger) ``` -## 使用智谱提供的API接口服务(暂时只支持glm4) +*若要使用openai wrapper调用非openai服务,请提前设定 -export ZHIPUAI_API_KEY= +export LLM_API_BASE= ```python -from llms.zhipu_wrapper import zhipuai_llm +from llms.openai_wrapper import openai_llm -result = zhipuai_llm([{'role': 'system', 'content': ''}, {'role': 'user', 'content': ''}], logger=logger) +result = openai_llm([{'role': 'system', 'content': '''}, {'role': 'user', 'content': '''}], 'deepseek-chat', + logger=logger) ``` ## 对于本地部署模型的支持 diff --git a/client/backend/llms/dashscope_wrapper.py b/client/backend/llms/dashscope_wrapper.py index a5b8e89..61ccd81 100644 --- a/client/backend/llms/dashscope_wrapper.py +++ b/client/backend/llms/dashscope_wrapper.py @@ -5,37 +5,32 @@ import time from http import HTTPStatus import dashscope import random +import os -def dashscope_llm(messages: list, - model: str, - seed: int = 1234, - max_tokens: int = 2000, - temperature: float = 1, - stop: list = None, - enable_search: bool = False, - logger=None) -> str: +DASHSCOPE_KEY = os.getenv("LLM_API_KEY") +if not DASHSCOPE_KEY: + raise ValueError("请指定LLM_API_KEY的环境变量") +dashscope.api_key = DASHSCOPE_KEY + + +def dashscope_llm(messages: list, model: str, logger=None, **kwargs) -> str: if logger: logger.debug(f'messages:\n {messages}') - logger.debug(f'params:\n model: {model}, max_tokens: {max_tokens}, temperature: {temperature}, stop: {stop},' - f'enable_search: {enable_search}, seed: {seed}') + logger.debug(f'model: {model}') + logger.debug(f'kwargs:\n {kwargs}') - for i in range(3): - response = dashscope.Generation.call( - model=model, - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - stop=stop, - enable_search=enable_search, - seed=seed, - result_format='message', # set the result to be "message" format. - ) + response = dashscope.Generation.call( + messages=messages, + model=model, + result_format='message', # set the result to be "message" format. + **kwargs + ) + for i in range(2): if response.status_code == HTTPStatus.OK: break - if response.message == "Input data may contain inappropriate content.": break @@ -45,7 +40,13 @@ def dashscope_llm(messages: list, print(f"request failed. code: {response.code}, message:{response.message}\nretrying...") time.sleep(1 + i*30) - seed = random.randint(1, 10000) + kwargs['seed'] = random.randint(1, 10000) + response = dashscope.Generation.call( + messages=messages, + model=model, + result_format='message', # set the result to be "message" format. + **kwargs + ) if response.status_code != HTTPStatus.OK: if logger: diff --git a/client/backend/llms/zhipuai_wrapper.py b/client/backend/llms/openai_wrapper.py similarity index 52% rename from client/backend/llms/zhipuai_wrapper.py rename to client/backend/llms/openai_wrapper.py index 4ca4704..253c936 100644 --- a/client/backend/llms/zhipuai_wrapper.py +++ b/client/backend/llms/openai_wrapper.py @@ -1,48 +1,60 @@ +''' +除了openai外,很多大模型提供商也都使用openai的SDK,对于这一类可以统一使用本wrapper +这里演示使用deepseek提供的DeepSeek-V2 +''' + +import random import os - -from zhipuai import ZhipuAI - -zhipu_token = os.environ.get('ZHIPUAI_API_KEY', "") -if not zhipu_token: - raise ValueError('请设置环境变量ZHIPUAI_API_KEY') - -client = ZhipuAI(api_key=zhipu_token) # 填写您自己的APIKey +from openai import OpenAI +import time -def zhipuai_llm(messages: list, - model: str, - seed: int = 1234, - max_tokens: int = 2000, - temperature: float = 0.8, - stop: list = None, - enable_search: bool = False, - logger=None) -> str: +token = os.environ.get('LLM_API_KEY', "") +if not token: + raise ValueError('请设置环境变量LLM_API_KEY') + +base_url = os.environ.get('LLM_API_BASE', "") + +client = OpenAI(api_key=token, base_url=base_url) + + +def openai_llm(messages: list, model: str, logger=None, **kwargs) -> str: if logger: logger.debug(f'messages:\n {messages}') - logger.debug(f'params:\n model: {model}, max_tokens: {max_tokens}, temperature: {temperature}, stop: {stop},' - f'enable_search: {enable_search}, seed: {seed}') + logger.debug(f'model: {model}') + logger.debug(f'kwargs:\n {kwargs}') - for i in range(3): - try: - response = client.chat.completions.create( - model="glm-4", # 填写需要调用的模型名称 - seed=seed, - messages=messages, - temperature=temperature, - max_tokens=max_tokens, - ) - if response and response.choices: - break - except Exception as e: - if logger: - logger.error(f'error:\n {e}') - else: - print(e) - continue + response = client.chat.completions.create(messages=messages, model=model, **kwargs) + + for i in range(2): + if response and response.choices: + break + + if logger: + logger.warning(f"request failed. code: {response}\nretrying...") + else: + print(f"request failed. code: {response}\nretrying...") + + time.sleep(1 + i * 30) + kwargs['seed'] = random.randint(1, 10000) + response = client.chat.completions.create( + messages=messages, + model=model, + **kwargs + ) + + if not response or not response.choices: + if logger: + logger.warning( + f"request failed. code: {response}\nabort after multiple retries...") + else: + print( + f"request failed. code: {response}\naborted after multiple retries...") + return '' if logger: - logger.debug(f'result:\n {response}') + logger.debug(f'result:\n {response.choices[0]}') logger.debug(f'usage:\n {response.usage}') return response.choices[0].message.content @@ -76,5 +88,5 @@ Hackers that breached Las Vegas casinos rely on violent threats, research shows' data = [{'role': 'user', 'content': user_content}] start_time = time.time() - pprint(zhipuai_llm(data, 'glm-4')) + pprint(openai_llm(data, "deepseek-chat")) print(f'time cost: {time.time() - start_time}') diff --git a/client/backend/requirements.txt b/client/backend/requirements.txt index 6a3baf6..f6ad15f 100644 --- a/client/backend/requirements.txt +++ b/client/backend/requirements.txt @@ -2,6 +2,7 @@ fastapi pydantic uvicorn dashscope #optional(使用阿里灵积时安装) +openai #optional(使用兼容openai sdk的llm服务时安装) volcengine #optional(使用火山翻译时安装) python-docx BCEmbedding==0.1.3 diff --git a/client/env_sample b/client/env_sample index 2e4fe41..8c937d9 100755 --- a/client/env_sample +++ b/client/env_sample @@ -1,4 +1,5 @@ -export DASHSCOPE_API_KEY="" +export LLM_API_KEY="" +export LLM_API_BASE="" ##使用本地模型服务或者使用openai_wrapper调用非openai服务时用 export VOLC_KEY="AK|SK" #**for embeddig model**