"""异步任务管理模块 基于 Redis 的异步任务管理,支持任务创建、执行、状态查询和 Webhook 回调。 """ import asyncio import json import logging import secrets from datetime import datetime, timezone from typing import Any, Dict, List, Optional, Type import httpx import redis.asyncio as aioredis from ..algorithms.base import BaseAlgorithm from ..config import settings from ..core.metrics_unified import incr, observe logger = logging.getLogger(__name__) class JobManager: """异步任务管理器""" def __init__(self): self._redis_client: Optional[aioredis.Redis] = None self._algorithm_registry: Dict[str, Type[BaseAlgorithm]] = {} self._http_client: Optional[httpx.AsyncClient] = None self._semaphore: Optional[asyncio.Semaphore] = None self._max_concurrent_jobs: int = 0 async def initialize(self) -> None: """初始化 Redis 连接和 HTTP 客户端""" # 初始化 Redis 异步连接 try: self._redis_client = aioredis.Redis( host=settings.redis_host, port=settings.redis_port, db=settings.redis_db, password=settings.redis_password if settings.redis_password else None, decode_responses=True, socket_connect_timeout=5, socket_timeout=5, ) # 测试连接 await self._redis_client.ping() logger.info(f"任务管理器 Redis 连接成功: {settings.redis_host}:{settings.redis_port}") except Exception as e: logger.error(f"任务管理器 Redis 连接失败: {e}") self._redis_client = None # 初始化 HTTP 客户端 self._http_client = httpx.AsyncClient(timeout=settings.webhook_timeout) # 初始化并发控制信号量 self._max_concurrent_jobs = settings.max_concurrent_jobs self._semaphore = asyncio.Semaphore(self._max_concurrent_jobs) logger.info(f"任务并发限制已设置: {self._max_concurrent_jobs}") # 注册算法 self._register_algorithms() async def shutdown(self) -> None: """关闭连接""" if self._redis_client: await self._redis_client.close() logger.info("任务管理器 Redis 连接已关闭") if self._http_client: await self._http_client.aclose() logger.info("任务管理器 HTTP 客户端已关闭") def _register_algorithms(self) -> None: """注册可用的算法类""" from ..algorithms import __all__ as algorithm_names from .. import algorithms as algorithms_module for name in algorithm_names: cls = getattr(algorithms_module, name, None) if cls and isinstance(cls, type) and issubclass(cls, BaseAlgorithm): if cls is not BaseAlgorithm: self._algorithm_registry[name] = cls logger.debug(f"已注册算法: {name}") logger.info(f"已注册 {len(self._algorithm_registry)} 个算法") def get_available_algorithms(self) -> List[str]: """获取可用算法列表""" return list(self._algorithm_registry.keys()) def _generate_job_id(self) -> str: """生成 12 位十六进制任务 ID""" return secrets.token_hex(6) def _get_timestamp(self) -> str: """获取 ISO 8601 格式时间戳""" return datetime.now(timezone.utc).isoformat() async def create_job( self, algorithm: str, params: Dict[str, Any], webhook: Optional[str] = None, request_id: Optional[str] = None, ) -> str: """创建新任务,返回 job_id Args: algorithm: 算法名称 params: 算法参数 webhook: 回调 URL(可选) request_id: 关联的请求 ID(可选) Returns: str: 任务 ID Raises: RuntimeError: Redis 不可用时抛出 ValueError: 算法不存在时抛出 """ if not self._redis_client: raise RuntimeError("Redis 不可用,无法创建任务") if algorithm not in self._algorithm_registry: raise ValueError(f"算法 '{algorithm}' 不存在") job_id = self._generate_job_id() created_at = self._get_timestamp() # 构建任务数据 job_data = { "status": "pending", "algorithm": algorithm, "params": json.dumps(params), "webhook": webhook or "", "request_id": request_id or "", "created_at": created_at, "started_at": "", "completed_at": "", "result": "", "error": "", "metadata": "", } # 存储到 Redis key = f"job:{job_id}" await self._redis_client.hset(key, mapping=job_data) # 记录指标 incr("jobs_created_total", {"algorithm": algorithm}) logger.info(f"任务已创建: job_id={job_id}, algorithm={algorithm}") return job_id async def get_job(self, job_id: str) -> Optional[Dict[str, Any]]: """获取任务信息 Args: job_id: 任务 ID Returns: 任务信息字典,不存在时返回 None """ if not self._redis_client: return None key = f"job:{job_id}" job_data = await self._redis_client.hgetall(key) if not job_data: return None # 解析 JSON 字段 result = { "job_id": job_id, "status": job_data.get("status", ""), "algorithm": job_data.get("algorithm", ""), "created_at": job_data.get("created_at", ""), "started_at": job_data.get("started_at") or None, "completed_at": job_data.get("completed_at") or None, "result": None, "error": job_data.get("error") or None, "metadata": None, } # 解析 result if job_data.get("result"): try: result["result"] = json.loads(job_data["result"]) except json.JSONDecodeError: result["result"] = None # 解析 metadata if job_data.get("metadata"): try: result["metadata"] = json.loads(job_data["metadata"]) except json.JSONDecodeError: result["metadata"] = None return result async def execute_job(self, job_id: str) -> None: """执行任务(在后台任务中调用) Args: job_id: 任务 ID """ if not self._redis_client: logger.error(f"Redis 不可用,无法执行任务: {job_id}") return if not self._semaphore: logger.error(f"并发控制未初始化,无法执行任务: {job_id}") return key = f"job:{job_id}" job_data = await self._redis_client.hgetall(key) if not job_data: logger.error(f"任务不存在: {job_id}") return algorithm_name = job_data.get("algorithm", "") webhook_url = job_data.get("webhook", "") # 解析参数 try: params = json.loads(job_data.get("params", "{}")) except json.JSONDecodeError: params = {} # 使用信号量控制并发 async with self._semaphore: # 更新状态为 running started_at = self._get_timestamp() await self._redis_client.hset(key, mapping={"status": "running", "started_at": started_at}) logger.info(f"开始执行任务: job_id={job_id}, algorithm={algorithm_name}") import time start_time = time.time() status = "completed" result_data = None error_msg = None metadata = None try: # 获取算法类并执行 algorithm_cls = self._algorithm_registry.get(algorithm_name) if not algorithm_cls: raise ValueError(f"算法 '{algorithm_name}' 不存在") algorithm = algorithm_cls() # 根据算法类型传递参数 if algorithm_name == "PrimeChecker": execution_result = algorithm.execute(params.get("number", 0)) else: # 通用参数传递 execution_result = algorithm.execute(**params) if execution_result.get("success"): result_data = execution_result.get("result", {}) metadata = execution_result.get("metadata", {}) else: status = "failed" error_msg = execution_result.get("error", "算法执行失败") metadata = execution_result.get("metadata", {}) except Exception as e: status = "failed" error_msg = str(e) logger.error(f"任务执行失败: job_id={job_id}, error={e}", exc_info=True) # 计算执行时间 elapsed_time = time.time() - start_time completed_at = self._get_timestamp() # 更新任务状态 update_data = { "status": status, "completed_at": completed_at, "result": json.dumps(result_data) if result_data else "", "error": error_msg or "", "metadata": json.dumps(metadata) if metadata else "", } await self._redis_client.hset(key, mapping=update_data) # 设置 TTL await self._redis_client.expire(key, settings.job_result_ttl) # 记录指标 incr("jobs_completed_total", {"algorithm": algorithm_name, "status": status}) observe("job_execution_duration_seconds", {"algorithm": algorithm_name}, elapsed_time) logger.info(f"任务执行完成: job_id={job_id}, status={status}, elapsed={elapsed_time:.3f}s") # 发送 Webhook 回调 if webhook_url: await self._send_webhook(job_id, webhook_url) async def _send_webhook(self, job_id: str, webhook_url: str) -> None: """发送 Webhook 回调(带重试) Args: job_id: 任务 ID webhook_url: 回调 URL """ if not self._http_client: logger.warning("HTTP 客户端不可用,无法发送 Webhook") return # 获取任务数据 job_data = await self.get_job(job_id) if not job_data: logger.error(f"无法获取任务数据用于 Webhook: {job_id}") return # 构建回调负载 payload = { "job_id": job_data["job_id"], "status": job_data["status"], "algorithm": job_data["algorithm"], "result": job_data["result"], "error": job_data["error"], "metadata": job_data["metadata"], "completed_at": job_data["completed_at"], } # 重试间隔(指数退避) retry_delays = [1, 5, 15] max_retries = settings.webhook_max_retries for attempt in range(max_retries): try: response = await self._http_client.post( webhook_url, json=payload, headers={"Content-Type": "application/json"}, ) if response.status_code < 400: incr("webhook_deliveries_total", {"status": "success"}) logger.info( f"Webhook 发送成功: job_id={job_id}, url={webhook_url}, " f"status_code={response.status_code}" ) return else: logger.warning( f"Webhook 响应错误: job_id={job_id}, status_code={response.status_code}" ) except Exception as e: logger.warning( f"Webhook 发送失败 (尝试 {attempt + 1}/{max_retries}): " f"job_id={job_id}, error={e}" ) # 等待后重试 if attempt < max_retries - 1: delay = retry_delays[min(attempt, len(retry_delays) - 1)] await asyncio.sleep(delay) # 所有重试都失败 incr("webhook_deliveries_total", {"status": "failed"}) logger.error(f"Webhook 发送最终失败: job_id={job_id}, url={webhook_url}") def is_available(self) -> bool: """检查任务管理器是否可用""" return self._redis_client is not None async def enqueue_job(self, job_id: str) -> bool: """将任务加入队列 Args: job_id: 任务 ID Returns: bool: 是否成功入队 """ if not self._redis_client: logger.error(f"Redis 不可用,无法入队任务: {job_id}") return False try: await self._redis_client.lpush(settings.job_queue_key, job_id) logger.info(f"任务已入队: job_id={job_id}") return True except Exception as e: logger.error(f"任务入队失败: job_id={job_id}, error={e}") return False async def dequeue_job(self, timeout: int = 5) -> Optional[str]: """从队列获取任务(阻塞式) Args: timeout: 阻塞超时时间(秒) Returns: Optional[str]: 任务 ID,超时返回 None """ if not self._redis_client: return None try: result = await self._redis_client.brpop(settings.job_queue_key, timeout=timeout) if result: # brpop 返回 (key, value) 元组 return result[1] return None except Exception as e: logger.error(f"任务出队失败: error={e}") return None async def acquire_job_lock(self, job_id: str) -> bool: """获取任务执行锁(分布式锁) Args: job_id: 任务 ID Returns: bool: 是否成功获取锁 """ if not self._redis_client: return False lock_key = f"job:lock:{job_id}" try: acquired = await self._redis_client.set( lock_key, "locked", nx=True, ex=settings.job_lock_ttl ) if acquired: logger.debug(f"获取任务锁成功: job_id={job_id}") return acquired is not None except Exception as e: logger.error(f"获取任务锁失败: job_id={job_id}, error={e}") return False async def release_job_lock(self, job_id: str) -> bool: """释放任务执行锁 Args: job_id: 任务 ID Returns: bool: 是否成功释放锁 """ if not self._redis_client: return False lock_key = f"job:lock:{job_id}" try: await self._redis_client.delete(lock_key) logger.debug(f"释放任务锁成功: job_id={job_id}") return True except Exception as e: logger.error(f"释放任务锁失败: job_id={job_id}, error={e}") return False async def increment_concurrency(self) -> int: """增加全局并发计数 Returns: int: 增加后的并发数 """ if not self._redis_client: return 0 try: count = await self._redis_client.incr(settings.job_concurrency_key) return count except Exception as e: logger.error(f"增加并发计数失败: error={e}") return 0 async def decrement_concurrency(self) -> int: """减少全局并发计数 Returns: int: 减少后的并发数 """ if not self._redis_client: return 0 try: count = await self._redis_client.decr(settings.job_concurrency_key) # 防止计数变为负数 if count < 0: await self._redis_client.set(settings.job_concurrency_key, 0) return 0 return count except Exception as e: logger.error(f"减少并发计数失败: error={e}") return 0 async def get_global_concurrency(self) -> int: """获取当前全局并发数 Returns: int: 当前并发数 """ if not self._redis_client: return 0 try: count = await self._redis_client.get(settings.job_concurrency_key) return int(count) if count else 0 except Exception as e: logger.error(f"获取并发计数失败: error={e}") return 0 async def can_execute(self) -> bool: """检查是否可以执行新任务(全局并发控制) Returns: bool: 是否可以执行 """ current = await self.get_global_concurrency() return current < settings.max_concurrent_jobs async def get_job_retry_count(self, job_id: str) -> int: """获取任务重试次数 Args: job_id: 任务 ID Returns: int: 重试次数 """ if not self._redis_client: return 0 key = f"job:{job_id}" try: retry_count = await self._redis_client.hget(key, "retry_count") return int(retry_count) if retry_count else 0 except Exception: return 0 async def increment_job_retry(self, job_id: str) -> int: """增加任务重试次数 Args: job_id: 任务 ID Returns: int: 增加后的重试次数 """ if not self._redis_client: return 0 key = f"job:{job_id}" try: await self._redis_client.hincrby(key, "retry_count", 1) retry_count = await self._redis_client.hget(key, "retry_count") return int(retry_count) if retry_count else 1 except Exception as e: logger.error(f"增加重试次数失败: job_id={job_id}, error={e}") return 0 def get_concurrency_status(self) -> Dict[str, int]: """获取并发状态 Returns: Dict[str, int]: 包含以下键的字典 - max_concurrent: 最大并发数 - available_slots: 可用槽位数 - running_jobs: 当前运行中的任务数 """ if not self._semaphore: return { "max_concurrent": 0, "available_slots": 0, "running_jobs": 0, } max_concurrent = self._max_concurrent_jobs available_slots = self._semaphore._value running_jobs = max_concurrent - available_slots return { "max_concurrent": max_concurrent, "available_slots": available_slots, "running_jobs": running_jobs, } # 全局单例 _job_manager: Optional[JobManager] = None async def get_job_manager() -> JobManager: """获取任务管理器单例""" global _job_manager if _job_manager is None: _job_manager = JobManager() await _job_manager.initialize() return _job_manager async def shutdown_job_manager() -> None: """关闭任务管理器""" global _job_manager if _job_manager is not None: await _job_manager.shutdown() _job_manager = None