Files
FunctionalScaffold/src/functional_scaffold/core/job_manager.py
Roog (顾新培) 265e8d1e3d main:支持 Worker 模式运行并优化任务管理
变更内容:
- 在 `Dockerfile` 和 `docker-compose.yml` 中添加 Worker 模式支持,包含运行模式 `RUN_MODE` 的配置。
- 更新 API 路由,改为将任务入队处理,并由 Worker 执行。
- 在 JobManager 中新增任务队列及分布式锁功能,支持任务的入队、出队、执行控制以及重试机制。
- 添加全局并发控制逻辑,避免任务超额运行。
- 扩展单元测试,覆盖任务队列、锁机制和并发控制的各类场景。
- 在 Serverless 配置中分别为 API 和 Worker 添加独立服务定义。

提升任务调度灵活性,增强系统可靠性与扩展性。
2026-02-03 18:38:08 +08:00

610 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""异步任务管理模块
基于 Redis 的异步任务管理,支持任务创建、执行、状态查询和 Webhook 回调。
"""
import asyncio
import json
import logging
import secrets
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Type
import httpx
import redis.asyncio as aioredis
from ..algorithms.base import BaseAlgorithm
from ..config import settings
from ..core.metrics_unified import incr, observe
logger = logging.getLogger(__name__)
class JobManager:
"""异步任务管理器"""
def __init__(self):
self._redis_client: Optional[aioredis.Redis] = None
self._algorithm_registry: Dict[str, Type[BaseAlgorithm]] = {}
self._http_client: Optional[httpx.AsyncClient] = None
self._semaphore: Optional[asyncio.Semaphore] = None
self._max_concurrent_jobs: int = 0
async def initialize(self) -> None:
"""初始化 Redis 连接和 HTTP 客户端"""
# 初始化 Redis 异步连接
try:
self._redis_client = aioredis.Redis(
host=settings.redis_host,
port=settings.redis_port,
db=settings.redis_db,
password=settings.redis_password if settings.redis_password else None,
decode_responses=True,
socket_connect_timeout=5,
socket_timeout=5,
)
# 测试连接
await self._redis_client.ping()
logger.info(f"任务管理器 Redis 连接成功: {settings.redis_host}:{settings.redis_port}")
except Exception as e:
logger.error(f"任务管理器 Redis 连接失败: {e}")
self._redis_client = None
# 初始化 HTTP 客户端
self._http_client = httpx.AsyncClient(timeout=settings.webhook_timeout)
# 初始化并发控制信号量
self._max_concurrent_jobs = settings.max_concurrent_jobs
self._semaphore = asyncio.Semaphore(self._max_concurrent_jobs)
logger.info(f"任务并发限制已设置: {self._max_concurrent_jobs}")
# 注册算法
self._register_algorithms()
async def shutdown(self) -> None:
"""关闭连接"""
if self._redis_client:
await self._redis_client.close()
logger.info("任务管理器 Redis 连接已关闭")
if self._http_client:
await self._http_client.aclose()
logger.info("任务管理器 HTTP 客户端已关闭")
def _register_algorithms(self) -> None:
"""注册可用的算法类"""
from ..algorithms import __all__ as algorithm_names
from .. import algorithms as algorithms_module
for name in algorithm_names:
cls = getattr(algorithms_module, name, None)
if cls and isinstance(cls, type) and issubclass(cls, BaseAlgorithm):
if cls is not BaseAlgorithm:
self._algorithm_registry[name] = cls
logger.debug(f"已注册算法: {name}")
logger.info(f"已注册 {len(self._algorithm_registry)} 个算法")
def get_available_algorithms(self) -> List[str]:
"""获取可用算法列表"""
return list(self._algorithm_registry.keys())
def _generate_job_id(self) -> str:
"""生成 12 位十六进制任务 ID"""
return secrets.token_hex(6)
def _get_timestamp(self) -> str:
"""获取 ISO 8601 格式时间戳"""
return datetime.now(timezone.utc).isoformat()
async def create_job(
self,
algorithm: str,
params: Dict[str, Any],
webhook: Optional[str] = None,
request_id: Optional[str] = None,
) -> str:
"""创建新任务,返回 job_id
Args:
algorithm: 算法名称
params: 算法参数
webhook: 回调 URL可选
request_id: 关联的请求 ID可选
Returns:
str: 任务 ID
Raises:
RuntimeError: Redis 不可用时抛出
ValueError: 算法不存在时抛出
"""
if not self._redis_client:
raise RuntimeError("Redis 不可用,无法创建任务")
if algorithm not in self._algorithm_registry:
raise ValueError(f"算法 '{algorithm}' 不存在")
job_id = self._generate_job_id()
created_at = self._get_timestamp()
# 构建任务数据
job_data = {
"status": "pending",
"algorithm": algorithm,
"params": json.dumps(params),
"webhook": webhook or "",
"request_id": request_id or "",
"created_at": created_at,
"started_at": "",
"completed_at": "",
"result": "",
"error": "",
"metadata": "",
}
# 存储到 Redis
key = f"job:{job_id}"
await self._redis_client.hset(key, mapping=job_data)
# 记录指标
incr("jobs_created_total", {"algorithm": algorithm})
logger.info(f"任务已创建: job_id={job_id}, algorithm={algorithm}")
return job_id
async def get_job(self, job_id: str) -> Optional[Dict[str, Any]]:
"""获取任务信息
Args:
job_id: 任务 ID
Returns:
任务信息字典,不存在时返回 None
"""
if not self._redis_client:
return None
key = f"job:{job_id}"
job_data = await self._redis_client.hgetall(key)
if not job_data:
return None
# 解析 JSON 字段
result = {
"job_id": job_id,
"status": job_data.get("status", ""),
"algorithm": job_data.get("algorithm", ""),
"created_at": job_data.get("created_at", ""),
"started_at": job_data.get("started_at") or None,
"completed_at": job_data.get("completed_at") or None,
"result": None,
"error": job_data.get("error") or None,
"metadata": None,
}
# 解析 result
if job_data.get("result"):
try:
result["result"] = json.loads(job_data["result"])
except json.JSONDecodeError:
result["result"] = None
# 解析 metadata
if job_data.get("metadata"):
try:
result["metadata"] = json.loads(job_data["metadata"])
except json.JSONDecodeError:
result["metadata"] = None
return result
async def execute_job(self, job_id: str) -> None:
"""执行任务(在后台任务中调用)
Args:
job_id: 任务 ID
"""
if not self._redis_client:
logger.error(f"Redis 不可用,无法执行任务: {job_id}")
return
if not self._semaphore:
logger.error(f"并发控制未初始化,无法执行任务: {job_id}")
return
key = f"job:{job_id}"
job_data = await self._redis_client.hgetall(key)
if not job_data:
logger.error(f"任务不存在: {job_id}")
return
algorithm_name = job_data.get("algorithm", "")
webhook_url = job_data.get("webhook", "")
# 解析参数
try:
params = json.loads(job_data.get("params", "{}"))
except json.JSONDecodeError:
params = {}
# 使用信号量控制并发
async with self._semaphore:
# 更新状态为 running
started_at = self._get_timestamp()
await self._redis_client.hset(key, mapping={"status": "running", "started_at": started_at})
logger.info(f"开始执行任务: job_id={job_id}, algorithm={algorithm_name}")
import time
start_time = time.time()
status = "completed"
result_data = None
error_msg = None
metadata = None
try:
# 获取算法类并执行
algorithm_cls = self._algorithm_registry.get(algorithm_name)
if not algorithm_cls:
raise ValueError(f"算法 '{algorithm_name}' 不存在")
algorithm = algorithm_cls()
# 根据算法类型传递参数
if algorithm_name == "PrimeChecker":
execution_result = algorithm.execute(params.get("number", 0))
else:
# 通用参数传递
execution_result = algorithm.execute(**params)
if execution_result.get("success"):
result_data = execution_result.get("result", {})
metadata = execution_result.get("metadata", {})
else:
status = "failed"
error_msg = execution_result.get("error", "算法执行失败")
metadata = execution_result.get("metadata", {})
except Exception as e:
status = "failed"
error_msg = str(e)
logger.error(f"任务执行失败: job_id={job_id}, error={e}", exc_info=True)
# 计算执行时间
elapsed_time = time.time() - start_time
completed_at = self._get_timestamp()
# 更新任务状态
update_data = {
"status": status,
"completed_at": completed_at,
"result": json.dumps(result_data) if result_data else "",
"error": error_msg or "",
"metadata": json.dumps(metadata) if metadata else "",
}
await self._redis_client.hset(key, mapping=update_data)
# 设置 TTL
await self._redis_client.expire(key, settings.job_result_ttl)
# 记录指标
incr("jobs_completed_total", {"algorithm": algorithm_name, "status": status})
observe("job_execution_duration_seconds", {"algorithm": algorithm_name}, elapsed_time)
logger.info(f"任务执行完成: job_id={job_id}, status={status}, elapsed={elapsed_time:.3f}s")
# 发送 Webhook 回调
if webhook_url:
await self._send_webhook(job_id, webhook_url)
async def _send_webhook(self, job_id: str, webhook_url: str) -> None:
"""发送 Webhook 回调(带重试)
Args:
job_id: 任务 ID
webhook_url: 回调 URL
"""
if not self._http_client:
logger.warning("HTTP 客户端不可用,无法发送 Webhook")
return
# 获取任务数据
job_data = await self.get_job(job_id)
if not job_data:
logger.error(f"无法获取任务数据用于 Webhook: {job_id}")
return
# 构建回调负载
payload = {
"job_id": job_data["job_id"],
"status": job_data["status"],
"algorithm": job_data["algorithm"],
"result": job_data["result"],
"error": job_data["error"],
"metadata": job_data["metadata"],
"completed_at": job_data["completed_at"],
}
# 重试间隔(指数退避)
retry_delays = [1, 5, 15]
max_retries = settings.webhook_max_retries
for attempt in range(max_retries):
try:
response = await self._http_client.post(
webhook_url,
json=payload,
headers={"Content-Type": "application/json"},
)
if response.status_code < 400:
incr("webhook_deliveries_total", {"status": "success"})
logger.info(
f"Webhook 发送成功: job_id={job_id}, url={webhook_url}, "
f"status_code={response.status_code}"
)
return
else:
logger.warning(
f"Webhook 响应错误: job_id={job_id}, status_code={response.status_code}"
)
except Exception as e:
logger.warning(
f"Webhook 发送失败 (尝试 {attempt + 1}/{max_retries}): "
f"job_id={job_id}, error={e}"
)
# 等待后重试
if attempt < max_retries - 1:
delay = retry_delays[min(attempt, len(retry_delays) - 1)]
await asyncio.sleep(delay)
# 所有重试都失败
incr("webhook_deliveries_total", {"status": "failed"})
logger.error(f"Webhook 发送最终失败: job_id={job_id}, url={webhook_url}")
def is_available(self) -> bool:
"""检查任务管理器是否可用"""
return self._redis_client is not None
async def enqueue_job(self, job_id: str) -> bool:
"""将任务加入队列
Args:
job_id: 任务 ID
Returns:
bool: 是否成功入队
"""
if not self._redis_client:
logger.error(f"Redis 不可用,无法入队任务: {job_id}")
return False
try:
await self._redis_client.lpush(settings.job_queue_key, job_id)
logger.info(f"任务已入队: job_id={job_id}")
return True
except Exception as e:
logger.error(f"任务入队失败: job_id={job_id}, error={e}")
return False
async def dequeue_job(self, timeout: int = 5) -> Optional[str]:
"""从队列获取任务(阻塞式)
Args:
timeout: 阻塞超时时间(秒)
Returns:
Optional[str]: 任务 ID超时返回 None
"""
if not self._redis_client:
return None
try:
result = await self._redis_client.brpop(settings.job_queue_key, timeout=timeout)
if result:
# brpop 返回 (key, value) 元组
return result[1]
return None
except Exception as e:
logger.error(f"任务出队失败: error={e}")
return None
async def acquire_job_lock(self, job_id: str) -> bool:
"""获取任务执行锁(分布式锁)
Args:
job_id: 任务 ID
Returns:
bool: 是否成功获取锁
"""
if not self._redis_client:
return False
lock_key = f"job:lock:{job_id}"
try:
acquired = await self._redis_client.set(
lock_key, "locked", nx=True, ex=settings.job_lock_ttl
)
if acquired:
logger.debug(f"获取任务锁成功: job_id={job_id}")
return acquired is not None
except Exception as e:
logger.error(f"获取任务锁失败: job_id={job_id}, error={e}")
return False
async def release_job_lock(self, job_id: str) -> bool:
"""释放任务执行锁
Args:
job_id: 任务 ID
Returns:
bool: 是否成功释放锁
"""
if not self._redis_client:
return False
lock_key = f"job:lock:{job_id}"
try:
await self._redis_client.delete(lock_key)
logger.debug(f"释放任务锁成功: job_id={job_id}")
return True
except Exception as e:
logger.error(f"释放任务锁失败: job_id={job_id}, error={e}")
return False
async def increment_concurrency(self) -> int:
"""增加全局并发计数
Returns:
int: 增加后的并发数
"""
if not self._redis_client:
return 0
try:
count = await self._redis_client.incr(settings.job_concurrency_key)
return count
except Exception as e:
logger.error(f"增加并发计数失败: error={e}")
return 0
async def decrement_concurrency(self) -> int:
"""减少全局并发计数
Returns:
int: 减少后的并发数
"""
if not self._redis_client:
return 0
try:
count = await self._redis_client.decr(settings.job_concurrency_key)
# 防止计数变为负数
if count < 0:
await self._redis_client.set(settings.job_concurrency_key, 0)
return 0
return count
except Exception as e:
logger.error(f"减少并发计数失败: error={e}")
return 0
async def get_global_concurrency(self) -> int:
"""获取当前全局并发数
Returns:
int: 当前并发数
"""
if not self._redis_client:
return 0
try:
count = await self._redis_client.get(settings.job_concurrency_key)
return int(count) if count else 0
except Exception as e:
logger.error(f"获取并发计数失败: error={e}")
return 0
async def can_execute(self) -> bool:
"""检查是否可以执行新任务(全局并发控制)
Returns:
bool: 是否可以执行
"""
current = await self.get_global_concurrency()
return current < settings.max_concurrent_jobs
async def get_job_retry_count(self, job_id: str) -> int:
"""获取任务重试次数
Args:
job_id: 任务 ID
Returns:
int: 重试次数
"""
if not self._redis_client:
return 0
key = f"job:{job_id}"
try:
retry_count = await self._redis_client.hget(key, "retry_count")
return int(retry_count) if retry_count else 0
except Exception:
return 0
async def increment_job_retry(self, job_id: str) -> int:
"""增加任务重试次数
Args:
job_id: 任务 ID
Returns:
int: 增加后的重试次数
"""
if not self._redis_client:
return 0
key = f"job:{job_id}"
try:
await self._redis_client.hincrby(key, "retry_count", 1)
retry_count = await self._redis_client.hget(key, "retry_count")
return int(retry_count) if retry_count else 1
except Exception as e:
logger.error(f"增加重试次数失败: job_id={job_id}, error={e}")
return 0
def get_concurrency_status(self) -> Dict[str, int]:
"""获取并发状态
Returns:
Dict[str, int]: 包含以下键的字典
- max_concurrent: 最大并发数
- available_slots: 可用槽位数
- running_jobs: 当前运行中的任务数
"""
if not self._semaphore:
return {
"max_concurrent": 0,
"available_slots": 0,
"running_jobs": 0,
}
max_concurrent = self._max_concurrent_jobs
available_slots = self._semaphore._value
running_jobs = max_concurrent - available_slots
return {
"max_concurrent": max_concurrent,
"available_slots": available_slots,
"running_jobs": running_jobs,
}
# 全局单例
_job_manager: Optional[JobManager] = None
async def get_job_manager() -> JobManager:
"""获取任务管理器单例"""
global _job_manager
if _job_manager is None:
_job_manager = JobManager()
await _job_manager.initialize()
return _job_manager
async def shutdown_job_manager() -> None:
"""关闭任务管理器"""
global _job_manager
if _job_manager is not None:
await _job_manager.shutdown()
_job_manager = None