121 lines
4.0 KiB
Python
121 lines
4.0 KiB
Python
import shutil
|
||
import tempfile
|
||
import time
|
||
import uuid
|
||
from contextlib import asynccontextmanager
|
||
from pathlib import Path
|
||
|
||
from loguru import logger
|
||
|
||
|
||
class DistributedLock:
|
||
"""分布式锁封装,优先使用 Redis,不可用时降级为文件锁(带 TTL)"""
|
||
|
||
def __init__(self, name: str, ttl_seconds: int = 600):
|
||
self.name = name
|
||
self.ttl = ttl_seconds
|
||
self.token = str(uuid.uuid4())
|
||
self._use_redis = False
|
||
self._redis = None
|
||
self._file_path = Path(tempfile.gettempdir()) / f"jobdata_lock_{self.name}"
|
||
self._init_redis()
|
||
|
||
def _init_redis(self) -> None:
|
||
try:
|
||
import redis.asyncio as aioredis
|
||
from app.settings.config import settings
|
||
|
||
host = getattr(settings, "REDIS_HOST", None) or ""
|
||
if not host:
|
||
return
|
||
self._redis = aioredis.Redis(
|
||
host=host,
|
||
port=getattr(settings, "REDIS_PORT", 6379),
|
||
db=getattr(settings, "REDIS_DB", 0),
|
||
password=getattr(settings, "REDIS_PASS", None) or None,
|
||
socket_timeout=3,
|
||
)
|
||
self._use_redis = True
|
||
except Exception:
|
||
self._use_redis = False
|
||
|
||
async def _ping_redis(self) -> bool:
|
||
if not self._redis:
|
||
return False
|
||
try:
|
||
return bool(await self._redis.ping())
|
||
except Exception:
|
||
self._use_redis = False
|
||
return False
|
||
|
||
async def acquire(self) -> bool:
|
||
"""获取锁,返回是否成功"""
|
||
if self._use_redis and self._redis is not None:
|
||
try:
|
||
if not await self._ping_redis():
|
||
return self._try_file_lock()
|
||
return bool(await self._redis.set(
|
||
f"lock:{self.name}", self.token, nx=True, ex=self.ttl
|
||
))
|
||
except Exception:
|
||
pass
|
||
return self._try_file_lock()
|
||
|
||
def _try_file_lock(self) -> bool:
|
||
"""文件锁(带 TTL 过期检查),使用绝对路径"""
|
||
lock_dir = self._file_path
|
||
lock_meta = lock_dir / "meta"
|
||
try:
|
||
lock_dir.mkdir()
|
||
lock_meta.write_text(str(time.time()))
|
||
return True
|
||
except FileExistsError:
|
||
if lock_meta.exists():
|
||
try:
|
||
created = float(lock_meta.read_text())
|
||
if time.time() - created > self.ttl:
|
||
logger.warning(
|
||
f"Stale file lock detected for '{self.name}', "
|
||
f"age={time.time() - created:.0f}s > ttl={self.ttl}s. Cleaning up."
|
||
)
|
||
shutil.rmtree(lock_dir, ignore_errors=True)
|
||
try:
|
||
lock_dir.mkdir()
|
||
lock_meta.write_text(str(time.time()))
|
||
return True
|
||
except Exception:
|
||
return False
|
||
except (ValueError, OSError):
|
||
pass
|
||
return False
|
||
except Exception:
|
||
return False
|
||
|
||
async def release(self) -> None:
|
||
"""释放锁"""
|
||
if self._use_redis and self._redis is not None:
|
||
try:
|
||
key = f"lock:{self.name}"
|
||
val = await self._redis.get(key)
|
||
if val and val.decode() == self.token:
|
||
await self._redis.delete(key)
|
||
except Exception:
|
||
pass
|
||
try:
|
||
shutil.rmtree(self._file_path, ignore_errors=True)
|
||
except Exception:
|
||
pass
|
||
|
||
@asynccontextmanager
|
||
async def context(self):
|
||
"""上下文管理:获取成功才进入"""
|
||
acquired = await self.acquire()
|
||
try:
|
||
if acquired:
|
||
yield True
|
||
else:
|
||
yield False
|
||
finally:
|
||
if acquired:
|
||
await self.release()
|