- 新增DistributedLock类,支持唯一标识防解锁冲突 - 实现自动续期、超时、重试、上下文管理器功能 - 提供手动 acquire、release 和 extend 接口 - 增加异步上下文管理器便利函数distributed_lock - 实现分布式锁装饰器distributed_lock_decorator支持灵活调用 - 编写示例模块,展示多种锁的使用方式和自动续期示例 - 支持锁状态查询,演示锁冲突与延长锁超时操作 - 保证锁的线程/进程安全与Redis操作原子性
334 lines
9.9 KiB
Python
334 lines
9.9 KiB
Python
"""
|
||
分布式锁实现,基于 Redis。
|
||
|
||
提供基于 Redis 的分布式锁机制,支持自动续期、超时、上下文管理器等功能。
|
||
"""
|
||
|
||
import asyncio
|
||
import uuid
|
||
from typing import Optional, Any
|
||
from contextlib import asynccontextmanager
|
||
from redis.asyncio import Redis
|
||
from core.redis import get_redis
|
||
|
||
|
||
class DistributedLock:
|
||
"""
|
||
基于 Redis 的分布式锁实现。
|
||
|
||
特性:
|
||
- 使用唯一标识符防止误解锁
|
||
- 支持自动续期机制
|
||
- 支持超时和重试
|
||
- 支持上下文管理器
|
||
- 线程/进程安全
|
||
|
||
示例:
|
||
# 方式1: 使用上下文管理器(推荐)
|
||
async with DistributedLock(redis, "my_resource"):
|
||
# 执行需要互斥的操作
|
||
await do_something()
|
||
|
||
# 方式2: 手动加锁解锁
|
||
lock = DistributedLock(redis, "my_resource")
|
||
if await lock.acquire():
|
||
try:
|
||
await do_something()
|
||
finally:
|
||
await lock.release()
|
||
|
||
# 方式3: 使用装饰器
|
||
@distributed_lock("my_resource")
|
||
async def my_function():
|
||
await do_something()
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
redis: Redis,
|
||
lock_name: str,
|
||
timeout: int = 30,
|
||
retry_times: int = 0,
|
||
retry_delay: float = 0.1,
|
||
auto_renewal: bool = False,
|
||
renewal_interval: Optional[int] = None,
|
||
):
|
||
"""
|
||
初始化分布式锁。
|
||
|
||
Args:
|
||
redis: Redis 客户端实例
|
||
lock_name: 锁的名称(资源标识符)
|
||
timeout: 锁的超时时间(秒),防止死锁,默认 30 秒
|
||
retry_times: 获取锁失败时的重试次数,0 表示不重试
|
||
retry_delay: 重试间隔时间(秒),默认 0.1 秒
|
||
auto_renewal: 是否启用自动续期,默认 False
|
||
renewal_interval: 自动续期间隔(秒),默认为 timeout 的 1/3
|
||
"""
|
||
self.redis = redis
|
||
self.lock_name = f"distributed_lock:{lock_name}"
|
||
self.timeout = timeout
|
||
self.retry_times = retry_times
|
||
self.retry_delay = retry_delay
|
||
self.auto_renewal = auto_renewal
|
||
self.renewal_interval = renewal_interval or max(1, timeout // 3)
|
||
|
||
# 使用 UUID 作为锁的唯一标识符,防止误解锁
|
||
self.identifier = str(uuid.uuid4())
|
||
|
||
# 自动续期任务
|
||
self._renewal_task: Optional[asyncio.Task] = None
|
||
self._is_locked = False
|
||
|
||
async def acquire(self, blocking: bool = True) -> bool:
|
||
"""
|
||
获取锁。
|
||
|
||
Args:
|
||
blocking: 是否阻塞等待,默认 True
|
||
|
||
Returns:
|
||
bool: 成功获取锁返回 True,否则返回 False
|
||
"""
|
||
retry_count = 0
|
||
|
||
while True:
|
||
# 尝试获取锁:使用 SET NX EX 命令(原子操作)
|
||
acquired = await self.redis.set(
|
||
self.lock_name,
|
||
self.identifier,
|
||
nx=True, # Only set if not exists
|
||
ex=self.timeout, # Set expiry
|
||
)
|
||
|
||
if acquired:
|
||
self._is_locked = True
|
||
|
||
# 如果启用自动续期,启动续期任务
|
||
if self.auto_renewal:
|
||
self._start_renewal()
|
||
|
||
return True
|
||
|
||
# 如果不阻塞或达到重试次数,返回失败
|
||
if not blocking or retry_count >= self.retry_times:
|
||
return False
|
||
|
||
# 等待后重试
|
||
retry_count += 1
|
||
await asyncio.sleep(self.retry_delay)
|
||
|
||
async def release(self) -> bool:
|
||
"""
|
||
释放锁。
|
||
|
||
使用 Lua 脚本确保只有锁的持有者才能释放锁(原子操作)。
|
||
|
||
Returns:
|
||
bool: 成功释放返回 True,否则返回 False
|
||
"""
|
||
if not self._is_locked:
|
||
return False
|
||
|
||
# 停止自动续期任务
|
||
if self._renewal_task and not self._renewal_task.done():
|
||
self._renewal_task.cancel()
|
||
try:
|
||
await self._renewal_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
# 使用 Lua 脚本确保只有锁的持有者才能释放锁
|
||
lua_script = """
|
||
if redis.call("get", KEYS[1]) == ARGV[1] then
|
||
return redis.call("del", KEYS[1])
|
||
else
|
||
return 0
|
||
end
|
||
"""
|
||
|
||
result = await self.redis.eval(lua_script, 1, self.lock_name, self.identifier) # type: ignore
|
||
self._is_locked = False
|
||
return bool(result)
|
||
|
||
async def extend(self, additional_time: Optional[int] = None) -> bool:
|
||
"""
|
||
延长锁的持有时间。
|
||
|
||
Args:
|
||
additional_time: 额外延长的时间(秒),默认使用 timeout
|
||
|
||
Returns:
|
||
bool: 成功延长返回 True,否则返回 False
|
||
"""
|
||
if not self._is_locked:
|
||
return False
|
||
|
||
extend_time = additional_time or self.timeout
|
||
|
||
# 使用 Lua 脚本确保只有锁的持有者才能延长时间
|
||
lua_script = """
|
||
if redis.call("get", KEYS[1]) == ARGV[1] then
|
||
return redis.call("expire", KEYS[1], ARGV[2])
|
||
else
|
||
return 0
|
||
end
|
||
"""
|
||
|
||
result = await self.redis.eval( # type: ignore
|
||
lua_script,
|
||
1,
|
||
self.lock_name,
|
||
self.identifier,
|
||
extend_time,
|
||
)
|
||
return bool(result)
|
||
|
||
async def is_locked_by_me(self) -> bool:
|
||
"""
|
||
检查锁是否由当前实例持有。
|
||
|
||
Returns:
|
||
bool: 由当前实例持有返回 True,否则返回 False
|
||
"""
|
||
if not self._is_locked:
|
||
return False
|
||
|
||
value = await self.redis.get(self.lock_name)
|
||
return value == self.identifier
|
||
|
||
async def is_locked_by_anyone(self) -> bool:
|
||
"""
|
||
检查锁是否被任何实例持有。
|
||
|
||
Returns:
|
||
bool: 锁被持有返回 True,否则返回 False
|
||
"""
|
||
return await self.redis.exists(self.lock_name) > 0
|
||
|
||
def _start_renewal(self) -> None:
|
||
"""启动自动续期任务。"""
|
||
self._renewal_task = asyncio.create_task(self._renewal_loop())
|
||
|
||
async def _renewal_loop(self) -> None:
|
||
"""自动续期循环。"""
|
||
try:
|
||
while self._is_locked:
|
||
await asyncio.sleep(self.renewal_interval)
|
||
if self._is_locked:
|
||
success = await self.extend()
|
||
if not success:
|
||
# 续期失败,可能锁已被其他进程获取
|
||
self._is_locked = False
|
||
break
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
async def __aenter__(self):
|
||
"""上下文管理器入口。"""
|
||
acquired = await self.acquire()
|
||
if not acquired:
|
||
raise RuntimeError(f"Failed to acquire lock: {self.lock_name}")
|
||
return self
|
||
|
||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||
"""上下文管理器出口。"""
|
||
await self.release()
|
||
return False
|
||
|
||
|
||
@asynccontextmanager
|
||
async def distributed_lock(
|
||
lock_name: str,
|
||
timeout: int = 30,
|
||
retry_times: int = 3,
|
||
retry_delay: float = 0.1,
|
||
auto_renewal: bool = False,
|
||
redis: Optional[Redis] = None,
|
||
):
|
||
"""
|
||
分布式锁上下文管理器(便捷函数)。
|
||
|
||
Args:
|
||
lock_name: 锁的名称
|
||
timeout: 锁的超时时间(秒)
|
||
retry_times: 重试次数
|
||
retry_delay: 重试间隔(秒)
|
||
auto_renewal: 是否自动续期
|
||
redis: Redis 客户端,如果不提供则使用默认客户端
|
||
|
||
示例:
|
||
async with distributed_lock("my_resource", timeout=60):
|
||
await do_something()
|
||
"""
|
||
if redis is None:
|
||
redis = await get_redis()
|
||
|
||
lock = DistributedLock(
|
||
redis=redis,
|
||
lock_name=lock_name,
|
||
timeout=timeout,
|
||
retry_times=retry_times,
|
||
retry_delay=retry_delay,
|
||
auto_renewal=auto_renewal,
|
||
)
|
||
|
||
async with lock:
|
||
yield lock
|
||
|
||
|
||
def distributed_lock_decorator(
|
||
lock_name: Optional[str] = None,
|
||
timeout: int = 30,
|
||
retry_times: int = 3,
|
||
retry_delay: float = 0.1,
|
||
auto_renewal: bool = False,
|
||
):
|
||
"""
|
||
分布式锁装饰器。
|
||
|
||
Args:
|
||
lock_name: 锁的名称,如果不提供则使用函数名
|
||
timeout: 锁的超时时间(秒)
|
||
retry_times: 重试次数
|
||
retry_delay: 重试间隔(秒)
|
||
auto_renewal: 是否自动续期
|
||
|
||
示例:
|
||
# 使用函数名作为锁名
|
||
@distributed_lock_decorator()
|
||
async def my_function():
|
||
await do_something()
|
||
|
||
# 指定锁名
|
||
@distributed_lock_decorator("custom_lock_name")
|
||
async def my_function():
|
||
await do_something()
|
||
|
||
# 指定参数
|
||
@distributed_lock_decorator("custom_lock", timeout=60, auto_renewal=True)
|
||
async def my_function():
|
||
await do_something()
|
||
"""
|
||
def decorator(func):
|
||
async def wrapper(*args, **kwargs):
|
||
# 确定锁名称
|
||
actual_lock_name = lock_name or f"{func.__module__}.{func.__name__}"
|
||
|
||
redis = await get_redis()
|
||
lock = DistributedLock(
|
||
redis=redis,
|
||
lock_name=actual_lock_name,
|
||
timeout=timeout,
|
||
retry_times=retry_times,
|
||
retry_delay=retry_delay,
|
||
auto_renewal=auto_renewal,
|
||
)
|
||
|
||
async with lock:
|
||
return await func(*args, **kwargs)
|
||
|
||
return wrapper
|
||
|
||
return decorator
|