From 87ed8c071cb5f16ff93a7a91b1fe02cf7af60a40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roog=20=28=E9=A1=BE=E6=96=B0=E5=9F=B9=29?= Date: Mon, 2 Feb 2026 17:11:52 +0800 Subject: [PATCH] =?UTF-8?q?main:=E6=96=B0=E5=A2=9E=E5=B9=B6=E5=8F=91?= =?UTF-8?q?=E6=8E=A7=E5=88=B6=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 变更内容: - 增加 `max_concurrent_jobs` 配置项,支持设置最大并发任务数。 - 为 `JobManager` 添加信号量控制实现任务并发限制。 - 新增获取任务并发状态的接口 `/jobs/concurrency/status`。 - 编写并发控制功能相关的测试用 --- src/functional_scaffold/api/models.py | 18 +++ src/functional_scaffold/api/routes.py | 55 ++++++++ src/functional_scaffold/config.py | 1 + src/functional_scaffold/core/job_manager.py | 147 +++++++++++++------- tests/test_job_manager.py | 98 +++++++++++++ 5 files changed, 265 insertions(+), 54 deletions(-) diff --git a/src/functional_scaffold/api/models.py b/src/functional_scaffold/api/models.py index faf89ab..6b07556 100644 --- a/src/functional_scaffold/api/models.py +++ b/src/functional_scaffold/api/models.py @@ -152,3 +152,21 @@ class JobStatusResponse(BaseModel): result: Optional[Dict[str, Any]] = Field(None, description="执行结果(仅完成时返回)") error: Optional[str] = Field(None, description="错误信息(仅失败时返回)") metadata: Optional[Dict[str, Any]] = Field(None, description="元数据信息") + + +class ConcurrencyStatusResponse(BaseModel): + """并发状态响应""" + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "max_concurrent": 10, + "available_slots": 7, + "running_jobs": 3, + } + } + ) + + max_concurrent: int = Field(..., description="最大并发任务数") + available_slots: int = Field(..., description="当前可用槽位数") + running_jobs: int = Field(..., description="当前运行中的任务数") diff --git a/src/functional_scaffold/api/routes.py b/src/functional_scaffold/api/routes.py index d10c762..980828c 100644 --- a/src/functional_scaffold/api/routes.py +++ b/src/functional_scaffold/api/routes.py @@ -15,6 +15,7 @@ from .models import ( JobCreateResponse, JobStatusResponse, JobStatus, + ConcurrencyStatusResponse, ) from .dependencies import get_request_id from ..algorithms.prime_checker import PrimeChecker @@ -292,3 +293,57 @@ async def get_job_status(job_id: str): "message": str(e), }, ) + + +@router.get( + "/jobs/concurrency/status", + response_model=ConcurrencyStatusResponse, + summary="查询并发状态", + description="查询任务管理器的并发执行状态", + responses={ + 200: {"description": "成功", "model": ConcurrencyStatusResponse}, + 503: {"description": "服务不可用", "model": ErrorResponse}, + }, +) +async def get_concurrency_status(): + """ + 查询并发状态 + + 返回当前任务管理器的并发执行状态,包括: + - 最大并发任务数 + - 当前可用槽位数 + - 当前运行中的任务数 + """ + try: + job_manager = await get_job_manager() + + # 检查任务管理器是否可用 + if not job_manager.is_available(): + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail={ + "error": "SERVICE_UNAVAILABLE", + "message": "任务管理器不可用", + }, + ) + + concurrency_status = job_manager.get_concurrency_status() + + return ConcurrencyStatusResponse( + max_concurrent=concurrency_status["max_concurrent"], + available_slots=concurrency_status["available_slots"], + running_jobs=concurrency_status["running_jobs"], + ) + + except HTTPException: + raise + + except Exception as e: + logger.error(f"查询并发状态失败: {str(e)}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "error": "INTERNAL_ERROR", + "message": str(e), + }, + ) diff --git a/src/functional_scaffold/config.py b/src/functional_scaffold/config.py index 5f89561..dc1ed5a 100644 --- a/src/functional_scaffold/config.py +++ b/src/functional_scaffold/config.py @@ -53,6 +53,7 @@ class Settings(BaseSettings): job_result_ttl: int = 1800 # 结果缓存时间(秒),默认 30 分钟 webhook_max_retries: int = 3 # Webhook 最大重试次数 webhook_timeout: int = 10 # Webhook 超时时间(秒) + max_concurrent_jobs: int = 10 # 最大并发任务数 # 全局配置实例 diff --git a/src/functional_scaffold/core/job_manager.py b/src/functional_scaffold/core/job_manager.py index 6fc89f8..fcb86fd 100644 --- a/src/functional_scaffold/core/job_manager.py +++ b/src/functional_scaffold/core/job_manager.py @@ -27,6 +27,8 @@ class JobManager: self._redis_client: Optional[aioredis.Redis] = None self._algorithm_registry: Dict[str, Type[BaseAlgorithm]] = {} self._http_client: Optional[httpx.AsyncClient] = None + self._semaphore: Optional[asyncio.Semaphore] = None + self._max_concurrent_jobs: int = 0 async def initialize(self) -> None: """初始化 Redis 连接和 HTTP 客户端""" @@ -51,6 +53,11 @@ class JobManager: # 初始化 HTTP 客户端 self._http_client = httpx.AsyncClient(timeout=settings.webhook_timeout) + # 初始化并发控制信号量 + self._max_concurrent_jobs = settings.max_concurrent_jobs + self._semaphore = asyncio.Semaphore(self._max_concurrent_jobs) + logger.info(f"任务并发限制已设置: {self._max_concurrent_jobs}") + # 注册算法 self._register_algorithms() @@ -203,6 +210,10 @@ class JobManager: logger.error(f"Redis 不可用,无法执行任务: {job_id}") return + if not self._semaphore: + logger.error(f"并发控制未初始化,无法执行任务: {job_id}") + return + key = f"job:{job_id}" job_data = await self._redis_client.hgetall(key) @@ -219,74 +230,76 @@ class JobManager: except json.JSONDecodeError: params = {} - # 更新状态为 running - started_at = self._get_timestamp() - await self._redis_client.hset(key, mapping={"status": "running", "started_at": started_at}) + # 使用信号量控制并发 + async with self._semaphore: + # 更新状态为 running + started_at = self._get_timestamp() + await self._redis_client.hset(key, mapping={"status": "running", "started_at": started_at}) - logger.info(f"开始执行任务: job_id={job_id}, algorithm={algorithm_name}") + logger.info(f"开始执行任务: job_id={job_id}, algorithm={algorithm_name}") - import time + import time - start_time = time.time() - status = "completed" - result_data = None - error_msg = None - metadata = None + start_time = time.time() + status = "completed" + result_data = None + error_msg = None + metadata = None - try: - # 获取算法类并执行 - algorithm_cls = self._algorithm_registry.get(algorithm_name) - if not algorithm_cls: - raise ValueError(f"算法 '{algorithm_name}' 不存在") + try: + # 获取算法类并执行 + algorithm_cls = self._algorithm_registry.get(algorithm_name) + if not algorithm_cls: + raise ValueError(f"算法 '{algorithm_name}' 不存在") - algorithm = algorithm_cls() + algorithm = algorithm_cls() - # 根据算法类型传递参数 - if algorithm_name == "PrimeChecker": - execution_result = algorithm.execute(params.get("number", 0)) - else: - # 通用参数传递 - execution_result = algorithm.execute(**params) + # 根据算法类型传递参数 + if algorithm_name == "PrimeChecker": + execution_result = algorithm.execute(params.get("number", 0)) + else: + # 通用参数传递 + execution_result = algorithm.execute(**params) - if execution_result.get("success"): - result_data = execution_result.get("result", {}) - metadata = execution_result.get("metadata", {}) - else: + if execution_result.get("success"): + result_data = execution_result.get("result", {}) + metadata = execution_result.get("metadata", {}) + else: + status = "failed" + error_msg = execution_result.get("error", "算法执行失败") + metadata = execution_result.get("metadata", {}) + + except Exception as e: status = "failed" - error_msg = execution_result.get("error", "算法执行失败") - metadata = execution_result.get("metadata", {}) + error_msg = str(e) + logger.error(f"任务执行失败: job_id={job_id}, error={e}", exc_info=True) - except Exception as e: - status = "failed" - error_msg = str(e) - logger.error(f"任务执行失败: job_id={job_id}, error={e}", exc_info=True) + # 计算执行时间 + elapsed_time = time.time() - start_time + completed_at = self._get_timestamp() - # 计算执行时间 - elapsed_time = time.time() - start_time - completed_at = self._get_timestamp() + # 更新任务状态 + update_data = { + "status": status, + "completed_at": completed_at, + "result": json.dumps(result_data) if result_data else "", + "error": error_msg or "", + "metadata": json.dumps(metadata) if metadata else "", + } + await self._redis_client.hset(key, mapping=update_data) - # 更新任务状态 - update_data = { - "status": status, - "completed_at": completed_at, - "result": json.dumps(result_data) if result_data else "", - "error": error_msg or "", - "metadata": json.dumps(metadata) if metadata else "", - } - await self._redis_client.hset(key, mapping=update_data) + # 设置 TTL + await self._redis_client.expire(key, settings.job_result_ttl) - # 设置 TTL - await self._redis_client.expire(key, settings.job_result_ttl) + # 记录指标 + incr("jobs_completed_total", {"algorithm": algorithm_name, "status": status}) + observe("job_execution_duration_seconds", {"algorithm": algorithm_name}, elapsed_time) - # 记录指标 - incr("jobs_completed_total", {"algorithm": algorithm_name, "status": status}) - observe("job_execution_duration_seconds", {"algorithm": algorithm_name}, elapsed_time) + logger.info(f"任务执行完成: job_id={job_id}, status={status}, elapsed={elapsed_time:.3f}s") - logger.info(f"任务执行完成: job_id={job_id}, status={status}, elapsed={elapsed_time:.3f}s") - - # 发送 Webhook 回调 - if webhook_url: - await self._send_webhook(job_id, webhook_url) + # 发送 Webhook 回调 + if webhook_url: + await self._send_webhook(job_id, webhook_url) async def _send_webhook(self, job_id: str, webhook_url: str) -> None: """发送 Webhook 回调(带重试) @@ -359,6 +372,32 @@ class JobManager: """检查任务管理器是否可用""" return self._redis_client is not None + def get_concurrency_status(self) -> Dict[str, int]: + """获取并发状态 + + Returns: + Dict[str, int]: 包含以下键的字典 + - max_concurrent: 最大并发数 + - available_slots: 可用槽位数 + - running_jobs: 当前运行中的任务数 + """ + if not self._semaphore: + return { + "max_concurrent": 0, + "available_slots": 0, + "running_jobs": 0, + } + + max_concurrent = self._max_concurrent_jobs + available_slots = self._semaphore._value + running_jobs = max_concurrent - available_slots + + return { + "max_concurrent": max_concurrent, + "available_slots": available_slots, + "running_jobs": running_jobs, + } + # 全局单例 _job_manager: Optional[JobManager] = None diff --git a/tests/test_job_manager.py b/tests/test_job_manager.py index e717691..2fadd7c 100644 --- a/tests/test_job_manager.py +++ b/tests/test_job_manager.py @@ -186,6 +186,10 @@ class TestJobManagerWithMocks: manager._redis_client = mock_redis manager._register_algorithms() + # 初始化 semaphore + import asyncio + manager._semaphore = asyncio.Semaphore(10) + await manager.execute_job("test-job-id") # 验证状态更新被调用 @@ -399,3 +403,97 @@ class TestWebhook: # 验证重试次数 assert mock_http.post.call_count == 2 + + +class TestConcurrencyControl: + """测试并发控制功能""" + + @pytest.mark.asyncio + async def test_get_concurrency_status(self): + """测试获取并发状态""" + manager = JobManager() + + # 初始化 semaphore + manager._max_concurrent_jobs = 10 + manager._semaphore = asyncio.Semaphore(10) + + status = manager.get_concurrency_status() + + assert status["max_concurrent"] == 10 + assert status["available_slots"] == 10 + assert status["running_jobs"] == 0 + + @pytest.mark.asyncio + async def test_get_concurrency_status_without_semaphore(self): + """测试未初始化 semaphore 时获取并发状态""" + manager = JobManager() + + status = manager.get_concurrency_status() + + assert status["max_concurrent"] == 0 + assert status["available_slots"] == 0 + assert status["running_jobs"] == 0 + + @pytest.mark.asyncio + async def test_concurrency_limit(self): + """测试并发限制是否生效""" + manager = JobManager() + + # 设置较小的并发限制 + manager._max_concurrent_jobs = 2 + manager._semaphore = asyncio.Semaphore(2) + + # 模拟 Redis + mock_redis = AsyncMock() + mock_redis.hgetall = AsyncMock( + return_value={ + "status": "pending", + "algorithm": "PrimeChecker", + "params": '{"number": 17}', + "webhook": "", + "request_id": "test-request-id", + "created_at": "2026-02-02T10:00:00+00:00", + } + ) + mock_redis.hset = AsyncMock() + mock_redis.expire = AsyncMock() + manager._redis_client = mock_redis + manager._register_algorithms() + + # 创建一个慢速任务 + async def slow_execute(): + async with manager._semaphore: + await asyncio.sleep(0.1) + + # 启动 3 个任务 + tasks = [asyncio.create_task(slow_execute()) for _ in range(3)] + + # 等待一小段时间,让前两个任务获取 semaphore + await asyncio.sleep(0.01) + + # 检查并发状态 + status = manager.get_concurrency_status() + assert status["running_jobs"] == 2 # 只有 2 个任务在运行 + assert status["available_slots"] == 0 # 没有可用槽位 + + # 等待所有任务完成 + await asyncio.gather(*tasks) + + # 检查最终状态 + status = manager.get_concurrency_status() + assert status["running_jobs"] == 0 + assert status["available_slots"] == 2 + + def test_concurrency_status_api(self, client): + """测试并发状态 API 端点""" + response = client.get("/jobs/concurrency/status") + + assert response.status_code == status.HTTP_200_OK + data = response.json() + + assert "max_concurrent" in data + assert "available_slots" in data + assert "running_jobs" in data + assert isinstance(data["max_concurrent"], int) + assert isinstance(data["available_slots"], int) + assert isinstance(data["running_jobs"], int)