"""异步任务管理模块 基于 Redis 的异步任务管理,支持任务创建、执行、状态查询和 Webhook 回调。 """ import asyncio import json import logging import secrets import time from datetime import datetime, timezone from typing import Any, Dict, List, Optional, Type import httpx import redis.asyncio as aioredis from ..algorithms.base import BaseAlgorithm from ..config import settings from ..core.metrics_unified import incr, observe from ..core.tracing import set_request_id logger = logging.getLogger(__name__) class JobManager: """异步任务管理器""" # Lua 脚本:安全释放锁(验证 token) RELEASE_LOCK_SCRIPT = """ local current = redis.call('GET', KEYS[1]) if current == ARGV[1] then return redis.call('DEL', KEYS[1]) end return 0 """ # Lua 脚本:锁续租(验证 token 后延长 TTL) RENEW_LOCK_SCRIPT = """ local current = redis.call('GET', KEYS[1]) if current == ARGV[1] then return redis.call('EXPIRE', KEYS[1], ARGV[2]) end return 0 """ def __init__(self): self._redis_client: Optional[aioredis.Redis] = None self._algorithm_registry: Dict[str, Type[BaseAlgorithm]] = {} self._http_client: Optional[httpx.AsyncClient] = None self._semaphore: Optional[asyncio.Semaphore] = None self._max_concurrent_jobs: int = 0 async def initialize(self) -> None: """初始化 Redis 连接和 HTTP 客户端""" # 初始化 Redis 异步连接 try: self._redis_client = aioredis.Redis( host=settings.redis_host, port=settings.redis_port, db=settings.redis_db, password=settings.redis_password if settings.redis_password else None, decode_responses=True, socket_connect_timeout=5, socket_timeout=5, ) # 测试连接 await self._redis_client.ping() logger.info(f"任务管理器 Redis 连接成功: {settings.redis_host}:{settings.redis_port}") except Exception as e: logger.error(f"任务管理器 Redis 连接失败: {e}") self._redis_client = None # 初始化 HTTP 客户端 self._http_client = httpx.AsyncClient(timeout=settings.webhook_timeout) # 初始化并发控制信号量 self._max_concurrent_jobs = settings.max_concurrent_jobs self._semaphore = asyncio.Semaphore(self._max_concurrent_jobs) logger.info(f"任务并发限制已设置: {self._max_concurrent_jobs}") # 注册算法 self._register_algorithms() async def shutdown(self) -> None: """关闭连接""" if self._redis_client: await self._redis_client.close() logger.info("任务管理器 Redis 连接已关闭") if self._http_client: await self._http_client.aclose() logger.info("任务管理器 HTTP 客户端已关闭") def _register_algorithms(self) -> None: """注册可用的算法类""" from ..algorithms import __all__ as algorithm_names from .. import algorithms as algorithms_module for name in algorithm_names: cls = getattr(algorithms_module, name, None) if cls and isinstance(cls, type) and issubclass(cls, BaseAlgorithm): if cls is not BaseAlgorithm: self._algorithm_registry[name] = cls logger.debug(f"已注册算法: {name}") logger.info(f"已注册 {len(self._algorithm_registry)} 个算法") def get_available_algorithms(self) -> List[str]: """获取可用算法列表""" return list(self._algorithm_registry.keys()) def _generate_job_id(self) -> str: """生成 12 位十六进制任务 ID""" return secrets.token_hex(6) def _get_timestamp(self) -> str: """获取 ISO 8601 格式时间戳""" return datetime.now(timezone.utc).isoformat() async def create_job( self, algorithm: str, params: Dict[str, Any], webhook: Optional[str] = None, request_id: Optional[str] = None, ) -> str: """创建新任务,返回 job_id Args: algorithm: 算法名称 params: 算法参数 webhook: 回调 URL(可选) request_id: 关联的请求 ID(可选) Returns: str: 任务 ID Raises: RuntimeError: Redis 不可用时抛出 ValueError: 算法不存在时抛出 """ if not self._redis_client: raise RuntimeError("Redis 不可用,无法创建任务") if algorithm not in self._algorithm_registry: raise ValueError(f"算法 '{algorithm}' 不存在") job_id = self._generate_job_id() created_at = self._get_timestamp() # 构建任务数据 job_data = { "status": "pending", "algorithm": algorithm, "params": json.dumps(params), "webhook": webhook or "", "request_id": request_id or "", "created_at": created_at, "started_at": "", "completed_at": "", "result": "", "error": "", "metadata": "", } # 存储到 Redis key = f"job:{job_id}" await self._redis_client.hset(key, mapping=job_data) # 记录指标 await incr("jobs_created_total", {"algorithm": algorithm}) logger.info(f"任务已创建: job_id={job_id}, algorithm={algorithm}") return job_id async def get_job(self, job_id: str) -> Optional[Dict[str, Any]]: """获取任务信息 Args: job_id: 任务 ID Returns: 任务信息字典,不存在时返回 None """ if not self._redis_client: return None key = f"job:{job_id}" job_data = await self._redis_client.hgetall(key) if not job_data: return None # 解析 JSON 字段 result = { "job_id": job_id, "status": job_data.get("status", ""), "algorithm": job_data.get("algorithm", ""), "request_id": job_data.get("request_id") or None, "created_at": job_data.get("created_at", ""), "started_at": job_data.get("started_at") or None, "completed_at": job_data.get("completed_at") or None, "result": None, "error": job_data.get("error") or None, "metadata": None, } # 解析 result if job_data.get("result"): try: result["result"] = json.loads(job_data["result"]) except json.JSONDecodeError: result["result"] = None # 解析 metadata if job_data.get("metadata"): try: result["metadata"] = json.loads(job_data["metadata"]) except json.JSONDecodeError: result["metadata"] = None return result async def execute_job(self, job_id: str) -> None: """执行任务(在后台任务中调用) Args: job_id: 任务 ID """ if not self._redis_client: logger.error(f"Redis 不可用,无法执行任务: {job_id}") return if not self._semaphore: logger.error(f"并发控制未初始化,无法执行任务: {job_id}") return key = f"job:{job_id}" job_data = await self._redis_client.hgetall(key) if not job_data: logger.error(f"任务不存在: {job_id}") return algorithm_name = job_data.get("algorithm", "") webhook_url = job_data.get("webhook", "") request_id = job_data.get("request_id", "") # 设置 request_id 上下文,确保日志中包含 request_id if request_id: set_request_id(request_id) # 解析参数 try: params = json.loads(job_data.get("params", "{}")) except json.JSONDecodeError: params = {} # 使用信号量控制并发 async with self._semaphore: # 更新状态为 running started_at = self._get_timestamp() await self._redis_client.hset( key, mapping={"status": "running", "started_at": started_at} ) logger.info(f"开始执行任务: job_id={job_id}, algorithm={algorithm_name}") import time start_time = time.time() status = "completed" result_data = None error_msg = None metadata = None try: # 获取算法类并执行 algorithm_cls = self._algorithm_registry.get(algorithm_name) if not algorithm_cls: raise ValueError(f"算法 '{algorithm_name}' 不存在") algorithm = algorithm_cls() # 根据算法类型传递参数 if algorithm_name == "PrimeChecker": execution_result = algorithm.execute(params.get("number", 0)) else: # 通用参数传递 execution_result = algorithm.execute(**params) if execution_result.get("success"): result_data = execution_result.get("result", {}) metadata = execution_result.get("metadata", {}) else: status = "failed" error_msg = execution_result.get("error", "算法执行失败") metadata = execution_result.get("metadata", {}) except Exception as e: status = "failed" error_msg = str(e) logger.error(f"任务执行失败: job_id={job_id}, error={e}", exc_info=True) # 计算执行时间 elapsed_time = time.time() - start_time completed_at = self._get_timestamp() # 更新任务状态 update_data = { "status": status, "completed_at": completed_at, "result": json.dumps(result_data) if result_data else "", "error": error_msg or "", "metadata": json.dumps(metadata) if metadata else "", } await self._redis_client.hset(key, mapping=update_data) # 设置 TTL await self._redis_client.expire(key, settings.job_result_ttl) # 记录指标 await incr("jobs_completed_total", {"algorithm": algorithm_name, "status": status}) await observe( "job_execution_duration_seconds", {"algorithm": algorithm_name}, elapsed_time ) logger.info( f"任务执行完成: job_id={job_id}, status={status}, elapsed={elapsed_time:.3f}s" ) # 发送 Webhook 回调 if webhook_url: await self._send_webhook(job_id, webhook_url) async def _send_webhook(self, job_id: str, webhook_url: str) -> None: """发送 Webhook 回调(带重试) Args: job_id: 任务 ID webhook_url: 回调 URL """ if not self._http_client: logger.warning("HTTP 客户端不可用,无法发送 Webhook") return # 获取任务数据 job_data = await self.get_job(job_id) if not job_data: logger.error(f"无法获取任务数据用于 Webhook: {job_id}") return # 构建回调负载 payload = { "job_id": job_data["job_id"], "status": job_data["status"], "algorithm": job_data["algorithm"], "result": job_data["result"], "error": job_data["error"], "metadata": job_data["metadata"], "completed_at": job_data["completed_at"], } # 重试间隔(指数退避) retry_delays = [1, 5, 15] max_retries = settings.webhook_max_retries for attempt in range(max_retries): try: response = await self._http_client.post( webhook_url, json=payload, headers={"Content-Type": "application/json"}, ) if response.status_code < 400: await incr("webhook_deliveries_total", {"status": "success"}) logger.info( f"Webhook 发送成功: job_id={job_id}, url={webhook_url}, " f"status_code={response.status_code}" ) return else: logger.warning( f"Webhook 响应错误: job_id={job_id}, status_code={response.status_code}" ) except Exception as e: logger.warning( f"Webhook 发送失败 (尝试 {attempt + 1}/{max_retries}): " f"job_id={job_id}, error={e}" ) # 等待后重试 if attempt < max_retries - 1: delay = retry_delays[min(attempt, len(retry_delays) - 1)] await asyncio.sleep(delay) # 所有重试都失败 await incr("webhook_deliveries_total", {"status": "failed"}) logger.error(f"Webhook 发送最终失败: job_id={job_id}, url={webhook_url}") def is_available(self) -> bool: """检查任务管理器是否可用""" return self._redis_client is not None async def enqueue_job(self, job_id: str) -> bool: """将任务加入队列 Args: job_id: 任务 ID Returns: bool: 是否成功入队 """ if not self._redis_client: logger.error(f"Redis 不可用,无法入队任务: {job_id}") return False try: await self._redis_client.lpush(settings.job_queue_key, job_id) logger.info(f"任务已入队: job_id={job_id}") return True except Exception as e: logger.error(f"任务入队失败: job_id={job_id}, error={e}") return False async def dequeue_job(self, timeout: int = 5) -> Optional[str]: """从队列获取任务(阻塞式,转移式出队) 使用 BLMOVE 原子性地将任务从 job:queue 移动到 job:processing, 防止 Worker 崩溃时任务丢失。 Args: timeout: 阻塞超时时间(秒) Returns: Optional[str]: 任务 ID,超时返回 None """ if not self._redis_client: return None try: # 使用 BLMOVE 原子性转移任务 job_id = await self._redis_client.blmove( settings.job_queue_key, # 源: job:queue settings.job_processing_key, # 目标: job:processing timeout, "RIGHT", "LEFT", ) if job_id: # 记录出队时间戳到 ZSET await self._redis_client.zadd(settings.job_processing_ts_key, {job_id: time.time()}) logger.debug(f"任务已转移到处理队列: {job_id}") return job_id except Exception as e: logger.error(f"任务出队失败: error={e}") return None async def acquire_job_lock(self, job_id: str) -> Optional[str]: """获取任务执行锁(分布式锁,带 Token) Args: job_id: 任务 ID Returns: Optional[str]: 成功时返回锁 token,失败返回 None """ if not self._redis_client: return None lock_key = f"job:lock:{job_id}" lock_token = secrets.token_hex(16) # 随机 token lock_ttl = settings.job_execution_timeout + settings.job_lock_buffer try: acquired = await self._redis_client.set(lock_key, lock_token, nx=True, ex=lock_ttl) if acquired: logger.debug(f"获取任务锁成功: job_id={job_id}") return lock_token return None except Exception as e: logger.error(f"获取任务锁失败: job_id={job_id}, error={e}") return None async def release_job_lock(self, job_id: str, lock_token: Optional[str] = None) -> bool: """释放任务执行锁(使用 Lua 脚本验证 token) Args: job_id: 任务 ID lock_token: 锁 token(用于验证所有权) Returns: bool: 是否成功释放锁 """ if not self._redis_client: return False lock_key = f"job:lock:{job_id}" try: if lock_token: # 使用 Lua 脚本安全释放锁 result = await self._redis_client.eval( self.RELEASE_LOCK_SCRIPT, 1, lock_key, lock_token ) if result == 1: logger.debug(f"释放任务锁成功: job_id={job_id}") return True else: logger.warning(f"释放任务锁失败(token 不匹配): job_id={job_id}") return False else: # 向后兼容:无 token 时直接删除 await self._redis_client.delete(lock_key) logger.debug(f"释放任务锁成功(无 token 验证): job_id={job_id}") return True except Exception as e: logger.error(f"释放任务锁失败: job_id={job_id}, error={e}") return False async def increment_concurrency(self) -> int: """增加全局并发计数 Returns: int: 增加后的并发数 """ if not self._redis_client: return 0 try: count = await self._redis_client.incr(settings.job_concurrency_key) return count except Exception as e: logger.error(f"增加并发计数失败: error={e}") return 0 async def decrement_concurrency(self) -> int: """减少全局并发计数 Returns: int: 减少后的并发数 """ if not self._redis_client: return 0 try: count = await self._redis_client.decr(settings.job_concurrency_key) # 防止计数变为负数 if count < 0: await self._redis_client.set(settings.job_concurrency_key, 0) return 0 return count except Exception as e: logger.error(f"减少并发计数失败: error={e}") return 0 async def get_global_concurrency(self) -> int: """获取当前全局并发数 Returns: int: 当前并发数 """ if not self._redis_client: return 0 try: count = await self._redis_client.get(settings.job_concurrency_key) return int(count) if count else 0 except Exception as e: logger.error(f"获取并发计数失败: error={e}") return 0 async def can_execute(self) -> bool: """检查是否可以执行新任务(全局并发控制) Returns: bool: 是否可以执行 """ current = await self.get_global_concurrency() return current < settings.max_concurrent_jobs async def get_job_retry_count(self, job_id: str) -> int: """获取任务重试次数 Args: job_id: 任务 ID Returns: int: 重试次数 """ if not self._redis_client: return 0 key = f"job:{job_id}" try: retry_count = await self._redis_client.hget(key, "retry_count") return int(retry_count) if retry_count else 0 except Exception: return 0 async def increment_job_retry(self, job_id: str) -> int: """增加任务重试次数 Args: job_id: 任务 ID Returns: int: 增加后的重试次数 """ if not self._redis_client: return 0 key = f"job:{job_id}" try: await self._redis_client.hincrby(key, "retry_count", 1) retry_count = await self._redis_client.hget(key, "retry_count") return int(retry_count) if retry_count else 1 except Exception as e: logger.error(f"增加重试次数失败: job_id={job_id}, error={e}") return 0 async def ack_job(self, job_id: str) -> bool: """确认任务完成(从处理队列移除) Args: job_id: 任务 ID Returns: bool: 是否成功确认 """ if not self._redis_client: return False try: async with self._redis_client.pipeline(transaction=True) as pipe: pipe.lrem(settings.job_processing_key, 1, job_id) pipe.zrem(settings.job_processing_ts_key, job_id) await pipe.execute() logger.debug(f"任务已确认完成: job_id={job_id}") return True except Exception as e: logger.error(f"确认任务失败: job_id={job_id}, error={e}") return False async def nack_job(self, job_id: str, requeue: bool = True) -> bool: """拒绝任务(从处理队列移除,根据重试次数决定重新入队或进死信队列) Args: job_id: 任务 ID requeue: 是否尝试重新入队 Returns: bool: 是否成功处理 """ if not self._redis_client: return False try: retry_count = await self.get_job_retry_count(job_id) async with self._redis_client.pipeline(transaction=True) as pipe: pipe.lrem(settings.job_processing_key, 1, job_id) pipe.zrem(settings.job_processing_ts_key, job_id) if requeue and retry_count < settings.job_max_retries: pipe.lpush(settings.job_queue_key, job_id) logger.info(f"任务重新入队: job_id={job_id}, retry_count={retry_count}") else: pipe.lpush(settings.job_dlq_key, job_id) logger.warning(f"任务进入死信队列: job_id={job_id}, retry_count={retry_count}") await pipe.execute() return True except Exception as e: logger.error(f"拒绝任务失败: job_id={job_id}, error={e}") return False async def renew_job_lock(self, job_id: str, lock_token: str) -> bool: """续租任务锁(延长 TTL) Args: job_id: 任务 ID lock_token: 锁 token Returns: bool: 是否成功续租 """ if not self._redis_client: return False lock_key = f"job:lock:{job_id}" lock_ttl = settings.job_execution_timeout + settings.job_lock_buffer try: result = await self._redis_client.eval( self.RENEW_LOCK_SCRIPT, 1, lock_key, lock_token, lock_ttl ) if result == 1: logger.debug(f"锁续租成功: job_id={job_id}") return True else: logger.warning(f"锁续租失败(token 不匹配或锁已过期): job_id={job_id}") return False except Exception as e: logger.error(f"锁续租失败: job_id={job_id}, error={e}") return False async def recover_stale_jobs(self) -> int: """回收超时任务 扫描 job:processing:ts ZSET,找出超时的任务, 根据重试次数决定重新入队或进死信队列。 Returns: int: 回收的任务数量 """ if not self._redis_client: return 0 timeout = settings.job_execution_timeout + settings.job_lock_buffer cutoff = time.time() - timeout try: # 获取超时任务列表 stale_jobs = await self._redis_client.zrangebyscore( settings.job_processing_ts_key, "-inf", cutoff ) recovered = 0 for job_id in stale_jobs: # 增加重试次数 await self.increment_job_retry(job_id) retry_count = await self.get_job_retry_count(job_id) async with self._redis_client.pipeline(transaction=True) as pipe: pipe.lrem(settings.job_processing_key, 1, job_id) pipe.zrem(settings.job_processing_ts_key, job_id) if retry_count < settings.job_max_retries: pipe.lpush(settings.job_queue_key, job_id) logger.info(f"超时任务重新入队: job_id={job_id}, retry_count={retry_count}") else: pipe.lpush(settings.job_dlq_key, job_id) logger.warning( f"超时任务进入死信队列: job_id={job_id}, retry_count={retry_count}" ) await pipe.execute() recovered += 1 if recovered > 0: logger.info(f"回收超时任务完成: 共 {recovered} 个") return recovered except Exception as e: logger.error(f"回收超时任务失败: error={e}") return 0 def get_concurrency_status(self) -> Dict[str, int]: """获取并发状态 Returns: Dict[str, int]: 包含以下键的字典 - max_concurrent: 最大并发数 - available_slots: 可用槽位数 - running_jobs: 当前运行中的任务数 """ if not self._semaphore: return { "max_concurrent": 0, "available_slots": 0, "running_jobs": 0, } max_concurrent = self._max_concurrent_jobs available_slots = self._semaphore._value running_jobs = max_concurrent - available_slots return { "max_concurrent": max_concurrent, "available_slots": available_slots, "running_jobs": running_jobs, } async def collect_queue_metrics(self) -> Dict[str, Any]: """收集队列监控指标 Returns: Dict[str, Any]: 包含以下键的字典 - queue_length: 待处理队列长度 - processing_length: 处理中队列长度 - dlq_length: 死信队列长度 - oldest_waiting_seconds: 最长等待时间(秒) """ if not self._redis_client: return { "queue_length": 0, "processing_length": 0, "dlq_length": 0, "oldest_waiting_seconds": 0, } try: # 使用 pipeline 批量获取队列长度 async with self._redis_client.pipeline(transaction=False) as pipe: pipe.llen(settings.job_queue_key) pipe.llen(settings.job_processing_key) pipe.llen(settings.job_dlq_key) pipe.zrange(settings.job_processing_ts_key, 0, 0, withscores=True) results = await pipe.execute() queue_length = results[0] or 0 processing_length = results[1] or 0 dlq_length = results[2] or 0 # 计算最长等待时间 oldest_waiting_seconds = 0 if results[3]: # results[3] 是 [(job_id, timestamp), ...] 格式 oldest_ts = results[3][0][1] oldest_waiting_seconds = time.time() - oldest_ts # 更新指标 from .metrics_unified import set as metrics_set await metrics_set("job_queue_length", {"queue": "pending"}, queue_length) await metrics_set("job_queue_length", {"queue": "processing"}, processing_length) await metrics_set("job_queue_length", {"queue": "dlq"}, dlq_length) await metrics_set("job_oldest_waiting_seconds", None, oldest_waiting_seconds) return { "queue_length": queue_length, "processing_length": processing_length, "dlq_length": dlq_length, "oldest_waiting_seconds": oldest_waiting_seconds, } except Exception as e: logger.error(f"收集队列指标失败: error={e}") return { "queue_length": 0, "processing_length": 0, "dlq_length": 0, "oldest_waiting_seconds": 0, } # 全局单例 _job_manager: Optional[JobManager] = None async def get_job_manager() -> JobManager: """获取任务管理器单例""" global _job_manager if _job_manager is None: _job_manager = JobManager() await _job_manager.initialize() return _job_manager async def shutdown_job_manager() -> None: """关闭任务管理器""" global _job_manager if _job_manager is not None: await _job_manager.shutdown() _job_manager = None