diff --git a/test/get_info_test.py b/test/get_info_test.py
index a3bec50..7defab8 100644
--- a/test/get_info_test.py
+++ b/test/get_info_test.py
@@ -12,39 +12,41 @@ sys.path.append(project_root)
from core.llms.openai_wrapper import openai_llm as llm
-models = ['Qwen/Qwen2.5-7B-Instruct', 'Qwen/Qwen2.5-14B-Instruct', 'Qwen/Qwen2.5-32B-Instruct', 'deepseek-ai/DeepSeek-V2.5']
-
-async def main(texts: list[str], record_file: str, sys_prompt: str, focus_points: list):
+benchmark_model = 'Qwen/Qwen2.5-72B-Instruct'
+models = ['Qwen/Qwen2.5-7B-Instruct', 'Qwen/Qwen2.5-14B-Instruct', 'Qwen/Qwen2.5-32B-Instruct', 'deepseek-ai/DeepSeek-V2.5', 'internlm/internlm2_5-20b-chat']
+async def main(texts: list[str], link_dict: dict, record_file: str, sys_prompt: str, focus_points: list):
# first get more links
- judge_text = ''.join(texts)
- for model in models:
- _texts = texts.copy()
+ print(f'sys_prompt: \n{sys_prompt}')
+ benchmark_result = None
+ for model in [benchmark_model] + models:
+ _texts = []
+ for text in texts:
+ _texts.extend(text.split('\n\n'))
print(f"running {model} ...")
start_time = time.time()
hallucination_times = 0
text_batch = ''
- cache = []
+ cache = set()
while _texts:
t = _texts.pop(0)
- text_batch = f'{text_batch}{t}# '
- if len(text_batch) > 100 or len(_texts) == 0:
+ text_batch = f'{text_batch}{t}\n\n'
+ if len(text_batch) > 512 or len(_texts) == 0:
content = f'\n{text_batch}\n\n{get_info_suffix}'
result = await llm(
[{'role': 'system', 'content': sys_prompt}, {'role': 'user', 'content': content}],
model=model, temperature=0.1)
- #print(f"llm output\n{result}")
- text_batch = ''
+ print(f"llm output\n{result}\n")
result = re.findall(r'\"\"\"(.*?)\"\"\"', result, re.DOTALL)
- if result: cache.append(result[-1])
-
- infos = []
- for item in cache:
- segs = item.split('//')
- infos.extend([s.strip() for s in segs if s.strip()])
- for content in infos:
- if content not in judge_text:
- print(f'not in raw content:\n{content}')
- hallucination_times += 1
+ if result:
+ # 在result[-1]中找到所有类似[4]这样的片段
+ links = re.findall(r'\[\d+\]', result[-1])
+ for link in links:
+ if link not in text_batch:
+ hallucination_times += 1
+ print(f'\n**not in text_batch: {link}**\n')
+ continue
+ cache.add(link)
+ text_batch = ''
t1 = time.time()
get_infos_time = t1 - start_time
@@ -52,13 +54,26 @@ async def main(texts: list[str], record_file: str, sys_prompt: str, focus_points
print("*" * 12)
print('\n\n')
- infos_to_record = '\n'.join(infos)
+ for link in cache:
+ if link not in link_dict:
+ print(f'\n**not in link_dict: {link}**\n')
+ if model == benchmark_model:
+ benchmark_result = cache.copy()
+ diff = 'benchmark'
+ else:
+ # 计算当前cache与benchmark的差异
+ missing_in_cache = len(benchmark_result - cache) # benchmark中有但cache中没有的
+ extra_in_cache = len(cache - benchmark_result) # cache中有但benchmark中没有的
+ total_diff = missing_in_cache + extra_in_cache
+ diff = f'差异{total_diff}个(遗漏{missing_in_cache}个,多出{extra_in_cache}个)'
+
+ infos_to_record = '\n'.join(list(set(link_dict[link] for link in cache)))
with open(record_file, 'a') as f:
f.write(f"llm model: {model}\n")
f.write(f"process time: {get_infos_time} s\n")
f.write(f"bad generate times: {hallucination_times}\n")
- f.write(f"total segments: {len(infos)}\n")
+ f.write(f"diff from benchmark: {diff}\n")
f.write(f"segments: \n{infos_to_record}\n")
f.write("*" * 12)
f.write('\n\n')
@@ -105,4 +120,4 @@ if __name__ == '__main__':
with open(record_file, 'a') as f:
f.write(f"raw materials in: {dirs}\n\n")
- asyncio.run(main(sample['texts'], record_file, system_prompt, focus_points))
+ asyncio.run(main(sample['links_part'], sample['link_dict'], record_file, system_prompt, focus_points))
diff --git a/test/prompts.py b/test/prompts.py
index 0622f3c..ad5005d 100644
--- a/test/prompts.py
+++ b/test/prompts.py
@@ -1,19 +1,17 @@
-get_info_system = '''你将被给到一段使用标签包裹的网页文本,你的任务是从前到后仔细阅读文本,并摘抄与如下关注点相关的原文片段。关注点及其解释如下:
+get_info_system = '''你将被给到一段使用标签包裹的网页文本,你的任务是从前到后仔细阅读文本,提取出与如下任一关注点相关的原文片段。关注点及其解释如下:
{focus_statement}\n
在进行提取时,请遵循以下原则:
- 理解关注点的含义以及进一步的解释(如有),确保提取的内容与关注点强相关并符合解释(如有)的范围
-- 在满足上面原则的前提下,摘抄出全部相关片段
-- 摘抄出的原文片段务必保持原文原样,包括标点符号都不要更改,尤其注意保留类似"[3]"这样的引用标记'''
+- 在满足上面原则的前提下,提取出全部可能相关的片段
+- 提取出的原文片段务必保留类似"[3]"这样的引用标记,后续的处理需要用到这些引用标记'''
-get_info_suffix = '''请将摘抄出的原文片段用"//"分隔,并整体用三引号包裹后输出。三引号内不要有其他内容,如果文本中不包含任何与关注点相关的内容则保持三引号内为空。
+get_info_suffix = '''请逐条输出提取的原文片段,并整体用三引号包裹。三引号内除了提取出的原文片段外不要有其他内容,如果文本中不包含任何与关注点相关的内容则保持三引号内为空。
如下是输出格式示例::
"""
原文片段1
-//
原文片段2
-//
...
"""'''