2026-03-22 23:22:30 +08:00

370 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from datetime import date, datetime
import random
from typing import Any, Dict, List, Type
from tortoise.expressions import Q
from app.models.keyword import BossKeyword, QcwyKeyword, ZhilianKeyword
class KeywordController:
def __init__(self) -> None:
self._model_map: Dict[str, Type] = {
"boss": BossKeyword,
"qcwy": QcwyKeyword,
"zhilian": ZhilianKeyword,
}
async def get_available(
self, source: str, limit: int = 1, reserve: bool = True, crawler_id: str = ""
) -> Dict[str, Any]:
"""获取可用关键词,优先返回断点续爬和失败重试的关键词
优先级:
1. partial断点续爬
2. failed 且 retry_count < 3失败重试
3. 全新未使用关键词
"""
model = self._ensure_model(source)
today = date.today()
now = datetime.now()
# 优先级 1: 断点续爬 (partial)
partial_q = Q(crawl_status="partial", last_requested_date=today)
# 优先级 2: 失败重试 (failed, retry < 3)
failed_q = Q(crawl_status="failed", last_requested_date=today, retry_count__lt=3)
# 优先级 3: 全新关键词
fresh_q = Q(last_requested_date__not=today) | Q(last_requested_date=None)
items = []
for priority, query, is_fresh in [
("partial", partial_q, False),
("failed", failed_q, False),
("fresh", fresh_q, True),
]:
count = await model.filter(query).count()
if count == 0:
continue
take = max(1, min(limit - len(items), count))
if take <= 0:
break
try:
offset = random.randint(0, max(0, count - take))
candidates = await model.filter(query).offset(offset).limit(take).only("id")
candidate_ids = [r.id for r in candidates]
if not candidate_ids:
continue
update_fields = {
"last_requested_at": now,
"crawl_status": "crawling",
"crawl_started_at": now,
"crawler_id": crawler_id,
}
if is_fresh:
update_fields["last_requested_date"] = today
update_fields["last_completed_page"] = 0
update_fields["total_pages"] = 0
update_fields["jobs_found"] = 0
update_fields["error_message"] = ""
update_fields["retry_count"] = 0
if reserve:
await model.filter(id__in=candidate_ids).update(**update_fields)
records = await model.filter(id__in=candidate_ids).limit(take)
for r in records:
items.append({
"id": r.id,
"city": r.city,
"job": r.job,
"last_completed_page": r.last_completed_page,
"crawl_status": r.crawl_status,
})
if len(items) >= limit:
break
except Exception:
continue
total_available = await model.filter(
partial_q | failed_q | fresh_q
).count()
return {
"code": 200,
"message": "查询可用检索条件成功",
"data": {
"items": items,
"total": total_available,
"limit": limit,
},
}
async def report_page_progress(
self,
source: str,
keyword_id: int,
page: int,
total_pages: int = 0,
jobs_found: int = 0,
) -> Dict[str, Any]:
"""爬虫汇报单页完成进度"""
model = self._ensure_model(source)
obj = await model.filter(id=keyword_id).first()
if not obj:
return {"code": 404, "message": "关键词不存在"}
obj.last_completed_page = page
if total_pages > 0:
obj.total_pages = total_pages
obj.jobs_found = (obj.jobs_found or 0) + jobs_found
await obj.save(update_fields=["last_completed_page", "total_pages", "jobs_found"])
return {
"code": 200,
"message": "进度更新成功",
"data": {
"keyword_id": keyword_id,
"last_completed_page": obj.last_completed_page,
"total_pages": obj.total_pages,
"jobs_found": obj.jobs_found,
},
}
async def report_crawl_complete(
self,
source: str,
keyword_id: int,
status: str,
error_message: str = "",
) -> Dict[str, Any]:
"""爬虫汇报爬取完成或失败"""
model = self._ensure_model(source)
obj = await model.filter(id=keyword_id).first()
if not obj:
return {"code": 404, "message": "关键词不存在"}
if status not in ("completed", "failed"):
return {"code": 400, "message": "status 仅支持 completed/failed"}
obj.crawl_status = status
obj.error_message = error_message
update_fields = ["crawl_status", "error_message"]
if status == "failed":
obj.retry_count = (obj.retry_count or 0) + 1
update_fields.append("retry_count")
await obj.save(update_fields=update_fields)
return {
"code": 200,
"message": f"爬取状态已更新为 {status}",
"data": {
"keyword_id": keyword_id,
"crawl_status": obj.crawl_status,
"retry_count": obj.retry_count,
},
}
async def get_stats(self, source: str, on_date: date | None = None) -> Dict[str, Any]:
"""统计指定平台关键词使用和爬取状态"""
model = self._ensure_model(source)
d = on_date or date.today()
total = await model.all().count()
used = await model.filter(last_requested_date=d).count()
unused = max(0, total - used)
crawling = await model.filter(crawl_status="crawling", last_requested_date=d).count()
completed = await model.filter(crawl_status="completed", last_requested_date=d).count()
failed = await model.filter(crawl_status="failed", last_requested_date=d).count()
partial = await model.filter(crawl_status="partial", last_requested_date=d).count()
return {
"code": 200,
"message": "统计成功",
"data": {
"date": str(d),
"total": total,
"used": used,
"unused": unused,
"crawl_status": {
"crawling": crawling,
"completed": completed,
"failed": failed,
"partial": partial,
},
},
}
async def mark_used(self, source: str, ids: List[int]) -> Dict[str, Any]:
"""将检索条件标记为今日已使用"""
model = self._ensure_model(source)
updated = 0
now = datetime.now()
today = date.today()
for rid in ids:
obj = await model.filter(id=rid).first()
if obj is None:
continue
if obj.last_requested_date == today:
continue
obj.last_requested_date = today
obj.last_requested_at = now
await obj.save()
updated += 1
return {
"code": 200,
"message": "状态更新完成",
"data": {
"updated": updated,
"ids": ids,
"date": str(today),
},
}
async def list_keywords(
self,
source: str,
page: int = 1,
page_size: int = 20,
city: str | None = None,
job: str | None = None,
) -> Dict[str, Any]:
"""获取关键词列表"""
model = self._ensure_model(source)
queryset = model.all()
if city:
queryset = queryset.filter(city__icontains=city)
if job:
queryset = queryset.filter(job__icontains=job)
total = await queryset.count()
queryset = queryset.order_by("-id").offset((page - 1) * page_size).limit(page_size)
items = await queryset.values(
"id",
"city",
"job",
"last_requested_date",
"last_requested_at",
"crawl_status",
"last_completed_page",
"total_pages",
"jobs_found",
"retry_count",
"created_at",
"updated_at",
)
return {
"code": 200,
"message": "获取成功",
"data": items,
"total": total,
"page": page,
"page_size": page_size,
}
async def create_keyword(self, source: str, obj_in: Any) -> Dict[str, Any]:
"""创建关键词"""
model = self._ensure_model(source)
exists = await model.filter(city=obj_in.city, job=obj_in.job).exists()
if exists:
return {"code": 400, "message": "该关键词组合已存在"}
obj = await model.create(**obj_in.model_dump())
data = {
"id": obj.id,
"city": obj.city,
"job": obj.job,
"last_requested_date": obj.last_requested_date,
"last_requested_at": obj.last_requested_at,
"created_at": obj.created_at,
"updated_at": obj.updated_at,
}
return {"code": 200, "message": "创建成功", "data": data}
async def update_keyword(self, source: str, id: int, obj_in: Any) -> Dict[str, Any]:
"""更新关键词"""
model = self._ensure_model(source)
obj = await model.filter(id=id).first()
if not obj:
return {"code": 404, "message": "记录不存在"}
update_data = obj_in.model_dump(exclude_unset=True)
if update_data:
if "city" in update_data or "job" in update_data:
city = update_data.get("city", obj.city)
job = update_data.get("job", obj.job)
exists = await model.filter(city=city, job=job).exclude(id=id).exists()
if exists:
return {"code": 400, "message": "该关键词组合已存在"}
await obj.update_from_dict(update_data)
await obj.save()
data = {
"id": obj.id,
"city": obj.city,
"job": obj.job,
"last_requested_date": obj.last_requested_date,
"last_requested_at": obj.last_requested_at,
"created_at": obj.created_at,
"updated_at": obj.updated_at,
}
return {"code": 200, "message": "更新成功", "data": data}
async def delete_keyword(self, source: str, id: int) -> Dict[str, Any]:
"""删除关键词"""
model = self._ensure_model(source)
obj = await model.filter(id=id).first()
if not obj:
return {"code": 404, "message": "记录不存在"}
await obj.delete()
return {
"code": 200,
"message": "删除成功",
}
async def get_overview_stats(self) -> Dict[str, Any]:
"""获取所有平台的统计概览"""
today = date.today()
stats = {}
for source, model in self._model_map.items():
total = await model.all().count()
used = await model.filter(last_requested_date=today).count()
crawling = await model.filter(crawl_status="crawling", last_requested_date=today).count()
completed = await model.filter(crawl_status="completed", last_requested_date=today).count()
failed = await model.filter(crawl_status="failed", last_requested_date=today).count()
partial_count = await model.filter(crawl_status="partial", last_requested_date=today).count()
stats[source] = {
"total": total,
"used": used,
"unused": max(0, total - used),
"crawl_status": {
"crawling": crawling,
"completed": completed,
"failed": failed,
"partial": partial_count,
},
}
return {
"code": 200,
"message": "获取概览统计成功",
"data": stats,
}
def _ensure_model(self, source: str) -> Type:
"""根据平台标识返回对应模型类型"""
model = self._model_map.get(source)
if not model:
raise ValueError("不支持的平台标识")
return model