mirror of
https://github.com/TeamWiseFlow/wiseflow.git
synced 2025-01-23 02:20:20 +08:00
test for V0.3.7
This commit is contained in:
parent
77c3914d12
commit
43ae4dfb86
@ -12,39 +12,41 @@ sys.path.append(project_root)
|
||||
|
||||
from core.llms.openai_wrapper import openai_llm as llm
|
||||
|
||||
models = ['Qwen/Qwen2.5-7B-Instruct', 'Qwen/Qwen2.5-14B-Instruct', 'Qwen/Qwen2.5-32B-Instruct', 'deepseek-ai/DeepSeek-V2.5']
|
||||
|
||||
async def main(texts: list[str], record_file: str, sys_prompt: str, focus_points: list):
|
||||
benchmark_model = 'Qwen/Qwen2.5-72B-Instruct'
|
||||
models = ['Qwen/Qwen2.5-7B-Instruct', 'Qwen/Qwen2.5-14B-Instruct', 'Qwen/Qwen2.5-32B-Instruct', 'deepseek-ai/DeepSeek-V2.5', 'internlm/internlm2_5-20b-chat']
|
||||
async def main(texts: list[str], link_dict: dict, record_file: str, sys_prompt: str, focus_points: list):
|
||||
# first get more links
|
||||
judge_text = ''.join(texts)
|
||||
for model in models:
|
||||
_texts = texts.copy()
|
||||
print(f'sys_prompt: \n{sys_prompt}')
|
||||
benchmark_result = None
|
||||
for model in [benchmark_model] + models:
|
||||
_texts = []
|
||||
for text in texts:
|
||||
_texts.extend(text.split('\n\n'))
|
||||
print(f"running {model} ...")
|
||||
start_time = time.time()
|
||||
hallucination_times = 0
|
||||
text_batch = ''
|
||||
cache = []
|
||||
cache = set()
|
||||
while _texts:
|
||||
t = _texts.pop(0)
|
||||
text_batch = f'{text_batch}{t}# '
|
||||
if len(text_batch) > 100 or len(_texts) == 0:
|
||||
text_batch = f'{text_batch}{t}\n\n'
|
||||
if len(text_batch) > 512 or len(_texts) == 0:
|
||||
content = f'<text>\n{text_batch}</text>\n\n{get_info_suffix}'
|
||||
result = await llm(
|
||||
[{'role': 'system', 'content': sys_prompt}, {'role': 'user', 'content': content}],
|
||||
model=model, temperature=0.1)
|
||||
#print(f"llm output\n{result}")
|
||||
text_batch = ''
|
||||
print(f"llm output\n{result}\n")
|
||||
result = re.findall(r'\"\"\"(.*?)\"\"\"', result, re.DOTALL)
|
||||
if result: cache.append(result[-1])
|
||||
|
||||
infos = []
|
||||
for item in cache:
|
||||
segs = item.split('//')
|
||||
infos.extend([s.strip() for s in segs if s.strip()])
|
||||
for content in infos:
|
||||
if content not in judge_text:
|
||||
print(f'not in raw content:\n{content}')
|
||||
hallucination_times += 1
|
||||
if result:
|
||||
# 在result[-1]中找到所有类似[4]这样的片段
|
||||
links = re.findall(r'\[\d+\]', result[-1])
|
||||
for link in links:
|
||||
if link not in text_batch:
|
||||
hallucination_times += 1
|
||||
print(f'\n**not in text_batch: {link}**\n')
|
||||
continue
|
||||
cache.add(link)
|
||||
text_batch = ''
|
||||
|
||||
t1 = time.time()
|
||||
get_infos_time = t1 - start_time
|
||||
@ -52,13 +54,26 @@ async def main(texts: list[str], record_file: str, sys_prompt: str, focus_points
|
||||
print("*" * 12)
|
||||
print('\n\n')
|
||||
|
||||
infos_to_record = '\n'.join(infos)
|
||||
for link in cache:
|
||||
if link not in link_dict:
|
||||
print(f'\n**not in link_dict: {link}**\n')
|
||||
if model == benchmark_model:
|
||||
benchmark_result = cache.copy()
|
||||
diff = 'benchmark'
|
||||
else:
|
||||
# 计算当前cache与benchmark的差异
|
||||
missing_in_cache = len(benchmark_result - cache) # benchmark中有但cache中没有的
|
||||
extra_in_cache = len(cache - benchmark_result) # cache中有但benchmark中没有的
|
||||
total_diff = missing_in_cache + extra_in_cache
|
||||
diff = f'差异{total_diff}个(遗漏{missing_in_cache}个,多出{extra_in_cache}个)'
|
||||
|
||||
infos_to_record = '\n'.join(list(set(link_dict[link] for link in cache)))
|
||||
|
||||
with open(record_file, 'a') as f:
|
||||
f.write(f"llm model: {model}\n")
|
||||
f.write(f"process time: {get_infos_time} s\n")
|
||||
f.write(f"bad generate times: {hallucination_times}\n")
|
||||
f.write(f"total segments: {len(infos)}\n")
|
||||
f.write(f"diff from benchmark: {diff}\n")
|
||||
f.write(f"segments: \n{infos_to_record}\n")
|
||||
f.write("*" * 12)
|
||||
f.write('\n\n')
|
||||
@ -105,4 +120,4 @@ if __name__ == '__main__':
|
||||
|
||||
with open(record_file, 'a') as f:
|
||||
f.write(f"raw materials in: {dirs}\n\n")
|
||||
asyncio.run(main(sample['texts'], record_file, system_prompt, focus_points))
|
||||
asyncio.run(main(sample['links_part'], sample['link_dict'], record_file, system_prompt, focus_points))
|
||||
|
@ -1,19 +1,17 @@
|
||||
|
||||
get_info_system = '''你将被给到一段使用<text></text>标签包裹的网页文本,你的任务是从前到后仔细阅读文本,并摘抄与如下关注点相关的原文片段。关注点及其解释如下:
|
||||
get_info_system = '''你将被给到一段使用<text></text>标签包裹的网页文本,你的任务是从前到后仔细阅读文本,提取出与如下任一关注点相关的原文片段。关注点及其解释如下:
|
||||
|
||||
{focus_statement}\n
|
||||
在进行提取时,请遵循以下原则:
|
||||
- 理解关注点的含义以及进一步的解释(如有),确保提取的内容与关注点强相关并符合解释(如有)的范围
|
||||
- 在满足上面原则的前提下,摘抄出全部相关片段
|
||||
- 摘抄出的原文片段务必保持原文原样,包括标点符号都不要更改,尤其注意保留类似"[3]"这样的引用标记'''
|
||||
- 在满足上面原则的前提下,提取出全部可能相关的片段
|
||||
- 提取出的原文片段务必保留类似"[3]"这样的引用标记,后续的处理需要用到这些引用标记'''
|
||||
|
||||
get_info_suffix = '''请将摘抄出的原文片段用"//"分隔,并整体用三引号包裹后输出。三引号内不要有其他内容,如果文本中不包含任何与关注点相关的内容则保持三引号内为空。
|
||||
get_info_suffix = '''请逐条输出提取的原文片段,并整体用三引号包裹。三引号内除了提取出的原文片段外不要有其他内容,如果文本中不包含任何与关注点相关的内容则保持三引号内为空。
|
||||
如下是输出格式示例::
|
||||
"""
|
||||
原文片段1
|
||||
//
|
||||
原文片段2
|
||||
//
|
||||
...
|
||||
"""'''
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user