from datetime import datetime from typing import Any from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request from starlette.responses import Response from app.models.metrics import IpUploadStats class IpTrackingMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: try: response = await call_next(request) except Exception as e: # Let other middleware or exception handlers handle it raise e try: path = request.url.path if path.startswith("/api/v1/universal/data") or path.startswith("/api/v1/boss/") or path.startswith("/api/v1/qcwy/") or path.startswith("/api/v1/zhilian/"): args = getattr(request.state, "request_args", {}) source = args.get("platform") or ( "boss" if path.startswith("/api/v1/boss/") else "qcwy" if path.startswith("/api/v1/qcwy/") else "zhilian" if path.startswith("/api/v1/zhilian/") else "" ) ip = self._extract_ip(request) count = self._estimate_count(args, response) if source and ip and count: await self._update_stats(source, ip, count) except Exception: pass return response def _extract_ip(self, request: Request) -> str: xfwd = request.headers.get("x-forwarded-for") or request.headers.get("X-Forwarded-For") if xfwd: return xfwd.split(",")[0].strip() xreal = request.headers.get("x-real-ip") or request.headers.get("X-Real-IP") if xreal: return xreal.strip() return request.client.host if request.client else "" def _estimate_count(self, args: dict, response: Response) -> int: try: # 同步接口:从响应体的数据段读取成功数量 if hasattr(response, "body") and response.body: import json data = json.loads(response.body) if isinstance(data, dict) and isinstance(data.get("data"), dict): d = data["data"] if "success" in d: return int(d.get("success", 0)) # 异步接口或无详细响应:按请求体估算 if "data_list" in args and isinstance(args.get("data_list"), list): return len(args.get("data_list")) if "data" in args: return 1 except Exception: pass return 0 async def _update_stats(self, source: str, ip: str, inc: int) -> None: from datetime import timezone # 使用timezone-aware datetime,确保与数据库中的datetime类型一致 now = datetime.now(timezone.utc) today = now.date() obj = await IpUploadStats.get_or_none(source=source, ip=ip, date=today) if obj: obj.upload_count = obj.upload_count + inc obj.last_report_at = now if getattr(obj, "status", "normal") != "normal": obj.status = "normal" await obj.save() else: await IpUploadStats.create( source=source, ip=ip, date=today, upload_count=inc, last_report_at=now, status="normal", )