JobData/app/core/ip_tracking.py

82 lines
3.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from datetime import datetime
from typing import Any
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response
from app.models.metrics import IpUploadStats
class IpTrackingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
try:
response = await call_next(request)
except Exception as e:
# Let other middleware or exception handlers handle it
raise e
try:
path = request.url.path
if path.startswith("/api/v1/universal/data") or path.startswith("/api/v1/boss/") or path.startswith("/api/v1/qcwy/") or path.startswith("/api/v1/zhilian/"):
args = getattr(request.state, "request_args", {})
source = args.get("platform") or (
"boss" if path.startswith("/api/v1/boss/") else "qcwy" if path.startswith("/api/v1/qcwy/") else "zhilian" if path.startswith("/api/v1/zhilian/") else ""
)
ip = self._extract_ip(request)
count = self._estimate_count(args, response)
if source and ip and count:
await self._update_stats(source, ip, count)
except Exception:
pass
return response
def _extract_ip(self, request: Request) -> str:
xfwd = request.headers.get("x-forwarded-for") or request.headers.get("X-Forwarded-For")
if xfwd:
return xfwd.split(",")[0].strip()
xreal = request.headers.get("x-real-ip") or request.headers.get("X-Real-IP")
if xreal:
return xreal.strip()
return request.client.host if request.client else ""
def _estimate_count(self, args: dict, response: Response) -> int:
try:
# 同步接口:从响应体的数据段读取成功数量
if hasattr(response, "body") and response.body:
import json
data = json.loads(response.body)
if isinstance(data, dict) and isinstance(data.get("data"), dict):
d = data["data"]
if "success" in d:
return int(d.get("success", 0))
# 异步接口或无详细响应:按请求体估算
if "data_list" in args and isinstance(args.get("data_list"), list):
return len(args.get("data_list"))
if "data" in args:
return 1
except Exception:
pass
return 0
async def _update_stats(self, source: str, ip: str, inc: int) -> None:
from datetime import timezone
# 使用timezone-aware datetime确保与数据库中的datetime类型一致
now = datetime.now(timezone.utc)
today = now.date()
obj = await IpUploadStats.get_or_none(source=source, ip=ip, date=today)
if obj:
obj.upload_count = obj.upload_count + inc
obj.last_report_at = now
if getattr(obj, "status", "normal") != "normal":
obj.status = "normal"
await obj.save()
else:
await IpUploadStats.create(
source=source,
ip=ip,
date=today,
upload_count=inc,
last_report_at=now,
status="normal",
)