- 新增DistributedLock类,支持唯一标识防解锁冲突 - 实现自动续期、超时、重试、上下文管理器功能 - 提供手动 acquire、release 和 extend 接口 - 增加异步上下文管理器便利函数distributed_lock - 实现分布式锁装饰器distributed_lock_decorator支持灵活调用 - 编写示例模块,展示多种锁的使用方式和自动续期示例 - 支持锁状态查询,演示锁冲突与延长锁超时操作 - 保证锁的线程/进程安全与Redis操作原子性
370 lines
12 KiB
Python
370 lines
12 KiB
Python
"""
|
||
分布式锁单元测试。
|
||
|
||
测试分布式锁的各种功能和边界情况。
|
||
"""
|
||
|
||
import pytest
|
||
import asyncio
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
from redis.asyncio import Redis
|
||
from core.distributed_lock import (
|
||
DistributedLock,
|
||
distributed_lock,
|
||
distributed_lock_decorator,
|
||
)
|
||
|
||
|
||
@pytest.fixture
|
||
async def mock_redis():
|
||
"""创建 Mock Redis 客户端"""
|
||
redis = AsyncMock(spec=Redis)
|
||
redis.set = AsyncMock(return_value=True)
|
||
redis.get = AsyncMock(return_value=None)
|
||
redis.delete = AsyncMock(return_value=1)
|
||
redis.exists = AsyncMock(return_value=0)
|
||
redis.expire = AsyncMock(return_value=True)
|
||
redis.eval = AsyncMock(return_value=1)
|
||
return redis
|
||
|
||
|
||
class TestDistributedLock:
|
||
"""测试 DistributedLock 类"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_acquire_and_release(self, mock_redis):
|
||
"""测试基本的获取和释放锁"""
|
||
lock = DistributedLock(mock_redis, "test_lock", timeout=30)
|
||
|
||
# 测试获取锁
|
||
assert await lock.acquire() is True
|
||
mock_redis.set.assert_called_once()
|
||
|
||
# 测试释放锁
|
||
assert await lock.release() is True
|
||
mock_redis.eval.assert_called_once()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_acquire_fail(self, mock_redis):
|
||
"""测试获取锁失败的情况"""
|
||
mock_redis.set = AsyncMock(return_value=False)
|
||
|
||
lock = DistributedLock(
|
||
mock_redis,
|
||
"test_lock",
|
||
timeout=30,
|
||
retry_times=0,
|
||
)
|
||
|
||
# 非阻塞模式下应该立即返回 False
|
||
assert await lock.acquire(blocking=False) is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_acquire_with_retry(self, mock_redis):
|
||
"""测试带重试的获取锁"""
|
||
# 前两次失败,第三次成功
|
||
mock_redis.set = AsyncMock(side_effect=[False, False, True])
|
||
|
||
lock = DistributedLock(
|
||
mock_redis,
|
||
"test_lock",
|
||
timeout=30,
|
||
retry_times=5,
|
||
retry_delay=0.01, # 减少测试时间
|
||
)
|
||
|
||
# 应该在第三次尝试时成功
|
||
assert await lock.acquire() is True
|
||
assert mock_redis.set.call_count == 3
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_context_manager(self, mock_redis):
|
||
"""测试上下文管理器"""
|
||
executed = False
|
||
|
||
async with DistributedLock(mock_redis, "test_lock", timeout=30):
|
||
executed = True
|
||
|
||
assert executed is True
|
||
assert mock_redis.set.called
|
||
assert mock_redis.eval.called
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_context_manager_acquire_fail(self, mock_redis):
|
||
"""测试上下文管理器获取锁失败"""
|
||
mock_redis.set = AsyncMock(return_value=False)
|
||
|
||
with pytest.raises(RuntimeError, match="Failed to acquire lock"):
|
||
async with DistributedLock(
|
||
mock_redis,
|
||
"test_lock",
|
||
timeout=30,
|
||
retry_times=0,
|
||
):
|
||
pass
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_extend_lock(self, mock_redis):
|
||
"""测试延长锁的持有时间"""
|
||
lock = DistributedLock(mock_redis, "test_lock", timeout=30)
|
||
await lock.acquire()
|
||
|
||
# 延长锁
|
||
assert await lock.extend(additional_time=60) is True
|
||
|
||
await lock.release()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_extend_without_acquire(self, mock_redis):
|
||
"""测试在未获取锁的情况下延长锁"""
|
||
lock = DistributedLock(mock_redis, "test_lock", timeout=30)
|
||
|
||
# 未获取锁时延长应该失败
|
||
assert await lock.extend() is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_is_locked_by_me(self, mock_redis):
|
||
"""测试检查锁是否由当前实例持有"""
|
||
lock = DistributedLock(mock_redis, "test_lock", timeout=30)
|
||
|
||
# 未获取锁时
|
||
assert await lock.is_locked_by_me() is False
|
||
|
||
# 获取锁后
|
||
await lock.acquire()
|
||
mock_redis.get = AsyncMock(return_value=lock.identifier)
|
||
assert await lock.is_locked_by_me() is True
|
||
|
||
await lock.release()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_is_locked_by_anyone(self, mock_redis):
|
||
"""测试检查锁是否被任何实例持有"""
|
||
lock = DistributedLock(mock_redis, "test_lock", timeout=30)
|
||
|
||
# 锁不存在
|
||
mock_redis.exists = AsyncMock(return_value=0)
|
||
assert await lock.is_locked_by_anyone() is False
|
||
|
||
# 锁存在
|
||
mock_redis.exists = AsyncMock(return_value=1)
|
||
assert await lock.is_locked_by_anyone() is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_release_without_acquire(self, mock_redis):
|
||
"""测试未获取锁就释放"""
|
||
lock = DistributedLock(mock_redis, "test_lock", timeout=30)
|
||
|
||
# 未获取锁时释放应该返回 False
|
||
assert await lock.release() is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_auto_renewal(self, mock_redis):
|
||
"""测试自动续期功能"""
|
||
lock = DistributedLock(
|
||
mock_redis,
|
||
"test_lock",
|
||
timeout=2,
|
||
auto_renewal=True,
|
||
renewal_interval=1,
|
||
)
|
||
|
||
await lock.acquire()
|
||
|
||
# 等待足够的时间让续期至少执行一次
|
||
await asyncio.sleep(1.5)
|
||
|
||
# 验证续期被调用(通过 eval 调用)
|
||
# 注意:第一次 eval 是 acquire 调用的,后续的是续期调用的
|
||
# 实际上我们这里主要验证逻辑,mock 环境下可能不完全准确
|
||
|
||
await lock.release()
|
||
|
||
|
||
class TestDistributedLockHelpers:
|
||
"""测试辅助函数"""
|
||
|
||
@pytest.mark.asyncio
|
||
@patch("core.distributed_lock.get_redis")
|
||
async def test_distributed_lock_function(self, mock_get_redis, mock_redis):
|
||
"""测试 distributed_lock 便捷函数"""
|
||
mock_get_redis.return_value = mock_redis
|
||
|
||
executed = False
|
||
|
||
async with distributed_lock("test_resource", timeout=30):
|
||
executed = True
|
||
|
||
assert executed is True
|
||
assert mock_redis.set.called
|
||
|
||
@pytest.mark.asyncio
|
||
@patch("core.distributed_lock.get_redis")
|
||
async def test_distributed_lock_decorator(self, mock_get_redis, mock_redis):
|
||
"""测试装饰器"""
|
||
mock_get_redis.return_value = mock_redis
|
||
|
||
call_count = 0
|
||
|
||
@distributed_lock_decorator("test_lock")
|
||
async def test_function():
|
||
nonlocal call_count
|
||
call_count += 1
|
||
return "success"
|
||
|
||
result = await test_function()
|
||
|
||
assert result == "success"
|
||
assert call_count == 1
|
||
assert mock_redis.set.called
|
||
|
||
@pytest.mark.asyncio
|
||
@patch("core.distributed_lock.get_redis")
|
||
async def test_decorator_default_lock_name(self, mock_get_redis, mock_redis):
|
||
"""测试装饰器使用默认锁名"""
|
||
mock_get_redis.return_value = mock_redis
|
||
|
||
@distributed_lock_decorator()
|
||
async def my_function():
|
||
return "done"
|
||
|
||
result = await my_function()
|
||
|
||
assert result == "done"
|
||
# 锁名应该是函数的模块名 + 函数名
|
||
# 验证 set 被调用时使用了包含函数名的锁名
|
||
assert mock_redis.set.called
|
||
|
||
|
||
class TestDistributedLockConcurrency:
|
||
"""测试并发场景"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_concurrent_lock_acquisition(self, mock_redis):
|
||
"""测试并发获取锁"""
|
||
# 模拟:第一个请求成功,其他失败
|
||
call_count = 0
|
||
|
||
async def set_side_effect(*args, **kwargs):
|
||
nonlocal call_count
|
||
call_count += 1
|
||
return call_count == 1 # 只有第一次返回 True
|
||
|
||
mock_redis.set = AsyncMock(side_effect=set_side_effect)
|
||
|
||
results = []
|
||
|
||
async def try_acquire():
|
||
lock = DistributedLock(
|
||
mock_redis,
|
||
"shared_resource",
|
||
timeout=30,
|
||
retry_times=0,
|
||
)
|
||
result = await lock.acquire(blocking=False)
|
||
results.append(result)
|
||
if result:
|
||
await lock.release()
|
||
|
||
# 并发执行 5 个任务
|
||
await asyncio.gather(*[try_acquire() for _ in range(5)])
|
||
|
||
# 应该只有一个成功
|
||
assert sum(results) == 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_lock_ordering(self, mock_redis):
|
||
"""测试锁的顺序获取"""
|
||
execution_order = []
|
||
|
||
# 模拟锁的获取和释放
|
||
lock_held = False
|
||
|
||
async def set_impl(*args, **kwargs):
|
||
nonlocal lock_held
|
||
if kwargs.get('nx') and not lock_held:
|
||
lock_held = True
|
||
return True
|
||
return False
|
||
|
||
async def eval_impl(*args, **kwargs):
|
||
nonlocal lock_held
|
||
lock_held = False
|
||
return 1
|
||
|
||
mock_redis.set = AsyncMock(side_effect=set_impl)
|
||
mock_redis.eval = AsyncMock(side_effect=eval_impl)
|
||
|
||
async def worker(worker_id: int):
|
||
lock = DistributedLock(
|
||
mock_redis,
|
||
"ordered_resource",
|
||
timeout=10,
|
||
retry_times=20,
|
||
retry_delay=0.01,
|
||
)
|
||
|
||
if await lock.acquire():
|
||
execution_order.append(worker_id)
|
||
await asyncio.sleep(0.02) # 模拟工作
|
||
await lock.release()
|
||
|
||
# 启动 3 个工作者
|
||
await asyncio.gather(*[worker(i) for i in range(3)])
|
||
|
||
# 所有工作者都应该执行
|
||
assert len(execution_order) == 3
|
||
# 顺序可能不同,但不应该有重复
|
||
assert len(set(execution_order)) == 3
|
||
|
||
|
||
class TestEdgeCases:
|
||
"""测试边界情况"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_zero_timeout(self, mock_redis):
|
||
"""测试超时时间为 0"""
|
||
lock = DistributedLock(mock_redis, "test_lock", timeout=0)
|
||
await lock.acquire()
|
||
|
||
# 应该使用 0 作为超时
|
||
call_args = mock_redis.set.call_args
|
||
assert call_args[1]['ex'] == 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_very_long_timeout(self, mock_redis):
|
||
"""测试非常长的超时时间"""
|
||
lock = DistributedLock(mock_redis, "test_lock", timeout=86400) # 1 天
|
||
await lock.acquire()
|
||
|
||
call_args = mock_redis.set.call_args
|
||
assert call_args[1]['ex'] == 86400
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_multiple_release(self, mock_redis):
|
||
"""测试多次释放锁"""
|
||
lock = DistributedLock(mock_redis, "test_lock", timeout=30)
|
||
await lock.acquire()
|
||
|
||
# 第一次释放应该成功
|
||
assert await lock.release() is True
|
||
|
||
# 第二次释放应该失败(因为已经不持有锁)
|
||
assert await lock.release() is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_lock_with_special_characters(self, mock_redis):
|
||
"""测试包含特殊字符的锁名"""
|
||
special_names = [
|
||
"lock:with:colons",
|
||
"lock/with/slashes",
|
||
"lock-with-dashes",
|
||
"lock_with_underscores",
|
||
"lock.with.dots",
|
||
]
|
||
|
||
for name in special_names:
|
||
lock = DistributedLock(mock_redis, name, timeout=30)
|
||
assert await lock.acquire() is True
|
||
await lock.release()
|