main:新增并发控制功能
变更内容: - 增加 `max_concurrent_jobs` 配置项,支持设置最大并发任务数。 - 为 `JobManager` 添加信号量控制实现任务并发限制。 - 新增获取任务并发状态的接口 `/jobs/concurrency/status`。 - 编写并发控制功能相关的测试用
This commit is contained in:
@@ -152,3 +152,21 @@ class JobStatusResponse(BaseModel):
|
|||||||
result: Optional[Dict[str, Any]] = Field(None, description="执行结果(仅完成时返回)")
|
result: Optional[Dict[str, Any]] = Field(None, description="执行结果(仅完成时返回)")
|
||||||
error: Optional[str] = Field(None, description="错误信息(仅失败时返回)")
|
error: Optional[str] = Field(None, description="错误信息(仅失败时返回)")
|
||||||
metadata: Optional[Dict[str, Any]] = Field(None, description="元数据信息")
|
metadata: Optional[Dict[str, Any]] = Field(None, description="元数据信息")
|
||||||
|
|
||||||
|
|
||||||
|
class ConcurrencyStatusResponse(BaseModel):
|
||||||
|
"""并发状态响应"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
json_schema_extra={
|
||||||
|
"example": {
|
||||||
|
"max_concurrent": 10,
|
||||||
|
"available_slots": 7,
|
||||||
|
"running_jobs": 3,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
max_concurrent: int = Field(..., description="最大并发任务数")
|
||||||
|
available_slots: int = Field(..., description="当前可用槽位数")
|
||||||
|
running_jobs: int = Field(..., description="当前运行中的任务数")
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from .models import (
|
|||||||
JobCreateResponse,
|
JobCreateResponse,
|
||||||
JobStatusResponse,
|
JobStatusResponse,
|
||||||
JobStatus,
|
JobStatus,
|
||||||
|
ConcurrencyStatusResponse,
|
||||||
)
|
)
|
||||||
from .dependencies import get_request_id
|
from .dependencies import get_request_id
|
||||||
from ..algorithms.prime_checker import PrimeChecker
|
from ..algorithms.prime_checker import PrimeChecker
|
||||||
@@ -292,3 +293,57 @@ async def get_job_status(job_id: str):
|
|||||||
"message": str(e),
|
"message": str(e),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/jobs/concurrency/status",
|
||||||
|
response_model=ConcurrencyStatusResponse,
|
||||||
|
summary="查询并发状态",
|
||||||
|
description="查询任务管理器的并发执行状态",
|
||||||
|
responses={
|
||||||
|
200: {"description": "成功", "model": ConcurrencyStatusResponse},
|
||||||
|
503: {"description": "服务不可用", "model": ErrorResponse},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def get_concurrency_status():
|
||||||
|
"""
|
||||||
|
查询并发状态
|
||||||
|
|
||||||
|
返回当前任务管理器的并发执行状态,包括:
|
||||||
|
- 最大并发任务数
|
||||||
|
- 当前可用槽位数
|
||||||
|
- 当前运行中的任务数
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
job_manager = await get_job_manager()
|
||||||
|
|
||||||
|
# 检查任务管理器是否可用
|
||||||
|
if not job_manager.is_available():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail={
|
||||||
|
"error": "SERVICE_UNAVAILABLE",
|
||||||
|
"message": "任务管理器不可用",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
concurrency_status = job_manager.get_concurrency_status()
|
||||||
|
|
||||||
|
return ConcurrencyStatusResponse(
|
||||||
|
max_concurrent=concurrency_status["max_concurrent"],
|
||||||
|
available_slots=concurrency_status["available_slots"],
|
||||||
|
running_jobs=concurrency_status["running_jobs"],
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"查询并发状态失败: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail={
|
||||||
|
"error": "INTERNAL_ERROR",
|
||||||
|
"message": str(e),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ class Settings(BaseSettings):
|
|||||||
job_result_ttl: int = 1800 # 结果缓存时间(秒),默认 30 分钟
|
job_result_ttl: int = 1800 # 结果缓存时间(秒),默认 30 分钟
|
||||||
webhook_max_retries: int = 3 # Webhook 最大重试次数
|
webhook_max_retries: int = 3 # Webhook 最大重试次数
|
||||||
webhook_timeout: int = 10 # Webhook 超时时间(秒)
|
webhook_timeout: int = 10 # Webhook 超时时间(秒)
|
||||||
|
max_concurrent_jobs: int = 10 # 最大并发任务数
|
||||||
|
|
||||||
|
|
||||||
# 全局配置实例
|
# 全局配置实例
|
||||||
|
|||||||
@@ -27,6 +27,8 @@ class JobManager:
|
|||||||
self._redis_client: Optional[aioredis.Redis] = None
|
self._redis_client: Optional[aioredis.Redis] = None
|
||||||
self._algorithm_registry: Dict[str, Type[BaseAlgorithm]] = {}
|
self._algorithm_registry: Dict[str, Type[BaseAlgorithm]] = {}
|
||||||
self._http_client: Optional[httpx.AsyncClient] = None
|
self._http_client: Optional[httpx.AsyncClient] = None
|
||||||
|
self._semaphore: Optional[asyncio.Semaphore] = None
|
||||||
|
self._max_concurrent_jobs: int = 0
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
"""初始化 Redis 连接和 HTTP 客户端"""
|
"""初始化 Redis 连接和 HTTP 客户端"""
|
||||||
@@ -51,6 +53,11 @@ class JobManager:
|
|||||||
# 初始化 HTTP 客户端
|
# 初始化 HTTP 客户端
|
||||||
self._http_client = httpx.AsyncClient(timeout=settings.webhook_timeout)
|
self._http_client = httpx.AsyncClient(timeout=settings.webhook_timeout)
|
||||||
|
|
||||||
|
# 初始化并发控制信号量
|
||||||
|
self._max_concurrent_jobs = settings.max_concurrent_jobs
|
||||||
|
self._semaphore = asyncio.Semaphore(self._max_concurrent_jobs)
|
||||||
|
logger.info(f"任务并发限制已设置: {self._max_concurrent_jobs}")
|
||||||
|
|
||||||
# 注册算法
|
# 注册算法
|
||||||
self._register_algorithms()
|
self._register_algorithms()
|
||||||
|
|
||||||
@@ -203,6 +210,10 @@ class JobManager:
|
|||||||
logger.error(f"Redis 不可用,无法执行任务: {job_id}")
|
logger.error(f"Redis 不可用,无法执行任务: {job_id}")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if not self._semaphore:
|
||||||
|
logger.error(f"并发控制未初始化,无法执行任务: {job_id}")
|
||||||
|
return
|
||||||
|
|
||||||
key = f"job:{job_id}"
|
key = f"job:{job_id}"
|
||||||
job_data = await self._redis_client.hgetall(key)
|
job_data = await self._redis_client.hgetall(key)
|
||||||
|
|
||||||
@@ -219,6 +230,8 @@ class JobManager:
|
|||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
params = {}
|
params = {}
|
||||||
|
|
||||||
|
# 使用信号量控制并发
|
||||||
|
async with self._semaphore:
|
||||||
# 更新状态为 running
|
# 更新状态为 running
|
||||||
started_at = self._get_timestamp()
|
started_at = self._get_timestamp()
|
||||||
await self._redis_client.hset(key, mapping={"status": "running", "started_at": started_at})
|
await self._redis_client.hset(key, mapping={"status": "running", "started_at": started_at})
|
||||||
@@ -359,6 +372,32 @@ class JobManager:
|
|||||||
"""检查任务管理器是否可用"""
|
"""检查任务管理器是否可用"""
|
||||||
return self._redis_client is not None
|
return self._redis_client is not None
|
||||||
|
|
||||||
|
def get_concurrency_status(self) -> Dict[str, int]:
|
||||||
|
"""获取并发状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, int]: 包含以下键的字典
|
||||||
|
- max_concurrent: 最大并发数
|
||||||
|
- available_slots: 可用槽位数
|
||||||
|
- running_jobs: 当前运行中的任务数
|
||||||
|
"""
|
||||||
|
if not self._semaphore:
|
||||||
|
return {
|
||||||
|
"max_concurrent": 0,
|
||||||
|
"available_slots": 0,
|
||||||
|
"running_jobs": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
max_concurrent = self._max_concurrent_jobs
|
||||||
|
available_slots = self._semaphore._value
|
||||||
|
running_jobs = max_concurrent - available_slots
|
||||||
|
|
||||||
|
return {
|
||||||
|
"max_concurrent": max_concurrent,
|
||||||
|
"available_slots": available_slots,
|
||||||
|
"running_jobs": running_jobs,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# 全局单例
|
# 全局单例
|
||||||
_job_manager: Optional[JobManager] = None
|
_job_manager: Optional[JobManager] = None
|
||||||
|
|||||||
@@ -186,6 +186,10 @@ class TestJobManagerWithMocks:
|
|||||||
manager._redis_client = mock_redis
|
manager._redis_client = mock_redis
|
||||||
manager._register_algorithms()
|
manager._register_algorithms()
|
||||||
|
|
||||||
|
# 初始化 semaphore
|
||||||
|
import asyncio
|
||||||
|
manager._semaphore = asyncio.Semaphore(10)
|
||||||
|
|
||||||
await manager.execute_job("test-job-id")
|
await manager.execute_job("test-job-id")
|
||||||
|
|
||||||
# 验证状态更新被调用
|
# 验证状态更新被调用
|
||||||
@@ -399,3 +403,97 @@ class TestWebhook:
|
|||||||
|
|
||||||
# 验证重试次数
|
# 验证重试次数
|
||||||
assert mock_http.post.call_count == 2
|
assert mock_http.post.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestConcurrencyControl:
|
||||||
|
"""测试并发控制功能"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_concurrency_status(self):
|
||||||
|
"""测试获取并发状态"""
|
||||||
|
manager = JobManager()
|
||||||
|
|
||||||
|
# 初始化 semaphore
|
||||||
|
manager._max_concurrent_jobs = 10
|
||||||
|
manager._semaphore = asyncio.Semaphore(10)
|
||||||
|
|
||||||
|
status = manager.get_concurrency_status()
|
||||||
|
|
||||||
|
assert status["max_concurrent"] == 10
|
||||||
|
assert status["available_slots"] == 10
|
||||||
|
assert status["running_jobs"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_concurrency_status_without_semaphore(self):
|
||||||
|
"""测试未初始化 semaphore 时获取并发状态"""
|
||||||
|
manager = JobManager()
|
||||||
|
|
||||||
|
status = manager.get_concurrency_status()
|
||||||
|
|
||||||
|
assert status["max_concurrent"] == 0
|
||||||
|
assert status["available_slots"] == 0
|
||||||
|
assert status["running_jobs"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrency_limit(self):
|
||||||
|
"""测试并发限制是否生效"""
|
||||||
|
manager = JobManager()
|
||||||
|
|
||||||
|
# 设置较小的并发限制
|
||||||
|
manager._max_concurrent_jobs = 2
|
||||||
|
manager._semaphore = asyncio.Semaphore(2)
|
||||||
|
|
||||||
|
# 模拟 Redis
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.hgetall = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"status": "pending",
|
||||||
|
"algorithm": "PrimeChecker",
|
||||||
|
"params": '{"number": 17}',
|
||||||
|
"webhook": "",
|
||||||
|
"request_id": "test-request-id",
|
||||||
|
"created_at": "2026-02-02T10:00:00+00:00",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
mock_redis.hset = AsyncMock()
|
||||||
|
mock_redis.expire = AsyncMock()
|
||||||
|
manager._redis_client = mock_redis
|
||||||
|
manager._register_algorithms()
|
||||||
|
|
||||||
|
# 创建一个慢速任务
|
||||||
|
async def slow_execute():
|
||||||
|
async with manager._semaphore:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
# 启动 3 个任务
|
||||||
|
tasks = [asyncio.create_task(slow_execute()) for _ in range(3)]
|
||||||
|
|
||||||
|
# 等待一小段时间,让前两个任务获取 semaphore
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
|
# 检查并发状态
|
||||||
|
status = manager.get_concurrency_status()
|
||||||
|
assert status["running_jobs"] == 2 # 只有 2 个任务在运行
|
||||||
|
assert status["available_slots"] == 0 # 没有可用槽位
|
||||||
|
|
||||||
|
# 等待所有任务完成
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# 检查最终状态
|
||||||
|
status = manager.get_concurrency_status()
|
||||||
|
assert status["running_jobs"] == 0
|
||||||
|
assert status["available_slots"] == 2
|
||||||
|
|
||||||
|
def test_concurrency_status_api(self, client):
|
||||||
|
"""测试并发状态 API 端点"""
|
||||||
|
response = client.get("/jobs/concurrency/status")
|
||||||
|
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert "max_concurrent" in data
|
||||||
|
assert "available_slots" in data
|
||||||
|
assert "running_jobs" in data
|
||||||
|
assert isinstance(data["max_concurrent"], int)
|
||||||
|
assert isinstance(data["available_slots"], int)
|
||||||
|
assert isinstance(data["running_jobs"], int)
|
||||||
|
|||||||
Reference in New Issue
Block a user