Files
FunctionalScaffold/tests/test_job_manager.py
Roog (顾新培) 87ed8c071c main:新增并发控制功能
变更内容:
- 增加 `max_concurrent_jobs` 配置项,支持设置最大并发任务数。
- 为 `JobManager` 添加信号量控制实现任务并发限制。
- 新增获取任务并发状态的接口 `/jobs/concurrency/status`。
- 编写并发控制功能相关的测试用
2026-02-02 17:11:52 +08:00

500 lines
17 KiB
Python

"""异步任务管理器测试"""
import asyncio
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi import status
from src.functional_scaffold.core.job_manager import (
JobManager,
get_job_manager,
shutdown_job_manager,
)
from src.functional_scaffold.api.models import JobStatus
class TestJobManager:
"""测试 JobManager 类"""
@pytest.fixture
def mock_redis(self):
"""模拟 Redis 客户端"""
mock = AsyncMock()
mock.ping = AsyncMock(return_value=True)
mock.hset = AsyncMock()
mock.hgetall = AsyncMock(return_value={})
mock.expire = AsyncMock()
mock.close = AsyncMock()
return mock
@pytest.fixture
def mock_http_client(self):
"""模拟 HTTP 客户端"""
mock = AsyncMock()
mock.post = AsyncMock()
mock.aclose = AsyncMock()
return mock
@pytest.mark.asyncio
async def test_generate_job_id(self):
"""测试任务 ID 生成"""
manager = JobManager()
job_id = manager._generate_job_id()
assert len(job_id) == 12
assert all(c in "0123456789abcdef" for c in job_id)
@pytest.mark.asyncio
async def test_get_timestamp(self):
"""测试时间戳生成"""
manager = JobManager()
timestamp = manager._get_timestamp()
assert "T" in timestamp
assert timestamp.endswith("+00:00") or timestamp.endswith("Z")
@pytest.mark.asyncio
async def test_get_available_algorithms(self):
"""测试获取可用算法列表"""
manager = JobManager()
manager._register_algorithms()
algorithms = manager.get_available_algorithms()
assert "PrimeChecker" in algorithms
@pytest.mark.asyncio
async def test_is_available_without_redis(self):
"""测试 Redis 不可用时的状态"""
manager = JobManager()
assert manager.is_available() is False
class TestJobManagerWithMocks:
"""使用 Mock 测试 JobManager"""
@pytest.mark.asyncio
async def test_create_job(self):
"""测试创建任务"""
manager = JobManager()
# 模拟 Redis
mock_redis = AsyncMock()
mock_redis.hset = AsyncMock()
manager._redis_client = mock_redis
manager._register_algorithms()
job_id = await manager.create_job(
algorithm="PrimeChecker",
params={"number": 17},
webhook="https://example.com/callback",
request_id="test-request-id",
)
assert len(job_id) == 12
mock_redis.hset.assert_called_once()
@pytest.mark.asyncio
async def test_create_job_invalid_algorithm(self):
"""测试创建任务时算法不存在"""
manager = JobManager()
mock_redis = AsyncMock()
manager._redis_client = mock_redis
manager._register_algorithms()
with pytest.raises(ValueError, match="不存在"):
await manager.create_job(
algorithm="NonExistentAlgorithm",
params={},
)
@pytest.mark.asyncio
async def test_create_job_redis_unavailable(self):
"""测试 Redis 不可用时创建任务"""
manager = JobManager()
manager._register_algorithms()
with pytest.raises(RuntimeError, match="Redis 不可用"):
await manager.create_job(
algorithm="PrimeChecker",
params={"number": 17},
)
@pytest.mark.asyncio
async def test_get_job(self):
"""测试获取任务信息"""
manager = JobManager()
mock_redis = AsyncMock()
mock_redis.hgetall = AsyncMock(
return_value={
"status": "completed",
"algorithm": "PrimeChecker",
"created_at": "2026-02-02T10:00:00+00:00",
"started_at": "2026-02-02T10:00:01+00:00",
"completed_at": "2026-02-02T10:00:02+00:00",
"result": '{"number": 17, "is_prime": true}',
"error": "",
"metadata": '{"elapsed_time": 0.001}',
}
)
manager._redis_client = mock_redis
job_data = await manager.get_job("test-job-id")
assert job_data is not None
assert job_data["job_id"] == "test-job-id"
assert job_data["status"] == "completed"
assert job_data["algorithm"] == "PrimeChecker"
assert job_data["result"]["number"] == 17
assert job_data["result"]["is_prime"] is True
@pytest.mark.asyncio
async def test_get_job_not_found(self):
"""测试获取不存在的任务"""
manager = JobManager()
mock_redis = AsyncMock()
mock_redis.hgetall = AsyncMock(return_value={})
manager._redis_client = mock_redis
job_data = await manager.get_job("non-existent-job")
assert job_data is None
@pytest.mark.asyncio
async def test_execute_job(self):
"""测试执行任务"""
manager = JobManager()
mock_redis = AsyncMock()
mock_redis.hgetall = AsyncMock(
return_value={
"status": "pending",
"algorithm": "PrimeChecker",
"params": '{"number": 17}',
"webhook": "",
"request_id": "test-request-id",
"created_at": "2026-02-02T10:00:00+00:00",
}
)
mock_redis.hset = AsyncMock()
mock_redis.expire = AsyncMock()
manager._redis_client = mock_redis
manager._register_algorithms()
# 初始化 semaphore
import asyncio
manager._semaphore = asyncio.Semaphore(10)
await manager.execute_job("test-job-id")
# 验证状态更新被调用
assert mock_redis.hset.call_count >= 2 # running + completed
mock_redis.expire.assert_called_once()
class TestJobsAPI:
"""测试 /jobs API 端点"""
def test_create_job_success(self, client):
"""测试成功创建任务"""
with patch(
"src.functional_scaffold.api.routes.get_job_manager", new_callable=AsyncMock
) as mock_get_manager:
mock_manager = MagicMock()
mock_manager.is_available.return_value = True
mock_manager.get_available_algorithms.return_value = ["PrimeChecker"]
mock_manager.create_job = AsyncMock(return_value="abc123def456")
mock_manager.get_job = AsyncMock(
return_value={
"job_id": "abc123def456",
"status": "pending",
"algorithm": "PrimeChecker",
"created_at": "2026-02-02T10:00:00+00:00",
}
)
mock_manager.execute_job = AsyncMock()
mock_get_manager.return_value = mock_manager
response = client.post(
"/jobs",
json={
"algorithm": "PrimeChecker",
"params": {"number": 17},
},
)
assert response.status_code == status.HTTP_202_ACCEPTED
data = response.json()
assert data["job_id"] == "abc123def456"
assert data["status"] == "pending"
assert data["message"] == "任务已创建"
def test_create_job_algorithm_not_found(self, client):
"""测试创建任务时算法不存在"""
with patch(
"src.functional_scaffold.api.routes.get_job_manager", new_callable=AsyncMock
) as mock_get_manager:
mock_manager = MagicMock()
mock_manager.is_available.return_value = True
mock_manager.get_available_algorithms.return_value = ["PrimeChecker"]
mock_get_manager.return_value = mock_manager
response = client.post(
"/jobs",
json={
"algorithm": "NonExistentAlgorithm",
"params": {},
},
)
assert response.status_code == status.HTTP_404_NOT_FOUND
data = response.json()
assert data["detail"]["error"] == "ALGORITHM_NOT_FOUND"
def test_create_job_service_unavailable(self, client):
"""测试服务不可用时创建任务"""
with patch(
"src.functional_scaffold.api.routes.get_job_manager", new_callable=AsyncMock
) as mock_get_manager:
mock_manager = MagicMock()
mock_manager.is_available.return_value = False
mock_get_manager.return_value = mock_manager
response = client.post(
"/jobs",
json={
"algorithm": "PrimeChecker",
"params": {"number": 17},
},
)
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
def test_get_job_status_success(self, client):
"""测试成功查询任务状态"""
with patch(
"src.functional_scaffold.api.routes.get_job_manager", new_callable=AsyncMock
) as mock_get_manager:
mock_manager = MagicMock()
mock_manager.is_available.return_value = True
mock_manager.get_job = AsyncMock(
return_value={
"job_id": "abc123def456",
"status": "completed",
"algorithm": "PrimeChecker",
"created_at": "2026-02-02T10:00:00+00:00",
"started_at": "2026-02-02T10:00:01+00:00",
"completed_at": "2026-02-02T10:00:02+00:00",
"result": {"number": 17, "is_prime": True},
"error": None,
"metadata": {"elapsed_time": 0.001},
}
)
mock_get_manager.return_value = mock_manager
response = client.get("/jobs/abc123def456")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["job_id"] == "abc123def456"
assert data["status"] == "completed"
assert data["result"]["is_prime"] is True
def test_get_job_status_not_found(self, client):
"""测试查询不存在的任务"""
with patch(
"src.functional_scaffold.api.routes.get_job_manager", new_callable=AsyncMock
) as mock_get_manager:
mock_manager = MagicMock()
mock_manager.is_available.return_value = True
mock_manager.get_job = AsyncMock(return_value=None)
mock_get_manager.return_value = mock_manager
response = client.get("/jobs/non-existent-job")
assert response.status_code == status.HTTP_404_NOT_FOUND
data = response.json()
assert data["detail"]["error"] == "JOB_NOT_FOUND"
def test_get_job_status_service_unavailable(self, client):
"""测试服务不可用时查询任务"""
with patch(
"src.functional_scaffold.api.routes.get_job_manager", new_callable=AsyncMock
) as mock_get_manager:
mock_manager = MagicMock()
mock_manager.is_available.return_value = False
mock_get_manager.return_value = mock_manager
response = client.get("/jobs/abc123def456")
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
class TestWebhook:
"""测试 Webhook 回调"""
@pytest.mark.asyncio
async def test_send_webhook_success(self):
"""测试成功发送 Webhook"""
manager = JobManager()
mock_redis = AsyncMock()
mock_redis.hgetall = AsyncMock(
return_value={
"status": "completed",
"algorithm": "PrimeChecker",
"created_at": "2026-02-02T10:00:00+00:00",
"completed_at": "2026-02-02T10:00:02+00:00",
"result": '{"number": 17, "is_prime": true}',
"error": "",
"metadata": '{"elapsed_time": 0.001}',
}
)
manager._redis_client = mock_redis
mock_response = MagicMock()
mock_response.status_code = 200
mock_http = AsyncMock()
mock_http.post = AsyncMock(return_value=mock_response)
manager._http_client = mock_http
await manager._send_webhook("test-job-id", "https://example.com/callback")
mock_http.post.assert_called_once()
call_args = mock_http.post.call_args
assert call_args[0][0] == "https://example.com/callback"
assert "json" in call_args[1]
@pytest.mark.asyncio
async def test_send_webhook_retry_on_failure(self):
"""测试 Webhook 失败时重试"""
manager = JobManager()
mock_redis = AsyncMock()
mock_redis.hgetall = AsyncMock(
return_value={
"status": "completed",
"algorithm": "PrimeChecker",
"created_at": "2026-02-02T10:00:00+00:00",
"completed_at": "2026-02-02T10:00:02+00:00",
"result": "{}",
"error": "",
"metadata": "{}",
}
)
manager._redis_client = mock_redis
mock_http = AsyncMock()
mock_http.post = AsyncMock(side_effect=Exception("Connection error"))
manager._http_client = mock_http
# 使用较短的重试间隔进行测试
with patch("src.functional_scaffold.core.job_manager.settings") as mock_settings:
mock_settings.webhook_max_retries = 2
mock_settings.webhook_timeout = 1
await manager._send_webhook("test-job-id", "https://example.com/callback")
# 验证重试次数
assert mock_http.post.call_count == 2
class TestConcurrencyControl:
"""测试并发控制功能"""
@pytest.mark.asyncio
async def test_get_concurrency_status(self):
"""测试获取并发状态"""
manager = JobManager()
# 初始化 semaphore
manager._max_concurrent_jobs = 10
manager._semaphore = asyncio.Semaphore(10)
status = manager.get_concurrency_status()
assert status["max_concurrent"] == 10
assert status["available_slots"] == 10
assert status["running_jobs"] == 0
@pytest.mark.asyncio
async def test_get_concurrency_status_without_semaphore(self):
"""测试未初始化 semaphore 时获取并发状态"""
manager = JobManager()
status = manager.get_concurrency_status()
assert status["max_concurrent"] == 0
assert status["available_slots"] == 0
assert status["running_jobs"] == 0
@pytest.mark.asyncio
async def test_concurrency_limit(self):
"""测试并发限制是否生效"""
manager = JobManager()
# 设置较小的并发限制
manager._max_concurrent_jobs = 2
manager._semaphore = asyncio.Semaphore(2)
# 模拟 Redis
mock_redis = AsyncMock()
mock_redis.hgetall = AsyncMock(
return_value={
"status": "pending",
"algorithm": "PrimeChecker",
"params": '{"number": 17}',
"webhook": "",
"request_id": "test-request-id",
"created_at": "2026-02-02T10:00:00+00:00",
}
)
mock_redis.hset = AsyncMock()
mock_redis.expire = AsyncMock()
manager._redis_client = mock_redis
manager._register_algorithms()
# 创建一个慢速任务
async def slow_execute():
async with manager._semaphore:
await asyncio.sleep(0.1)
# 启动 3 个任务
tasks = [asyncio.create_task(slow_execute()) for _ in range(3)]
# 等待一小段时间,让前两个任务获取 semaphore
await asyncio.sleep(0.01)
# 检查并发状态
status = manager.get_concurrency_status()
assert status["running_jobs"] == 2 # 只有 2 个任务在运行
assert status["available_slots"] == 0 # 没有可用槽位
# 等待所有任务完成
await asyncio.gather(*tasks)
# 检查最终状态
status = manager.get_concurrency_status()
assert status["running_jobs"] == 0
assert status["available_slots"] == 2
def test_concurrency_status_api(self, client):
"""测试并发状态 API 端点"""
response = client.get("/jobs/concurrency/status")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "max_concurrent" in data
assert "available_slots" in data
assert "running_jobs" in data
assert isinstance(data["max_concurrent"], int)
assert isinstance(data["available_slots"], int)
assert isinstance(data["running_jobs"], int)