mirror of
https://github.com/TeamWiseFlow/wiseflow.git
synced 2025-01-24 03:30:20 +08:00
244 lines
11 KiB
Python
244 lines
11 KiB
Python
from embeddings import embed_model, reranker
|
||
from langchain_community.vectorstores import FAISS
|
||
from langchain_core.documents import Document
|
||
from langchain_community.vectorstores.utils import DistanceStrategy
|
||
from langchain.retrievers import ContextualCompressionRetriever
|
||
from llms.dashscope_wrapper import dashscope_llm
|
||
from general_utils import isChinesePunctuation, is_chinese
|
||
from tranlsation_volcengine import text_translate
|
||
import time
|
||
import re
|
||
import configparser
|
||
|
||
max_tokens = 4000
|
||
relation_theshold = 0.525
|
||
|
||
config = configparser.ConfigParser()
|
||
config.read('../config.ini')
|
||
|
||
# 实践证明,如果强调让llm挖掘我国值得关注的线索,则挖掘效果不好(容易被新闻内容误导,错把别的国家当成我国,可能这时新闻内有我国这样的表述)
|
||
# step by step 如果是内心独白方式,输出格式包含两种,难度增加了,qwen-max不能很好的适应,也许可以改成两步,第一步先输出线索列表,第二步再会去找对应的新闻编号
|
||
# 但从实践来看,这样做的性价比并不高,且会引入新的不确定性。
|
||
_first_stage_prompt = f'''你是一名{config['prompts']['character']},你将被给到一个新闻列表,新闻文章用XML标签分隔。请对此进行分析,挖掘出特别值得{config['prompts']['focus']}线索。你给出的线索应该足够具体,而不是同类型新闻的归类描述,好的例子如:
|
||
"""{config['prompts']['good_sample1']}"""
|
||
不好的例子如:
|
||
"""{config['prompts']['bad_sample']}"""
|
||
|
||
请从头到尾仔细阅读每一条新闻的内容,不要遗漏,然后列出值得关注的线索,每条线索都用一句话进行描述,最终按一条一行的格式输出,并整体用三引号包裹,如下所示:
|
||
"""
|
||
{config['prompts']['good_sample1']}
|
||
{config['prompts']['good_sample2']}
|
||
"""
|
||
|
||
不管新闻列表是何种语言,请仅用中文输出分析结果。'''
|
||
|
||
_rewrite_insight_prompt = f'''你是一名{config['prompts']['character']},你将被给到一个新闻列表,新闻文章用 XML 标签分隔。请对此进行分析,从中挖掘出一条最值得关注的{config['prompts']['focus_type']}线索。你给出的线索应该足够具体,而不是同类型新闻的归类描述,好的例子如:
|
||
"""{config['prompts']['good_sample1']}"""
|
||
不好的例子如:
|
||
"""{config['prompts']['bad_sample']}"""
|
||
|
||
请保证只输出一条最值得关注的线索,线索请用一句话描述,并用三引号包裹输出,如下所示:
|
||
"""{config['prompts']['good_sample1']}"""
|
||
|
||
不管新闻列表是何种语言,请仅用中文输出分析结果。'''
|
||
|
||
|
||
def _parse_insight(article_text: str, cache: dict, logger=None) -> (bool, dict):
|
||
input_length = len(cache)
|
||
result = dashscope_llm([{'role': 'system', 'content': _first_stage_prompt}, {'role': 'user', 'content': article_text}],
|
||
'qwen1.5-72b-chat', logger=logger)
|
||
if result:
|
||
pattern = re.compile(r'\"\"\"(.*?)\"\"\"', re.DOTALL)
|
||
result = pattern.findall(result)
|
||
else:
|
||
logger.warning('1st-stage llm generate failed: no result')
|
||
|
||
if result:
|
||
try:
|
||
results = result[0].split('\n')
|
||
results = [_.strip() for _ in results if _.strip()]
|
||
to_del = []
|
||
to_add = []
|
||
for element in results:
|
||
if ";" in element:
|
||
to_del.append(element)
|
||
to_add.extend(element.split(';'))
|
||
for element in to_del:
|
||
results.remove(element)
|
||
results.extend(to_add)
|
||
results = list(set(results))
|
||
for text in results:
|
||
logger.debug(f'parse result: {text}')
|
||
# qwen-72b-chat 特例
|
||
# potential_insight = re.sub(r'编号[^:]*:', '', text)
|
||
potential_insight = text.strip()
|
||
if len(potential_insight) < 2:
|
||
logger.debug(f'parse failed: not enough potential_insight: {potential_insight}')
|
||
continue
|
||
if isChinesePunctuation(potential_insight[-1]):
|
||
potential_insight = potential_insight[:-1]
|
||
if potential_insight in cache:
|
||
continue
|
||
else:
|
||
cache[potential_insight] = []
|
||
except Exception as e:
|
||
logger.debug(f'parse failed: {e}')
|
||
|
||
output_length = len(cache)
|
||
if input_length == output_length:
|
||
return True, cache
|
||
return False, cache
|
||
|
||
|
||
def _rewrite_insight(context: str, logger=None) -> (bool, str):
|
||
result = dashscope_llm([{'role': 'system', 'content': _rewrite_insight_prompt}, {'role': 'user', 'content': context}],
|
||
'qwen1.5-72b-chat', logger=logger)
|
||
if result:
|
||
pattern = re.compile(r'\"\"\"(.*?)\"\"\"', re.DOTALL)
|
||
result = pattern.findall(result)
|
||
else:
|
||
logger.warning(f'insight rewrite process llm generate failed: no result')
|
||
|
||
if not result:
|
||
return True, ''
|
||
try:
|
||
results = result[0].split('\n')
|
||
text = results[0].strip()
|
||
logger.debug(f'parse result: {text}')
|
||
if len(text) < 2:
|
||
logger.debug(f'parse failed: not enough potential_insight: {text}')
|
||
return True, ''
|
||
if isChinesePunctuation(text[-1]):
|
||
text = text[:-1]
|
||
except Exception as e:
|
||
logger.debug(f'parse failed: {e}')
|
||
return True, ''
|
||
return False, text
|
||
|
||
|
||
def get_insight(articles: dict, titles: dict, logger=None) -> list:
|
||
context = ''
|
||
cache = {}
|
||
for value in articles.values():
|
||
if value['abstract']:
|
||
text = value['abstract']
|
||
else:
|
||
if value['title']:
|
||
text = value['title']
|
||
else:
|
||
if value['content']:
|
||
text = value['content']
|
||
else:
|
||
continue
|
||
# 这里不使用long context是因为阿里灵积经常检查出输入敏感词,但又不给敏感词反馈,对应批次只能放弃,用long context风险太大
|
||
# 另外long context中间部分llm可能会遗漏
|
||
context += f"<article>{text}</article>\n"
|
||
if len(context) < max_tokens:
|
||
continue
|
||
|
||
flag, cache = _parse_insight(context, cache, logger)
|
||
if flag:
|
||
logger.warning(f'following articles may not be completely analyzed: \n{context}')
|
||
|
||
context = ''
|
||
# 据说频繁调用会引发性能下降,每次调用后休息1s。现在轮替调用qwen-72b和max,所以不必了。
|
||
time.sleep(1)
|
||
if context:
|
||
flag, cache = _parse_insight(context, cache, logger)
|
||
if flag:
|
||
logger.warning(f'following articles may not be completely analyzed: \n{context}')
|
||
|
||
if not cache:
|
||
logger.warning('no insights found')
|
||
return []
|
||
|
||
# second stage: 匹配insights和article_titles
|
||
title_list = [Document(page_content=key, metadata={}) for key, value in titles.items()]
|
||
retriever = FAISS.from_documents(title_list, embed_model,
|
||
distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT).as_retriever(search_type="similarity",
|
||
search_kwargs={"score_threshold": relation_theshold, "k": 10})
|
||
compression = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=retriever)
|
||
|
||
for key in cache.keys():
|
||
logger.debug(f'searching related articles for insight: {key}')
|
||
rerank_results = compression.get_relevant_documents(key)
|
||
for i in range(len(rerank_results)):
|
||
if rerank_results[i].metadata['relevance_score'] < relation_theshold:
|
||
break
|
||
cache[key].append(titles[rerank_results[i].page_content])
|
||
if titles[rerank_results[i].page_content] not in articles:
|
||
articles[titles[rerank_results[i].page_content]] = {'title': rerank_results[i].page_content}
|
||
logger.info(f'{key} - {cache[key]}')
|
||
|
||
# third stage:对于对应文章重叠率超过25%的合并,然后对于有多个文章的,再次使用llm生成insight
|
||
# 因为实践中发现,第一次insight召回的文章标题可能都很相关,但是汇总起来却指向另一个角度的insight
|
||
def calculate_overlap(list1, list2):
|
||
# 计算两个列表的交集长度
|
||
intersection_length = len(set(list1).intersection(set(list2)))
|
||
# 计算重合率
|
||
overlap_rate = intersection_length / min(len(list1), len(list2))
|
||
return overlap_rate >= 0.75
|
||
|
||
merged_dict = {}
|
||
for key, value in cache.items():
|
||
if not value:
|
||
continue
|
||
merged = False
|
||
for existing_key, existing_value in merged_dict.items():
|
||
if calculate_overlap(value, existing_value):
|
||
merged_dict[existing_key].extend(value)
|
||
merged = True
|
||
break
|
||
if not merged:
|
||
merged_dict[key] = value
|
||
|
||
cache = {}
|
||
for key, value in merged_dict.items():
|
||
value = list(set(value))
|
||
if len(value) > 1:
|
||
context = ''
|
||
for _id in value:
|
||
context += f"<article>{articles[_id]['title']}</article>\n"
|
||
if len(context) >= max_tokens:
|
||
break
|
||
if not context:
|
||
continue
|
||
|
||
flag, new_insight = _rewrite_insight(context, logger)
|
||
if flag:
|
||
logger.warning(f'insight {key} may contain wrong')
|
||
cache[key] = value
|
||
else:
|
||
if cache:
|
||
title_list = [Document(page_content=key, metadata={}) for key, value in cache.items()]
|
||
retriever = FAISS.from_documents(title_list, embed_model,
|
||
distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT).as_retriever(
|
||
search_type="similarity",
|
||
search_kwargs={"score_threshold": 0.85, "k": 1})
|
||
compression = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=retriever)
|
||
rerank_results = compression.get_relevant_documents(new_insight)
|
||
if rerank_results and rerank_results[0].metadata['relevance_score'] > 0.85:
|
||
logger.debug(f"{new_insight} is too similar to {rerank_results[0].page_content}, merging")
|
||
cache[rerank_results[0].page_content].extend(value)
|
||
cache[rerank_results[0].page_content] = list(set(cache[rerank_results[0].page_content]))
|
||
else:
|
||
cache[new_insight] = value
|
||
else:
|
||
cache[new_insight] = value
|
||
else:
|
||
cache[key] = value
|
||
|
||
# 排序,对应articles越多的越靠前
|
||
# sorted_cache = sorted(cache.items(), key=lambda x: len(x[1]), reverse=True)
|
||
logger.info('re-ranking ressult:')
|
||
new_cache = []
|
||
for key, value in cache.items():
|
||
if not is_chinese(key):
|
||
translate_text = text_translate([key], target_language='zh', logger=logger)
|
||
if translate_text:
|
||
key = translate_text[0]
|
||
logger.info(f'{key} - {value}')
|
||
new_cache.append({'content': key, 'articles': value})
|
||
|
||
return new_cache
|