wiseflow/client/backend/get_insight.py
2024-04-07 09:37:47 +08:00

244 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from embeddings import embed_model, reranker
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from langchain_community.vectorstores.utils import DistanceStrategy
from langchain.retrievers import ContextualCompressionRetriever
from llms.dashscope_wrapper import dashscope_llm
from general_utils import isChinesePunctuation, is_chinese
from tranlsation_volcengine import text_translate
import time
import re
import configparser
max_tokens = 4000
relation_theshold = 0.525
config = configparser.ConfigParser()
config.read('../config.ini')
# 实践证明如果强调让llm挖掘我国值得关注的线索则挖掘效果不好容易被新闻内容误导错把别的国家当成我国可能这时新闻内有我国这样的表述
# step by step 如果是内心独白方式输出格式包含两种难度增加了qwen-max不能很好的适应也许可以改成两步第一步先输出线索列表第二步再会去找对应的新闻编号
# 但从实践来看,这样做的性价比并不高,且会引入新的不确定性。
_first_stage_prompt = f'''你是一名{config['prompts']['character']}你将被给到一个新闻列表新闻文章用XML标签分隔。请对此进行分析挖掘出特别值得{config['prompts']['focus']}线索。你给出的线索应该足够具体,而不是同类型新闻的归类描述,好的例子如:
"""{config['prompts']['good_sample1']}"""
不好的例子如:
"""{config['prompts']['bad_sample']}"""
请从头到尾仔细阅读每一条新闻的内容,不要遗漏,然后列出值得关注的线索,每条线索都用一句话进行描述,最终按一条一行的格式输出,并整体用三引号包裹,如下所示:
"""
{config['prompts']['good_sample1']}
{config['prompts']['good_sample2']}
"""
不管新闻列表是何种语言,请仅用中文输出分析结果。'''
_rewrite_insight_prompt = f'''你是一名{config['prompts']['character']},你将被给到一个新闻列表,新闻文章用 XML 标签分隔。请对此进行分析,从中挖掘出一条最值得关注的{config['prompts']['focus_type']}线索。你给出的线索应该足够具体,而不是同类型新闻的归类描述,好的例子如:
"""{config['prompts']['good_sample1']}"""
不好的例子如:
"""{config['prompts']['bad_sample']}"""
请保证只输出一条最值得关注的线索,线索请用一句话描述,并用三引号包裹输出,如下所示:
"""{config['prompts']['good_sample1']}"""
不管新闻列表是何种语言,请仅用中文输出分析结果。'''
def _parse_insight(article_text: str, cache: dict, logger=None) -> (bool, dict):
input_length = len(cache)
result = dashscope_llm([{'role': 'system', 'content': _first_stage_prompt}, {'role': 'user', 'content': article_text}],
'qwen1.5-72b-chat', logger=logger)
if result:
pattern = re.compile(r'\"\"\"(.*?)\"\"\"', re.DOTALL)
result = pattern.findall(result)
else:
logger.warning('1st-stage llm generate failed: no result')
if result:
try:
results = result[0].split('\n')
results = [_.strip() for _ in results if _.strip()]
to_del = []
to_add = []
for element in results:
if "" in element:
to_del.append(element)
to_add.extend(element.split(''))
for element in to_del:
results.remove(element)
results.extend(to_add)
results = list(set(results))
for text in results:
logger.debug(f'parse result: {text}')
# qwen-72b-chat 特例
# potential_insight = re.sub(r'编号[^]*', '', text)
potential_insight = text.strip()
if len(potential_insight) < 2:
logger.debug(f'parse failed: not enough potential_insight: {potential_insight}')
continue
if isChinesePunctuation(potential_insight[-1]):
potential_insight = potential_insight[:-1]
if potential_insight in cache:
continue
else:
cache[potential_insight] = []
except Exception as e:
logger.debug(f'parse failed: {e}')
output_length = len(cache)
if input_length == output_length:
return True, cache
return False, cache
def _rewrite_insight(context: str, logger=None) -> (bool, str):
result = dashscope_llm([{'role': 'system', 'content': _rewrite_insight_prompt}, {'role': 'user', 'content': context}],
'qwen1.5-72b-chat', logger=logger)
if result:
pattern = re.compile(r'\"\"\"(.*?)\"\"\"', re.DOTALL)
result = pattern.findall(result)
else:
logger.warning(f'insight rewrite process llm generate failed: no result')
if not result:
return True, ''
try:
results = result[0].split('\n')
text = results[0].strip()
logger.debug(f'parse result: {text}')
if len(text) < 2:
logger.debug(f'parse failed: not enough potential_insight: {text}')
return True, ''
if isChinesePunctuation(text[-1]):
text = text[:-1]
except Exception as e:
logger.debug(f'parse failed: {e}')
return True, ''
return False, text
def get_insight(articles: dict, titles: dict, logger=None) -> list:
context = ''
cache = {}
for value in articles.values():
if value['abstract']:
text = value['abstract']
else:
if value['title']:
text = value['title']
else:
if value['content']:
text = value['content']
else:
continue
# 这里不使用long context是因为阿里灵积经常检查出输入敏感词但又不给敏感词反馈对应批次只能放弃用long context风险太大
# 另外long context中间部分llm可能会遗漏
context += f"<article>{text}</article>\n"
if len(context) < max_tokens:
continue
flag, cache = _parse_insight(context, cache, logger)
if flag:
logger.warning(f'following articles may not be completely analyzed: \n{context}')
context = ''
# 据说频繁调用会引发性能下降每次调用后休息1s。现在轮替调用qwen-72b和max所以不必了。
time.sleep(1)
if context:
flag, cache = _parse_insight(context, cache, logger)
if flag:
logger.warning(f'following articles may not be completely analyzed: \n{context}')
if not cache:
logger.warning('no insights found')
return []
# second stage: 匹配insights和article_titles
title_list = [Document(page_content=key, metadata={}) for key, value in titles.items()]
retriever = FAISS.from_documents(title_list, embed_model,
distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT).as_retriever(search_type="similarity",
search_kwargs={"score_threshold": relation_theshold, "k": 10})
compression = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=retriever)
for key in cache.keys():
logger.debug(f'searching related articles for insight: {key}')
rerank_results = compression.get_relevant_documents(key)
for i in range(len(rerank_results)):
if rerank_results[i].metadata['relevance_score'] < relation_theshold:
break
cache[key].append(titles[rerank_results[i].page_content])
if titles[rerank_results[i].page_content] not in articles:
articles[titles[rerank_results[i].page_content]] = {'title': rerank_results[i].page_content}
logger.info(f'{key} - {cache[key]}')
# third stage对于对应文章重叠率超过25%的合并然后对于有多个文章的再次使用llm生成insight
# 因为实践中发现第一次insight召回的文章标题可能都很相关但是汇总起来却指向另一个角度的insight
def calculate_overlap(list1, list2):
# 计算两个列表的交集长度
intersection_length = len(set(list1).intersection(set(list2)))
# 计算重合率
overlap_rate = intersection_length / min(len(list1), len(list2))
return overlap_rate >= 0.75
merged_dict = {}
for key, value in cache.items():
if not value:
continue
merged = False
for existing_key, existing_value in merged_dict.items():
if calculate_overlap(value, existing_value):
merged_dict[existing_key].extend(value)
merged = True
break
if not merged:
merged_dict[key] = value
cache = {}
for key, value in merged_dict.items():
value = list(set(value))
if len(value) > 1:
context = ''
for _id in value:
context += f"<article>{articles[_id]['title']}</article>\n"
if len(context) >= max_tokens:
break
if not context:
continue
flag, new_insight = _rewrite_insight(context, logger)
if flag:
logger.warning(f'insight {key} may contain wrong')
cache[key] = value
else:
if cache:
title_list = [Document(page_content=key, metadata={}) for key, value in cache.items()]
retriever = FAISS.from_documents(title_list, embed_model,
distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT).as_retriever(
search_type="similarity",
search_kwargs={"score_threshold": 0.85, "k": 1})
compression = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=retriever)
rerank_results = compression.get_relevant_documents(new_insight)
if rerank_results and rerank_results[0].metadata['relevance_score'] > 0.85:
logger.debug(f"{new_insight} is too similar to {rerank_results[0].page_content}, merging")
cache[rerank_results[0].page_content].extend(value)
cache[rerank_results[0].page_content] = list(set(cache[rerank_results[0].page_content]))
else:
cache[new_insight] = value
else:
cache[new_insight] = value
else:
cache[key] = value
# 排序对应articles越多的越靠前
# sorted_cache = sorted(cache.items(), key=lambda x: len(x[1]), reverse=True)
logger.info('re-ranking ressult:')
new_cache = []
for key, value in cache.items():
if not is_chinese(key):
translate_text = text_translate([key], target_language='zh', logger=logger)
if translate_text:
key = translate_text[0]
logger.info(f'{key} - {value}')
new_cache.append({'content': key, 'articles': value})
return new_cache