变更内容: - 在 `Dockerfile` 和 `docker-compose.yml` 中添加 Worker 模式支持,包含运行模式 `RUN_MODE` 的配置。 - 更新 API 路由,改为将任务入队处理,并由 Worker 执行。 - 在 JobManager 中新增任务队列及分布式锁功能,支持任务的入队、出队、执行控制以及重试机制。 - 添加全局并发控制逻辑,避免任务超额运行。 - 扩展单元测试,覆盖任务队列、锁机制和并发控制的各类场景。 - 在 Serverless 配置中分别为 API 和 Worker 添加独立服务定义。 提升任务调度灵活性,增强系统可靠性与扩展性。
610 lines
19 KiB
Python
610 lines
19 KiB
Python
"""异步任务管理模块
|
||
|
||
基于 Redis 的异步任务管理,支持任务创建、执行、状态查询和 Webhook 回调。
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import secrets
|
||
from datetime import datetime, timezone
|
||
from typing import Any, Dict, List, Optional, Type
|
||
|
||
import httpx
|
||
import redis.asyncio as aioredis
|
||
|
||
from ..algorithms.base import BaseAlgorithm
|
||
from ..config import settings
|
||
from ..core.metrics_unified import incr, observe
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class JobManager:
|
||
"""异步任务管理器"""
|
||
|
||
def __init__(self):
|
||
self._redis_client: Optional[aioredis.Redis] = None
|
||
self._algorithm_registry: Dict[str, Type[BaseAlgorithm]] = {}
|
||
self._http_client: Optional[httpx.AsyncClient] = None
|
||
self._semaphore: Optional[asyncio.Semaphore] = None
|
||
self._max_concurrent_jobs: int = 0
|
||
|
||
async def initialize(self) -> None:
|
||
"""初始化 Redis 连接和 HTTP 客户端"""
|
||
# 初始化 Redis 异步连接
|
||
try:
|
||
self._redis_client = aioredis.Redis(
|
||
host=settings.redis_host,
|
||
port=settings.redis_port,
|
||
db=settings.redis_db,
|
||
password=settings.redis_password if settings.redis_password else None,
|
||
decode_responses=True,
|
||
socket_connect_timeout=5,
|
||
socket_timeout=5,
|
||
)
|
||
# 测试连接
|
||
await self._redis_client.ping()
|
||
logger.info(f"任务管理器 Redis 连接成功: {settings.redis_host}:{settings.redis_port}")
|
||
except Exception as e:
|
||
logger.error(f"任务管理器 Redis 连接失败: {e}")
|
||
self._redis_client = None
|
||
|
||
# 初始化 HTTP 客户端
|
||
self._http_client = httpx.AsyncClient(timeout=settings.webhook_timeout)
|
||
|
||
# 初始化并发控制信号量
|
||
self._max_concurrent_jobs = settings.max_concurrent_jobs
|
||
self._semaphore = asyncio.Semaphore(self._max_concurrent_jobs)
|
||
logger.info(f"任务并发限制已设置: {self._max_concurrent_jobs}")
|
||
|
||
# 注册算法
|
||
self._register_algorithms()
|
||
|
||
async def shutdown(self) -> None:
|
||
"""关闭连接"""
|
||
if self._redis_client:
|
||
await self._redis_client.close()
|
||
logger.info("任务管理器 Redis 连接已关闭")
|
||
|
||
if self._http_client:
|
||
await self._http_client.aclose()
|
||
logger.info("任务管理器 HTTP 客户端已关闭")
|
||
|
||
def _register_algorithms(self) -> None:
|
||
"""注册可用的算法类"""
|
||
from ..algorithms import __all__ as algorithm_names
|
||
from .. import algorithms as algorithms_module
|
||
|
||
for name in algorithm_names:
|
||
cls = getattr(algorithms_module, name, None)
|
||
if cls and isinstance(cls, type) and issubclass(cls, BaseAlgorithm):
|
||
if cls is not BaseAlgorithm:
|
||
self._algorithm_registry[name] = cls
|
||
logger.debug(f"已注册算法: {name}")
|
||
|
||
logger.info(f"已注册 {len(self._algorithm_registry)} 个算法")
|
||
|
||
def get_available_algorithms(self) -> List[str]:
|
||
"""获取可用算法列表"""
|
||
return list(self._algorithm_registry.keys())
|
||
|
||
def _generate_job_id(self) -> str:
|
||
"""生成 12 位十六进制任务 ID"""
|
||
return secrets.token_hex(6)
|
||
|
||
def _get_timestamp(self) -> str:
|
||
"""获取 ISO 8601 格式时间戳"""
|
||
return datetime.now(timezone.utc).isoformat()
|
||
|
||
async def create_job(
|
||
self,
|
||
algorithm: str,
|
||
params: Dict[str, Any],
|
||
webhook: Optional[str] = None,
|
||
request_id: Optional[str] = None,
|
||
) -> str:
|
||
"""创建新任务,返回 job_id
|
||
|
||
Args:
|
||
algorithm: 算法名称
|
||
params: 算法参数
|
||
webhook: 回调 URL(可选)
|
||
request_id: 关联的请求 ID(可选)
|
||
|
||
Returns:
|
||
str: 任务 ID
|
||
|
||
Raises:
|
||
RuntimeError: Redis 不可用时抛出
|
||
ValueError: 算法不存在时抛出
|
||
"""
|
||
if not self._redis_client:
|
||
raise RuntimeError("Redis 不可用,无法创建任务")
|
||
|
||
if algorithm not in self._algorithm_registry:
|
||
raise ValueError(f"算法 '{algorithm}' 不存在")
|
||
|
||
job_id = self._generate_job_id()
|
||
created_at = self._get_timestamp()
|
||
|
||
# 构建任务数据
|
||
job_data = {
|
||
"status": "pending",
|
||
"algorithm": algorithm,
|
||
"params": json.dumps(params),
|
||
"webhook": webhook or "",
|
||
"request_id": request_id or "",
|
||
"created_at": created_at,
|
||
"started_at": "",
|
||
"completed_at": "",
|
||
"result": "",
|
||
"error": "",
|
||
"metadata": "",
|
||
}
|
||
|
||
# 存储到 Redis
|
||
key = f"job:{job_id}"
|
||
await self._redis_client.hset(key, mapping=job_data)
|
||
|
||
# 记录指标
|
||
incr("jobs_created_total", {"algorithm": algorithm})
|
||
|
||
logger.info(f"任务已创建: job_id={job_id}, algorithm={algorithm}")
|
||
return job_id
|
||
|
||
async def get_job(self, job_id: str) -> Optional[Dict[str, Any]]:
|
||
"""获取任务信息
|
||
|
||
Args:
|
||
job_id: 任务 ID
|
||
|
||
Returns:
|
||
任务信息字典,不存在时返回 None
|
||
"""
|
||
if not self._redis_client:
|
||
return None
|
||
|
||
key = f"job:{job_id}"
|
||
job_data = await self._redis_client.hgetall(key)
|
||
|
||
if not job_data:
|
||
return None
|
||
|
||
# 解析 JSON 字段
|
||
result = {
|
||
"job_id": job_id,
|
||
"status": job_data.get("status", ""),
|
||
"algorithm": job_data.get("algorithm", ""),
|
||
"created_at": job_data.get("created_at", ""),
|
||
"started_at": job_data.get("started_at") or None,
|
||
"completed_at": job_data.get("completed_at") or None,
|
||
"result": None,
|
||
"error": job_data.get("error") or None,
|
||
"metadata": None,
|
||
}
|
||
|
||
# 解析 result
|
||
if job_data.get("result"):
|
||
try:
|
||
result["result"] = json.loads(job_data["result"])
|
||
except json.JSONDecodeError:
|
||
result["result"] = None
|
||
|
||
# 解析 metadata
|
||
if job_data.get("metadata"):
|
||
try:
|
||
result["metadata"] = json.loads(job_data["metadata"])
|
||
except json.JSONDecodeError:
|
||
result["metadata"] = None
|
||
|
||
return result
|
||
|
||
async def execute_job(self, job_id: str) -> None:
|
||
"""执行任务(在后台任务中调用)
|
||
|
||
Args:
|
||
job_id: 任务 ID
|
||
"""
|
||
if not self._redis_client:
|
||
logger.error(f"Redis 不可用,无法执行任务: {job_id}")
|
||
return
|
||
|
||
if not self._semaphore:
|
||
logger.error(f"并发控制未初始化,无法执行任务: {job_id}")
|
||
return
|
||
|
||
key = f"job:{job_id}"
|
||
job_data = await self._redis_client.hgetall(key)
|
||
|
||
if not job_data:
|
||
logger.error(f"任务不存在: {job_id}")
|
||
return
|
||
|
||
algorithm_name = job_data.get("algorithm", "")
|
||
webhook_url = job_data.get("webhook", "")
|
||
|
||
# 解析参数
|
||
try:
|
||
params = json.loads(job_data.get("params", "{}"))
|
||
except json.JSONDecodeError:
|
||
params = {}
|
||
|
||
# 使用信号量控制并发
|
||
async with self._semaphore:
|
||
# 更新状态为 running
|
||
started_at = self._get_timestamp()
|
||
await self._redis_client.hset(key, mapping={"status": "running", "started_at": started_at})
|
||
|
||
logger.info(f"开始执行任务: job_id={job_id}, algorithm={algorithm_name}")
|
||
|
||
import time
|
||
|
||
start_time = time.time()
|
||
status = "completed"
|
||
result_data = None
|
||
error_msg = None
|
||
metadata = None
|
||
|
||
try:
|
||
# 获取算法类并执行
|
||
algorithm_cls = self._algorithm_registry.get(algorithm_name)
|
||
if not algorithm_cls:
|
||
raise ValueError(f"算法 '{algorithm_name}' 不存在")
|
||
|
||
algorithm = algorithm_cls()
|
||
|
||
# 根据算法类型传递参数
|
||
if algorithm_name == "PrimeChecker":
|
||
execution_result = algorithm.execute(params.get("number", 0))
|
||
else:
|
||
# 通用参数传递
|
||
execution_result = algorithm.execute(**params)
|
||
|
||
if execution_result.get("success"):
|
||
result_data = execution_result.get("result", {})
|
||
metadata = execution_result.get("metadata", {})
|
||
else:
|
||
status = "failed"
|
||
error_msg = execution_result.get("error", "算法执行失败")
|
||
metadata = execution_result.get("metadata", {})
|
||
|
||
except Exception as e:
|
||
status = "failed"
|
||
error_msg = str(e)
|
||
logger.error(f"任务执行失败: job_id={job_id}, error={e}", exc_info=True)
|
||
|
||
# 计算执行时间
|
||
elapsed_time = time.time() - start_time
|
||
completed_at = self._get_timestamp()
|
||
|
||
# 更新任务状态
|
||
update_data = {
|
||
"status": status,
|
||
"completed_at": completed_at,
|
||
"result": json.dumps(result_data) if result_data else "",
|
||
"error": error_msg or "",
|
||
"metadata": json.dumps(metadata) if metadata else "",
|
||
}
|
||
await self._redis_client.hset(key, mapping=update_data)
|
||
|
||
# 设置 TTL
|
||
await self._redis_client.expire(key, settings.job_result_ttl)
|
||
|
||
# 记录指标
|
||
incr("jobs_completed_total", {"algorithm": algorithm_name, "status": status})
|
||
observe("job_execution_duration_seconds", {"algorithm": algorithm_name}, elapsed_time)
|
||
|
||
logger.info(f"任务执行完成: job_id={job_id}, status={status}, elapsed={elapsed_time:.3f}s")
|
||
|
||
# 发送 Webhook 回调
|
||
if webhook_url:
|
||
await self._send_webhook(job_id, webhook_url)
|
||
|
||
async def _send_webhook(self, job_id: str, webhook_url: str) -> None:
|
||
"""发送 Webhook 回调(带重试)
|
||
|
||
Args:
|
||
job_id: 任务 ID
|
||
webhook_url: 回调 URL
|
||
"""
|
||
if not self._http_client:
|
||
logger.warning("HTTP 客户端不可用,无法发送 Webhook")
|
||
return
|
||
|
||
# 获取任务数据
|
||
job_data = await self.get_job(job_id)
|
||
if not job_data:
|
||
logger.error(f"无法获取任务数据用于 Webhook: {job_id}")
|
||
return
|
||
|
||
# 构建回调负载
|
||
payload = {
|
||
"job_id": job_data["job_id"],
|
||
"status": job_data["status"],
|
||
"algorithm": job_data["algorithm"],
|
||
"result": job_data["result"],
|
||
"error": job_data["error"],
|
||
"metadata": job_data["metadata"],
|
||
"completed_at": job_data["completed_at"],
|
||
}
|
||
|
||
# 重试间隔(指数退避)
|
||
retry_delays = [1, 5, 15]
|
||
max_retries = settings.webhook_max_retries
|
||
|
||
for attempt in range(max_retries):
|
||
try:
|
||
response = await self._http_client.post(
|
||
webhook_url,
|
||
json=payload,
|
||
headers={"Content-Type": "application/json"},
|
||
)
|
||
|
||
if response.status_code < 400:
|
||
incr("webhook_deliveries_total", {"status": "success"})
|
||
logger.info(
|
||
f"Webhook 发送成功: job_id={job_id}, url={webhook_url}, "
|
||
f"status_code={response.status_code}"
|
||
)
|
||
return
|
||
else:
|
||
logger.warning(
|
||
f"Webhook 响应错误: job_id={job_id}, status_code={response.status_code}"
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.warning(
|
||
f"Webhook 发送失败 (尝试 {attempt + 1}/{max_retries}): "
|
||
f"job_id={job_id}, error={e}"
|
||
)
|
||
|
||
# 等待后重试
|
||
if attempt < max_retries - 1:
|
||
delay = retry_delays[min(attempt, len(retry_delays) - 1)]
|
||
await asyncio.sleep(delay)
|
||
|
||
# 所有重试都失败
|
||
incr("webhook_deliveries_total", {"status": "failed"})
|
||
logger.error(f"Webhook 发送最终失败: job_id={job_id}, url={webhook_url}")
|
||
|
||
def is_available(self) -> bool:
|
||
"""检查任务管理器是否可用"""
|
||
return self._redis_client is not None
|
||
|
||
async def enqueue_job(self, job_id: str) -> bool:
|
||
"""将任务加入队列
|
||
|
||
Args:
|
||
job_id: 任务 ID
|
||
|
||
Returns:
|
||
bool: 是否成功入队
|
||
"""
|
||
if not self._redis_client:
|
||
logger.error(f"Redis 不可用,无法入队任务: {job_id}")
|
||
return False
|
||
|
||
try:
|
||
await self._redis_client.lpush(settings.job_queue_key, job_id)
|
||
logger.info(f"任务已入队: job_id={job_id}")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"任务入队失败: job_id={job_id}, error={e}")
|
||
return False
|
||
|
||
async def dequeue_job(self, timeout: int = 5) -> Optional[str]:
|
||
"""从队列获取任务(阻塞式)
|
||
|
||
Args:
|
||
timeout: 阻塞超时时间(秒)
|
||
|
||
Returns:
|
||
Optional[str]: 任务 ID,超时返回 None
|
||
"""
|
||
if not self._redis_client:
|
||
return None
|
||
|
||
try:
|
||
result = await self._redis_client.brpop(settings.job_queue_key, timeout=timeout)
|
||
if result:
|
||
# brpop 返回 (key, value) 元组
|
||
return result[1]
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"任务出队失败: error={e}")
|
||
return None
|
||
|
||
async def acquire_job_lock(self, job_id: str) -> bool:
|
||
"""获取任务执行锁(分布式锁)
|
||
|
||
Args:
|
||
job_id: 任务 ID
|
||
|
||
Returns:
|
||
bool: 是否成功获取锁
|
||
"""
|
||
if not self._redis_client:
|
||
return False
|
||
|
||
lock_key = f"job:lock:{job_id}"
|
||
try:
|
||
acquired = await self._redis_client.set(
|
||
lock_key, "locked", nx=True, ex=settings.job_lock_ttl
|
||
)
|
||
if acquired:
|
||
logger.debug(f"获取任务锁成功: job_id={job_id}")
|
||
return acquired is not None
|
||
except Exception as e:
|
||
logger.error(f"获取任务锁失败: job_id={job_id}, error={e}")
|
||
return False
|
||
|
||
async def release_job_lock(self, job_id: str) -> bool:
|
||
"""释放任务执行锁
|
||
|
||
Args:
|
||
job_id: 任务 ID
|
||
|
||
Returns:
|
||
bool: 是否成功释放锁
|
||
"""
|
||
if not self._redis_client:
|
||
return False
|
||
|
||
lock_key = f"job:lock:{job_id}"
|
||
try:
|
||
await self._redis_client.delete(lock_key)
|
||
logger.debug(f"释放任务锁成功: job_id={job_id}")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"释放任务锁失败: job_id={job_id}, error={e}")
|
||
return False
|
||
|
||
async def increment_concurrency(self) -> int:
|
||
"""增加全局并发计数
|
||
|
||
Returns:
|
||
int: 增加后的并发数
|
||
"""
|
||
if not self._redis_client:
|
||
return 0
|
||
|
||
try:
|
||
count = await self._redis_client.incr(settings.job_concurrency_key)
|
||
return count
|
||
except Exception as e:
|
||
logger.error(f"增加并发计数失败: error={e}")
|
||
return 0
|
||
|
||
async def decrement_concurrency(self) -> int:
|
||
"""减少全局并发计数
|
||
|
||
Returns:
|
||
int: 减少后的并发数
|
||
"""
|
||
if not self._redis_client:
|
||
return 0
|
||
|
||
try:
|
||
count = await self._redis_client.decr(settings.job_concurrency_key)
|
||
# 防止计数变为负数
|
||
if count < 0:
|
||
await self._redis_client.set(settings.job_concurrency_key, 0)
|
||
return 0
|
||
return count
|
||
except Exception as e:
|
||
logger.error(f"减少并发计数失败: error={e}")
|
||
return 0
|
||
|
||
async def get_global_concurrency(self) -> int:
|
||
"""获取当前全局并发数
|
||
|
||
Returns:
|
||
int: 当前并发数
|
||
"""
|
||
if not self._redis_client:
|
||
return 0
|
||
|
||
try:
|
||
count = await self._redis_client.get(settings.job_concurrency_key)
|
||
return int(count) if count else 0
|
||
except Exception as e:
|
||
logger.error(f"获取并发计数失败: error={e}")
|
||
return 0
|
||
|
||
async def can_execute(self) -> bool:
|
||
"""检查是否可以执行新任务(全局并发控制)
|
||
|
||
Returns:
|
||
bool: 是否可以执行
|
||
"""
|
||
current = await self.get_global_concurrency()
|
||
return current < settings.max_concurrent_jobs
|
||
|
||
async def get_job_retry_count(self, job_id: str) -> int:
|
||
"""获取任务重试次数
|
||
|
||
Args:
|
||
job_id: 任务 ID
|
||
|
||
Returns:
|
||
int: 重试次数
|
||
"""
|
||
if not self._redis_client:
|
||
return 0
|
||
|
||
key = f"job:{job_id}"
|
||
try:
|
||
retry_count = await self._redis_client.hget(key, "retry_count")
|
||
return int(retry_count) if retry_count else 0
|
||
except Exception:
|
||
return 0
|
||
|
||
async def increment_job_retry(self, job_id: str) -> int:
|
||
"""增加任务重试次数
|
||
|
||
Args:
|
||
job_id: 任务 ID
|
||
|
||
Returns:
|
||
int: 增加后的重试次数
|
||
"""
|
||
if not self._redis_client:
|
||
return 0
|
||
|
||
key = f"job:{job_id}"
|
||
try:
|
||
await self._redis_client.hincrby(key, "retry_count", 1)
|
||
retry_count = await self._redis_client.hget(key, "retry_count")
|
||
return int(retry_count) if retry_count else 1
|
||
except Exception as e:
|
||
logger.error(f"增加重试次数失败: job_id={job_id}, error={e}")
|
||
return 0
|
||
|
||
def get_concurrency_status(self) -> Dict[str, int]:
|
||
"""获取并发状态
|
||
|
||
Returns:
|
||
Dict[str, int]: 包含以下键的字典
|
||
- max_concurrent: 最大并发数
|
||
- available_slots: 可用槽位数
|
||
- running_jobs: 当前运行中的任务数
|
||
"""
|
||
if not self._semaphore:
|
||
return {
|
||
"max_concurrent": 0,
|
||
"available_slots": 0,
|
||
"running_jobs": 0,
|
||
}
|
||
|
||
max_concurrent = self._max_concurrent_jobs
|
||
available_slots = self._semaphore._value
|
||
running_jobs = max_concurrent - available_slots
|
||
|
||
return {
|
||
"max_concurrent": max_concurrent,
|
||
"available_slots": available_slots,
|
||
"running_jobs": running_jobs,
|
||
}
|
||
|
||
|
||
# 全局单例
|
||
_job_manager: Optional[JobManager] = None
|
||
|
||
|
||
async def get_job_manager() -> JobManager:
|
||
"""获取任务管理器单例"""
|
||
global _job_manager
|
||
if _job_manager is None:
|
||
_job_manager = JobManager()
|
||
await _job_manager.initialize()
|
||
return _job_manager
|
||
|
||
|
||
async def shutdown_job_manager() -> None:
|
||
"""关闭任务管理器"""
|
||
global _job_manager
|
||
if _job_manager is not None:
|
||
await _job_manager.shutdown()
|
||
_job_manager = None
|