wiseflow/test/get_info_test.py

138 lines
6.4 KiB
Python
Raw Normal View History

2024-12-18 22:45:20 +08:00
# -*- coding: utf-8 -*-
2025-01-17 23:28:22 +08:00
import os, sys
2024-12-18 22:45:20 +08:00
import json
import asyncio
2024-12-23 10:12:52 +08:00
import time
from datetime import datetime
2025-01-17 23:28:22 +08:00
# 将core目录添加到Python路径
core_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'core')
sys.path.append(core_path)
2025-01-02 22:05:51 +08:00
2025-01-17 23:28:22 +08:00
# 现在可以直接导入模块因为core目录已经在Python路径中
from scrapers import *
from agents.get_info import pre_process
from utils.general_utils import is_chinese
from agents.get_info import get_author_and_publish_date, get_info, get_more_related_urls
from agents.get_info_prompts import *
2025-01-02 22:05:51 +08:00
2025-01-16 23:31:04 +08:00
benchmark_model = 'Qwen/Qwen2.5-72B-Instruct'
2025-01-17 23:28:22 +08:00
models = ['Qwen/Qwen2.5-7B-Instruct', 'Qwen/Qwen2.5-14B-Instruct', 'Qwen/Qwen2.5-32B-Instruct', 'deepseek-ai/DeepSeek-V2.5']
async def main(sample: dict, include_ap: bool, prompts: list, focus_dict: dict, record_file: str):
link_dict, links_parts, contents = sample['link_dict'], sample['links_part'], sample['contents']
get_link_sys_prompt, get_link_suffix_prompt, get_info_sys_prompt, get_info_suffix_prompt = prompts
2025-01-16 23:31:04 +08:00
for model in [benchmark_model] + models:
2025-01-17 23:28:22 +08:00
links_texts = []
for _parts in links_parts:
links_texts.extend(_parts.split('\n\n'))
contents = sample['contents'].copy()
2025-01-02 22:05:51 +08:00
print(f"running {model} ...")
start_time = time.time()
2025-01-17 23:28:22 +08:00
if include_ap:
author, publish_date = await get_author_and_publish_date(contents[0], model, test_mode=True)
get_ap_time = time.time() - start_time
print(f"get author and publish date time: {get_ap_time}")
else:
author, publish_date = '', ''
get_ap_time = 0
start_time = time.time()
more_url = await get_more_related_urls(links_texts, link_dict, [get_link_sys_prompt, get_link_suffix_prompt, model], test_mode=True)
get_more_url_time = time.time() - start_time
print(f"get more related urls time: {get_more_url_time}")
start_time = time.time()
infos = await get_info(contents, link_dict, [get_info_sys_prompt, get_info_suffix_prompt, model], focus_dict, author, publish_date, test_mode=True)
get_info_time = time.time() - start_time
print(f"get info time: {get_info_time}")
2024-12-18 22:45:20 +08:00
2025-01-16 23:31:04 +08:00
if model == benchmark_model:
2025-01-17 23:28:22 +08:00
benchmark_result = more_url.copy()
diff = f'benchmark: {len(benchmark_result)} results'
2025-01-16 23:31:04 +08:00
else:
2025-01-17 23:28:22 +08:00
missing_in_cache = len(benchmark_result - more_url) # benchmark中有但cache中没有的
extra_in_cache = len(more_url - benchmark_result) # cache中有但benchmark中没有的
2025-01-16 23:31:04 +08:00
total_diff = missing_in_cache + extra_in_cache
diff = f'差异{total_diff}个(遗漏{missing_in_cache}个,多出{extra_in_cache}个)'
2025-01-17 23:28:22 +08:00
related_urls_to_record = '\n'.join(more_url)
infos_to_record = [f"{fi['tag']}: {fi['content']}" for fi in infos]
infos_to_record = '\n'.join(infos_to_record)
2024-12-18 22:45:20 +08:00
with open(record_file, 'a') as f:
2025-01-17 23:28:22 +08:00
f.write(f"model: {model}\n")
if include_ap:
f.write(f"get author and publish date time: {get_ap_time}\n")
f.write(f"author: {author}\n")
f.write(f"publish date: {publish_date}\n")
f.write(f"get more related urls time: {get_more_url_time}\n")
2025-01-16 23:31:04 +08:00
f.write(f"diff from benchmark: {diff}\n")
2025-01-17 23:28:22 +08:00
f.write(f"get info time: {get_info_time}\n")
f.write(f"related urls: \n{related_urls_to_record}\n")
f.write(f"final result: \n{infos_to_record}\n")
2024-12-18 22:45:20 +08:00
f.write('\n\n')
2025-01-17 23:28:22 +08:00
print('\n\n')
2024-12-18 22:45:20 +08:00
if __name__ == '__main__':
2025-01-02 22:05:51 +08:00
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--sample_dir', '-D', type=str, default='')
2025-01-17 23:28:22 +08:00
parser.add_argument('--include_ap', '-I', type=bool, default=False)
2025-01-02 22:05:51 +08:00
args = parser.parse_args()
sample_dir = args.sample_dir
2025-01-17 23:28:22 +08:00
include_ap = args.include_ap
2025-01-02 22:05:51 +08:00
if not os.path.exists(os.path.join(sample_dir, 'focus_point.json')):
raise ValueError(f'{sample_dir} focus_point.json not found')
focus_points = json.load(open(os.path.join(sample_dir, 'focus_point.json'), 'r'))
focus_statement = ''
for item in focus_points:
tag = item["focuspoint"]
expl = item["explanation"]
2025-01-04 13:57:12 +08:00
focus_statement = f"{focus_statement}//{tag}//\n"
2025-01-02 22:05:51 +08:00
if expl:
2025-01-17 23:28:22 +08:00
if is_chinese(expl):
focus_statement = f"{focus_statement}解释:{expl}\n"
else:
focus_statement = f"{focus_statement}Explanation: {expl}\n"
focus_dict = {item["focuspoint"]: item["focuspoint"] for item in focus_points}
date_stamp = datetime.now().strftime('%Y-%m-%d')
if is_chinese(focus_statement):
get_link_sys_prompt = get_link_system.replace('{focus_statement}', focus_statement)
get_link_sys_prompt = f"今天的日期是{date_stamp}{get_link_sys_prompt}"
get_link_suffix_prompt = get_link_suffix
get_info_sys_prompt = get_info_system.replace('{focus_statement}', focus_statement)
get_info_sys_prompt = f"今天的日期是{date_stamp}{get_info_sys_prompt}"
get_info_suffix_prompt = get_info_suffix
else:
get_link_sys_prompt = get_link_system_en.replace('{focus_statement}', focus_statement)
get_link_sys_prompt = f"today is {date_stamp}, {get_link_sys_prompt}"
get_link_suffix_prompt = get_link_suffix_en
get_info_sys_prompt = get_info_system_en.replace('{focus_statement}', focus_statement)
get_info_sys_prompt = f"today is {date_stamp}, {get_info_sys_prompt}"
get_info_suffix_prompt = get_info_suffix_en
2025-01-02 22:05:51 +08:00
2025-01-17 23:28:22 +08:00
prompts = [get_link_sys_prompt, get_link_suffix_prompt, get_info_sys_prompt, get_info_suffix_prompt]
2025-01-02 22:05:51 +08:00
2025-01-04 13:57:12 +08:00
time_stamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
record_file = os.path.join(sample_dir, f'record-{time_stamp}.txt')
with open(record_file, 'w') as f:
f.write(f"focus statement: \n{focus_statement}\n\n")
2025-01-17 23:28:22 +08:00
for file in os.listdir(sample_dir):
if not file.endswith('_processed.json'):
2024-12-18 22:45:20 +08:00
continue
2025-01-17 23:28:22 +08:00
sample = json.load(open(os.path.join(sample_dir, file), 'r'))
if 'links_part' not in sample or 'link_dict' not in sample or 'contents' not in sample:
print(f'{file} not valid sample, skip')
continue
2025-01-04 13:57:12 +08:00
with open(record_file, 'a') as f:
2025-01-17 23:28:22 +08:00
f.write(f"raw materials: {file}\n\n")
print(f'start testing {file}')
asyncio.run(main(sample, include_ap, prompts, focus_dict, record_file))