81 lines
3.4 KiB
Python
81 lines
3.4 KiB
Python
from datetime import datetime
|
||
|
||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||
from starlette.requests import Request
|
||
from starlette.responses import Response
|
||
|
||
from app.models.metrics import IpUploadStats
|
||
|
||
|
||
class IpTrackingMiddleware(BaseHTTPMiddleware):
|
||
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
|
||
try:
|
||
response = await call_next(request)
|
||
except Exception as e:
|
||
# Let other middleware or exception handlers handle it
|
||
raise e
|
||
|
||
try:
|
||
path = request.url.path
|
||
if path.startswith("/api/v1/universal/data") or path.startswith("/api/v1/boss/") or path.startswith("/api/v1/qcwy/") or path.startswith("/api/v1/zhilian/"):
|
||
args = getattr(request.state, "request_args", {})
|
||
source = args.get("platform") or (
|
||
"boss" if path.startswith("/api/v1/boss/") else "qcwy" if path.startswith("/api/v1/qcwy/") else "zhilian" if path.startswith("/api/v1/zhilian/") else ""
|
||
)
|
||
ip = self._extract_ip(request)
|
||
count = self._estimate_count(args, response)
|
||
if source and ip and count:
|
||
await self._update_stats(source, ip, count)
|
||
except Exception:
|
||
pass
|
||
return response
|
||
|
||
def _extract_ip(self, request: Request) -> str:
|
||
xfwd = request.headers.get("x-forwarded-for") or request.headers.get("X-Forwarded-For")
|
||
if xfwd:
|
||
return xfwd.split(",")[0].strip()
|
||
xreal = request.headers.get("x-real-ip") or request.headers.get("X-Real-IP")
|
||
if xreal:
|
||
return xreal.strip()
|
||
return request.client.host if request.client else ""
|
||
|
||
def _estimate_count(self, args: dict, response: Response) -> int:
|
||
try:
|
||
# 同步接口:从响应体的数据段读取成功数量
|
||
if hasattr(response, "body") and response.body:
|
||
import json
|
||
data = json.loads(response.body)
|
||
if isinstance(data, dict) and isinstance(data.get("data"), dict):
|
||
d = data["data"]
|
||
if "success" in d:
|
||
return int(d.get("success", 0))
|
||
# 异步接口或无详细响应:按请求体估算
|
||
if "data_list" in args and isinstance(args.get("data_list"), list):
|
||
return len(args.get("data_list"))
|
||
if "data" in args:
|
||
return 1
|
||
except Exception:
|
||
pass
|
||
return 0
|
||
|
||
async def _update_stats(self, source: str, ip: str, inc: int) -> None:
|
||
from datetime import timezone
|
||
# 使用timezone-aware datetime,确保与数据库中的datetime类型一致
|
||
now = datetime.now(timezone.utc)
|
||
today = now.date()
|
||
obj = await IpUploadStats.get_or_none(source=source, ip=ip, date=today)
|
||
if obj:
|
||
obj.upload_count = obj.upload_count + inc
|
||
obj.last_report_at = now
|
||
if getattr(obj, "status", "normal") != "normal":
|
||
obj.status = "normal"
|
||
await obj.save()
|
||
else:
|
||
await IpUploadStats.create(
|
||
source=source,
|
||
ip=ip,
|
||
date=today,
|
||
upload_count=inc,
|
||
last_report_at=now,
|
||
status="normal",
|
||
) |