""" 分布式锁实现,基于 Redis。 提供基于 Redis 的分布式锁机制,支持自动续期、超时、上下文管理器等功能。 """ import asyncio import uuid from typing import Optional, Any from contextlib import asynccontextmanager from redis.asyncio import Redis from core.redis import get_redis class DistributedLock: """ 基于 Redis 的分布式锁实现。 特性: - 使用唯一标识符防止误解锁 - 支持自动续期机制 - 支持超时和重试 - 支持上下文管理器 - 线程/进程安全 示例: # 方式1: 使用上下文管理器(推荐) async with DistributedLock(redis, "my_resource"): # 执行需要互斥的操作 await do_something() # 方式2: 手动加锁解锁 lock = DistributedLock(redis, "my_resource") if await lock.acquire(): try: await do_something() finally: await lock.release() # 方式3: 使用装饰器 @distributed_lock("my_resource") async def my_function(): await do_something() """ def __init__( self, redis: Redis, lock_name: str, timeout: int = 30, retry_times: int = 0, retry_delay: float = 0.1, auto_renewal: bool = False, renewal_interval: Optional[int] = None, ): """ 初始化分布式锁。 Args: redis: Redis 客户端实例 lock_name: 锁的名称(资源标识符) timeout: 锁的超时时间(秒),防止死锁,默认 30 秒 retry_times: 获取锁失败时的重试次数,0 表示不重试 retry_delay: 重试间隔时间(秒),默认 0.1 秒 auto_renewal: 是否启用自动续期,默认 False renewal_interval: 自动续期间隔(秒),默认为 timeout 的 1/3 """ self.redis = redis self.lock_name = f"distributed_lock:{lock_name}" self.timeout = timeout self.retry_times = retry_times self.retry_delay = retry_delay self.auto_renewal = auto_renewal self.renewal_interval = renewal_interval or max(1, timeout // 3) # 使用 UUID 作为锁的唯一标识符,防止误解锁 self.identifier = str(uuid.uuid4()) # 自动续期任务 self._renewal_task: Optional[asyncio.Task] = None self._is_locked = False async def acquire(self, blocking: bool = True) -> bool: """ 获取锁。 Args: blocking: 是否阻塞等待,默认 True Returns: bool: 成功获取锁返回 True,否则返回 False """ retry_count = 0 while True: # 尝试获取锁:使用 SET NX EX 命令(原子操作) acquired = await self.redis.set( self.lock_name, self.identifier, nx=True, # Only set if not exists ex=self.timeout, # Set expiry ) if acquired: self._is_locked = True # 如果启用自动续期,启动续期任务 if self.auto_renewal: self._start_renewal() return True # 如果不阻塞或达到重试次数,返回失败 if not blocking or retry_count >= self.retry_times: return False # 等待后重试 retry_count += 1 await asyncio.sleep(self.retry_delay) async def release(self) -> bool: """ 释放锁。 使用 Lua 脚本确保只有锁的持有者才能释放锁(原子操作)。 Returns: bool: 成功释放返回 True,否则返回 False """ if not self._is_locked: return False # 停止自动续期任务 if self._renewal_task and not self._renewal_task.done(): self._renewal_task.cancel() try: await self._renewal_task except asyncio.CancelledError: pass # 使用 Lua 脚本确保只有锁的持有者才能释放锁 lua_script = """ if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("del", KEYS[1]) else return 0 end """ result = await self.redis.eval(lua_script, 1, self.lock_name, self.identifier) # type: ignore self._is_locked = False return bool(result) async def extend(self, additional_time: Optional[int] = None) -> bool: """ 延长锁的持有时间。 Args: additional_time: 额外延长的时间(秒),默认使用 timeout Returns: bool: 成功延长返回 True,否则返回 False """ if not self._is_locked: return False extend_time = additional_time or self.timeout # 使用 Lua 脚本确保只有锁的持有者才能延长时间 lua_script = """ if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("expire", KEYS[1], ARGV[2]) else return 0 end """ result = await self.redis.eval( # type: ignore lua_script, 1, self.lock_name, self.identifier, extend_time, ) return bool(result) async def is_locked_by_me(self) -> bool: """ 检查锁是否由当前实例持有。 Returns: bool: 由当前实例持有返回 True,否则返回 False """ if not self._is_locked: return False value = await self.redis.get(self.lock_name) return value == self.identifier async def is_locked_by_anyone(self) -> bool: """ 检查锁是否被任何实例持有。 Returns: bool: 锁被持有返回 True,否则返回 False """ return await self.redis.exists(self.lock_name) > 0 def _start_renewal(self) -> None: """启动自动续期任务。""" self._renewal_task = asyncio.create_task(self._renewal_loop()) async def _renewal_loop(self) -> None: """自动续期循环。""" try: while self._is_locked: await asyncio.sleep(self.renewal_interval) if self._is_locked: success = await self.extend() if not success: # 续期失败,可能锁已被其他进程获取 self._is_locked = False break except asyncio.CancelledError: pass async def __aenter__(self): """上下文管理器入口。""" acquired = await self.acquire() if not acquired: raise RuntimeError(f"Failed to acquire lock: {self.lock_name}") return self async def __aexit__(self, exc_type, exc_val, exc_tb): """上下文管理器出口。""" await self.release() return False @asynccontextmanager async def distributed_lock( lock_name: str, timeout: int = 30, retry_times: int = 3, retry_delay: float = 0.1, auto_renewal: bool = False, redis: Optional[Redis] = None, ): """ 分布式锁上下文管理器(便捷函数)。 Args: lock_name: 锁的名称 timeout: 锁的超时时间(秒) retry_times: 重试次数 retry_delay: 重试间隔(秒) auto_renewal: 是否自动续期 redis: Redis 客户端,如果不提供则使用默认客户端 示例: async with distributed_lock("my_resource", timeout=60): await do_something() """ if redis is None: redis = await get_redis() lock = DistributedLock( redis=redis, lock_name=lock_name, timeout=timeout, retry_times=retry_times, retry_delay=retry_delay, auto_renewal=auto_renewal, ) async with lock: yield lock def distributed_lock_decorator( lock_name: Optional[str] = None, timeout: int = 30, retry_times: int = 3, retry_delay: float = 0.1, auto_renewal: bool = False, ): """ 分布式锁装饰器。 Args: lock_name: 锁的名称,如果不提供则使用函数名 timeout: 锁的超时时间(秒) retry_times: 重试次数 retry_delay: 重试间隔(秒) auto_renewal: 是否自动续期 示例: # 使用函数名作为锁名 @distributed_lock_decorator() async def my_function(): await do_something() # 指定锁名 @distributed_lock_decorator("custom_lock_name") async def my_function(): await do_something() # 指定参数 @distributed_lock_decorator("custom_lock", timeout=60, auto_renewal=True) async def my_function(): await do_something() """ def decorator(func): async def wrapper(*args, **kwargs): # 确定锁名称 actual_lock_name = lock_name or f"{func.__module__}.{func.__name__}" redis = await get_redis() lock = DistributedLock( redis=redis, lock_name=actual_lock_name, timeout=timeout, retry_times=retry_times, retry_delay=retry_delay, auto_renewal=auto_renewal, ) async with lock: return await func(*args, **kwargs) return wrapper return decorator