main:新增并发控制功能

变更内容:
- 增加 `max_concurrent_jobs` 配置项,支持设置最大并发任务数。
- 为 `JobManager` 添加信号量控制实现任务并发限制。
- 新增获取任务并发状态的接口 `/jobs/concurrency/status`。
- 编写并发控制功能相关的测试用
This commit is contained in:
2026-02-02 17:11:52 +08:00
parent 57b276d038
commit 87ed8c071c
5 changed files with 265 additions and 54 deletions

View File

@@ -152,3 +152,21 @@ class JobStatusResponse(BaseModel):
result: Optional[Dict[str, Any]] = Field(None, description="执行结果(仅完成时返回)")
error: Optional[str] = Field(None, description="错误信息(仅失败时返回)")
metadata: Optional[Dict[str, Any]] = Field(None, description="元数据信息")
class ConcurrencyStatusResponse(BaseModel):
"""并发状态响应"""
model_config = ConfigDict(
json_schema_extra={
"example": {
"max_concurrent": 10,
"available_slots": 7,
"running_jobs": 3,
}
}
)
max_concurrent: int = Field(..., description="最大并发任务数")
available_slots: int = Field(..., description="当前可用槽位数")
running_jobs: int = Field(..., description="当前运行中的任务数")

View File

@@ -15,6 +15,7 @@ from .models import (
JobCreateResponse,
JobStatusResponse,
JobStatus,
ConcurrencyStatusResponse,
)
from .dependencies import get_request_id
from ..algorithms.prime_checker import PrimeChecker
@@ -292,3 +293,57 @@ async def get_job_status(job_id: str):
"message": str(e),
},
)
@router.get(
"/jobs/concurrency/status",
response_model=ConcurrencyStatusResponse,
summary="查询并发状态",
description="查询任务管理器的并发执行状态",
responses={
200: {"description": "成功", "model": ConcurrencyStatusResponse},
503: {"description": "服务不可用", "model": ErrorResponse},
},
)
async def get_concurrency_status():
"""
查询并发状态
返回当前任务管理器的并发执行状态,包括:
- 最大并发任务数
- 当前可用槽位数
- 当前运行中的任务数
"""
try:
job_manager = await get_job_manager()
# 检查任务管理器是否可用
if not job_manager.is_available():
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail={
"error": "SERVICE_UNAVAILABLE",
"message": "任务管理器不可用",
},
)
concurrency_status = job_manager.get_concurrency_status()
return ConcurrencyStatusResponse(
max_concurrent=concurrency_status["max_concurrent"],
available_slots=concurrency_status["available_slots"],
running_jobs=concurrency_status["running_jobs"],
)
except HTTPException:
raise
except Exception as e:
logger.error(f"查询并发状态失败: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"error": "INTERNAL_ERROR",
"message": str(e),
},
)

View File

@@ -53,6 +53,7 @@ class Settings(BaseSettings):
job_result_ttl: int = 1800 # 结果缓存时间(秒),默认 30 分钟
webhook_max_retries: int = 3 # Webhook 最大重试次数
webhook_timeout: int = 10 # Webhook 超时时间(秒)
max_concurrent_jobs: int = 10 # 最大并发任务数
# 全局配置实例

View File

