333 lines
11 KiB
Python
333 lines
11 KiB
Python
from datetime import date, datetime
|
||
import random
|
||
from typing import Any, Dict, List, Type
|
||
|
||
from tortoise.expressions import Q
|
||
|
||
from app.core.crud import CRUDBase
|
||
from app.models.keyword import BossKeyword, QcwyKeyword, ZhilianKeyword
|
||
|
||
|
||
class KeywordController:
|
||
def __init__(self) -> None:
|
||
self._model_map: Dict[str, Type] = {
|
||
"boss": BossKeyword,
|
||
"qcwy": QcwyKeyword,
|
||
"zhilian": ZhilianKeyword,
|
||
}
|
||
|
||
async def get_available(self, source: str, limit: int = 1, reserve: bool = True) -> Dict[str, Any]:
|
||
"""获取当天未使用的检索条件(城市+岗位)
|
||
|
||
参数:
|
||
source: 平台标识,取值为 boss|qcwy|zhilian
|
||
limit: 返回数量上限
|
||
reserve: 是否立即标记为已使用
|
||
|
||
返回:
|
||
包含 items/total/limit 的字典结构
|
||
|
||
注意:使用原子操作避免并发时的竞态条件
|
||
"""
|
||
model = self._ensure_model(source)
|
||
today = date.today()
|
||
now = datetime.now()
|
||
|
||
# 先统计总数
|
||
search = Q(last_requested_date__not=today) | Q(last_requested_date=None)
|
||
total = await model.filter(search).count()
|
||
items = []
|
||
|
||
if total > 0 and reserve:
|
||
# 使用原子操作:先更新,再查询已更新的记录
|
||
# 这样可以避免查询和标记之间的竞态条件
|
||
take = max(1, min(limit, total))
|
||
|
||
try:
|
||
# 获取一批未使用的记录ID(随机选择)
|
||
candidate_records = await model.filter(search).offset(
|
||
random.randint(0, max(0, total - take))
|
||
).limit(take).only('id')
|
||
|
||
candidate_ids = [r.id for r in candidate_records]
|
||
|
||
if candidate_ids:
|
||
# 原子性地更新这些记录(只更新未使用的)
|
||
# 使用数据库的原子UPDATE操作
|
||
updated_count = await model.filter(
|
||
id__in=candidate_ids
|
||
).filter(
|
||
Q(last_requested_date__isnull=True) | Q(last_requested_date__not=today)
|
||
).update(
|
||
last_requested_date=today,
|
||
last_requested_at=now
|
||
)
|
||
|
||
# 查询成功更新的记录
|
||
if updated_count > 0:
|
||
records = await model.filter(
|
||
id__in=candidate_ids,
|
||
last_requested_date=today
|
||
).limit(updated_count)
|
||
items = [{"id": r.id, "city": r.city, "job": r.job} for r in records]
|
||
except Exception as e:
|
||
# 如果原子操作失败,回退到原来的方法
|
||
import logging
|
||
logging.warning(f"原子操作失败,回退到原方法: {e}")
|
||
take = max(1, min(limit, total))
|
||
start = 0 if total == take else random.randint(0, total - take)
|
||
records = await model.filter(search).offset(start).limit(take)
|
||
items = [{"id": r.id, "city": r.city, "job": r.job} for r in records]
|
||
if reserve:
|
||
ids = [r.id for r in records]
|
||
await self.mark_used(source, ids)
|
||
elif total > 0:
|
||
# 如果不需要reserve,直接查询
|
||
take = max(1, min(limit, total))
|
||
start = 0 if total == take else random.randint(0, total - take)
|
||
records = await model.filter(search).offset(start).limit(take)
|
||
items = [{"id": r.id, "city": r.city, "job": r.job} for r in records]
|
||
|
||
return {
|
||
"code": 200,
|
||
"message": "查询可用检索条件成功",
|
||
"data": {
|
||
"items": items,
|
||
"total": total,
|
||
"limit": limit,
|
||
},
|
||
}
|
||
|
||
async def get_stats(self, source: str, on_date: date | None = None) -> Dict[str, Any]:
|
||
"""统计指定平台在某日期的使用与未使用数量
|
||
|
||
参数:
|
||
source: 平台标识,取值为 boss|qcwy|zhilian
|
||
on_date: 统计日期,不传则为今天
|
||
|
||
返回:
|
||
包含 total/used/unused 的字典结构
|
||
"""
|
||
model = self._ensure_model(source)
|
||
d = on_date or date.today()
|
||
total = await model.all().count()
|
||
used = await model.filter(last_requested_date=d).count()
|
||
unused = max(0, total - used)
|
||
return {
|
||
"code": 200,
|
||
"message": "统计成功",
|
||
"data": {
|
||
"date": str(d),
|
||
"total": total,
|
||
"used": used,
|
||
"unused": unused,
|
||
},
|
||
}
|
||
|
||
async def mark_used(self, source: str, ids: List[int]) -> Dict[str, Any]:
|
||
"""将检索条件标记为今日已使用
|
||
|
||
参数:
|
||
source: 平台标识,取值为 boss|qcwy|zhilian
|
||
ids: 需要标记的记录主键ID列表
|
||
|
||
返回:
|
||
更新结果,包括成功条数与日期
|
||
"""
|
||
model = self._ensure_model(source)
|
||
updated = 0
|
||
now = datetime.now()
|
||
today = date.today()
|
||
for rid in ids:
|
||
obj = await model.filter(id=rid).first()
|
||
if obj is None:
|
||
continue
|
||
if obj.last_requested_date == today:
|
||
continue
|
||
obj.last_requested_date = today
|
||
obj.last_requested_at = now
|
||
await obj.save()
|
||
updated += 1
|
||
return {
|
||
"code": 200,
|
||
"message": "状态更新完成",
|
||
"data": {
|
||
"updated": updated,
|
||
"ids": ids,
|
||
"date": str(today),
|
||
},
|
||
}
|
||
|
||
async def list_keywords(
|
||
self,
|
||
source: str,
|
||
page: int = 1,
|
||
page_size: int = 20,
|
||
city: str | None = None,
|
||
job: str | None = None,
|
||
) -> Dict[str, Any]:
|
||
"""获取关键词列表
|
||
|
||
参数:
|
||
source: 平台标识
|
||
page: 页码
|
||
page_size: 每页数量
|
||
city: 城市过滤
|
||
job: 职位过滤
|
||
|
||
返回:
|
||
包含列表数据和分页信息的字典
|
||
"""
|
||
model = self._ensure_model(source)
|
||
queryset = model.all()
|
||
if city:
|
||
queryset = queryset.filter(city__icontains=city)
|
||
if job:
|
||
queryset = queryset.filter(job__icontains=job)
|
||
|
||
total = await queryset.count()
|
||
queryset = queryset.order_by("-id").offset((page - 1) * page_size).limit(page_size)
|
||
items = await queryset.values(
|
||
"id",
|
||
"city",
|
||
"job",
|
||
"last_requested_date",
|
||
"last_requested_at",
|
||
"created_at",
|
||
"updated_at",
|
||
)
|
||
|
||
return {
|
||
"code": 200,
|
||
"message": "获取成功",
|
||
"data": items,
|
||
"total": total,
|
||
"page": page,
|
||
"page_size": page_size,
|
||
}
|
||
|
||
async def create_keyword(self, source: str, obj_in: Any) -> Dict[str, Any]:
|
||
"""创建关键词
|
||
|
||
参数:
|
||
source: 平台标识
|
||
obj_in: 创建数据对象
|
||
|
||
返回:
|
||
创建结果
|
||
"""
|
||
model = self._ensure_model(source)
|
||
# Check if already exists
|
||
exists = await model.filter(city=obj_in.city, job=obj_in.job).exists()
|
||
if exists:
|
||
return {"code": 400, "message": "该关键词组合已存在"}
|
||
|
||
obj = await model.create(**obj_in.model_dump())
|
||
data = {
|
||
"id": obj.id,
|
||
"city": obj.city,
|
||
"job": obj.job,
|
||
"last_requested_date": obj.last_requested_date,
|
||
"last_requested_at": obj.last_requested_at,
|
||
"created_at": obj.created_at,
|
||
"updated_at": obj.updated_at,
|
||
}
|
||
return {"code": 200, "message": "创建成功", "data": data}
|
||
|
||
async def update_keyword(self, source: str, id: int, obj_in: Any) -> Dict[str, Any]:
|
||
"""更新关键词
|
||
|
||
参数:
|
||
source: 平台标识
|
||
id: 记录ID
|
||
obj_in: 更新数据对象
|
||
|
||
返回:
|
||
更新结果
|
||
"""
|
||
model = self._ensure_model(source)
|
||
obj = await model.filter(id=id).first()
|
||
if not obj:
|
||
return {"code": 404, "message": "记录不存在"}
|
||
|
||
update_data = obj_in.model_dump(exclude_unset=True)
|
||
if update_data:
|
||
# Check for duplicates if updating city or job
|
||
if "city" in update_data or "job" in update_data:
|
||
city = update_data.get("city", obj.city)
|
||
job = update_data.get("job", obj.job)
|
||
exists = await model.filter(city=city, job=job).exclude(id=id).exists()
|
||
if exists:
|
||
return {"code": 400, "message": "该关键词组合已存在"}
|
||
|
||
await obj.update_from_dict(update_data)
|
||
await obj.save()
|
||
|
||
data = {
|
||
"id": obj.id,
|
||
"city": obj.city,
|
||
"job": obj.job,
|
||
"last_requested_date": obj.last_requested_date,
|
||
"last_requested_at": obj.last_requested_at,
|
||
"created_at": obj.created_at,
|
||
"updated_at": obj.updated_at,
|
||
}
|
||
return {"code": 200, "message": "更新成功", "data": data}
|
||
|
||
async def delete_keyword(self, source: str, id: int) -> Dict[str, Any]:
|
||
"""删除关键词
|
||
|
||
参数:
|
||
source: 平台标识
|
||
id: 记录ID
|
||
|
||
返回:
|
||
删除结果
|
||
"""
|
||
model = self._ensure_model(source)
|
||
obj = await model.filter(id=id).first()
|
||
if not obj:
|
||
return {"code": 404, "message": "记录不存在"}
|
||
|
||
await obj.delete()
|
||
return {
|
||
"code": 200,
|
||
"message": "删除成功",
|
||
}
|
||
|
||
async def get_overview_stats(self) -> Dict[str, Any]:
|
||
"""获取所有平台的统计概览
|
||
|
||
返回:
|
||
包含各平台统计数据的字典
|
||
"""
|
||
today = date.today()
|
||
stats = {}
|
||
for source, model in self._model_map.items():
|
||
total = await model.all().count()
|
||
used = await model.filter(last_requested_date=today).count()
|
||
stats[source] = {
|
||
"total": total,
|
||
"used": used,
|
||
"unused": max(0, total - used),
|
||
}
|
||
return {
|
||
"code": 200,
|
||
"message": "获取概览统计成功",
|
||
"data": stats,
|
||
}
|
||
|
||
def _ensure_model(self, source: str) -> Type:
|
||
"""根据平台标识返回对应模型类型
|
||
|
||
参数:
|
||
source: 平台标识,取值为 boss|qcwy|zhilian
|
||
|
||
返回:
|
||
对应的 Tortoise ORM 模型类型
|
||
"""
|
||
model = self._model_map.get(source)
|
||
if not model:
|
||
raise ValueError("不支持的平台标识")
|
||
return model
|