wiseflow/dashboard/get_report.py
2024-06-15 15:41:31 +08:00

228 lines
8.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import random
import re
import os
from core.backend import dashscope_llm
from docx import Document
from docx.oxml.ns import qn
from docx.shared import Pt, RGBColor
from docx.enum.text import WD_PARAGRAPH_ALIGNMENT
from datetime import datetime
from general_utils import isChinesePunctuation
from general_utils import get_logger_level
from loguru import logger
from pb_api import PbTalker
project_dir = os.environ.get("PROJECT_DIR", "")
os.makedirs(project_dir, exist_ok=True)
logger_file = os.path.join(project_dir, 'backend_service.log')
dsw_log = get_logger_level()
logger.add(
logger_file,
level=dsw_log,
backtrace=True,
diagnose=True,
rotation="50 MB"
)
pb = PbTalker(logger)
# qwen-72b-chat支持最大30k输入考虑prompt其他部分content不应超过30000字符长度
# 如果换qwen-max最大输入6k),这里就要换成6000,但这样很多文章不能分析了
# 本地部署模型qwen-14b这里可能仅支持4k输入可能根本这套模式就行不通
max_input_tokens = 30000
role_config = pb.read(collection_name='roleplays', filter=f'activated=True')
_role_config_id = ''
if role_config:
character = role_config[0]['character']
report_type = role_config[0]['report_type']
_role_config_id = role_config[0]['id']
else:
character, report_type = '', ''
if not character:
character = input('\033[0;32m 请为首席情报官指定角色设定eg. 来自中国的网络安全情报专家):\033[0m\n')
_role_config_id = pb.add(collection_name='roleplays', body={'character': character, 'activated': True})
if not _role_config_id:
raise Exception('pls check pb data无法获取角色设定')
if not report_type:
report_type = input('\033[0;32m 请为首席情报官指定报告类型eg. 网络安全情报):\033[0m\n')
_ = pb.update(collection_name='roleplays', id=_role_config_id, body={'report_type': report_type})
def get_report(insigt: str, articles: list[dict], memory: str, topics: list[str], comment: str, docx_file: str) -> (bool, str):
zh_index = ['', '', '', '', '', '', '', '', '', '', '十一', '十二']
if isChinesePunctuation(insigt[-1]):
insigt = insigt[:-1]
# 分离段落和标题
if len(topics) == 0:
title = ''
elif len(topics) == 1:
title = topics[0]
topics = []
else:
title = topics[0]
topics = [s.strip() for s in topics[1:] if s.strip()]
schema = f'【标题】{title}\n\n【综述】\n\n'
if topics:
for i in range(len(topics)):
schema += f'{zh_index[i]}{topics[i]}\n\n'
# 先判断是否是修改要求有原文和评论且原文的段落要求与给到的topics一致
system_prompt, user_prompt = '', ''
if memory and comment:
paragraphs = re.findall("、(.*?)】", memory)
if set(topics) <= set(paragraphs):
logger.debug("no change in Topics, need modified the report")
system_prompt = f'''你是一名{character},你近日向上级提交了一份{report_type}报告,如下是报告原文。接下来你将收到来自上级部门的修改意见,请据此修改你的报告:
报告原文:
"""{memory}"""
'''
user_prompt = f'上级部门修改意见:"""{comment}"""'
if not system_prompt or not user_prompt:
logger.debug("need generate the report")
texts = ''
for article in articles:
if article['content']:
texts += f"<article>{article['content']}</article>\n"
else:
if article['abstract']:
texts += f"<article>{article['abstract']}</article>\n"
else:
texts += f"<article>{article['title']}</article>\n"
if len(texts) > max_input_tokens:
break
logger.debug(f"articles context length: {len(texts)}")
system_prompt = f'''你是一名{character},在近期的工作中我们从所关注的网站中发现了一条重要的{report_type}线索线索和相关文章用XML标签分隔如下
情报线索: """{insigt} """
相关文章:
{texts}
现在请基于这些信息按要求输出专业的书面报告。'''
if comment:
user_prompt = (f'1、不管原始资料是什么语言你必须使用简体中文输出报告除非是人名、组织和机构的名称、缩写'
f'2、对事实的陈述务必基于所提供的相关文章绝对不可以臆想3、{comment}\n')
else:
user_prompt = ('1、不管原始资料是什么语言你必须使用简体中文输出报告除非是人名、组织和机构的名称、缩写'
'2、对事实的陈述务必基于所提供的相关文章绝对不可以臆想。')
user_prompt += f'\n请按如下格式输出你的报告:\n{schema}'
# 生成阶段
check_flag = False
check_list = schema.split('\n\n')
check_list = [_[1:] for _ in check_list if _.startswith('')]
result = ''
for i in range(2):
result = dashscope_llm([{'role': 'system', 'content': system_prompt}, {'role': 'user', 'content': user_prompt}],
'qwen1.5-72b-chat', seed=random.randint(1, 10000), logger=logger)
logger.debug(f"raw result:\n{result}")
if len(result) > 50:
check_flag = True
for check_item in check_list[2:]:
if check_item not in result:
check_flag = False
break
if check_flag:
break
logger.debug("result not good, re-generating...")
if not check_flag:
# 这里其实存在两种情况一个是llm失效一个是多次尝试后生成结果还是不行
if not result:
logger.warning('report-process-error: LLM out of work!')
return False, ''
else:
logger.warning('report-process-error: cannot generate, change topics and insight, then re-try')
return False, ''
# parse process
contents = result.split("")
bodies = {}
for text in contents:
for item in check_list:
if text.startswith(item):
check_list.remove(item)
key, value = text.split("")
value = value.strip()
if isChinesePunctuation(value[0]):
value = value[1:]
bodies[key] = value.strip()
break
if not bodies:
logger.warning('report-process-error: cannot generate, change topics and insight, then re-try')
return False, ''
if '标题' not in bodies:
if "" in contents[0]:
_title = contents[0].split("")[0]
bodies['标题'] = _title.strip()
else:
if len(contents) > 1 and "" in contents[1]:
_title = contents[0].split("")[0]
bodies['标题'] = _title.strip()
else:
bodies['标题'] = ""
doc = Document()
doc.styles['Normal'].font.name = u'宋体'
doc.styles['Normal']._element.rPr.rFonts.set(qn('w:eastAsia'), u'宋体')
doc.styles['Normal'].font.size = Pt(12)
doc.styles['Normal'].font.color.rgb = RGBColor(0, 0, 0)
# 先写好标题和摘要
if not title:
title = bodies['标题']
Head = doc.add_heading(level=1)
Head.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
run = Head.add_run(title)
run.font.name = u'Cambria'
run.font.color.rgb = RGBColor(0, 0, 0)
run._element.rPr.rFonts.set(qn('w:eastAsia'), u'Cambria')
doc.add_paragraph(
f"\n生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
del bodies['标题']
if '综述' in bodies:
doc.add_paragraph(f"\t{bodies['综述']}\n")
del bodies['综述']
# 逐段添加章节
for key, value in bodies.items():
Head = doc.add_heading(level=2)
run = Head.add_run(key)
run.font.name = u'Cambria'
run.font.color.rgb = RGBColor(0, 0, 0)
doc.add_paragraph(f"{value}\n")
# 添加附件引用信息源
Head = doc.add_heading(level=2)
run = Head.add_run("附:原始信息网页")
run.font.name = u'Cambria'
run.font.color.rgb = RGBColor(0, 0, 0)
contents = []
for i, article in enumerate(articles):
date_text = str(article['publish_time'])
if len(date_text) == 8:
date_text = f"{date_text[:4]}-{date_text[4:6]}-{date_text[6:]}"
contents.append(f"{i+1}{article['title']}|{date_text}\n{article['url']} ")
doc.add_paragraph("\n\n".join(contents))
doc.save(docx_file)
return True, result[result.find(""):]