@@ -27,6 +27,8 @@ class JobManager:
self._redis_client: Optional[aioredis.Redis] = None
self._algorithm_registry: Dict[str, Type[BaseAlgorithm]] = {}
self._http_client: Optional[httpx.AsyncClient] = None
self._semaphore: Optional[asyncio.Semaphore] = None
self._max_concurrent_jobs: int = 0
async def initialize(self) -> None:
"""初始化 Redis 连接和 HTTP 客户端"""
@@ -51,6 +53,11 @@ class JobManager:
# 初始化 HTTP 客户端
self._http_client = httpx.AsyncClient(timeout=settings.webhook_timeout)
# 初始化并发控制信号量
self._max_concurrent_jobs = settings.max_concurrent_jobs
self._semaphore = asyncio.Semaphore(self._max_concurrent_jobs)
logger.info(f"任务并发限制已设置: {self._max_concurrent_jobs}")
# 注册算法
self._register_algorithms()
@@ -203,6 +210,10 @@ class JobManager:
logger.error(f"Redis 不可用,无法执行任务: {job_id}")
return
if not self._semaphore:
logger.error(f"并发控制未初始化,无法执行任务: {job_id}")
return
key = f"job:{job_id}"
job_data = await self._redis_client.hgetall(key)
@@ -219,74 +230,76 @@ class JobManager:
except json.JSONDecodeError:
params = {}
# 更新状态为 running
started_at = self._get_timestamp()
await self._redis_client.hset(key, mapping={"status": "running", "started_at": started_at})
# 使用信号量控制并发
async with self._semaphore:
# 更新状态为 running
started_at = self._get_timestamp()
await self._redis_client.hset(key, mapping={"status": "running", "started_at": started_at})
logger.info(f"开始执行任务: job_id={job_id}, algorithm={algorithm_name}")
logger.info(f"开始执行任务: job_id={job_id}, algorithm={algorithm_name}")
import time
import time
start_time = time.time()
status = "completed"
result_data = None
error_msg = None
metadata = None
start_time = time.time()
status = "completed"
result_data = None
error_msg = None
metadata = None
try:
# 获取算法类并执行
algorithm_cls = self._algorithm_registry.get(algorithm_name)
if not algorithm_cls:
raise ValueError(f"算法 '{algorithm_name}' 不存在")
try:
# 获取算法类并执行
algorithm_cls = self._algorithm_registry.get(algorithm_name)
if not algorithm_cls:
raise ValueError(f"算法 '{algorithm_name}' 不存在")
algorithm = algorithm_cls()
algorithm = algorithm_cls()
# 根据算法类型传递参数
if algorithm_name == "PrimeChecker":
execution_result = algorithm.execute(params.get("number", 0))
else:
# 通用参数传递
execution_result = algorithm.execute(**params)
# 根据算法类型传递参数
if algorithm_name == "PrimeChecker":
execution_result = algorithm.execute(params.get("number", 0))
else:
# 通用参数传递
execution_result = algorithm.execute(**params)
if execution_result.get("success"):
result_data = execution_result.get("result", {})
metadata = execution_result.get("metadata", {})
else:
if execution_result.get("success"):
result_data = execution_result.get("result", {})
metadata = execution_result.get("metadata", {})
else:
status = "failed"
error_msg = execution_result.get("error", "算法执行失败")
metadata = execution_result.get("metadata", {})
except Exception as e:
status = "failed"
error_msg = execution_result.get("error", "算法执行失败")
metadata = execution_result.get("metadata", {})
error_msg = str(e)
logger.error(f"任务执行失败: job_id={job_id}, error={e}", exc_info=True)
except Exception as e:
status = "failed"
error_msg = str(e)
logger.error(f"任务执行失败: job_id={job_id}, error={e}", exc_info=True)
# 计算执行时间
elapsed_time = time.time() - start_time
completed_at = self._get_timestamp()
# 计算执行时间
elapsed_time = time.time() - start_time
completed_at = self._get_timestamp()
# 更新任务状态
update_data = {
"status": status,
"completed_at": completed_at,
"result": json.dumps(result_data) if result_data else "",
"error": error_msg or "",
"metadata": json.dumps(metadata) if metadata else "",
}
await self._redis_client.hset(key, mapping=update_data)
# 更新任务状态
update_data = {
"status": status,
"completed_at": completed_at,
"result": json.dumps(result_data) if result_data else "",
"error": error_msg or "",
"metadata": json.dumps(metadata) if metadata else "",
}
await self._redis_client.hset(key, mapping=update_data)
# 设置 TTL
await self._redis_client.expire(key, settings.job_result_ttl)
# 设置 TTL
await self._redis_client.expire(key, settings.job_result_ttl)
# 记录指标
incr("jobs_completed_total", {"algorithm": algorithm_name, "status": status})
observe("job_execution_duration_seconds", {"algorithm": algorithm_name}, elapsed_time)
# 记录指标
incr("jobs_completed_total", {"algorithm": algorithm_name, "status": status})
observe("job_execution_duration_seconds", {"algorithm": algorithm_name}, elapsed_time)
logger.info(f"任务执行完成: job_id={job_id}, status={status}, elapsed={elapsed_time:.3f}s")
logger.info(f"任务执行完成: job_id={job_id}, status={status}, elapsed={elapsed_time:.3f}s")
# 发送 Webhook 回调
if webhook_url:
await self._send_webhook(job_id, webhook_url)
# 发送 Webhook 回调
if webhook_url:
await self._send_webhook(job_id, webhook_url)
async def _send_webhook(self, job_id: str, webhook_url: str) -> None:
"""发送 Webhook 回调(带重试)
@@ -359,6 +372,32 @@ class JobManager:
"""检查任务管理器是否可用"""
return self._redis_client is not None
def get_concurrency_status(self) -> Dict[str, int]:
"""获取并发状态
Returns:
Dict[str, int]: 包含以下键的字典
- max_concurrent: 最大并发数
- available_slots: 可用槽位数
- running_jobs: 当前运行中的任务数
"""
if not self._semaphore:
return {
"max_concurrent": 0,
"available_slots": 0,
"running_jobs": 0,
}
max_concurrent = self._max_concurrent_jobs
available_slots = self._semaphore._value
running_jobs = max_concurrent - available_slots
return {
"max_concurrent": max_concurrent,
"available_slots": available_slots,
"running_jobs": running_jobs,
}
# 全局单例
_job_manager: Optional[JobManager] = None

View File

@@ -186,6 +186,10 @@ class TestJobManagerWithMocks:
manager._redis_client = mock_redis
manager._register_algorithms()
# 初始化 semaphore
import asyncio
manager._semaphore = asyncio.Semaphore(10)
await manager.execute_job("test-job-id")
# 验证状态更新被调用
@@ -399,3 +403,97 @@ class TestWebhook:
# 验证重试次数
assert mock_http.post.call_count == 2
class TestConcurrencyControl:
"""测试并发控制功能"""
@pytest.mark.asyncio
async def test_get_concurrency_status(self):
"""测试获取并发状态"""
manager = JobManager()
# 初始化 semaphore
manager._max_concurrent_jobs = 10
manager._semaphore = asyncio.Semaphore(10)
status = manager.get_concurrency_status()
assert status["max_concurrent"] == 10
assert status["available_slots"] == 10
assert status["running_jobs"] == 0
@pytest.mark.asyncio
async def test_get_concurrency_status_without_semaphore(self):
"""测试未初始化 semaphore 时获取并发状态"""
manager = JobManager()
status = manager.get_concurrency_status()
assert status["max_concurrent"] == 0
assert status["available_slots"] == 0
assert status["running_jobs"] == 0
@pytest.mark.asyncio
async def test_concurrency_limit(self):
"""测试并发限制是否生效"""
manager = JobManager()
# 设置较小的并发限制
manager._max_concurrent_jobs = 2
manager._semaphore = asyncio.Semaphore(2)
# 模拟 Redis
mock_redis = AsyncMock()
mock_redis.hgetall = AsyncMock(
return_value={
"status": "pending",
"algorithm": "PrimeChecker",
"params": '{"number": 17}',
"webhook": "",
"request_id": "test-request-id",
"created_at": "2026-02-02T10:00:00+00:00",
}
)
mock_redis.hset = AsyncMock()
mock_redis.expire = AsyncMock()
manager._redis_client = mock_redis
manager._register_algorithms()
# 创建一个慢速任务
async def slow_execute():
async with manager._semaphore:
await asyncio.sleep(0.1)
# 启动 3 个任务
tasks = [asyncio.create_task(slow_execute()) for _ in range(3)]
# 等待一小段时间,让前两个任务获取 semaphore
await asyncio.sleep(0.01)
# 检查并发状态
status = manager.get_concurrency_status()
assert status["running_jobs"] == 2 # 只有 2 个任务在运行
assert status["available_slots"] == 0 # 没有可用槽位
# 等待所有任务完成
await asyncio.gather(*tasks)
# 检查最终状态
status = manager.get_concurrency_status()
assert status["running_jobs"] == 0
assert status["available_slots"] == 2
def test_concurrency_status_api(self, client):
"""测试并发状态 API 端点"""
response = client.get("/jobs/concurrency/status")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "max_concurrent" in data
assert "available_slots" in data
assert "running_jobs" in data
assert isinstance(data["max_concurrent"], int)
assert isinstance(data["available_slots"], int)
assert isinstance(data["running_jobs"], int)