370 lines
13 KiB
Python
370 lines
13 KiB
Python
from datetime import date, datetime
|
||
import random
|
||
from typing import Any, Dict, List, Type
|
||
|
||
from tortoise.expressions import Q
|
||
|
||
from app.models.keyword import BossKeyword, QcwyKeyword, ZhilianKeyword
|
||
|
||
|
||
class KeywordController:
|
||
def __init__(self) -> None:
|
||
self._model_map: Dict[str, Type] = {
|
||
"boss": BossKeyword,
|
||
"qcwy": QcwyKeyword,
|
||
"zhilian": ZhilianKeyword,
|
||
}
|
||
|
||
async def get_available(
|
||
self, source: str, limit: int = 1, reserve: bool = True, crawler_id: str = ""
|
||
) -> Dict[str, Any]:
|
||
"""获取可用关键词,优先返回断点续爬和失败重试的关键词
|
||
|
||
优先级:
|
||
1. partial(断点续爬)
|
||
2. failed 且 retry_count < 3(失败重试)
|
||
3. 全新未使用关键词
|
||
"""
|
||
model = self._ensure_model(source)
|
||
today = date.today()
|
||
now = datetime.now()
|
||
|
||
# 优先级 1: 断点续爬 (partial)
|
||
partial_q = Q(crawl_status="partial", last_requested_date=today)
|
||
# 优先级 2: 失败重试 (failed, retry < 3)
|
||
failed_q = Q(crawl_status="failed", last_requested_date=today, retry_count__lt=3)
|
||
# 优先级 3: 全新关键词
|
||
fresh_q = Q(last_requested_date__not=today) | Q(last_requested_date=None)
|
||
|
||
items = []
|
||
|
||
for priority, query, is_fresh in [
|
||
("partial", partial_q, False),
|
||
("failed", failed_q, False),
|
||
("fresh", fresh_q, True),
|
||
]:
|
||
count = await model.filter(query).count()
|
||
if count == 0:
|
||
continue
|
||
|
||
take = max(1, min(limit - len(items), count))
|
||
if take <= 0:
|
||
break
|
||
|
||
try:
|
||
offset = random.randint(0, max(0, count - take))
|
||
candidates = await model.filter(query).offset(offset).limit(take).only("id")
|
||
candidate_ids = [r.id for r in candidates]
|
||
|
||
if not candidate_ids:
|
||
continue
|
||
|
||
update_fields = {
|
||
"last_requested_at": now,
|
||
"crawl_status": "crawling",
|
||
"crawl_started_at": now,
|
||
"crawler_id": crawler_id,
|
||
}
|
||
|
||
if is_fresh:
|
||
update_fields["last_requested_date"] = today
|
||
update_fields["last_completed_page"] = 0
|
||
update_fields["total_pages"] = 0
|
||
update_fields["jobs_found"] = 0
|
||
update_fields["error_message"] = ""
|
||
update_fields["retry_count"] = 0
|
||
|
||
if reserve:
|
||
await model.filter(id__in=candidate_ids).update(**update_fields)
|
||
|
||
records = await model.filter(id__in=candidate_ids).limit(take)
|
||
for r in records:
|
||
items.append({
|
||
"id": r.id,
|
||
"city": r.city,
|
||
"job": r.job,
|
||
"last_completed_page": r.last_completed_page,
|
||
"crawl_status": r.crawl_status,
|
||
})
|
||
|
||
if len(items) >= limit:
|
||
break
|
||
except Exception:
|
||
continue
|
||
|
||
total_available = await model.filter(
|
||
partial_q | failed_q | fresh_q
|
||
).count()
|
||
|
||
return {
|
||
"code": 200,
|
||
"message": "查询可用检索条件成功",
|
||
"data": {
|
||
"items": items,
|
||
"total": total_available,
|
||
"limit": limit,
|
||
},
|
||
}
|
||
|
||
async def report_page_progress(
|
||
self,
|
||
source: str,
|
||
keyword_id: int,
|
||
page: int,
|
||
total_pages: int = 0,
|
||
jobs_found: int = 0,
|
||
) -> Dict[str, Any]:
|
||
"""爬虫汇报单页完成进度"""
|
||
model = self._ensure_model(source)
|
||
obj = await model.filter(id=keyword_id).first()
|
||
if not obj:
|
||
return {"code": 404, "message": "关键词不存在"}
|
||
|
||
obj.last_completed_page = page
|
||
if total_pages > 0:
|
||
obj.total_pages = total_pages
|
||
obj.jobs_found = (obj.jobs_found or 0) + jobs_found
|
||
await obj.save(update_fields=["last_completed_page", "total_pages", "jobs_found"])
|
||
|
||
return {
|
||
"code": 200,
|
||
"message": "进度更新成功",
|
||
"data": {
|
||
"keyword_id": keyword_id,
|
||
"last_completed_page": obj.last_completed_page,
|
||
"total_pages": obj.total_pages,
|
||
"jobs_found": obj.jobs_found,
|
||
},
|
||
}
|
||
|
||
async def report_crawl_complete(
|
||
self,
|
||
source: str,
|
||
keyword_id: int,
|
||
status: str,
|
||
error_message: str = "",
|
||
) -> Dict[str, Any]:
|
||
"""爬虫汇报爬取完成或失败"""
|
||
model = self._ensure_model(source)
|
||
obj = await model.filter(id=keyword_id).first()
|
||
if not obj:
|
||
return {"code": 404, "message": "关键词不存在"}
|
||
|
||
if status not in ("completed", "failed"):
|
||
return {"code": 400, "message": "status 仅支持 completed/failed"}
|
||
|
||
obj.crawl_status = status
|
||
obj.error_message = error_message
|
||
update_fields = ["crawl_status", "error_message"]
|
||
|
||
if status == "failed":
|
||
obj.retry_count = (obj.retry_count or 0) + 1
|
||
update_fields.append("retry_count")
|
||
|
||
await obj.save(update_fields=update_fields)
|
||
|
||
return {
|
||
"code": 200,
|
||
"message": f"爬取状态已更新为 {status}",
|
||
"data": {
|
||
"keyword_id": keyword_id,
|
||
"crawl_status": obj.crawl_status,
|
||
"retry_count": obj.retry_count,
|
||
},
|
||
}
|
||
|
||
async def get_stats(self, source: str, on_date: date | None = None) -> Dict[str, Any]:
|
||
"""统计指定平台关键词使用和爬取状态"""
|
||
model = self._ensure_model(source)
|
||
d = on_date or date.today()
|
||
total = await model.all().count()
|
||
used = await model.filter(last_requested_date=d).count()
|
||
unused = max(0, total - used)
|
||
|
||
crawling = await model.filter(crawl_status="crawling", last_requested_date=d).count()
|
||
completed = await model.filter(crawl_status="completed", last_requested_date=d).count()
|
||
failed = await model.filter(crawl_status="failed", last_requested_date=d).count()
|
||
partial = await model.filter(crawl_status="partial", last_requested_date=d).count()
|
||
|
||
return {
|
||
"code": 200,
|
||
"message": "统计成功",
|
||
"data": {
|
||
"date": str(d),
|
||
"total": total,
|
||
"used": used,
|
||
"unused": unused,
|
||
"crawl_status": {
|
||
"crawling": crawling,
|
||
"completed": completed,
|
||
"failed": failed,
|
||
"partial": partial,
|
||
},
|
||
},
|
||
}
|
||
|
||
async def mark_used(self, source: str, ids: List[int]) -> Dict[str, Any]:
|
||
"""将检索条件标记为今日已使用"""
|
||
model = self._ensure_model(source)
|
||
updated = 0
|
||
now = datetime.now()
|
||
today = date.today()
|
||
for rid in ids:
|
||
obj = await model.filter(id=rid).first()
|
||
if obj is None:
|
||
continue
|
||
if obj.last_requested_date == today:
|
||
continue
|
||
obj.last_requested_date = today
|
||
obj.last_requested_at = now
|
||
await obj.save()
|
||
updated += 1
|
||
return {
|
||
"code": 200,
|
||
"message": "状态更新完成",
|
||
"data": {
|
||
"updated": updated,
|
||
"ids": ids,
|
||
"date": str(today),
|
||
},
|
||
}
|
||
|
||
async def list_keywords(
|
||
self,
|
||
source: str,
|
||
page: int = 1,
|
||
page_size: int = 20,
|
||
city: str | None = None,
|
||
job: str | None = None,
|
||
) -> Dict[str, Any]:
|
||
"""获取关键词列表"""
|
||
model = self._ensure_model(source)
|
||
queryset = model.all()
|
||
if city:
|
||
queryset = queryset.filter(city__icontains=city)
|
||
if job:
|
||
queryset = queryset.filter(job__icontains=job)
|
||
|
||
total = await queryset.count()
|
||
queryset = queryset.order_by("-id").offset((page - 1) * page_size).limit(page_size)
|
||
items = await queryset.values(
|
||
"id",
|
||
"city",
|
||
"job",
|
||
"last_requested_date",
|
||
"last_requested_at",
|
||
"crawl_status",
|
||
"last_completed_page",
|
||
"total_pages",
|
||
"jobs_found",
|
||
"retry_count",
|
||
"created_at",
|
||
"updated_at",
|
||
)
|
||
|
||
return {
|
||
"code": 200,
|
||
"message": "获取成功",
|
||
"data": items,
|
||
"total": total,
|
||
"page": page,
|
||
"page_size": page_size,
|
||
}
|
||
|
||
async def create_keyword(self, source: str, obj_in: Any) -> Dict[str, Any]:
|
||
"""创建关键词"""
|
||
model = self._ensure_model(source)
|
||
exists = await model.filter(city=obj_in.city, job=obj_in.job).exists()
|
||
if exists:
|
||
return {"code": 400, "message": "该关键词组合已存在"}
|
||
|
||
obj = await model.create(**obj_in.model_dump())
|
||
data = {
|
||
"id": obj.id,
|
||
"city": obj.city,
|
||
"job": obj.job,
|
||
"last_requested_date": obj.last_requested_date,
|
||
"last_requested_at": obj.last_requested_at,
|
||
"created_at": obj.created_at,
|
||
"updated_at": obj.updated_at,
|
||
}
|
||
return {"code": 200, "message": "创建成功", "data": data}
|
||
|
||
async def update_keyword(self, source: str, id: int, obj_in: Any) -> Dict[str, Any]:
|
||
"""更新关键词"""
|
||
model = self._ensure_model(source)
|
||
obj = await model.filter(id=id).first()
|
||
if not obj:
|
||
return {"code": 404, "message": "记录不存在"}
|
||
|
||
update_data = obj_in.model_dump(exclude_unset=True)
|
||
if update_data:
|
||
if "city" in update_data or "job" in update_data:
|
||
city = update_data.get("city", obj.city)
|
||
job = update_data.get("job", obj.job)
|
||
exists = await model.filter(city=city, job=job).exclude(id=id).exists()
|
||
if exists:
|
||
return {"code": 400, "message": "该关键词组合已存在"}
|
||
|
||
await obj.update_from_dict(update_data)
|
||
await obj.save()
|
||
|
||
data = {
|
||
"id": obj.id,
|
||
"city": obj.city,
|
||
"job": obj.job,
|
||
"last_requested_date": obj.last_requested_date,
|
||
"last_requested_at": obj.last_requested_at,
|
||
"created_at": obj.created_at,
|
||
"updated_at": obj.updated_at,
|
||
}
|
||
return {"code": 200, "message": "更新成功", "data": data}
|
||
|
||
async def delete_keyword(self, source: str, id: int) -> Dict[str, Any]:
|
||
"""删除关键词"""
|
||
model = self._ensure_model(source)
|
||
obj = await model.filter(id=id).first()
|
||
if not obj:
|
||
return {"code": 404, "message": "记录不存在"}
|
||
|
||
await obj.delete()
|
||
return {
|
||
"code": 200,
|
||
"message": "删除成功",
|
||
}
|
||
|
||
async def get_overview_stats(self) -> Dict[str, Any]:
|
||
"""获取所有平台的统计概览"""
|
||
today = date.today()
|
||
stats = {}
|
||
for source, model in self._model_map.items():
|
||
total = await model.all().count()
|
||
used = await model.filter(last_requested_date=today).count()
|
||
crawling = await model.filter(crawl_status="crawling", last_requested_date=today).count()
|
||
completed = await model.filter(crawl_status="completed", last_requested_date=today).count()
|
||
failed = await model.filter(crawl_status="failed", last_requested_date=today).count()
|
||
partial_count = await model.filter(crawl_status="partial", last_requested_date=today).count()
|
||
stats[source] = {
|
||
"total": total,
|
||
"used": used,
|
||
"unused": max(0, total - used),
|
||
"crawl_status": {
|
||
"crawling": crawling,
|
||
"completed": completed,
|
||
"failed": failed,
|
||
"partial": partial_count,
|
||
},
|
||
}
|
||
return {
|
||
"code": 200,
|
||
"message": "获取概览统计成功",
|
||
"data": stats,
|
||
}
|
||
|
||
def _ensure_model(self, source: str) -> Type:
|
||
"""根据平台标识返回对应模型类型"""
|
||
model = self._model_map.get(source)
|
||
if not model:
|
||
raise ValueError("不支持的平台标识")
|
||
return model
|