wiseflow/client/backend/work_process.py
2024-04-09 11:38:51 +08:00

166 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
import requests
from get_logger import get_logger
from datetime import datetime, timedelta, date
from scrapers import scraper_map
from scrapers.general_scraper import general_scraper
from pb_api import PbTalker
from urllib.parse import urlparse
from get_insight import get_insight
from general_utils import is_chinese
from tranlsation_volcengine import text_translate
import concurrent.futures
# 一般用于第一次爬虫时避免抓取过多太久的文章同时超过这个天数的数据库文章也不会再用来匹配insight
expiration_date = datetime.now() - timedelta(days=90)
expiration_date = expiration_date.date()
expiration_str = expiration_date.strftime("%Y%m%d")
class ServiceProcesser:
def __init__(self, name: str = 'Work_Processor', record_snapshot: bool = False):
self.name = name
self.project_dir = os.environ.get("PROJECT_DIR", "")
# 1. base initialization
self.cache_url = os.path.join(self.project_dir, name)
os.makedirs(self.cache_url, exist_ok=True)
self.logger = get_logger(name=self.name, file=os.path.join(self.project_dir, f'{self.name}.log'))
self.pb = PbTalker(self.logger)
# 2. load the llm
# self.llm = LocalLlmWrapper() # if you use the local-llm
if record_snapshot:
snap_short_server = os.environ.get('SNAPSHOTS_SERVER', '')
if not snap_short_server:
raise Exception('SNAPSHOTS_SERVER is not set.')
self.snap_short_server = f"http://{snap_short_server}"
else:
self.snap_short_server = None
self.logger.info(f'{self.name} init success.')
def __call__(self, expiration: date = expiration_date, sites: list[str] = None):
# 先清空一下cache
self.logger.info(f'wake, prepare to work, now is {datetime.now()}')
cache = {}
self.logger.debug(f'clear cache -- {cache}')
# 从pb数据库中读取所有文章url
# 这里publish_time用int格式综合考虑下这个是最容易操作的模式虽然糙了点
existing_articles = self.pb.read(collection_name='articles', fields=['id', 'title', 'url'], filter=f'publish_time>{expiration_str}')
all_title = {}
existings = []
for article in existing_articles:
all_title[article['title']] = article['id']
existings.append(article['url'])
# 定义扫描源列表如果不指定就默认遍历scraper_map, 另外这里还要考虑指定的source不在scraper_map的情况这时应该使用通用爬虫
sources = sites if sites else list(scraper_map.keys())
new_articles = []
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
futures = []
for site in sources:
if site in scraper_map:
futures.append(executor.submit(scraper_map[site], expiration, existings))
else:
futures.append(executor.submit(general_scraper, site, expiration, existings))
concurrent.futures.wait(futures)
for future in futures:
try:
new_articles.extend(future.result())
except Exception as e:
self.logger.error(f'error when scraping-- {e}')
for value in new_articles:
if not value:
continue
from_site = urlparse(value['url']).netloc
from_site = from_site.replace('www.', '')
from_site = from_site.split('.')[0]
if value['abstract']:
value['abstract'] = f"({from_site} 报道){value['abstract']}"
value['content'] = f"({from_site} 报道){value['content']}"
value['images'] = json.dumps(value['images'])
article_id = self.pb.add(collection_name='articles', body=value)
if article_id:
cache[article_id] = value
all_title[value['title']] = article_id
else:
self.logger.warning(f'add article {value} failed, writing to cache_file')
with open(os.path.join(self.cache_url, 'cache_articles.json'), 'a', encoding='utf-8') as f:
json.dump(value, f, ensure_ascii=False, indent=4)
if not cache:
self.logger.warning(f'no new articles. now is {datetime.now()}')
return
# insight 流程
new_insights = get_insight(cache, all_title, logger=self.logger)
if new_insights:
for insight in new_insights:
if not insight['content']:
continue
insight_id = self.pb.add(collection_name='insights', body=insight)
if not insight_id:
self.logger.warning(f'write insight {insight} to pb failed, writing to cache_file')
with open(os.path.join(self.cache_url, 'cache_insights.json'), 'a', encoding='utf-8') as f:
json.dump(insight, f, ensure_ascii=False, indent=4)
for article_id in insight['articles']:
raw_article = self.pb.read(collection_name='articles', fields=['abstract', 'title', 'translation_result'], filter=f'id="{article_id}"')
if not raw_article or not raw_article[0]:
self.logger.warning(f'get article {article_id} failed, skipping')
continue
if raw_article[0]['translation_result']:
continue
if is_chinese(raw_article[0]['title']):
continue
translate_text = text_translate([raw_article[0]['title'], raw_article[0]['abstract']], target_language='zh', logger=self.logger)
if translate_text:
related_id = self.pb.add(collection_name='article_translation', body={'title': translate_text[0], 'abstract': translate_text[1], 'raw': article_id})
if not related_id:
self.logger.warning(f'write article_translation {article_id} failed')
else:
_ = self.pb.update(collection_name='articles', id=article_id, body={'translation_result': related_id})
if not _:
self.logger.warning(f'update article {article_id} failed')
else:
self.logger.warning(f'translate article {article_id} failed')
else:
# 尝试把所有文章的title作为insigts这是备选方案
if len(cache) < 25:
self.logger.info('generate_insights-warning: no insights and no more than 25 articles so use article title as insights')
for key, value in cache.items():
if value['title']:
if is_chinese(value['title']):
text_for_insight = value['title']
else:
text_for_insight = text_translate([value['title']], logger=self.logger)
if text_for_insight:
insight_id = self.pb.add(collection_name='insights',
body={'content': text_for_insight[0], 'articles': [key]})
if not insight_id:
self.logger.warning(f'write insight {text_for_insight[0]} to pb failed, writing to cache_file')
with open(os.path.join(self.cache_url, 'cache_insights.json'), 'a',
encoding='utf-8') as f:
json.dump({'content': text_for_insight[0], 'articles': [key]}, f, ensure_ascii=False, indent=4)
else:
self.logger.warning('generate_insights-error: can not generate insights, pls re-try')
self.logger.info(f'work done, now is {datetime.now()}')
if self.snap_short_server:
self.logger.info(f'now starting article snapshot with {self.snap_short_server}')
for key, value in cache.items():
if value['url']:
try:
snapshot = requests.get(f"{self.snap_short_server}/zip", {'url': value['url']}, timeout=60)
file = open(snapshot.text, 'rb')
_ = self.pb.upload('articles', key, 'snapshot', key, file)
file.close()
except Exception as e:
self.logger.warning(f'error when snapshot {value["url"]}, {e}')
self.logger.info(f'now snapshot done, now is {datetime.now()}')