2024-12-18 22:45:20 +08:00
|
|
|
|
# -*- coding: utf-8 -*-
|
2025-01-02 22:05:51 +08:00
|
|
|
|
import os, re, sys
|
2024-12-18 22:45:20 +08:00
|
|
|
|
import json
|
|
|
|
|
import asyncio
|
2024-12-23 10:12:52 +08:00
|
|
|
|
import time
|
|
|
|
|
from prompts import *
|
2025-01-16 10:56:57 +08:00
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
2025-01-02 22:05:51 +08:00
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
|
project_root = os.path.dirname(current_dir) # get parent dir
|
|
|
|
|
sys.path.append(project_root)
|
|
|
|
|
|
|
|
|
|
from core.llms.openai_wrapper import openai_llm as llm
|
|
|
|
|
|
2025-01-16 23:31:04 +08:00
|
|
|
|
benchmark_model = 'Qwen/Qwen2.5-72B-Instruct'
|
|
|
|
|
models = ['Qwen/Qwen2.5-7B-Instruct', 'Qwen/Qwen2.5-14B-Instruct', 'Qwen/Qwen2.5-32B-Instruct', 'deepseek-ai/DeepSeek-V2.5', 'internlm/internlm2_5-20b-chat']
|
|
|
|
|
async def main(texts: list[str], link_dict: dict, record_file: str, sys_prompt: str, focus_points: list):
|
2025-01-02 22:05:51 +08:00
|
|
|
|
# first get more links
|
2025-01-16 23:31:04 +08:00
|
|
|
|
print(f'sys_prompt: \n{sys_prompt}')
|
|
|
|
|
benchmark_result = None
|
|
|
|
|
for model in [benchmark_model] + models:
|
|
|
|
|
_texts = []
|
|
|
|
|
for text in texts:
|
|
|
|
|
_texts.extend(text.split('\n\n'))
|
2025-01-02 22:05:51 +08:00
|
|
|
|
print(f"running {model} ...")
|
|
|
|
|
start_time = time.time()
|
2025-01-16 10:56:57 +08:00
|
|
|
|
hallucination_times = 0
|
2025-01-02 22:05:51 +08:00
|
|
|
|
text_batch = ''
|
2025-01-16 23:31:04 +08:00
|
|
|
|
cache = set()
|
2025-01-16 10:56:57 +08:00
|
|
|
|
while _texts:
|
|
|
|
|
t = _texts.pop(0)
|
2025-01-16 23:31:04 +08:00
|
|
|
|
text_batch = f'{text_batch}{t}\n\n'
|
|
|
|
|
if len(text_batch) > 512 or len(_texts) == 0:
|
2025-01-16 10:56:57 +08:00
|
|
|
|
content = f'<text>\n{text_batch}</text>\n\n{get_info_suffix}'
|
2025-01-02 22:05:51 +08:00
|
|
|
|
result = await llm(
|
2025-01-16 10:56:57 +08:00
|
|
|
|
[{'role': 'system', 'content': sys_prompt}, {'role': 'user', 'content': content}],
|
2025-01-02 22:05:51 +08:00
|
|
|
|
model=model, temperature=0.1)
|
2025-01-16 23:31:04 +08:00
|
|
|
|
print(f"llm output\n{result}\n")
|
2025-01-04 13:57:12 +08:00
|
|
|
|
result = re.findall(r'\"\"\"(.*?)\"\"\"', result, re.DOTALL)
|
2025-01-16 23:31:04 +08:00
|
|
|
|
if result:
|
|
|
|
|
# 在result[-1]中找到所有类似[4]这样的片段
|
|
|
|
|
links = re.findall(r'\[\d+\]', result[-1])
|
|
|
|
|
for link in links:
|
|
|
|
|
if link not in text_batch:
|
|
|
|
|
hallucination_times += 1
|
|
|
|
|
print(f'\n**not in text_batch: {link}**\n')
|
|
|
|
|
continue
|
|
|
|
|
cache.add(link)
|
|
|
|
|
text_batch = ''
|
2024-12-18 22:45:20 +08:00
|
|
|
|
|
2025-01-16 10:56:57 +08:00
|
|
|
|
t1 = time.time()
|
|
|
|
|
get_infos_time = t1 - start_time
|
|
|
|
|
print(f"get more infos time: {get_infos_time}")
|
2024-12-18 22:45:20 +08:00
|
|
|
|
print("*" * 12)
|
|
|
|
|
print('\n\n')
|
|
|
|
|
|
2025-01-16 23:31:04 +08:00
|
|
|
|
for link in cache:
|
|
|
|
|
if link not in link_dict:
|
|
|
|
|
print(f'\n**not in link_dict: {link}**\n')
|
|
|
|
|
if model == benchmark_model:
|
|
|
|
|
benchmark_result = cache.copy()
|
|
|
|
|
diff = 'benchmark'
|
|
|
|
|
else:
|
|
|
|
|
# 计算当前cache与benchmark的差异
|
|
|
|
|
missing_in_cache = len(benchmark_result - cache) # benchmark中有但cache中没有的
|
|
|
|
|
extra_in_cache = len(cache - benchmark_result) # cache中有但benchmark中没有的
|
|
|
|
|
total_diff = missing_in_cache + extra_in_cache
|
|
|
|
|
diff = f'差异{total_diff}个(遗漏{missing_in_cache}个,多出{extra_in_cache}个)'
|
|
|
|
|
|
|
|
|
|
infos_to_record = '\n'.join(list(set(link_dict[link] for link in cache)))
|
2025-01-02 22:05:51 +08:00
|
|
|
|
|
2024-12-18 22:45:20 +08:00
|
|
|
|
with open(record_file, 'a') as f:
|
|
|
|
|
f.write(f"llm model: {model}\n")
|
2025-01-16 10:56:57 +08:00
|
|
|
|
f.write(f"process time: {get_infos_time} s\n")
|
|
|
|
|
f.write(f"bad generate times: {hallucination_times}\n")
|
2025-01-16 23:31:04 +08:00
|
|
|
|
f.write(f"diff from benchmark: {diff}\n")
|
2025-01-16 10:56:57 +08:00
|
|
|
|
f.write(f"segments: \n{infos_to_record}\n")
|
2024-12-18 22:45:20 +08:00
|
|
|
|
f.write("*" * 12)
|
|
|
|
|
f.write('\n\n')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
2025-01-02 22:05:51 +08:00
|
|
|
|
import argparse
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
parser.add_argument('--sample_dir', '-D', type=str, default='')
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
sample_dir = args.sample_dir
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(os.path.join(sample_dir, 'focus_point.json')):
|
|
|
|
|
raise ValueError(f'{sample_dir} focus_point.json not found')
|
|
|
|
|
|
|
|
|
|
focus_points = json.load(open(os.path.join(sample_dir, 'focus_point.json'), 'r'))
|
|
|
|
|
focus_statement = ''
|
|
|
|
|
for item in focus_points:
|
|
|
|
|
tag = item["focuspoint"]
|
|
|
|
|
expl = item["explanation"]
|
2025-01-04 13:57:12 +08:00
|
|
|
|
focus_statement = f"{focus_statement}//{tag}//\n"
|
2025-01-02 22:05:51 +08:00
|
|
|
|
if expl:
|
|
|
|
|
focus_statement = f"{focus_statement}解释:{expl}\n"
|
|
|
|
|
|
2025-01-16 10:56:57 +08:00
|
|
|
|
get_info_system = get_info_system.replace('{focus_statement}', focus_statement)
|
|
|
|
|
system_prompt = f"今天的日期是{datetime.now().strftime('%Y-%m-%d')},{get_info_system}"
|
2025-01-02 22:05:51 +08:00
|
|
|
|
focus_points = [item["focuspoint"] for item in focus_points]
|
|
|
|
|
|
2025-01-04 13:57:12 +08:00
|
|
|
|
time_stamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
|
|
|
|
|
record_file = os.path.join(sample_dir, f'record-{time_stamp}.txt')
|
|
|
|
|
with open(record_file, 'w') as f:
|
|
|
|
|
f.write(f"focus statement: \n{focus_statement}\n\n")
|
|
|
|
|
|
2025-01-02 22:05:51 +08:00
|
|
|
|
for dirs in os.listdir(sample_dir):
|
|
|
|
|
if not os.path.isdir(os.path.join(sample_dir, dirs)):
|
2024-12-18 22:45:20 +08:00
|
|
|
|
continue
|
2025-01-02 22:05:51 +08:00
|
|
|
|
_path = os.path.join(sample_dir, dirs)
|
|
|
|
|
print(f'start testing {_path}')
|
2025-01-16 10:56:57 +08:00
|
|
|
|
if 'sample.json' not in os.listdir(_path):
|
|
|
|
|
print(f'{dirs} sample.json not found, skip')
|
|
|
|
|
continue
|
|
|
|
|
sample = json.load(open(os.path.join(_path, 'sample.json'), 'r'))
|
|
|
|
|
|
2025-01-04 13:57:12 +08:00
|
|
|
|
with open(record_file, 'a') as f:
|
2025-01-02 22:05:51 +08:00
|
|
|
|
f.write(f"raw materials in: {dirs}\n\n")
|
2025-01-16 23:31:04 +08:00
|
|
|
|
asyncio.run(main(sample['links_part'], sample['link_dict'], record_file, system_prompt, focus_points))
|